In [207]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import itertools
import copy
import numpy as np

In [208]:
#IMPORTANT
#ORDER OF LAYERS IN __INIT__ MUST BE SAME AS FORWARD PASS
#FOR THE RIGHT INDEX NUMBERS OF LAYERS


#TO CHECK:

#ARE RIGHT ROWS AND COLUMNS REMOVED?
#WHAT ABOUT BIASES?!


class Net(nn.Module):
    def __init__(self):
        
        #SEE NOTE ABOVE
        
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        output = F.log_softmax(x, dim=1)
        return output
    
    def add_unit(self, index_layer):
        
        #PICK LAYERS ACCORDING TO INDEX, SEE NOTE ABOVE CLASS
        
        layer_names_list = [x[0] for x in self.named_children()]
        layer_name_1 = layer_names_list[index_layer]
        layer_name_2 = layer_names_list[index_layer+1]
        
        #get current weights of layer 1
        current_weights_1 = self._modules[layer_name_1].weight.data
        current_bias_1 = self._modules[layer_name_1].bias.data
        current_num_nodes_1 = current_weights_1.shape[1]
        
        #get current weights of layer 2
        current_weights_2 = self._modules[layer_name_2].weight.data
        current_bias_2 = self._modules[layer_name_2].bias.data
        current_num_nodes_2 = current_weights_2.shape[0]

        #make weights for node layer 1 (add choice of init here)
        add_weights_1 = torch.zeros([1,current_num_nodes_1])
        add_bias_1 = torch.zeros([1,1])
        
        #What is this exactly?
        nn.init.xavier_uniform_(add_weights_1, gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(add_bias_1, gain=nn.init.calculate_gain('relu'))
        
        #make new connection weights for all nodes of layer after
        add_weights_2 = torch.zeros([current_num_nodes_2,1])
        nn.init.xavier_uniform_(add_weights_2, gain=nn.init.calculate_gain('relu'))

        #merge weights
        new_weights_1 = torch.cat([current_weights_1,add_weights_1],dim=0) #add bottom row
        new_bias_1 = torch.cat([current_bias_1,add_bias_1[0]]) #add bottom row
        new_weights_2 = torch.cat([add_weights_2,current_weights_2],dim=1) #add first column

        #update layer shape - also adds bias automatically (lets do it manually?!)
                
        self._modules[layer_name_1] = nn.Linear(new_weights_1.shape[1],new_weights_1.shape[0])
        self._modules[layer_name_2] = nn.Linear(new_weights_2.shape[1],new_weights_2.shape[0])
        
        #set weight data to new values        
        self._modules[layer_name_1].weight.data = torch.tensor(new_weights_1, requires_grad=True)
        self._modules[layer_name_2].weight.data = torch.tensor(new_weights_2, requires_grad=True)
        
        self._modules[layer_name_1].bias.data = torch.tensor(new_bias_1, requires_grad=True)
        self._modules[layer_name_2].bias.data = torch.tensor(current_bias_2, requires_grad=True)
    
        
    def remove_unit(self, index_layer, index_node):
        
        layer_names_list = [x[0] for x in self.named_children()]
        layer_name_1 = layer_names_list[index_layer]
        layer_name_2 = layer_names_list[index_layer+1]
        
        current_weights_1 = self._modules[layer_name_1].weight.data
        current_bias_1 = self._modules[layer_name_1].bias.data
        
        current_weights_2 = self._modules[layer_name_2].weight.data
        current_bias_2 = self._modules[layer_name_2].bias.data
        
        new_weights_1 = np.delete(current_weights_1,index_node,axis=0)
        new_bias_1 = np.delete(current_bias_1,index_node)
        new_weights_2 = np.delete(current_weights_2,index_node,axis=1)
        
        #update layer shape - also adds bias automatically (lets do it manually?!)
                
        self._modules[layer_name_1] = nn.Linear(new_weights_1.shape[1],new_weights_1.shape[0])
        self._modules[layer_name_2] = nn.Linear(new_weights_2.shape[1],new_weights_2.shape[0])
        
        #set weight data to new values        
        self._modules[layer_name_1].weight.data = torch.tensor(new_weights_1, requires_grad=True)
        self._modules[layer_name_1].bias.data = torch.tensor(new_bias_1, requires_grad=True)
        self._modules[layer_name_2].weight.data = torch.tensor(new_weights_2, requires_grad=True)  
        self._modules[layer_name_2].bias.data = current_bias_2
        
    def print_model(self):
        
        print(self.fc1.weight.data)
        

In [217]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    i = 0
    
    model.print_model()
    
    for batch_idx, (data, target) in enumerate(train_loader):
    #for batch_idx, (data,target) in itertools.islice(train_loader, stop=10000):
    
        #Take only 1000 samples
        i += 1
        if i >= 1000/64:
            break
            
        #model.add_unit(0) 
        #model.remove_unit(0,0)
        
        print('Model sizes:')
        print(model.fc1.weight.data.shape)
        print(model.fc2.weight.data.shape)
        
            
        #normal training    
        data, target = data.to(device), target.to(device)
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        
        #return previously saved state, thus resetting to before training
        #model.load_state_dict(new_state_dict,strict=False)
    
    model.print_model()

def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

class Args:
    batch_size= 64
    test_batch_size = 1000
    epochs = 2
    lr = 1
    gamma = 0.7
    no_cuda = False
    seed = 1
    log_interval = 10
    save_model = False

args=Args()

def main():
    args = Args()
    
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_dataset = datasets.MNIST('../data' , train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))    
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    
    
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
        

In [220]:
model = Net()

print('Model before adding node:')

#print(model.fc1.weight.data)
#print(model.fc2.weight.data)
print(model.fc1.bias.data)
#print(model.fc2.bias.data)

model.add_unit(0)

print('Model after adding node:')

#print(model.fc1.weight.data)
#print(model.fc2.weight.data)
print(model.fc1.bias.data)
#print(model.fc2.bias.data)

print('Model after removing node:')

model.remove_unit(0,0)

#print(model.fc1.weight.data)
#print(model.fc2.weight.data)
print(model.fc1.bias.data)
#print(model.fc2.bias.data)


Model before adding node:
tensor([ 0.0056,  0.0128,  0.0074, -0.0289, -0.0141,  0.0229, -0.0174, -0.0181,
         0.0271, -0.0349, -0.0050,  0.0325,  0.0232,  0.0343, -0.0229, -0.0356,
        -0.0077,  0.0218,  0.0222,  0.0259,  0.0079,  0.0079, -0.0089, -0.0253,
         0.0263,  0.0235, -0.0084, -0.0011, -0.0108, -0.0195, -0.0200, -0.0019,
        -0.0174,  0.0132,  0.0086, -0.0148, -0.0143, -0.0327,  0.0068, -0.0225,
         0.0341,  0.0184,  0.0125, -0.0280,  0.0027,  0.0259, -0.0079, -0.0147,
         0.0098,  0.0205, -0.0018, -0.0232,  0.0181,  0.0126,  0.0093, -0.0048,
        -0.0142, -0.0063, -0.0140, -0.0320,  0.0154, -0.0347, -0.0350, -0.0341,
        -0.0076, -0.0266,  0.0303,  0.0333, -0.0353, -0.0116,  0.0337,  0.0005,
        -0.0164, -0.0093,  0.0304,  0.0124,  0.0340, -0.0161,  0.0298,  0.0251,
         0.0028, -0.0249,  0.0277,  0.0356, -0.0065,  0.0338,  0.0027,  0.0170,
         0.0264, -0.0146,  0.0139, -0.0211, -0.0205,  0.0292,  0.0170, -0.0286,
         0.006



Model after adding node:
tensor([ 5.5931e-03,  1.2797e-02,  7.3695e-03, -2.8897e-02, -1.4114e-02,
         2.2940e-02, -1.7367e-02, -1.8120e-02,  2.7107e-02, -3.4926e-02,
        -5.0315e-03,  3.2486e-02,  2.3173e-02,  3.4305e-02, -2.2901e-02,
        -3.5554e-02, -7.6800e-03,  2.1779e-02,  2.2154e-02,  2.5926e-02,
         7.8792e-03,  7.9319e-03, -8.9259e-03, -2.5263e-02,  2.6281e-02,
         2.3486e-02, -8.4360e-03, -1.0504e-03, -1.0813e-02, -1.9501e-02,
        -2.0023e-02, -1.9388e-03, -1.7414e-02,  1.3173e-02,  8.6280e-03,
        -1.4803e-02, -1.4284e-02, -3.2729e-02,  6.8203e-03, -2.2544e-02,
         3.4093e-02,  1.8426e-02,  1.2466e-02, -2.7986e-02,  2.6731e-03,
         2.5878e-02, -7.9318e-03, -1.4717e-02,  9.7808e-03,  2.0521e-02,
        -1.8385e-03, -2.3220e-02,  1.8139e-02,  1.2592e-02,  9.2745e-03,
        -4.7574e-03, -1.4242e-02, -6.3438e-03, -1.3961e-02, -3.2031e-02,
         1.5428e-02, -3.4723e-02, -3.4992e-02, -3.4062e-02, -7.5591e-03,
        -2.6556e-02,  3.02



tensor([ 1.2797e-02,  7.3695e-03, -2.8897e-02, -1.4114e-02,  2.2940e-02,
        -1.7367e-02, -1.8120e-02,  2.7107e-02, -3.4926e-02, -5.0315e-03,
         3.2486e-02,  2.3173e-02,  3.4305e-02, -2.2901e-02, -3.5554e-02,
        -7.6800e-03,  2.1779e-02,  2.2154e-02,  2.5926e-02,  7.8792e-03,
         7.9319e-03, -8.9259e-03, -2.5263e-02,  2.6281e-02,  2.3486e-02,
        -8.4360e-03, -1.0504e-03, -1.0813e-02, -1.9501e-02, -2.0023e-02,
        -1.9388e-03, -1.7414e-02,  1.3173e-02,  8.6280e-03, -1.4803e-02,
        -1.4284e-02, -3.2729e-02,  6.8203e-03, -2.2544e-02,  3.4093e-02,
         1.8426e-02,  1.2466e-02, -2.7986e-02,  2.6731e-03,  2.5878e-02,
        -7.9318e-03, -1.4717e-02,  9.7808e-03,  2.0521e-02, -1.8385e-03,
        -2.3220e-02,  1.8139e-02,  1.2592e-02,  9.2745e-03, -4.7574e-03,
        -1.4242e-02, -6.3438e-03, -1.3961e-02, -3.2031e-02,  1.5428e-02,
        -3.4723e-02, -3.4992e-02, -3.4062e-02, -7.5591e-03, -2.6556e-02,
         3.0266e-02,  3.3312e-02, -3.5343e-02, -1.1

In [219]:
main()

tensor([[ 0.0184, -0.0158, -0.0069,  ...,  0.0068, -0.0041,  0.0025],
        [-0.0274, -0.0224, -0.0309,  ..., -0.0029,  0.0013, -0.0167],
        [ 0.0282, -0.0095, -0.0340,  ..., -0.0141,  0.0056, -0.0335],
        ...,
        [-0.0265,  0.0014, -0.0012,  ...,  0.0290, -0.0258, -0.0296],
        [ 0.0270, -0.0221,  0.0240,  ..., -0.0223, -0.0174,  0.0129],
        [-0.0289,  0.0196, -0.0356,  ...,  0.0096, -0.0082, -0.0165]])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Si



tensor([[ 1.3641e-02, -2.0525e-02, -1.1685e-02,  ...,  2.0176e-03,
         -8.8556e-03, -2.2576e-03],
        [-2.4533e-02, -1.9602e-02, -2.8083e-02,  ..., -8.2588e-05,
          4.0989e-03, -1.3843e-02],
        [ 2.3490e-02, -1.4160e-02, -3.8729e-02,  ..., -1.8764e-02,
          9.4096e-04, -3.8213e-02],
        ...,
        [-2.9572e-02, -1.6730e-03, -4.2347e-03,  ...,  2.5927e-02,
         -2.8908e-02, -3.2716e-02],
        [ 2.6773e-02, -2.2268e-02,  2.3859e-02,  ..., -2.2453e-02,
         -1.7536e-02,  1.2714e-02],
        [-2.8904e-02,  1.9560e-02, -3.5654e-02,  ...,  9.5554e-03,
         -8.2646e-03, -1.6501e-02]])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.Size([128, 784])
torch.Size([10, 128])
Model sizes:
torch.S