In [13]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F

torch.set_default_dtype(torch.float32)

In [18]:
from torchsummary import summary

## Declare classes

In [45]:
class SimpleNetwork(torch.nn.Module):
    def __init__(self, inp_size, layers=[100]):
        super().__init__()
        self.layers = []
        
        self.inp = torch.nn.Linear(inp_size, layers[0])
        self.output = torch.nn.Linear(layers[0], 10)
        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax()
    
    def forward(self, x):
        out = x
        out = self.inp(out)
        out = self.relu(out)
        out = self.output(out)
        return out
    
class InsertableNet(SimpleNetwork):
    def __init__(self, weights, inp_size, layers=[100]):
        super().__init__(inp_size, layers)
        input_w_size = inp_size*layers[0]
        input_b_size = layers[0]

        hidden_w_size = layers[0]*10
        hidden_b_size = 10

        self.inp_weights = weights[:input_w_size].reshape((layers[0], inp_size))
        self.inp_bias = weights[input_w_size:input_w_size+input_b_size]

        self.output_weights = weights[input_w_size+input_b_size:input_w_size+input_b_size+hidden_w_size].reshape((10, layers[0]))
        self.output_bias = weights[input_w_size+input_b_size+hidden_w_size:input_w_size+input_b_size+hidden_w_size+hidden_b_size]

    def forward(self, data):
        out = F.linear(data, self.inp_weights, self.inp_bias)
        out = self.relu(out)
        out = F.linear(out, self.output_weights, self.output_bias)
        return out
    
class MaskedNetwork(SimpleNetwork):
    def __init__(self, input_size, mask_size, layers=[10]):
        super().__init__(mask_size, layers=layers)
        template = np.zeros(input_size)
        mask = np.random.choice(len(template), mask_size, False)
        template[mask] = 1
        self.mask = torch.from_numpy(template).to(torch.bool)
        
    def forward(self, x):
        data = x[:, self.mask]
        return super().forward(data)

In [30]:
class Hypernetwork(torch.nn.Module):
    def __init__(self, inp_size=784, mask_size=20, node_hidden_size=20, layers=[256, 512]):
        super().__init__()
        self.mask_size = mask_size
        self.input_size = inp_size
        self.node_hidden_size = node_hidden_size
        
        input_w_size = mask_size*node_hidden_size
        input_b_size = node_hidden_size

        hidden_w_size = node_hidden_size*10
        hidden_b_size = 10
            
        self.out_size = input_w_size+input_b_size+hidden_w_size+hidden_b_size
        
        self.input = torch.nn.Linear(inp_size, layers[0])
        self.hidden1 = torch.nn.Linear(layers[0], layers[1])
        self.out = torch.nn.Linear(layers[1], self.out_size)
        
        self.relu = torch.nn.ReLU()
        self.template = np.zeros(inp_size)
        
    def forward(self, data, mask=None):
        if mask is None:
            mask = np.random.choice(len(self.template), self.mask_size, False)
            tmp = self.template.copy()
            tmp[mask] = 1
            mask = torch.from_numpy(tmp).to(torch.float32)
        
        weights = self.craft_network(mask)
        nn = InsertableNet(weights, self.mask_size, layers=[self.node_hidden_size])
        masked_data = data[:, mask.to(torch.bool)]
        return nn(masked_data)
        
    def craft_network(self, mask):
        out = self.input(mask)
        out = self.relu(out)
        out = self.hidden1(out)
        out = self.relu(out)
#         out = self.hidden2(out)
#         out = self.relu(out)
        out = self.out(out)
        return out

## Load data

In [31]:
mods = [transforms.ToTensor(), 
        transforms.Normalize((0.1307,), (0.3081,)),    #mean and std of MNIST
        transforms.Lambda(lambda x: torch.flatten(x))]
mods = transforms.Compose(mods)

trainset = datasets.MNIST(root='./data/train', train=True, download=True, transform=mods)
testset = datasets.MNIST(root='./data/test', train=False, download=True, transform=mods)

In [32]:
batch_size = 64

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=1)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=1)

## Craft model

In [52]:
hypernet = Hypernetwork(mask_size=20)
summary(hypernet, (784, ))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                       [-1]         200,960
              ReLU-2                       [-1]               0
            Linear-3                       [-1]         131,584
              ReLU-4                       [-1]               0
            Linear-5                       [-1]         323,190
Total params: 655,734
Trainable params: 655,734
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 2.50
Estimated Total Size (MB): 2.50
----------------------------------------------------------------


In [58]:
input_size = 784
mask_size = 20

template = np.zeros(input_size)
mask = np.random.choice(len(template), mask_size, False)
template[mask] = 1
mask = torch.from_numpy(template).to(torch.float32)

In [59]:
mask.sum()

tensor(20.)

In [60]:
epochs = 2
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(hypernet.parameters(), lr=3e-4)

In [None]:
for epoch in range(epochs): 
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = hypernet(inputs, mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        correct += (outputs.argmax(1)==labels).float().sum()
        total += outputs.shape[0]
        running_loss += loss.item()
        if i % 100 == 0:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f acc: %.3f%%' %
                  (epoch + 1, i, running_loss / 100, correct/total*100))
            running_loss = 0.0
            correct = 0
            total=0

    correct = 0
    for data in testloader:
        images, labels = data
        outputs = hypernet(images)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
    print(f"Test acc: {correct/len(testset)*100}")

[1,     0] loss: 0.023 acc: 10.938%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f72c4fea4c0>
Traceback (most recent call last):
  File "/home/ginterhauser/miniconda3/envs/image_processing/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ginterhauser/miniconda3/envs/image_processing/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ginterhauser/miniconda3/envs/image_processing/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
