In [1]:
from dotenv import load_dotenv
load_dotenv()

import os

In [2]:
from comet_ml import Experiment, Optimizer

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data_utils
import pandas as pd
from collections import defaultdict

torch.set_default_dtype(torch.float32)

In [3]:
from torchsummary import summary
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [4]:
from ipynb.fs.defs.hypernet_training import SimpleNetwork, Hypernetwork, get_dataset, train_slow_step, test_model, InsertableNet

In [5]:
DEVICE = 'cuda:1'

## Train feature extractor hypernet

In [13]:
mask_size = 400
masks_no = 8
epochs = 300

batch_size = 32

size = 100

criterion = torch.nn.CrossEntropyLoss()
hypernet = Hypernetwork(mask_size=mask_size, node_hidden_size=20, test_nodes=masks_no, device=DEVICE).to(DEVICE)   
hypernet = hypernet.train()
optimizer = torch.optim.Adam(hypernet.parameters(), lr=3e-4, weight_decay=1e-5)

trainloader, testloader = get_dataset(size, True, masks_no, mask_size, shared_mask=True)
train_slow_step(hypernet, optimizer, criterion, (trainloader, testloader), size, epochs, masks_no, device=DEVICE, tag="decoupled-head")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/b3f1358f7d9149729601cfecbe608352

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [03:32<00:00,  1.41it/s, loss=4.51, test_acc=68.8]
COMET INFO: Uploading 1 metrics, params and output messages


(69.78999999999999, 2.0666567343167768)

In [14]:
test_model(hypernet, testloader, DEVICE)

68.87

## Prepare feature extractor

In [15]:
class Extractor:
    def __init__(self, hypernet, device='cpu'):
        nns = []
        weights = hypernet.craft_network(hypernet.test_mask)
        for w in weights:
            nns.append(InsertableNet(w.detach().to(device), hypernet.mask_size, layers=[hypernet.node_hidden_size]).to(device))
        self.nns = nns
        
    def extract(self, data):
        embeddings = []
        for mask, nn in zip(hypernet.test_mask, self.nns):
            masked = data[:, mask.to(torch.bool)]
            embeddings.append(F.linear(masked, nn.inp_weights, nn.inp_bias))

#         embeddings = torch.stack(embeddings, axis=-1).mean(axis=-1)
        embeddings = torch.concat(embeddings, dim=1)
        return embeddings

In [16]:
extractor = Extractor(hypernet.to('cuda:1'), 'cuda:1')
for inputs, labels, _ in trainloader:
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE)
    
    extracted = extractor.extract(inputs)
    print(extracted.shape)
    break

torch.Size([32, 160])


## Train predictor

In [17]:
def train_extracted(hypernet, optimizer, criterion, loaders, data_size, epochs, masks_no, device='cuda:0'):
    experiment = Experiment(api_key=os.environ.get("COMET_KEY"), project_name="hypernetwork", display_summary_level=0)
    experiment.add_tag("decoupled-head-predictor")
    experiment.log_parameter("test_nodes", hypernet.test_nodes)
    experiment.log_parameter("mask_size", hypernet.mask_size)
    experiment.log_parameter("training_size", data_size)
    experiment.log_parameter("input_size", hypernet.input_size)
    experiment.log_parameter("masks_no", masks_no)

    trainloader, testloader = loaders
    train_loss = []
    test_loss = []
    test_accs = []
    mask_idx = 0
    with trange(epochs) as t:
        for epoch in t:
            total_loss = 0
            running_loss = 0.0
            correct = 0
            total = 0
            hypernet.train()
            for i, data in enumerate(trainloader):
                inputs, labels, _ = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                inputs = extractor.extract(inputs).detach().to(device)
                masks = []
                for i in range(len(inputs)):
                    masks.append(hypernet.test_mask[mask_idx])
                masks = torch.stack(masks).to(device)
                mask_idx = (mask_idx+1) % len(hypernet.test_mask)

                optimizer.zero_grad()

                outputs = hypernet(inputs, masks)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                correct += (outputs.argmax(1)==labels).float().sum()
                total += outputs.shape[0]
                running_loss += loss.item()
                train_loss.append(loss.item())
                if i>0 and i % 100 == 0:
                    total_loss += running_loss/100

                    running_loss = 0.0
                    correct = 0
                    total=0

            total_loss = 0
            correct = 0
            denom = 0

            hypernet.eval()
            if epoch%5==0:
                for i, data in enumerate(testloader):
                    images, labels, masks = data
                    images = images.to(device)
                    images = extractor.extract(images).detach().to(device)
                    
                    labels = labels.to(device)

                    denom += len(labels)

                    outputs = hypernet(images)
                    _, predicted = torch.max(outputs.data, 1)
                    correct += (predicted == labels).sum().item()
                    total_loss += criterion(outputs, labels).item()

                test_loss.append(total_loss/i)
                test_accs.append(correct/denom*100)

                t.set_postfix(test_acc=correct/denom*100, loss=total_loss/i)
                experiment.log_metric("test_accuracy", correct/len(testloader.dataset)*100, step=epoch)
                experiment.log_metric("test_loss", test_loss[-1], step=epoch)

    experiment.end()
    return max(test_accs), test_loss[np.argmax(test_accs)]

In [None]:
mask_size = 100
epochs = 400
results = defaultdict(lambda: defaultdict(list))
size = 100

for masks_no in [8, 50, 15, 200]:
    for i in range(5):
        criterion = torch.nn.CrossEntropyLoss()
        hypernet_pred = Hypernetwork(inp_size=160, mask_size=mask_size, node_hidden_size=10, test_nodes=masks_no, device='cuda:1').to(DEVICE)    
        hypernet_pred = hypernet_pred.train()
        optimizer = torch.optim.Adam(hypernet_pred.parameters(), lr=3e-4, weight_decay=1e-5)

        trainloader, testloader = get_dataset(size, True, masks_no, mask_size, shared_mask=True)
        print(masks_no)
        res = train_extracted(hypernet_pred, optimizer, criterion, (trainloader, testloader), size, epochs, masks_no, device=DEVICE)
        results[masks_no][size].append(res)



100


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/2e7686bc7fe84d149f314e5b067f7e95

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [04:49<00:00,  1.38it/s, loss=6.43, test_acc=62.9]
COMET INFO: Uploading 1 metrics, params and output messages


100


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/54f03f1006ad421bacc90b089240c7ad

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [04:48<00:00,  1.39it/s, loss=4.77, test_acc=67.8]
COMET INFO: Uploading 1 metrics, params and output messages


100


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/7fda395ed0d2488eb241cea6528486e6

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [04:51<00:00,  1.37it/s, loss=4.96, test_acc=66.8]
COMET INFO: Uploading 1 metrics, params and output messages


100


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/2f2f85f540f1479d9e02fda237906038

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [04:48<00:00,  1.39it/s, loss=5.43, test_acc=66.2]
COMET INFO: Uploading 1 metrics, params and output messages
COMET INFO: Waiting for completion of the file uploads (may take several seconds)
COMET INFO: The Python SDK has 10800 seconds to finish before aborting...
COMET INFO: All files uploaded, waiting for confirmation they have been all received


100


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/81739f0f1e89476e92c6dac96df5537f

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [04:49<00:00,  1.38it/s, loss=4.7, test_acc=67]
COMET INFO: Uploading 1 metrics, params and output messages


100


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/8e84ebc91a5443bdb9fb92e215ebeed9

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [13:08<00:00,  1.97s/it, loss=1.44, test_acc=63.4]
COMET INFO: Uploading 1 metrics, params and output messages


100


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/wwydmanski/hypernetwork/aa39cf17378b4b5392f2130c618d9522

 70%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                              | 280/400 [09:11<02:02,  1.02s/it, loss=1.28, test_acc=63.7]

In [None]:
print("Test accuracy")
for key in results.keys():
    def _pad(x):
        res = [subitem[0] for subitem in x]
        res += [res[-1]]*(10-len(res))
        return res
        
    test_acc_df = pd.DataFrame({i: _pad(j) for i, j in results[key].items()})
    print(key)
    print(test_acc_df.mean(axis=0))

In [None]:
print("Test loss")
for key in results.keys():
    test_acc_df = pd.DataFrame({i: [subitem[1] for subitem in j] for i, j in results[key].items()})
    print(key)
    print(test_acc_df.mean(axis=0))