In [1]:
from dotenv import load_dotenv
load_dotenv()

import os

In [2]:
from comet_ml import Experiment, Optimizer

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data_utils
import pandas as pd
from collections import defaultdict

torch.set_default_dtype(torch.float32)

In [3]:
from torchsummary import summary
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [4]:
from ipynb.fs.defs.hypernet_training import SimpleNetwork, Hypernetwork, get_dataset, train_slow_step, test_model, InsertableNet, SimpleNetwork, train_regular

In [5]:
DEVICE = 'cuda:1'

## Subclass hypernetwork

In [6]:
class HypernetWithFE(Hypernetwork):
    def __init__(self, feature_extractor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.feature_extractor = feature_extractor
   
    def forward(self, data, mask=None):
        extracted = self.feature_extractor(data)
        return super().forward(extracted, mask)

In [7]:
extractor = torch.nn.Sequential(
    torch.nn.Linear(784, 100)
)
HypernetWithFE(extractor, inp_size=100)

HypernetWithFE(
  (input): Linear(in_features=100, out_features=64, bias=True)
  (hidden1): Linear(in_features=64, out_features=256, bias=True)
  (hidden2): Linear(in_features=256, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=630, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (relu): ReLU()
  (feature_extractor): Sequential(
    (0): Linear(in_features=784, out_features=100, bias=True)
  )
)

## Train predictor

In [None]:
epochs = 1000
results = defaultdict(lambda: defaultdict(list))
size = 100

for mask_size in [10, 20, 30]:
    for masks_no in [5, 15, 30, 50]:
        for i in range(5):
            criterion = torch.nn.CrossEntropyLoss()
            extractor = torch.nn.Sequential(
                torch.nn.Linear(784, 100)
            ).to(DEVICE)
            
            hypernet_pred = HypernetWithFE(feature_extractor=extractor, inp_size=100, mask_size=mask_size, node_hidden_size=20, test_nodes=masks_no, device=DEVICE).to(DEVICE)    
            hypernet_pred = hypernet_pred.train()
            optimizer = torch.optim.Adam(hypernet_pred.parameters(), lr=3e-4, weight_decay=1e-5)

            trainloader, testloader = get_dataset(size, test_batch_size=512)
            res = train_slow_step(hypernet_pred, optimizer, criterion, (trainloader, testloader), size, epochs, masks_no, tag='hypernet-e2e-fe', device=DEVICE, test_every=10)
            results[masks_no][mask_size].append(res)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/3280db8ee0544394b2a59acb28de0af5

100%|█████████████████████████████████████████████████████| 1000/1000 [06:37<00:00,  2.52it/s, loss=7.69, test_acc=62.9]
COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/49d1d5f0527b4ab697e3c1f607444ac0

100%|█████████████████████████████████████████████████████| 1000/1000 [06:35<00:00,  2.53it/s, loss=12.7, test_acc=59.8]
COMET INFO: Uploading 1 metrics, params and output messages
COMET INFO: Waiting for completion of the file uploads (may take several seconds)
COMET INFO: The Python SDK has 10800 seconds to finish before aborting...
COMET INFO: All files uploaded, waiting for confirmation they have been all received
COMET INFO: Experiment is l

In [9]:
print("Test accuracy")
for key in results.keys():
    def _pad(x):
        res = [subitem[0] for subitem in x]
        res += [res[-1]]*(10-len(res))
        return res
        
    test_acc_df = pd.DataFrame({i: _pad(j) for i, j in results[key].items()})
    print(key)
    print(test_acc_df.mean(axis=0))

Test accuracy
5
10    60.789
20    63.331
30    64.887
dtype: float64
15
10    61.504
20    61.490
30    63.883
dtype: float64
30
10    62.253
20    64.379
30    64.290
dtype: float64
50
10    60.334
20    59.854
30    62.110
dtype: float64


In [15]:
print("Test loss")
for key in results.keys():
    test_acc_df = pd.DataFrame({i: [subitem[1] for subitem in j] for i, j in results[key].items()})
    print(key)
    print(test_acc_df.mean(axis=0))

Test loss
8
50    2.651092
70    2.489010
dtype: float64


ValueError: All arrays must be of the same length