In [None]:
from dotenv import load_dotenv
load_dotenv()

import os

In [None]:
from comet_ml import Experiment, Optimizer

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data_utils
import pandas as pd
from collections import defaultdict

torch.set_default_dtype(torch.float32)

In [None]:
from torchsummary import summary
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [None]:
from ipynb.fs.defs.hypernet_training import SimpleNetwork, Hypernetwork, get_dataset, train_slow_step, test_model, InsertableNet, SimpleNetwork, train_regular

In [None]:
DEVICE = 'cuda:1'

## Prepare feature extractor

### Hypernet feature extractor

In [11]:
mask_size = 400
masks_no = 8
epochs = 300

batch_size = 32

size = 100

criterion = torch.nn.CrossEntropyLoss()
hypernet = Hypernetwork(mask_size=mask_size, node_hidden_size=100, test_nodes=masks_no, device=DEVICE).to(DEVICE)   
hypernet = hypernet.train()
optimizer = torch.optim.Adam(hypernet.parameters(), lr=3e-4, weight_decay=1e-5)

trainloader, testloader = get_dataset(size, True, masks_no, mask_size, shared_mask=True)
train_slow_step(hypernet, optimizer, criterion, (trainloader, testloader), size, epochs, masks_no, device=DEVICE, tag="decoupled-head")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/c7a33ea1f32f45ff8fe8a630a5be4bed

100%|███████████████████████████████████████████████████████| 300/300 [03:29<00:00,  1.43it/s, loss=2.74, test_acc=69.2]
COMET INFO: Uploading 1 metrics, params and output messages


(70.36, 2.0875586384716325)

In [12]:
test_model(hypernet, testloader, DEVICE)

69.26

In [13]:
class Extractor:
    def __init__(self, hypernet, device='cpu'):
        nns = []
        weights = hypernet.craft_network(hypernet.test_mask)
        for w in weights:
            nns.append(InsertableNet(w.detach().to(device), hypernet.mask_size, layers=[hypernet.node_hidden_size]).to(device))
        self.nns = nns
        
    def extract(self, data):
        embeddings = []
        for mask, nn in zip(hypernet.test_mask, self.nns):
            masked = data[:, mask.to(torch.bool)]
            embeddings.append(F.linear(masked, nn.inp_weights, nn.inp_bias))

        embeddings = torch.stack(embeddings, axis=-1).mean(axis=-1)
#         embeddings = torch.concat(embeddings, dim=1)
        return embeddings

In [14]:
extractor = Extractor(hypernet.to('cuda:1'), 'cuda:1')
for inputs, labels, _ in trainloader:
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE)
    
    extracted = extractor.extract(inputs)
    print(extracted.shape)
    break

torch.Size([32, 100])


### Dense feature extractor 

In [6]:
mask_size = 400
masks_no = 8
epochs = 300

batch_size = 32

size = 100

criterion = torch.nn.CrossEntropyLoss()
network = SimpleNetwork(784).to(DEVICE)
network = network.train()
optimizer = torch.optim.Adam(network.parameters(), lr=3e-4)

trainloader, testloader = get_dataset(size, batch_size=batch_size, test_batch_size=128)
train_regular(network, optimizer, criterion, (trainloader, testloader), size, epochs, device=DEVICE, name="decoupled-head-dense")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/705281a99dd343a3888cca7bf0737343

100%|███████████████████████████████████████████████████████| 300/300 [02:14<00:00,  2.22it/s, loss=1.18, test_acc=69.3]
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/wwydmanski/hypernetwork/705281a99dd343a3888cca7bf0737343
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     loss [121]         : (0.0014174256939440966, 2.350419521331787)
COMET INFO:     test_accuracy [60] : (28.299999999999997, 69.38)
COMET INFO:     test_loss [60]     : (0.986374250971354, 2.2277075877556434)
COMET INFO:   Parameters:
COMET INFO:     check_val_every_n_epoch : 5
COMET INFO:     max_epochs              : 300
COMET INFO:     training_size           : 100
COMET INFO:   Uploads:
COMET IN

(69.38, 1.175550303589075)

In [7]:
test_model(network, testloader, DEVICE)

UnboundLocalError: local variable 'images' referenced before assignment

In [8]:
class DenseExtractor:
    def __init__(self, network, device='cpu'):
        self.nn = network.to(device)
        
    def extract(self, data):
        return self.nn.inp(data)

In [10]:
extractor = DenseExtractor(network.to('cuda:1'), 'cuda:1')
for inputs, labels in trainloader:
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE)
    
    extracted = extractor.extract(inputs)
    print(extracted.shape)
    break

torch.Size([32, 100])


## Train predictor

In [17]:
def train_extracted(hypernet, optimizer, criterion, loaders, data_size, epochs, masks_no, device='cuda:0'):
    experiment = Experiment(api_key=os.environ.get("COMET_KEY"), project_name="hypernetwork", display_summary_level=0)
    experiment.add_tag("decoupled-head-predictor")
    experiment.log_parameter("test_nodes", hypernet.test_nodes)
    experiment.log_parameter("mask_size", hypernet.mask_size)
    experiment.log_parameter("training_size", data_size)
    experiment.log_parameter("input_size", hypernet.input_size)
    experiment.log_parameter("masks_no", masks_no)
    experiment.log_parameter("max_epochs", epochs)
    experiment.log_parameter("check_val_every_n_epoch", 5)
    
    trainloader, testloader = loaders
    train_loss = []
    test_loss = []
    test_accs = []
    mask_idx = 0
    with trange(epochs) as t:
        for epoch in t:
            total_loss = 0
            running_loss = 0.0
            correct = 0
            total = 0
            hypernet.train()
            for i, data in enumerate(trainloader):
                try:
                    inputs, labels, _ = data
                except ValueError:
                    inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                inputs = extractor.extract(inputs).detach().to(device)
                masks = []
                for i in range(len(inputs)):
                    masks.append(hypernet.test_mask[mask_idx])
                masks = torch.stack(masks).to(device)
                mask_idx = (mask_idx+1) % len(hypernet.test_mask)

                optimizer.zero_grad()

                outputs = hypernet(inputs, masks)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                correct += (outputs.argmax(1)==labels).float().sum()
                total += outputs.shape[0]
                running_loss += loss.item()
                train_loss.append(loss.item())
                if i>0 and i % 100 == 0:
                    total_loss += running_loss/100

                    running_loss = 0.0
                    correct = 0
                    total=0

            total_loss = 0
            correct = 0
            denom = 0

            hypernet.eval()
            if epoch%5==0:
                for i, data in enumerate(testloader):
                    try:
                        images, labels, _ = data
                    except ValueError:
                        images, labels = data
                    images = images.to(device)
                    images = extractor.extract(images).detach().to(device)
                    
                    labels = labels.to(device)

                    denom += len(labels)

                    outputs = hypernet(images)
                    _, predicted = torch.max(outputs.data, 1)
                    correct += (predicted == labels).sum().item()
                    total_loss += criterion(outputs, labels).item()

                test_loss.append(total_loss/i)
                test_accs.append(correct/denom*100)

                t.set_postfix(test_acc=correct/denom*100, loss=total_loss/i)
                experiment.log_metric("test_accuracy", correct/len(testloader.dataset)*100, step=epoch)
                experiment.log_metric("test_loss", test_loss[-1], step=epoch)

    experiment.end()
    return max(test_accs), test_loss[np.argmax(test_accs)]

In [None]:
epochs = 1000
results = defaultdict(lambda: defaultdict(list))
size = 100

for mask_size in [15, 20, 25]:
    for masks_no in [10, 15, 20, 25, 30]:
        for i in range(5):
            criterion = torch.nn.CrossEntropyLoss()
            hypernet_pred = Hypernetwork(inp_size=100, mask_size=mask_size, node_hidden_size=20, test_nodes=masks_no, device='cuda:1').to(DEVICE)    
            hypernet_pred = hypernet_pred.train()
            optimizer = torch.optim.Adam(hypernet_pred.parameters(), lr=3e-4, weight_decay=1e-5)

            trainloader, testloader = get_dataset(size, True, masks_no, mask_size, shared_mask=True)
            print(masks_no)
            res = train_extracted(hypernet_pred, optimizer, criterion, (trainloader, testloader), size, epochs, masks_no, device=DEVICE)
            results[masks_no][mask_size].append(res)

10


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/c483d245945f4602b0d092261fc31117

 44%|███████████████████████▊                              | 440/1000 [05:12<04:28,  2.09it/s, loss=2.53, test_acc=68.4]

In [14]:
print("Test accuracy")
for key in results.keys():
    def _pad(x):
        res = [subitem[0] for subitem in x]
        res += [res[-1]]*(10-len(res))
        return res
        
    test_acc_df = pd.DataFrame({i: _pad(j) for i, j in results[key].items()})
    print(key)
    print(test_acc_df.mean(axis=0))

Test accuracy
8
50    67.483
70    67.753
dtype: float64
50
50    64.079
70    65.040
dtype: float64
15
50    67.211
dtype: float64
4
50    66.865
dtype: float64


In [15]:
print("Test loss")
for key in results.keys():
    test_acc_df = pd.DataFrame({i: [subitem[1] for subitem in j] for i, j in results[key].items()})
    print(key)
    print(test_acc_df.mean(axis=0))

Test loss
8
50    2.651092
70    2.489010
dtype: float64


ValueError: All arrays must be of the same length