In [1]:
%cd hypernet

/home/z1157095/hypernet-cnn/hypernet


In [2]:
from dotenv import load_dotenv
load_dotenv()

import random
import os

In [3]:
from comet_ml import Experiment, Optimizer

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data_utils
import pandas as pd
from collections import defaultdict

torch.set_default_dtype(torch.float32)

In [4]:
from torchsummary import summary
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [5]:
import tabular_hypernet as hp
from tabular_hypernet.semisl import SSLCELossWithThreshold

In [6]:
hp.training_utils.get_dataset

<function tabular_hypernet.training_utils.get_dataset(size=60000, masked=False, mask_no=200, mask_size=700, shared_mask=False, batch_size=32, test_batch_size=32)>

In [7]:
os.environ['COMET_KEY'] = 'UXrV5UxyhTK3cyQNG6BDuc4bE'
os.environ.get("COMET_KEY")

'UXrV5UxyhTK3cyQNG6BDuc4bE'

In [8]:
class SSLCEPseudoLabel(torch.nn.Module):
    def __init__(self, beta=0.1):
        super(SSLCEPseudoLabel, self).__init__()
        
        self.y_f1 = torch.nn.CrossEntropyLoss()
        self.y_f2 = torch.nn.CrossEntropyLoss()
        
        self.f1_f2 = torch.nn.CrossEntropyLoss()
        self.f2_f1 = torch.nn.CrossEntropyLoss()
        
        self.beta = beta
    
    def forward(self, sup_input, unsup_input):
        sup_outputs1, sup_outputs2, sup_labels = sup_input
        unsup_outputs1, unsup_outputs2, unsup_labels = unsup_input
        
        self.supervised_loss = self.y_f1(sup_outputs1, sup_labels) + self.y_f2(sup_outputs2, sup_labels)
        
        self.self_supervised_loss = self.f1_f2(unsup_outputs1, unsup_labels) \
                                + self.f2_f1(unsup_outputs2, unsup_labels)
        
        return self.supervised_loss + self.beta * self.self_supervised_loss

In [9]:
def permute(t):
    idx = torch.randperm(t.shape[0])
    return t[idx]

In [10]:
def train_semisl_N(hypernet, optimizer, criterion, loaders, data_size, epochs, masks_no,
                    changing_beta=None,
                    log_to_comet=True,
                    experiment=None,
                    shuffled_masks=True,                   
                    device='cuda:0', 
                    tags=["semi-slow-step-hypernet"],
                    project_name="semi-hypernetwork",
                    test_every=5,
                    description=None,
                    log_params={}
                ):
    """ Train hypernetwork using 2 masks per iteration, one for x1 (sup & unsup), another for x2 (sup & unsup)"""
    trainloader, testloader = loaders
    
    print('!! log to comet is', log_to_comet, '\n')
    
    if log_to_comet:
        if experiment is None:
            experiment = Experiment(api_key=os.environ.get("COMET_KEY"), project_name=project_name, display_summary_level=0)
        experiment.add_tags(tags)
        experiment.log_parameter("test_nodes", hypernet.test_nodes)
        experiment.log_parameter("mask_size", hypernet.mask_size)
#         experiment.log_parameter("node_hidden_size", hypernet.node_hidden_size)
        experiment.log_parameter("lr", optimizer.defaults['lr'])
        experiment.log_parameter("training_size", sum(data_size))
        experiment.log_parameter("sup_train_size", data_size[0])
        experiment.log_parameter("masks_no", masks_no)
        experiment.log_parameter("max_epochs", epochs)
        experiment.log_parameter("check_val_every_n_epoch", test_every)
        experiment.log_parameter("train_batch_size", trainloader.batch_size)
        experiment.log_parameter("test_batch_size", testloader.batch_size)
        
        for log_par_k in log_params.keys():
            experiment.log_parameter(log_par_k, log_params[log_par_k])

        
        if description: 
            experiment.log_text(description)      
    
    train_loss = []
    test_loss = []
    test_accs = []
    mask_idx = 0
    
    with trange(epochs) as t:
        for epoch in t:
            total_loss = 0
            running_loss = 0.0
            correct = 0
            total = 0
            
            supervised_train_loss = 0.
            unsupervised_train_loss = 0.
            train_denom = 0
    
            hypernet.train()
            
            if changing_beta:
                changing_beta(epoch, criterion)

            
            for i, (sup_data, unsup_data) in enumerate(trainloader):
                    
                sup_inputs, sup_labels = sup_data
                unsup_inputs, _ = unsup_data    
                    
                sup_inputs = sup_inputs.to(device)
                sup_labels = sup_labels.to(device)
                unsup_inputs = unsup_inputs.to(device)
                
                hypernet.eval()
                unsup_outputs = hypernet(unsup_inputs)
                _, unsup_predicted = torch.max(unsup_outputs.data, 1)
                hypernet.train()
                
                ## f1
                masks1 = []
                for _ in range(len(sup_inputs)):
                    masks1.append(hypernet.test_mask[mask_idx])
                masks1 = torch.stack(masks1).to(device)
                mask_idx += 1
                
                if shuffled_masks:
                    if mask_idx >= len(hypernet.test_mask):
                        hypernet.test_mask = permute(hypernet.test_mask)
                        mask_idx = 0
                else:
                    mask_idx %= len(hypernet.test_mask)
                
                # supervised
                sup_outputs1 = hypernet(sup_inputs, masks1)
                
                # unsupervised
                unsup_outputs1 = hypernet(unsup_inputs, masks1)
        
                ## f2
                masks2 = []
                for _ in range(len(sup_inputs)):
                    masks2.append(hypernet.test_mask[mask_idx])
                masks2 = torch.stack(masks2).to(device)
                mask_idx += 1
                
                if shuffled_masks:
                    if mask_idx >= len(hypernet.test_mask):
                        hypernet.test_mask = permute(hypernet.test_mask)
                        mask_idx = 0
                else:
                    mask_idx %= len(hypernet.test_mask)
                
                # supervised
                sup_outputs2 = hypernet(sup_inputs, masks2)
                
                # unsupervised
                unsup_outputs2 = hypernet(unsup_inputs, masks2)
                
                
                optimizer.zero_grad()
                
                loss = criterion(
                    (sup_outputs1, sup_outputs2, sup_labels), 
                    (unsup_outputs1, unsup_outputs2, unsup_predicted)
                )
                
                loss.backward()
                optimizer.step()
                

                running_loss += loss.item()
                supervised_train_loss += criterion.supervised_loss
                unsupervised_train_loss += criterion.self_supervised_loss
                train_loss.append(loss.item())
                train_denom += 1
                        
            
            
            if epoch%test_every==0:
                if log_to_comet:
                    experiment.log_metric("beta_coef", criterion.beta, step=epoch)
                    experiment.log_metric('sup_train_loss', supervised_train_loss/train_denom, step=epoch)
                    experiment.log_metric('self_sup_train_loss', unsupervised_train_loss/train_denom, step=epoch)
                    experiment.log_metric('train_loss', running_loss/train_denom, step=epoch)
                
                
                
                # eval
                total_loss = 0
                correct = 0
                denom = 0

                test_criterion = torch.nn.CrossEntropyLoss()
                hypernet.eval()

                for i, data in enumerate(testloader):
                    try:
                        images, labels, _ = data
                    except ValueError:
                        images, labels = data
                    images = images.to(device)
                    labels = labels.to(device)

                    denom += len(labels)

                    outputs = hypernet(images)
                    _, predicted = torch.max(outputs.data, 1)
                    correct += (predicted == labels).sum().item()
                    total_loss += test_criterion(outputs, labels).item()

                test_loss.append(total_loss/denom)
                test_accs.append(correct/denom*100)

                t.set_postfix(test_acc=correct/denom*100, loss=total_loss/i)
                
                if log_to_comet:
                    experiment.log_metric("test_accuracy", correct/len(testloader.dataset)*100, step=epoch)
                    experiment.log_metric("test_loss", test_loss[-1], step=epoch)
    
    if experiment:
        experiment.end()
                                 
    return max(test_accs), test_loss[np.argmax(test_accs)]

In [11]:
def changing_beta(epoch, criterion):
    if epoch == 0:
        criterion.beta = 0
    elif epoch == 1:
        criterion.beta = 0.1
    elif epoch == 90:
        criterion.beta = 1.
        

### Setup for training

In [12]:
seed = 5

In [13]:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

In [14]:
mask_size = 100

In [15]:
dataset = hp.semisl.get_train_test_sets()

In [16]:
epochs = 100

masks_no = 100


results = defaultdict(list)
size = (100, 59900)

In [17]:
#68.6   1.78

### Test pseudolabeling

In [18]:
for lr in [3e-5]:

    criterion = SSLCEPseudoLabel(beta=0)

    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

    # dataset


    hypernet = hp.Hypernetwork(
        architecture=torch.nn.Sequential(
            torch.nn.Linear(784, 64), 
            torch.nn.ReLU(),
            torch.nn.Linear(64, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
        ),
        target_architecture=[(mask_size, 100), (100, 10)],
        test_nodes=masks_no,
    ).cuda()


    hypernet = hypernet.train()



    optimizer = torch.optim.Adam(hypernet.parameters(), lr=lr)

    # loaders
    sup_trainloader, unsup_trainloader, testloader = hp.semisl.get_dataloaders(dataset=dataset, size=size, batch_size=32, test_batch_size=64)
    trainloader = hp.semisl.TrainDataLoaderSemi(sup_trainloader, unsup_trainloader)

    results[size].append(
                        train_semisl_N(
                            hypernet,
                            optimizer,
                            criterion,
                            (trainloader, testloader), 
                            size,
                            epochs,
                            masks_no,
                            changing_beta=changing_beta,
                            log_to_comet=True,
                            tags=['pseudolabel_initial'],
                            description=
                            """
                            test pseudolabeling. Labels produced in hypernet.eval mode
                            """,
                            log_params={'seed': seed, 'pseudolabels': True}
                        )
                        )


torch.Size([1, 128])




!! log to comet is True 



COMET INFO: Experiment is live on comet.ml https://www.comet.com/abulenok/semi-hypernetwork/93479305d81f4e9287301b75f6fb2d11

100%|█████████████████████████████████████████████████████████████████| 100/100 [1:55:17<00:00, 69.17s/it, loss=1.78, test_acc=70.1]
COMET ERROR: Error sending a notification, make sure you have opted-in for notifications
COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...


In [19]:
for lr in [3e-5]:

    criterion = SSLCELossWithThreshold(beta=0)

    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

    # dataset


    hypernet = hp.Hypernetwork(
        architecture=torch.nn.Sequential(
            torch.nn.Linear(784, 64), 
            torch.nn.ReLU(),
            torch.nn.Linear(64, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
        ),
        target_architecture=[(mask_size, 100), (100, 10)],
        test_nodes=masks_no,
    ).cuda()


    hypernet = hypernet.train()



    optimizer = torch.optim.Adam(hypernet.parameters(), lr=lr)

    # loaders
    sup_trainloader, unsup_trainloader, testloader = hp.semisl.get_dataloaders(dataset=dataset, size=size, batch_size=32, test_batch_size=64)
    trainloader = hp.semisl.TrainDataLoaderSemi(sup_trainloader, unsup_trainloader)

    results[size].append(
                        hp.semisl.train_semisl(
                            hypernet,
                            optimizer,
                            criterion,
                            (trainloader, testloader), 
                            size,
                            epochs,
                            masks_no,
                            changing_beta=None,
                            log_to_comet=True,
                            tags=['pseudolabel_initial'],
                            description=
                            """
                            test pseudolabeling. Labels produced in hypernet.eval mode
                            """,
                            log_params={'seed': seed, 'pseudolabels': False}
                        )
                        )




torch.Size([1, 128])
!! log to comet is True 



COMET INFO: Experiment is live on comet.ml https://www.comet.com/abulenok/semi-hypernetwork/0309012d15074e1a9b2ea3e5a84abdfa

100%|███████████████████████████████████████████████████████████████████| 100/100 [31:47<00:00, 19.08s/it, loss=1.74, test_acc=72.7]
COMET ERROR: Error sending a notification, make sure you have opted-in for notifications
COMET INFO: Uploading 1 metrics, params and output messages
