In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt


batch_size = 32
test_batch_size = 1000
epochs = 10
momentum = 0.5
no_cuda = False
seed = 1
log_interval = 10

cuda = not no_cuda and torch.cuda.is_available()

torch.manual_seed(seed)

if cuda:
    torch.cuda.manual_seed(seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=True, **kwargs)

class MLPNetModified(nn.Module):
    def __init__(self, f1, f2, f3):
        super(MLPNetModified, self).__init__()
        self.f1 = f1
        self.f2 = f2
        self.f3 = f3
        self.fc1 = nn.Linear(28*28, 500)
        self.fc2 = nn.Linear(500, 256)
        self.fc3 = nn.Linear(256, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        half = int(len(x[0])/2)
        first_part = x[:, 0:half]
        second_part = x[:, half:]
        first_part = self.f1(first_part)
        second_part = self.f1(second_part)
        x = torch.cat((first_part, second_part), 1)
        x = self.fc2(x)
        half = int(len(x[0])/2)
        first_part = x[:, 0:half]
        second_part = x[:, half:]
        first_part = self.f2(first_part)
        second_part = self.f2(second_part)
        x = torch.cat((first_part, second_part), 1)
        x = self.fc3(x)
        half = int(len(x[0])/2)
        first_part = x[:, 0:half]
        second_part = x[:, half:]
        first_part = self.f3(first_part)
        second_part = self.f3(second_part)
        x = torch.cat((first_part, second_part), 1)
        return F.log_softmax(x)
    def name(self):
        return 'mlpnet'

def solve(f1, f2, f3 ,lr):
    print (str(f1).split()[1], str(f2).split()[1], str(f3).split()[1])
    model = MLPNetModified(f1, f2, f3)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    train_loss = []
    test_losses = []
    test_accuracy = []
    parameters_list = []
    def train(epoch):
        model.train()
        loss_to_print = 0
        for data, target in train_loader:
            if cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            loss_to_print += loss.data[0]
                # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #     epoch, batch_idx * len(data), len(train_loader.dataset),
                #     100. * batch_idx / len(train_loader), loss.data[0]))
        train_loss.append(loss_to_print)
        print (epoch, loss_to_print)
        p= model.state_dict() 
        parameters_list.append(p['fc1.weight'].numpy())
        print(parameters_list)
        print(p['fc1.weight'][0][0])
        return train_loss
    def test(epoch):
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            if cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        test_loss /= len(test_loader.dataset)
        if (epoch == epochs):
            print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))
        test_losses.append(test_loss)
        test_accuracy.append(100. * correct / len(test_loader.dataset))
        return test_losses
    for epoch in range(1, epochs + 1):
        TRAIN_LOSS = train(epoch)
        TEST_LOSS  = test(epoch)
    return TRAIN_LOSS,TEST_LOSS,parameters_list


train_plots_relu0, test_plots_relu0, parameters0 = solve(F.relu, F.relu, F.relu, 0.35)
# train_plots_relu1, test_plots_relu1, parameters1 = solve(F.relu, F.relu, F.relu, 0.65)
# train_plots_relu2, test_plots_relu2, parameters2 = solve(F.relu, F.relu, F.relu, 0.75)
# train_plots_relu3, test_plots_relu3, parameters3 = solve(F.relu, F.relu, F.relu, 0.85)
# train_plots_relu4, test_plots_relu4, parameters4 = solve(F.relu, F.relu, F.relu, 0.95)



relu relu relu
1 864.5455937159131
[array([[ 0.00633786,  0.0477781 ,  0.02800229, ...,  0.02833484,
         0.03359372, -0.00181577],
       [-0.0115374 ,  0.00215372, -0.03080347, ..., -0.02232358,
         0.03740259, -0.00312513],
       [ 0.01169637,  0.00373681, -0.00696185, ...,  0.02827242,
         0.01802573,  0.02345314],
       ..., 
       [ 0.03231156,  0.02374377,  0.0605252 , ...,  0.0235883 ,
         0.02649865,  0.03225754],
       [-0.0022875 ,  0.03523339, -0.00022792, ..., -0.02167422,
         0.01760476,  0.024813  ],
       [ 0.01805836,  0.00750022,  0.00028992, ..., -0.01455383,
         0.00124341, -0.00679247]], dtype=float32)]
0.0063378578051924706
2 594.3029355729359
[array([[ 0.03556556,  0.07700585,  0.05722999, ...,  0.05756254,
         0.06282146,  0.02741193],
       [-0.01180676,  0.00188436, -0.03107285, ..., -0.02259295,
         0.03713322, -0.00339449],
       [ 0.01063481,  0.00267523, -0.00802342, ...,  0.02721085,
         0.01696416,  0.02

6 4314.121509075165
[array([[-0.01362379,  0.02781671,  0.00804072, ...,  0.00837328,
         0.01363217, -0.02177739],
       [-0.00320895,  0.01048217, -0.02247501, ..., -0.01399512,
         0.04573105,  0.00520331],
       [ 0.0101801 ,  0.00222053, -0.00847813, ...,  0.02675613,
         0.01650946,  0.02193686],
       ..., 
       [ 0.03962502,  0.03105725,  0.06783859, ...,  0.03090179,
         0.03381211,  0.03957099],
       [ 0.00018054,  0.03770133,  0.00224012, ..., -0.01920622,
         0.02007276,  0.02728099],
       [ 0.02464637,  0.01408821,  0.00687792, ..., -0.00796581,
         0.00783141, -0.00020448]], dtype=float32), array([[-0.01362379,  0.02781671,  0.00804072, ...,  0.00837328,
         0.01363217, -0.02177739],
       [-0.00320895,  0.01048217, -0.02247501, ..., -0.01399512,
         0.04573105,  0.00520331],
       [ 0.0101801 ,  0.00222053, -0.00847813, ...,  0.02675613,
         0.01650946,  0.02193686],
       ..., 
       [ 0.03962502,  0.03105725,  0

8 4317.130795240402
[array([[-0.01362379,  0.02781671,  0.00804072, ...,  0.00837328,
         0.01363217, -0.02177739],
       [-0.00320895,  0.01048217, -0.02247501, ..., -0.01399512,
         0.04573105,  0.00520331],
       [ 0.0101801 ,  0.00222053, -0.00847813, ...,  0.02675613,
         0.01650946,  0.02193686],
       ..., 
       [ 0.03962502,  0.03105725,  0.06783859, ...,  0.03090179,
         0.03381211,  0.03957099],
       [ 0.00018054,  0.03770133,  0.00224012, ..., -0.01920622,
         0.02007276,  0.02728099],
       [ 0.02464637,  0.01408821,  0.00687792, ..., -0.00796581,
         0.00783141, -0.00020448]], dtype=float32), array([[-0.01362379,  0.02781671,  0.00804072, ...,  0.00837328,
         0.01363217, -0.02177739],
       [-0.00320895,  0.01048217, -0.02247501, ..., -0.01399512,
         0.04573105,  0.00520331],
       [ 0.0101801 ,  0.00222053, -0.00847813, ...,  0.02675613,
         0.01650946,  0.02193686],
       ..., 
       [ 0.03962502,  0.03105725,  0

10 4317.130795240402
[array([[-0.01362379,  0.02781671,  0.00804072, ...,  0.00837328,
         0.01363217, -0.02177739],
       [-0.00320895,  0.01048217, -0.02247501, ..., -0.01399512,
         0.04573105,  0.00520331],
       [ 0.0101801 ,  0.00222053, -0.00847813, ...,  0.02675613,
         0.01650946,  0.02193686],
       ..., 
       [ 0.03962502,  0.03105725,  0.06783859, ...,  0.03090179,
         0.03381211,  0.03957099],
       [ 0.00018054,  0.03770133,  0.00224012, ..., -0.01920622,
         0.02007276,  0.02728099],
       [ 0.02464637,  0.01408821,  0.00687792, ..., -0.00796581,
         0.00783141, -0.00020448]], dtype=float32), array([[-0.01362379,  0.02781671,  0.00804072, ...,  0.00837328,
         0.01363217, -0.02177739],
       [-0.00320895,  0.01048217, -0.02247501, ..., -0.01399512,
         0.04573105,  0.00520331],
       [ 0.0101801 ,  0.00222053, -0.00847813, ...,  0.02675613,
         0.01650946,  0.02193686],
       ..., 
       [ 0.03962502,  0.03105725,  