In [25]:
from dotenv import load_dotenv
load_dotenv()

import os

In [26]:
from comet_ml import Experiment, Optimizer

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data_utils
import pandas as pd
from collections import defaultdict

torch.set_default_dtype(torch.float32)

In [27]:
from torchsummary import summary
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [28]:
import tabular_hypernet as hp

In [29]:
class DatasetUpsampler:
    def __init__(self, dataset, desired_len):
        self.desired_len = desired_len
        self.dataset = dataset
        self.real_len = len(dataset)
        
    def __len__(self):
        return self.desired_len
    
    def __getitem__(self, idx):
        if idx >= self.desired_len:
            raise Error
        return self.dataset[idx % self.real_len]
        

In [30]:
def get_dataset(size=60000, masked=False, mask_no=200, mask_size=700, shared_mask=False, batch_size=32, test_batch_size=32):
    mods = [transforms.ToTensor(), 
        transforms.Normalize((0.1307,), (0.3081,)),    #mean and std of MNIST
        transforms.Lambda(lambda x: torch.flatten(x))]
    mods = transforms.Compose(mods)
    
    trainset = datasets.MNIST(root='./data/train', train=True, download=True, transform=mods)
    testset = datasets.MNIST(root='./data/test', train=False, download=True, transform=mods)
    
    sup_train_size = size // 10
    unsup_train_size = size - sup_train_size
    
    if masked:
        trainset = MaskedDataset(trainset, mask_no, mask_size)
        testset = MaskedDataset(testset, mask_no, mask_size)
        if shared_mask:
            testset.masks = trainset.masks
    
    ## supervised training dataset
    indices = torch.arange(sup_train_size)
    sup_trainset = data_utils.Subset(trainset, indices)
    
    # balance superivised dataset
    sup_trainset = DatasetUpsampler(sup_trainset, unsup_train_size)
    
    sup_trainloader = torch.utils.data.DataLoader(sup_trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=1)
    
    ## unsupervised training dataset
    indices = torch.arange(unsup_train_size) + sup_train_size
    unsup_trainset = data_utils.Subset(trainset, indices)
    
    unsup_trainloader = torch.utils.data.DataLoader(unsup_trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=1)
    
    ## test labeled dataset
    testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
                                         shuffle=False, num_workers=1)
    
    return sup_trainloader, unsup_trainloader, testloader

In [31]:
class TrainDataLoaderSemi:
    def __init__(self, sup_trainloader, unsup_trainloader):
        self.sup_trainloader = sup_trainloader
        self.unsup_trainloader = unsup_trainloader
    
#     def __next__(self):
#         return self.sup_trainloader.__next__(), self.unsup_trainloader.__next__()
    def __len__(self):
        if len(self.sup_trainloader) == len(self.unsup_trainloader):
            return len(self.unsup_trainloader)
        else:
            raise Error
    
    def __iter__(self):
        return zip(self.sup_trainloader, self.unsup_trainloader)
    

In [32]:
sup_trainloader, unsup_trainloader, testloader = get_dataset(500, batch_size=1)


In [33]:
len(unsup_trainloader)

450

In [34]:
len(sup_trainloader)

450

In [35]:
a = TrainDataLoaderSemi(sup_trainloader, unsup_trainloader)

In [36]:
len(a)

450

In [37]:
for n, (i, j) in enumerate(a):
    #print('lol', i[0].shape, j[0].shape)
    if n==2: 
        tmp = i[0]
        rr = i[1]
    if n==52:
        print(torch.all(tmp==i[0]))
        print(torch.all(rr==i[1]))
        break

tensor(True)
tensor(True)


In [38]:
import torch
import numpy as np
from tabular_hypernet.modules import InsertableNet
from tabular_hypernet.training_utils import get_dataloader, train_model

torch.set_default_dtype(torch.float32)


class Hypernetwork(torch.nn.Module):
    def __init__(self, inp_size=784, out_size=10, mask_size=20, node_hidden_size=20, layers=[64, 256, 128], test_nodes=100, device='cuda:0'):
        super().__init__()
        self.target_outsize = out_size
        self.device = device
        
        self.mask_size = mask_size
        self.input_size = inp_size
        self.node_hidden_size = node_hidden_size
        
        input_w_size = mask_size*node_hidden_size
        input_b_size = node_hidden_size

        hidden_w_size = node_hidden_size*out_size
        hidden_b_size = out_size
            
        self.out_size = input_w_size+input_b_size+hidden_w_size+hidden_b_size
        
        self.input = torch.nn.Linear(inp_size, layers[0])
        self.hidden1 = torch.nn.Linear(layers[0], layers[1])
        self.hidden2 = torch.nn.Linear(layers[1], layers[2])
        self.out = torch.nn.Linear(layers[2], self.out_size)
        
        self.dropout = torch.nn.Dropout()
        
#         self.relu = torch.nn.ReLU()
        self.relu = torch.relu
        self.template = np.zeros(inp_size)
        self.test_nodes = test_nodes
        self.test_mask = self._create_mask(test_nodes)
        
        self._retrained = True
        self._test_nets = None
        
    def to(self, device):
        super().to(device)
        self.device = device
        self.test_mask = self._create_mask(self.test_nodes)
        return self
        
    def forward(self, data, mask=None):
        """ Get a hypernet prediction. 
        During training we use a single target network per sample. 
        During eval, we create a network for each test mask and average their results
        
        Args:
            data - prediction input
            mask - either None or a torch.tensor((data.shape[0], data.shape[1])).
        """
        if self.training:
            self._retrained = True
            if mask is None:
                masks = np.array([np.random.choice((len(self.template)), self.mask_size, False) for _ in range(len(data))])
                tmp = np.array([self.template.copy() for _ in range(len(data))])
                for i, mask in enumerate(masks):
                    tmp[i, mask] = 1
                mask = torch.from_numpy(tmp).to(torch.float32).to(self.device)

            
            # If we have a few identical masks in a row
            # we only need to calculate target network
            # for the first one
            recalculate = [True]*len(mask)
            for i in range(1, len(mask)):
                if torch.equal(mask[i-1], mask[i]):
                    recalculate[i] = False
                    
            weights = self.craft_network(mask)
            mask = mask.to(torch.bool)
            
            res = torch.zeros((len(data), self.target_outsize)).to(self.device)
            for i in range(len(data)):
                if recalculate[i]:
                    nn = InsertableNet(weights[i], self.mask_size, self.target_outsize, layers=[self.node_hidden_size])
                masked_data = data[i, mask[i]]
                res[i] = nn(masked_data)
            return res
        else:
            if mask is None:
                mask = self.test_mask
                nets = self._get_test_nets()
            else:
                nets = self.__craft_nets(mask)
            mask = mask.to(torch.bool)

            res = torch.zeros((len(data), self.target_outsize)).to(self.device)
            for i in range(len(mask)):
                nn = nets[i]
                masked_data = data[:, mask[i]]
                res += nn(masked_data)
            res /= self.test_nodes
            return res

    def _get_test_nets(self):
        if self._retrained:
            nets = self.__craft_nets(self.test_mask)
            self._test_nets = nets
            self._retrained = False
        return self._test_nets
    
    def __craft_nets(self, mask):
        nets = []
        weights = self.craft_network(mask.to(torch.float32))
        for i in range(len(mask)):
            nn = InsertableNet(weights[i], self.mask_size, self.target_outsize, layers=[self.node_hidden_size])
            nets.append(nn)
        return nets
        
    def _create_mask(self, count):
        print('count', count)
        masks = np.array([np.random.choice((len(self.template)), self.mask_size, False) for _ in range(count)])
        print('masks', masks.shape)
        tmp = np.array([self.template.copy() for _ in range(count)])
        for i, mask in enumerate(masks):
            tmp[i, mask] = 1
        mask = torch.from_numpy(tmp).to(torch.float32).to(self.device)
        return mask
    
    def craft_network(self, mask):
        out = self.input(mask)
        out = self.relu(out)
        
        out = self.hidden1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        out = self.hidden2(out)
        out = self.relu(out)

        out = self.out(out)
        return out


In [39]:
print(torch.argmax.__name__)

argmax


In [40]:
class TabSSLCrossEntropyLoss(torch.nn.Module):
    def __init__(self, beta=0.1, unsup_target_wrapper=torch.nn.functional.softmax):
        super(TabSSLCrossEntropyLoss, self).__init__()
        
        self.y_f1 = torch.nn.CrossEntropyLoss()
        self.y_f2 = torch.nn.CrossEntropyLoss()
        
        self.f1_f2 = torch.nn.CrossEntropyLoss()
        self.f2_f1 = torch.nn.CrossEntropyLoss()
        
        self.beta = beta
        self.unsup_target_wrapper = unsup_target_wrapper
    
    def forward(self, sup_input, unsup_input):
        sup_outputs1, sup_outputs2, sup_labels = sup_input
        unsup_outputs1, unsup_outputs2 = unsup_input
        
        self.supervised_loss = self.y_f1(sup_outputs1, sup_labels) + self.y_f2(sup_outputs2, sup_labels)
        
        self.self_supervised_loss = self.f1_f2(unsup_outputs1, self.unsup_target_wrapper(unsup_outputs2, dim=1)) \
                                + self.f2_f1(unsup_outputs2, self.unsup_target_wrapper(unsup_outputs1, dim=1))
        
        return self.supervised_loss + self.beta * self.self_supervised_loss

In [41]:
hypernet = Hypernetwork(mask_size=700, node_hidden_size=100, test_nodes=50).cuda()
hypernet.eval()

count 50
masks (50, 700)


Hypernetwork(
  (input): Linear(in_features=784, out_features=64, bias=True)
  (hidden1): Linear(in_features=64, out_features=256, bias=True)
  (hidden2): Linear(in_features=256, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=71110, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [42]:
hypernet.test_mask.shape

torch.Size([50, 784])

In [43]:
def train_semisl(hypernet, optimizer, criterion, loaders, data_size, epochs, masks_no,
                    changing_beta=None,
                    log_to_comet=True,
                    experiment=None,
                    tag="semi-slow-step-hypernet", 
                    device='cuda:0', 
                    project_name="semi-hypernetwork",
                    test_every=5):
    """ Train hypernetwork using 2 masks per iteration, one for x1 (sup & unsup), another for x2 (sup & unsup)"""
    if log_to_comet:
        if experiment is None:
            experiment = Experiment(api_key=os.environ.get("COMET_KEY"), project_name=project_name, display_summary_level=0)
        experiment.add_tag(tag)
        experiment.log_parameter("test_nodes", hypernet.test_nodes)
        experiment.log_parameter("mask_size", hypernet.mask_size)
        experiment.log_parameter("node_hidden_size", hypernet.node_hidden_size)
        experiment.log_parameter("lr", optimizer.defaults['lr'])
        experiment.log_parameter("training_size", data_size)
        experiment.log_parameter("masks_no", masks_no)
        experiment.log_parameter("max_epochs", epochs)
        experiment.log_parameter("check_val_every_n_epoch", test_every)
        experiment.log_parameter("unsupervised_target_wrapper", criterion.unsup_target_wrapper.__name__)

    trainloader, testloader = loaders
    train_loss = []
    test_loss = []
    test_accs = []
    mask_idx = 0
    
    with trange(epochs) as t:
        for epoch in t:
            total_loss = 0
            running_loss = 0.0
            correct = 0
            total = 0
            
            supervised_train_loss = 0.
            unsupervised_train_loss = 0.
            train_denom = 0
    
            hypernet.train()
            
            if changing_beta:
                changing_beta(epoch, criterion)
            
            for i, (sup_data, unsup_data) in enumerate(trainloader):
                    
                sup_inputs, sup_labels = sup_data
                unsup_inputs, _ = unsup_data    
                    
                sup_inputs = sup_inputs.to(device)
                sup_labels = sup_labels.to(device)
                unsup_inputs = unsup_inputs.to(device)
                
                ## f1
                masks1 = []
                for i in range(len(sup_inputs)):
                    masks1.append(hypernet.test_mask[mask_idx])
                masks1 = torch.stack(masks1).to(device)
                mask_idx = (mask_idx+1) % len(hypernet.test_mask)
                
                # supervised
                sup_outputs1 = hypernet(sup_inputs, masks1)
                
                # unsup
                unsup_outputs1 = hypernet(unsup_inputs, masks1)
        
                ## f2
                masks2 = []
                for i in range(len(sup_inputs)):
                    masks2.append(hypernet.test_mask[mask_idx])
                masks2 = torch.stack(masks2).to(device)
                mask_idx = (mask_idx+1) % len(hypernet.test_mask)
                
                # supervised
                sup_outputs2 = hypernet(sup_inputs, masks2)
                
                # unsup
                unsup_outputs2 = hypernet(unsup_inputs, masks2)
                
                
                optimizer.zero_grad()
                
                loss = criterion((sup_outputs1, sup_outputs2, sup_labels), (unsup_outputs1, unsup_outputs2))
                loss.backward()
                optimizer.step()
                

                running_loss += loss.item()
                supervised_train_loss += criterion.supervised_loss
                unsupervised_train_loss += criterion.self_supervised_loss
                train_loss.append(loss.item())
                train_denom += 1
                        
            
            
            if epoch%test_every==0:
                if log_to_comet:
                    experiment.log_metric("beta_coef", criterion.beta, step=epoch)
                    experiment.log_metric('sup_train_loss', supervised_train_loss/train_denom, step=epoch)
                    experiment.log_metric('self_sup_train_loss', unsupervised_train_loss/train_denom, step=epoch)
                    experiment.log_metric('train_loss', running_loss/train_denom, step=epoch)
                
                
                
                # eval
                total_loss = 0
                correct = 0
                denom = 0

                test_criterion = torch.nn.CrossEntropyLoss()
                hypernet.eval()

                for i, data in enumerate(testloader):
                    try:
                        images, labels, _ = data
                    except ValueError:
                        images, labels = data
                    images = images.to(device)
                    labels = labels.to(device)

                    denom += len(labels)

                    outputs = hypernet(images)
                    _, predicted = torch.max(outputs.data, 1)
                    correct += (predicted == labels).sum().item()
                    total_loss += test_criterion(outputs, labels).item()

                test_loss.append(total_loss/denom)
                test_accs.append(correct/denom*100)

                t.set_postfix(test_acc=correct/denom*100, loss=total_loss/i)
                
                if log_to_comet:
                    experiment.log_metric("test_accuracy", correct/len(testloader.dataset)*100, step=epoch)
                    experiment.log_metric("test_loss", test_loss[-1], step=epoch)
                    
                
    experiment.end()
    return max(test_accs), test_loss[np.argmax(test_accs)]

In [44]:
os.environ['COMET_KEY'] = 'UXrV5UxyhTK3cyQNG6BDuc4bE'
os.environ.get("COMET_KEY")

'UXrV5UxyhTK3cyQNG6BDuc4bE'

In [45]:
def changing_beta(epoch, criterion):
    if epoch == 200:
        criterion.beta = 0.5

In [46]:
#beta_lst=torch.arange(0.1, 1, 0.9/epochs)

In [47]:
#beta_lst.size()

In [None]:
epochs = 400


mask_size = 100
masks_no = 200

results = defaultdict(list)
size = 1000

###
for f in [torch.nn.functional.softmax, torch.argmax]:
    criterion = TabSSLCrossEntropyLoss(beta=1., unsup_target_wrapper=f)

    hypernet = hp.Hypernetwork(mask_size=mask_size, node_hidden_size=100, test_nodes=masks_no).cuda()

    hypernet = hypernet.train()
    optimizer = torch.optim.Adam(hypernet.parameters(), lr=3e-4)

    # dataset & loaders
    sup_trainloader, unsup_trainloader, testloader = get_dataset(size)
    trainloader = TrainDataLoaderSemi(sup_trainloader, unsup_trainloader)

    results[size].append(train_semisl(hypernet, 
                                      optimizer, 
                                      criterion, 
                                      (trainloader, testloader), 
                                      size,
                                      epochs,
                                      masks_no,
                                      changing_beta=None,
                                      log_to_comet=True,
                                      tag='semisl_initial_experiment'))

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/abulenok/semi-hypernetwork/f40c3433c66b447aad7abc6e8411d3af

 70%|█████████████████████████████████████                | 280/400 [25:05<08:04,  4.04s/it, loss=2.73, test_acc=32.5]

In [None]:
def changing_beta(epoch, criterion):
    if epoch == 100:
        criterion.beta = 0.3
    if epoch == 190:
        criterion.beta = 0.5
    if epoch == 270:
        criterion.beta = 0.7
    if epoch == 360:
        criterion.beta = 0.9

In [None]:
epochs = 400


mask_size = 100
masks_no = 200

results = defaultdict(list)
size = 1000

###
for f in [torch.nn.functional.softmax, torch.argmax]:
    criterion = TabSSLCrossEntropyLoss(beta=0.1, unsup_target_wrapper=f)

    hypernet = hp.Hypernetwork(mask_size=mask_size, node_hidden_size=100, test_nodes=masks_no).cuda()

    hypernet = hypernet.train()
    optimizer = torch.optim.Adam(hypernet.parameters(), lr=3e-4)

    # dataset & loaders
    sup_trainloader, unsup_trainloader, testloader = get_dataset(size)
    trainloader = TrainDataLoaderSemi(sup_trainloader, unsup_trainloader)

    results[size].append(train_semisl(hypernet, 
                                      optimizer, 
                                      criterion, 
                                      (trainloader, testloader), 
                                      size,
                                      epochs,
                                      masks_no,
                                      changing_beta=changing_beta,
                                      log_to_comet=True,
                                      tag='semisl_initial_experiment'))