In [1]:
import torch
import torch.nn as nn

In [11]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.05)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class NoiseEncoder(nn.Module):
    def __init__(self, 
                 channel_in:  int,
                 channel_out: int,
                 kernel_size: int,
                 stride:      int,
                 padding:     int,
                 bias:        bool=False, 
                 num_classes: int=6,
                 alpha: float=1.0):
        
        super(self.__class__, self).__init__()
        self.conv = nn.Conv2d(channel_in, channel_out, kernel_size, stride, padding, bias=bias)
        self.bn = nn.BatchNorm2d(channel_out)
        self.bn_noise = nn.BatchNorm2d(channel_out)
        self.activation = nn.LeakyReLU(0.2)
        self.num_classes = num_classes
        self.register_buffer('buffer', None)
        self.alpha = alpha
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x
    
    def forward_clean(self, x, class_mask):
        x = self.conv(x)
        self.cal_class_per_std(x, class_mask)
        x = self.bn(x)
        x = self.activation(x)
        return x
        
    def forward_noise(self, x, newY):
        x = self.conv(x)
        x = x + self.alpha * torch.normal(mean=0, std=self.buffer[newY]).type_as(x)
        x = self.bn_noise(x)
        x = self.activation(x)
        return x
    
    def cal_class_per_std(self, x, idxs):
        std = []
        for i in range(self.num_classes):
            x_ = x[idxs[i]].detach().clone()
            
            std.append(x_.std(0))
        self.buffer = torch.stack(std)

class classifier32(nn.Module):
    def __init__(self, num_classes=2, alpha=1.0, **kwargs):
        super(self.__class__, self).__init__()
        self.num_classes = num_classes
        self.alpha = alpha
        # self.conv1 = NoiseEncoder(3,     64,    3, 1, 1, bias=False, num_classes=num_classes)
        # self.conv2 = NoiseEncoder(64,    64,    3, 1, 1, bias=False, num_classes=num_classes)
        # self.conv3 = NoiseEncoder(64,   128,    3, 2, 1, bias=False, num_classes=num_classes)
        self.block1 = nn.Sequential(
            nn.Dropout2d(0.2),
            NoiseEncoder(3,     64,    3, 1, 1, bias=False, num_classes=num_classes),
            NoiseEncoder(64,    64,    3, 1, 1, bias=False, num_classes=num_classes),
            NoiseEncoder(64,   128,    3, 2, 1, bias=False, num_classes=num_classes)
        )
        self.block2 = nn.Sequential(
            nn.Dropout2d(0.2),
            NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes),
            NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes),
            NoiseEncoder(128,  128,    3, 2, 1, bias=False, num_classes=num_classes)
        )
        self.block3 = nn.Sequential(
            nn.Dropout2d(0.2),
            NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes),
            NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes),
            NoiseEncoder(128,  128,    3, 2, 1, bias=False, num_classes=num_classes), 
        )
        
        # self.conv4 = NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes)
        # self.conv5 = NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes)
        # self.conv6 = NoiseEncoder(128,  128,    3, 2, 1, bias=False, num_classes=num_classes)
        
        # self.conv7 = NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes)
        # self.conv8 = NoiseEncoder(128,  128,    3, 1, 1, bias=False, num_classes=num_classes)
        # self.conv9 = NoiseEncoder(128,  128,    3, 2, 1, bias=False, num_classes=num_classes)
        
        self.fc = nn.Linear(128, num_classes * (num_classes - 1))
        # self.dr1 = nn.Dropout2d(0.2)
        # self.dr2 = nn.Dropout2d(0.2)
        # self.dr3 = nn.Dropout2d(0.2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.buffer = None
        self.position = 0
        self.apply(weights_init)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        logit = self.fc(x)
        return logit
 
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64):
        super(self.__class__, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            # 100, 64, 3
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            
            nn.Conv2d(ngf * 4, ngf * 2, 3, 1, 1, bias=False),
            # nn.BatchNorm2d(ngf * 2),
            # nn.ReLU(True),
            nn.Sigmoid(),
            
            
            # nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ngf * 2),
            # nn.ReLU(True),
            
            
            # # state size. (ngf*2) x 16 x 16
            # nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
            # nn.Sigmoid()
            # # state size. (nc) x 32 x 32
        )
        
        
    def forward(self, x):
        output = self.main(x)
        return output
    
class Discriminator(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(self.__class__, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channel, out_channel, 3, 2, 1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Linear(out_channel, 1)
        self.activation = nn.Sigmoid()
        
        
    def forward(self, x):
        output = self.main(x)
        output = self.avgpool(output)
        output = output.view(output.size(0), -1)
        output = self.classifier(output).flatten()
        output = self.activation(output)
        return output

In [4]:
model = classifier32()

In [5]:
x = torch.randn(32, 3, 32, 32)

In [6]:
l1 = model.block1(x)

In [7]:
l1.shape

torch.Size([32, 128, 16, 16])

In [9]:
l2 = model.block2(l1)

In [10]:
l2.shape

torch.Size([32, 128, 8, 8])

In [12]:
G = Generator()

In [15]:
noise = torch.FloatTensor(32, 100, 1, 1).normal_(0, 1)

In [17]:
feature = G(noise)

In [18]:
feature.shape

torch.Size([32, 128, 8, 8])

In [39]:
y = torch.randint(0, 10, (5,))

In [40]:
idxs = (y.unsqueeze(1) == torch.arange(11).unsqueeze(0))

In [49]:
fake_distribution = torch.zeros(5, 11)

In [50]:
fake_distribution[idxs] = 0.8

In [52]:
fake_distribution[:, -1] = 1

In [47]:
fake_distribution[idxs]= 1

In [54]:
fake_distribution.softmax(-1)

tensor([[0.0717, 0.0717, 0.0717, 0.0717, 0.1596, 0.0717, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949],
        [0.1596, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949],
        [0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717,
         0.1596, 0.1949],
        [0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.1596, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949],
        [0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.1596, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949]])

In [55]:
def kldiv(P, Q):
    return (P * (P / Q).log()).sum(-1)

In [56]:
test_logit = torch.randn(5, 11)

In [77]:
kldiv(test_logit.softmax(-1), test_logit.softmax(-1)).mean()

tensor(0.)

In [72]:
kld = nn.KLDivLoss(reduction='batchmean')

In [76]:
kld(test_logit.softmax(-1).log(), test_logit.softmax(-1))

tensor(0.)

In [75]:
test_logit.softmax(-1)
fake_distribution.softmax(-1)

tensor([[0.0717, 0.0717, 0.0717, 0.0717, 0.1596, 0.0717, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949],
        [0.1596, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949],
        [0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717,
         0.1596, 0.1949],
        [0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.1596, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949],
        [0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.1596, 0.0717, 0.0717, 0.0717,
         0.0717, 0.1949]])