In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sys
sys.path.append('/workspace/experiment')
from data.cifar10 import SplitCifar10, train_transform, val_transform
from data.capsule_split import get_splits

import os


In [4]:
weight_path = '/workspace/experiment/log/NoiseInjection/cifar10_s0/0'
os.path.exists(weight_path)

checkpoint = torch.load(weight_path+'/epoch=99-step=11799.ckpt')

In [7]:
class NoiseLayer(nn.Module):
    def __init__(self, alpha, num_classes):
        super(NoiseLayer, self).__init__()
        self.alpha = alpha
        self.num_classes = torch.arange(num_classes)
        
    def calculate_class_mean(self, 
                           x: torch.Tensor, 
                           y: torch.Tensor):
        """calculate the variance of each classes' noise

        Args:
            x (torch.Tensor): [input tensor]
            y (torch.Tensor): [target tensor]

        Returns:
            [Tensor]: [returns class dependent noise variance]
        """
        self.num_classes = self.num_classes.type_as(y)
        idxs = y.unsqueeze(0) == self.num_classes.unsqueeze(1)
        mean = []
        std = []
        for i in range(self.num_classes.shape[0]):
            x_ = x[idxs[i]]
            mean.append(x_.mean(0))
            std.append(x_.std(0))
        
        return torch.stack(mean), torch.stack(std)
    
    def forward(self, x, y):
        batch_size = x.size(0)
        class_mean, class_var = self.calculate_class_mean(x, y)
        
        # class_noise = torch.normal(mean=class_mean, std=class_var).type_as(x).detach()
        class_noise = torch.normal(mean=0., std=class_var).type_as(x).detach()

        index = torch.randperm(batch_size).type_as(y)
        newY = y[index]
        mask = y != newY
        if x.dim() == 2:
            mask = mask.unsqueeze(1).expand_as(x).type_as(x)
        else:
            mask = mask[...,None,None,None].expand_as(x).type_as(x)
        
        return x + self.alpha * class_noise[newY], newY

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.05)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class classifier32(nn.Module):
    def __init__(self, num_classes=2, alpha=0.5, **kwargs):
        super(self.__class__, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3,       64,     3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(64,      64,     3, 1, 1, bias=False)
        self.conv3 = nn.Conv2d(64,     128,     3, 2, 1, bias=False)

        self.conv4 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv5 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv6 = nn.Conv2d(128,    128,     3, 2, 1, bias=False)

        self.conv7 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv8 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv9 = nn.Conv2d(128,    128,     3, 2, 1, bias=False)

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)

        self.bn4 = nn.BatchNorm2d(128)
        self.bn5 = nn.BatchNorm2d(128)
        self.bn6 = nn.BatchNorm2d(128)

        self.bn7 = nn.BatchNorm2d(128)
        self.bn8 = nn.BatchNorm2d(128)
        self.bn9 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128, num_classes)
        self.dr1 = nn.Dropout2d(0.2)
        self.dr2 = nn.Dropout2d(0.2)
        self.dr3 = nn.Dropout2d(0.2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.apply(weights_init)
        self.noiseLayer = NoiseLayer(alpha, num_classes)

    def forward(self, x, y, return_features=[], noise=[]):
        batch_size = len(x)
        out_feat = []

        x = self.dr1(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv3(x)
        x = self.bn3(x)
        l1 = nn.LeakyReLU(0.2)(x)
        
        if 0 in return_features:
            out_feat.append(l1)
            
        if 0 in noise:
            l1, ny = self.noiseLayer(l1, y)

        x = self.dr2(l1)
        x = self.conv4(x)
        x = self.bn4(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv5(x)
        x = self.bn5(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv6(x)
        x = self.bn6(x)
        l2 = nn.LeakyReLU(0.2)(x)
        
        if 1 in return_features:
            out_feat.append(l2)
            
        if 1 in noise:
            l2, ny = self.noiseLayer(l2, y)

        x = self.dr3(l2)
        x = self.conv7(x)
        x = self.bn7(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv8(x)
        x = self.bn8(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv9(x)
        x = self.bn9(x)
        l3 = nn.LeakyReLU(0.2)(x)
        
        l3 = self.avgpool(l3)
        l3 = l3.view(batch_size, -1)
        
        if 2 in return_features:
            out_feat.append(l3)
        
        if 2 in noise:
            l3, ny = self.noiseLayer(l3, y)
        
        y = self.fc1(l3)
        
        if len(return_features) > 0:
            if len(return_features) == 1:
                out_feat = out_feat[0]

        if len(noise) > 0:
            return y, out_feat, ny
        
        return y, out_feat

In [5]:
checkpoint.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'hparams_name', 'hyper_parameters'])

In [6]:
checkpoint['state_dict']

odict_keys(['model.conv1.weight', 'model.conv2.weight', 'model.conv3.weight', 'model.conv4.weight', 'model.conv5.weight', 'model.conv6.weight', 'model.conv7.weight', 'model.conv8.weight', 'model.conv9.weight', 'model.bn1.weight', 'model.bn1.bias', 'model.bn1.running_mean', 'model.bn1.running_var', 'model.bn1.num_batches_tracked', 'model.bn2.weight', 'model.bn2.bias', 'model.bn2.running_mean', 'model.bn2.running_var', 'model.bn2.num_batches_tracked', 'model.bn3.weight', 'model.bn3.bias', 'model.bn3.running_mean', 'model.bn3.running_var', 'model.bn3.num_batches_tracked', 'model.bn4.weight', 'model.bn4.bias', 'model.bn4.running_mean', 'model.bn4.running_var', 'model.bn4.num_batches_tracked', 'model.bn5.weight', 'model.bn5.bias', 'model.bn5.running_mean', 'model.bn5.running_var', 'model.bn5.num_batches_tracked', 'model.bn6.weight', 'model.bn6.bias', 'model.bn6.running_mean', 'model.bn6.running_var', 'model.bn6.num_batches_tracked', 'model.bn7.weight', 'model.bn7.bias', 'model.bn7.running_m

In [16]:
model = classifier32(num_classes=6, alpha=0.5)
model_state_dict = model.state_dict()
for name, param in checkpoint['state_dict'].items():
    n = name.replace('model.', '')
    if n in model_state_dict.keys():
        model_state_dict[n].copy_(param)

In [19]:
dummyFC = nn.Linear(128, 6)
dummyFC_SD = dummyFC.state_dict()
dummyFC_SD['weight'].copy_(checkpoint['state_dict']['dummyFC.weight'])
dummyFC_SD['bias'].copy_(checkpoint['state_dict']['dummyFC.bias'])

tensor([-0.0103, -0.0271, -0.0646, -0.0680, -0.0014, -0.0277])

In [29]:
split = get_splits('cifar10', 0)
known_data = SplitCifar10('/datasets', train=False, transform=val_transform, split=split['known_classes'])
open_data = SplitCifar10('/datasets', train=False, transform=val_transform, split=split['unknown_classes']) 

In [30]:
known_loader = DataLoader(known_data, batch_size=128, shuffle=False, num_workers=4)
open_loader = DataLoader(open_data, batch_size=128, shuffle=False, num_workers=4)

In [31]:
tInput, tTarget = next(iter(known_loader))
model.cuda()
dummyFC.cuda()
tInput, tTarget = tInput.cuda(), tTarget.cuda()

In [32]:
logit, features = model(tInput, tTarget, return_features=[2,])

In [33]:
dummy_output = dummyFC(features)

In [39]:
logit.mean(0)

tensor([13.0603, 13.2038, 13.8284, 13.4362, 14.1509, 13.9361], device='cuda:0',
       grad_fn=<MeanBackward1>)

In [40]:
dummy_output.mean(0)

tensor([-0.1048,  0.2672,  0.1681, -0.1234,  0.4077, -0.0413], device='cuda:0',
       grad_fn=<MeanBackward1>)

In [41]:
oInput, oTarget = next(iter(open_loader))
oInput, oTarget = oInput.cuda(), oTarget.cuda()
ologit, ofeatures = model(oInput, oTarget, return_features=[2,])
oDummy_output = dummyFC(ofeatures)

In [42]:
oDummy_output.mean(0)

tensor([-0.1784,  0.2836,  0.1552, -0.1031,  0.4289, -0.0826], device='cuda:0',
       grad_fn=<MeanBackward1>)

In [44]:
ologit.mean(0)

tensor([13.3796, 13.6405, 14.2513, 13.9078, 14.6790, 14.2687], device='cuda:0',
       grad_fn=<MeanBackward1>)

In [62]:
a = torch.norm(ologit, dim=1)

In [63]:
b = torch.norm(oDummy_output, dim=1)

In [64]:
mrl = nn.MarginRankingLoss(margin=5)

In [68]:
mrl(oDummy_output, ologit, torch.ones(1).cuda())

tensor(18.9372, device='cuda:0', grad_fn=<MeanBackward0>)