In [2]:
%cd hypernet

/home/z1157095/hypernet-cnn/hypernet


In [3]:
from dotenv import load_dotenv
load_dotenv()

import os

In [4]:
from comet_ml import Experiment, Optimizer

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data_utils
import pandas as pd
from collections import defaultdict

torch.set_default_dtype(torch.float32)

In [5]:
from torchsummary import summary
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [6]:
import tabular_hypernet as hp

In [7]:
hp.training_utils.get_dataset

<function tabular_hypernet.training_utils.get_dataset(size=60000, masked=False, mask_no=200, mask_size=700, shared_mask=False, batch_size=32, test_batch_size=32)>

In [8]:
os.environ['COMET_KEY'] = 'UXrV5UxyhTK3cyQNG6BDuc4bE'
os.environ.get("COMET_KEY")

'UXrV5UxyhTK3cyQNG6BDuc4bE'

In [9]:
hp.semisl.get_dataset

<function tabular_hypernet.semisl.get_dataset(size=(100, 900), mask_no=200, mask_size=700, batch_size=32, test_batch_size=32, shuffle_train=True)>

In [10]:
class TabSSLCrossEntropyLoss(torch.nn.Module):
    def __init__(self, beta=0.1, unsup_target_wrapper=torch.nn.functional.softmax, threshold=None):
        super(TabSSLCrossEntropyLoss, self).__init__()
        
        self.y_f1 = torch.nn.CrossEntropyLoss()
        self.y_f2 = torch.nn.CrossEntropyLoss()
        
        self.f1_f2 = torch.nn.CrossEntropyLoss()
        self.f2_f1 = torch.nn.CrossEntropyLoss()
        
        self.beta = beta
        self.unsup_target_wrapper = unsup_target_wrapper
        self.threshold = threshold
        
    def is_observ_above_threshold(self, data):
        mask = torch.any(data >= self.threshold, dim=1)
        
        return mask
        
    
    def forward(self, sup_input, unsup_input):
        sup_outputs1, sup_outputs2, sup_labels = sup_input
        unsup_outputs1, unsup_outputs2 = unsup_input
        
        self.supervised_loss = self.y_f1(sup_outputs1, sup_labels) + self.y_f2(sup_outputs2, sup_labels)
        
        self.self_supervised_loss = 0
        if self.beta:
            if self.threshold:
                unsup_outputs1_target = torch.nn.functional.softmax(unsup_outputs1, dim=1)
                mask1 = self.is_observ_above_threshold(unsup_outputs1_target)

                if len(unsup_outputs1_target[mask1]):
                    unsup_outputs1_target = torch.argmax(unsup_outputs1_target[mask1], dim=1)
                    self.self_supervised_loss += self.f2_f1(unsup_outputs2[mask1], unsup_outputs1_target)

                unsup_outputs2_target = torch.nn.functional.softmax(unsup_outputs2, dim=1)
                mask2 = self.is_observ_above_threshold(unsup_outputs2_target)

                if len(unsup_outputs2_target[mask2]):
                    unsup_outputs2_target = torch.argmax(unsup_outputs2_target[mask2], dim=1)
                    self.self_supervised_loss += self.f1_f2(unsup_outputs1[mask2], unsup_outputs2_target)

            else:
                self.self_supervised_loss = self.f1_f2(unsup_outputs1, self.unsup_target_wrapper(unsup_outputs2, dim=1)) \
                                    + self.f2_f1(unsup_outputs2, self.unsup_target_wrapper(unsup_outputs1, dim=1))      
        
        return self.supervised_loss + self.beta * self.self_supervised_loss


In [11]:
epochs = 1


mask_size = 100
masks_no = 50

results = defaultdict(list)
size = (100, 59900)
seed = 5

In [12]:
for beta in [0.]:
    for threshold in [0.]:
        for lr in [3e-3]:
            criterion = TabSSLCrossEntropyLoss(beta=beta, threshold=threshold)
            
            np.random.seed(seed)
            hypernet = hp.Hypernetwork(mask_size=mask_size, node_hidden_size=100, test_nodes=masks_no).cuda()
            
            hypernet = hypernet.train()
            
            
            
            optimizer = torch.optim.Adam(hypernet.parameters(), lr=lr)

            # dataset & loaders
            sup_trainloader, unsup_trainloader, testloader = hp.semisl.get_dataset(size, batch_size=32, test_batch_size=64)
            trainloader = hp.semisl.TrainDataLoaderSemi(sup_trainloader, unsup_trainloader)

            results[size].append(hp.semisl.train_semisl(hypernet,
                                              optimizer,
                                              criterion,
                                              (trainloader, testloader), 
                                              size,
                                              epochs,
                                              masks_no,
                                              changing_beta=None,
                                              log_to_comet=False,
                                              tags=['fixed masks_GS_threshold,beta,lr'],
                                              description="""
                                              Grid search with threshold, beta and lr. Fixed masks.
                                              """,
                                            log_params={'seed': seed}
                                            ))
            

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [02:15<00:00, 135.01s/it, loss=1.82, test_acc=68.9]


In [13]:
print(hypernet)

Hypernetwork(
  (input): Linear(in_features=784, out_features=64, bias=True)
  (hidden1): Linear(in_features=64, out_features=256, bias=True)
  (hidden2): Linear(in_features=256, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=11110, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)


In [14]:
print(hypernet._get_test_nets())

[InsertableNet(
  (inp): Linear(in_features=100, out_features=100, bias=True)
  (output): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=None)
  (sigmoid): Sigmoid()
  (out_act): Sigmoid()
), InsertableNet(
  (inp): Linear(in_features=100, out_features=100, bias=True)
  (output): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=None)
  (sigmoid): Sigmoid()
  (out_act): Sigmoid()
), InsertableNet(
  (inp): Linear(in_features=100, out_features=100, bias=True)
  (output): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=None)
  (sigmoid): Sigmoid()
  (out_act): Sigmoid()
), InsertableNet(
  (inp): Linear(in_features=100, out_features=100, bias=True)
  (output): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=None)
  (sigmoid): Sigmoid()
  (out_act): Sigmoid()
), InsertableNet(
  (inp): Linear(in_features=100,