In [1]:
import PI
meta_params = {
    'num_of_train_dataset': 1000,
    'num_of_test_dataset': 100,
    'is_flatten': False
}

PI = PI.PIInterface(meta_params)

import torch
import torch.nn as nn
import torch.optim as optim
from MNIST_models import *

model = load_model('store/MNIST_CNN.pt')
PI.set_model(model)
# print('train acc:', PI.eval_model('train'))
# print('test acc:', PI.eval_model('test'))

1000 100
(100, 1, 28, 28)




In [2]:
'''
Each convolutional layer is followed by batch normalization (BN) 
and activation function (ReLU for generator and Leaky ReLU for discriminator)

For Generator:

For Discriminator: softmax + MSE(||softmax(D(x))-label(one hot)||^2)


Implemented:
shuffle=True (unsure)
mini-batch size: 256
Inputs 32 * 32 (padding 2)
concat layer for dense block in generator 
softmax for discriminator 

Pending:
draw samples during training 
compute acc 
'''
import torchvision
import torchvision.transforms as transforms

# create loader with mini-batch size 256 and padding images to 32*32
trans = transforms.Compose([
    transforms.Pad(padding=2),
    transforms.ToTensor()
])
dataset=torchvision.datasets.MNIST(root='.data', train=True, transform=trans, download=True)
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=256, shuffle=True)

In [3]:
import torch.nn as nn
import torch.optim as optim 

class Generator(nn.Module):
    def __init__(self, num_of_features):
        super().__init__()
        self.relu = nn.ReLU()
        
        # input layer
        self.input_deconv = nn.ConvTranspose2d(1, num_of_features, 4)
        
        # dense block 
        self.dense_conv1 = nn.Conv2d(num_of_features, num_of_features, 1)
        self.dense_conv2 = nn.Conv2d(num_of_features, num_of_features, 3, padding=1)
        self.dense_bn = nn.BatchNorm2d(num_of_features)
        
        # transition layer
        self.trans_conv = nn.Conv2d(num_of_features, num_of_features, 1)
        self.trans_deconv = nn.ConvTranspose2d(num_of_features, num_of_features, 2, stride=2)
        self.trans_bn = nn.BatchNorm2d(num_of_features)
        
        # output layer 
        self.output_conv = nn.Conv2d(num_of_features, 1, 1)
        self.output_bn = nn.BatchNorm2d(1)
        
    def dense_block(self, x):
        x2 = self.dense_conv1(x)
        x2 = self.relu(self.dense_bn(x2))
        x3 = self.dense_conv2(x2)
        x3 = self.relu(self.dense_bn(x3))
        return x3 + x
    
    def transition_layer(self, x):
        x2 = self.trans_conv(x)
        x2 = self.relu(self.trans_bn(x2))
        x3 = self.trans_deconv(x2)
        return x3
    
    def output_layer(self, x):
        x2 = self.output_conv(x)
        x2 = self.relu(self.output_bn(x2))
        return x2
    
    def forward(self, x):
        x = self.input_deconv(x)
        
        x = self.dense_block(x)
        x = self.transition_layer(x)
        
        x = self.dense_block(x)
        x = self.transition_layer(x)
        
        x = self.dense_block(x)
        x = self.transition_layer(x)
        
        x = self.output_layer(x)
        return x 

class Discriminator(nn.Module):
    def __init__(self, num_of_features):
        super().__init__()
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        
        # input layer
        self.input_conv = nn.Conv2d(1, num_of_features, 1)
        self.input_bn = nn.BatchNorm2d(num_of_features)
        
        # dense block 
        self.dense_conv1 = nn.Conv2d(num_of_features, num_of_features, 1)
        self.dense_conv2 = nn.Conv2d(num_of_features, num_of_features, 3, padding=1)
        self.dense_bn = nn.BatchNorm2d(num_of_features)
        
        # transition layer
        self.trans_conv = nn.Conv2d(num_of_features, num_of_features, 1)
        self.trans_pool = nn.MaxPool2d(2, stride=2)
        self.trans_bn = nn.BatchNorm2d(num_of_features)
        
        # output layer 
        self.output_fc = nn.Linear(32*4*4, 2)
    
    def input_layer(self, x):
        x2 = self.input_conv(x)
        x2 = self.relu(self.input_bn(x2))
        return x2
    
    def dense_block(self, x):
        x2 = self.dense_conv1(x)
        x2 = self.relu(self.dense_bn(x2))
        x3 = self.dense_conv2(x2)
        x3 = self.relu(self.dense_bn(x3))
        return x3 + x
    
    def transition_layer(self, x):
        x2 = self.trans_conv(x)
        x2 = self.relu(self.trans_bn(x2))
        x3 = self.trans_pool(x2)
        return x3
    
    def forward(self, x):
        x = self.input_layer(x)
        
        x = self.dense_block(x)
        x = self.transition_layer(x)

        x = self.dense_block(x)
        x = self.transition_layer(x)
        
        x = self.dense_block(x)
        x = self.transition_layer(x)
        
        x = x.view(-1, 32*4*4)
        x = self.output_fc(x)
        
        return self.softmax(x)
    
num_of_features = 32
draw_interval = 5
G = Generator(num_of_features)
D = Discriminator(num_of_features)

G_optim = torch.optim.Adam(G.parameters(), lr=1e-3)
D_optim = torch.optim.Adam(D.parameters(), lr=1e-4)
loss_func = nn.MSELoss()
import matplotlib.pyplot as plt 

# pretrained G 
# pending 

# Game between G and D 
for i, (X, Y) in enumerate(loader):
    if i < 10:
        noises = torch.randn((256, 1, 1, 1))
        G_X = G(noises)

        S_D_X = D(X)
        labels = torch.zeros(S_D_X.shape)
        labels[:, 0] = 1.
        DX_loss = loss_func(S_D_X, labels)

        S_D_GX = D(G_X)
        labels = torch.zeros(S_D_GX.shape)
        labels[:, 1] = 1.
        DGX_loss = loss_func(S_D_X, labels)

        D_loss = DX_loss + DGX_loss
        D_optim.zero_grad()
        D_loss.backward(retain_graph=True)
        D_optim.step()

        G_optim.zero_grad()
        G_loss = 1 - DGX_loss
        G_loss.backward()
        G_optim.step()

        print(i+1, 'Gen loss:', round(G_loss.item(), 3), 'Dis loss:', round(D_loss.item(), 3), 'Dis (DX) loss:', round(DX_loss.item(), 3), 'Dis (DGX) loss:', round(DGX_loss.item(), 3))

#         if i % draw_interval == 0:
#             img_counter = 1
#             img_bound = 4

#             for x in X:
#                 x = x[:, 2:-2, 2:-2]
#                 img = x.reshape(28, 28)
                
#                 plt.subplot(2, 2, img_counter)
#                 img_counter += 1
                
#                 if img_counter == img_bound: break 
            
            




1 Gen loss: 0.8 Dis loss: 0.529 Dis (DX) loss: 0.329 Dis (DGX) loss: 0.2
2 Gen loss: 0.773 Dis loss: 0.523 Dis (DX) loss: 0.295 Dis (DGX) loss: 0.227
3 Gen loss: 0.722 Dis loss: 0.521 Dis (DX) loss: 0.243 Dis (DGX) loss: 0.278
4 Gen loss: 0.711 Dis loss: 0.521 Dis (DX) loss: 0.232 Dis (DGX) loss: 0.289
5 Gen loss: 0.708 Dis loss: 0.523 Dis (DX) loss: 0.231 Dis (DGX) loss: 0.292
6 Gen loss: 0.717 Dis loss: 0.522 Dis (DX) loss: 0.239 Dis (DGX) loss: 0.283
7 Gen loss: 0.717 Dis loss: 0.519 Dis (DX) loss: 0.236 Dis (DGX) loss: 0.283
8 Gen loss: 0.749 Dis loss: 0.519 Dis (DX) loss: 0.267 Dis (DGX) loss: 0.251
9 Gen loss: 0.733 Dis loss: 0.521 Dis (DX) loss: 0.253 Dis (DGX) loss: 0.267
10 Gen loss: 0.753 Dis loss: 0.52 Dis (DX) loss: 0.273 Dis (DGX) loss: 0.247
