In [1]:
import torch
import torch.nn as nn

In [19]:
class NoiseLayer(nn.Module):
    def __init__(self, alpha, num_classes):
        super(NoiseLayer, self).__init__()
        self.alpha = alpha
        self.num_classes = torch.arange(num_classes)
        
    def calculate_class_mean(self, 
                           x: torch.Tensor, 
                           y: torch.Tensor):
        """calculate the variance of each classes' noise

        Args:
            x (torch.Tensor): [input tensor]
            y (torch.Tensor): [target tensor]

        Returns:
            [Tensor]: [returns class dependent noise variance]
        """
        self.num_classes = self.num_classes.type_as(y)
        idxs = y.unsqueeze(0) == self.num_classes.unsqueeze(1)
        mean = []
        std = []
        for i in range(self.num_classes.shape[0]):
            x_ = x[idxs[i]]
            mean.append(x_.mean(0))
            std.append(x_.std(0))
        
        return torch.stack(mean), torch.stack(std)
    
    def forward(self, x, y):
        batch_size = x.size(0)
        class_mean, class_var = self.calculate_class_mean(x, y)
        
        class_noise = torch.normal(mean=class_mean, std=class_var).type_as(x).detach()
        # class_noise = torch.normal(mean=0., std=class_var).type_as(x).detach()

        index = torch.randperm(batch_size).type_as(y)
        newY = y[index]
        mask = y != newY
        if x.dim() == 2:
            mask = mask.unsqueeze(1).expand_as(x).type_as(x)
        else:
            mask = mask[...,None,None,None].expand_as(x).type_as(x)
        
        return ((1 - self.alpha) * x + self.alpha * class_noise[newY]), newY
   

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False,
                     dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1,
                     stride=stride, bias=False)

class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class ResNet34(nn.Module):
    def __init__(self, in_channels=3, inplanes=64, alpha=0.5, num_classes=10,):
        super(ResNet34, self).__init__()
        self.inplanes = inplanes
        self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, 
                               stride=1, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.layer1 = self._make_layer(64, 64, 3)
        self.layer2 = self._make_layer(64, 128, 4, stride=2)
        self.layer3 = self._make_layer(128, 256, 6, stride=2)
        self.layer4 = self._make_layer(256, 512, 3)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.NL = NoiseLayer(alpha=alpha, num_classes=num_classes)

        # Init layers
        self._init_layers()
        
    def _init_layers(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', 
                                        nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, inplanes, planes, blocks, stride=1, downsample=None):
        if stride != 1 or inplanes != planes:
            downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride),
                nn.BatchNorm2d(planes),
            )
        layers = [BasicBlock(inplanes, planes, stride, downsample)]
        for _ in range(1, blocks):
            layers += [BasicBlock(planes, planes)]
        return nn.Sequential(*layers)

    def forward(self, x, y, noise=[]):
        ny = None
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x1 = self.layer1(x)
        if 0 in noise:
            x1, ny = self.NL(x1, y)
        x2 = self.layer2(x1)
        if 1 in noise:
            x2, ny = self.NL(x2, y)
        x3 = self.layer3(x2)
        if 2 in noise:
            x3, ny = self.NL(x3, y)
        x = self.layer4(x3)
        if 3 in noise:
            x, ny = self.NL(x, y)
            
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        out = {
            'x_l1': x1,
            'x_l2': x2,
            'x_l3': x3,
            'x_f': x,
            'ny': ny
        }
        return out


In [20]:
res = ResNet34(3, 64)

In [21]:
out = res(torch.randn(1, 3, 32, 32), torch.randint(0, 10, (1,)), noise=[1,])

In [24]:
out['x_f'].shape

torch.Size([1, 512])