In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sys
sys.path.append('/workspace/experiment')
from data.cifar10 import SplitCifar10, train_transform, val_transform
from data.capsule_split import get_splits


In [14]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.05)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class classifier32(nn.Module):
    def __init__(self, num_classes=2, **kwargs):
        super(self.__class__, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3,       64,     3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(64,      64,     3, 1, 1, bias=False)
        self.conv3 = nn.Conv2d(64,     128,     3, 2, 1, bias=False)

        self.conv4 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv5 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv6 = nn.Conv2d(128,    128,     3, 2, 1, bias=False)

        self.conv7 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv8 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv9 = nn.Conv2d(128,    128,     3, 2, 1, bias=False)

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)

        self.bn4 = nn.BatchNorm2d(128)
        self.bn5 = nn.BatchNorm2d(128)
        self.bn6 = nn.BatchNorm2d(128)

        self.bn7 = nn.BatchNorm2d(128)
        self.bn8 = nn.BatchNorm2d(128)
        self.bn9 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128, num_classes)
        self.dr1 = nn.Dropout2d(0.2)
        self.dr2 = nn.Dropout2d(0.2)
        self.dr3 = nn.Dropout2d(0.2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.apply(weights_init)

    def forward(self, x, return_features=[]):
        batch_size = len(x)
        out_feat = []

        x = self.dr1(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv3(x)
        x = self.bn3(x)
        l1 = nn.LeakyReLU(0.2)(x)
        
        if 0 in return_features:
            out_feat.append(l1)

        x = self.dr2(l1)
        x = self.conv4(x)
        x = self.bn4(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv5(x)
        x = self.bn5(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv6(x)
        x = self.bn6(x)
        l2 = nn.LeakyReLU(0.2)(x)
        
        if 1 in return_features:
            out_feat.append(l2)

        x = self.dr3(l2)
        x = self.conv7(x)
        x = self.bn7(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv8(x)
        x = self.bn8(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv9(x)
        x = self.bn9(x)
        l3 = nn.LeakyReLU(0.2)(x)
        
        
        l3 = self.avgpool(l3)
        l3 = l3.view(batch_size, -1)
        
        if 2 in return_features:
            out_feat.append(l3)
        
        y = self.fc1(l3)
        
        if len(return_features) > 0:
            if len(return_features) == 1:
                out_feat = out_feat[0]
            return y, out_feat

        return y, []

In [15]:
splits = get_splits('cifar10', 0)

In [16]:
dataset = SplitCifar10('/datasets', train=True, transform=train_transform,
                       split=splits['known_classes'])
loader = DataLoader(dataset, batch_size=32, num_workers=8, shuffle=False)

In [17]:
model = classifier32(len(splits['known_classes']))
model

classifier32(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (conv7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv9): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [18]:
tInput, tTarget = next(iter(loader))
tInput, tTarget = tInput, tTarget

In [19]:
logits, (l0, l1, l2) = model(tInput, return_features=[0, 1, 2])

In [87]:
class NoiseLayer(nn.Module):
    def __init__(self, alpha, num_classes):
        super(NoiseLayer, self).__init__()
        self.alpha = alpha
        self.num_classes = torch.arange(num_classes)
        
    def calculate_class_mean(self, 
                           x: torch.Tensor, 
                           y: torch.Tensor):
        """calculate the variance of each classes' noise

        Args:
            x (torch.Tensor): [input tensor]
            y (torch.Tensor): [target tensor]

        Returns:
            [Tensor]: [returns class dependent noise variance]
        """
        self.num_classes = self.num_classes.type_as(y)
        idxs = y.unsqueeze(0) == self.num_classes.unsqueeze(1)
        mean = []
        std = []
        for i in range(self.num_classes.shape[0]):
            x_ = x[idxs[i]]
            mean.append(x_.mean(0))
            std.append(x_.std(0))
        
        return torch.stack(mean), torch.stack(std)
    
    def forward(self, x, y):
        batch_size = x.size(0)
        class_mean, class_var = self.calculate_class_mean(x, y)
        
        # class_noise = torch.normal(class_mean, class_var).type_as(y).detach()
        class_noise = torch.normal(mean=class_mean, std=class_var).type_as(x).detach()

        index = torch.randperm(batch_size).type_as(y)
        newY = y[index]
        mask = y != newY
        if x.dim() == 2:
            mask = mask.unsqueeze(1).expand_as(x).type_as(x)
        else:
            mask = mask[...,None,None,None].expand_as(x).type_as(x)
        
        # print(x.shape)
        # print(class_noise[newY].shape)
        # print(mask.shape)
        # print(mask.dtype)
        
        return x + self.alpha * class_noise[newY] * mask
        
        
        
            

In [88]:
NT = NoiseLayer(0.1, 6)

In [95]:
class_mean, class_var = NT.calculate_class_mean(l2, tTarget)

In [96]:
class_noise = NT(l2, tTarget)

In [97]:
class_noise.shape

torch.Size([32, 128])

tensor([[[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False,

In [56]:
class_mean, class_std = NT.calculate_class_mean(l0, tTarget)
class_noise = torch.normal(class_mean, class_std).type_as(l0)

In [31]:
index = torch.randperm(l0.size(0))
newT = tTarget[index]
mask = tTarget != newT


In [39]:
mask[...,None,None,None].expand_as(l0)

tensor([[[[ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True]],

         [[ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True]],

         [[ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ...,  True,  True,

In [36]:
l0 + class_noise[newT]

tensor([[[[-2.7810e-03, -1.5414e+00, -1.5510e-01,  ...,  3.3025e-01,
            3.3629e-01,  5.6767e-02],
          [ 7.6202e-01, -1.2339e+00, -6.7145e-01,  ...,  3.1326e-01,
            7.9746e-01, -1.2930e+00],
          [ 8.2759e-01, -8.1531e-02,  1.9091e-02,  ...,  7.0440e-01,
           -2.7701e-01, -9.8359e-01],
          ...,
          [ 6.1855e-01,  1.1218e+00,  1.9784e+00,  ..., -3.4991e-01,
            5.4372e-01, -6.6887e-01],
          [ 7.4629e-01,  2.3558e+00,  4.8192e-01,  ...,  2.3959e+00,
            3.8946e-01,  8.4480e-01],
          [-2.0766e-01,  9.8926e-01,  8.4654e-01,  ...,  1.0619e+00,
            2.5451e-01, -1.3099e-01]],

         [[-6.2811e-01,  4.7220e-01, -5.6474e-01,  ..., -1.0992e-01,
           -3.7443e-01, -2.2504e-01],
          [-4.5362e-01, -9.0952e-01,  1.5988e-01,  ...,  9.0601e-01,
            2.4684e+00, -1.0118e-01],
          [ 2.5926e-01,  9.9467e-01,  1.3104e+00,  ...,  1.6846e+00,
            6.5253e-01,  1.0658e+00],
          ...,
     

In [26]:
index

tensor([13,  0, 12, 25,  8,  3, 24,  9, 23, 11,  7,  6, 17, 26, 30, 31, 10, 22,
         2, 19,  5, 18,  4, 28, 21, 16,  1, 20, 14, 27, 15, 29])

In [19]:
l0.device

device(type='cuda', index=0)

RuntimeError: CUDA error: device-side assert triggered