In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import numpy as np

%matplotlib notebook
import matplotlib.pyplot as plt

#debug use
import pdb

In [2]:
manualSeed = 999
torch.manual_seed(manualSeed)

<torch._C.Generator at 0x18030c2aa70>

In [3]:
trans = transforms.Compose([transforms.Resize(64),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,))
                           ])

In [4]:
train_set = torchvision.datasets.MNIST(root="./", train=True, transform=trans, download=True)
test_set = torchvision.datasets.MNIST(root="./", train=False, transform=trans, download=True)

In [5]:
batch_size = 128

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)

test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

In [6]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [14]:
class DNet(nn.Module):
    def __init__(self, d=64):
        super(DNet, self).__init__()
        self.conv1 = nn.Conv2d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
        self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0)
        
        self.bn2 = nn.BatchNorm2d(d*2)
        self.bn3 = nn.BatchNorm2d(d*4)
        self.bn4 = nn.BatchNorm2d(d*8)

    def forward(self, x):
        x = self.conv1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.conv2(x)
        x = F.leaky_relu(self.bn2(x), 0.2)
        x = self.conv3(x)
        x = F.leaky_relu(self.bn3(x), 0.2)
        x = self.conv4(x)
        x = F.leaky_relu(self.bn4(x), 0.2)
        x = self.conv5(x)
        
        return torch.sigmoid(x)
    
    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    
class GNet(nn.Module):
    def __init__(self, d=64):
        super(GNet, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(100, d*8, 4, stride=1, padding=0)
        self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(d, 1, 4, stride=2, padding=1)
        
        self.bn1 = nn.BatchNorm2d(d*8)
        self.bn2 = nn.BatchNorm2d(d*4)
        self.bn3 = nn.BatchNorm2d(d*2)
        self.bn4 = nn.BatchNorm2d(d*1)
        
    def weight_init(self):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    def forward(self, x):
        x = self.deconv1(x)
        x = F.relu(self.bn1(x))
        x = self.deconv2(x)
        x = F.relu(self.bn2(x))
        x = self.deconv3(x)
        x = F.relu(self.bn3(x))
        x = self.deconv4(x)
        x = F.relu(self.bn4(x))
        x = self.deconv5(x)
        x = torch.tanh(x)
        
        return x
    
    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


In [15]:
D = DNet()
G = GNet()
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
loss_func = nn.BCELoss()

In [16]:
D_optimz = optim.Adam(D.parameters(), 2e-4, betas=(0.5, 0.999))
G_optimz = optim.Adam(G.parameters(), 2e-4, betas=(0.5, 0.999))

In [19]:
iter_cnt = 0
D.cuda()
G.cuda()

for epoch in range(20):
        for batch_idx, (batch_x, _) in enumerate(train_loader):

            batch_x = batch_x.cuda()
            
            # labels
            y_real = torch.ones((batch_x.shape[0])).cuda()
            y_fake = torch.zeros((batch_x.shape[0])).cuda()
            
            # D loss calculation
            D_out = D(batch_x).squeeze()
            D_real_loss = loss_func(D_out, y_real)
            
            # noise
            z = torch.randn(batch_x.shape[0], 100, 1, 1).cuda()
            
            G_out = G(z)
            D_out = D(G_out).squeeze()
            D_fake_loss = loss_func(D_out, y_fake)
            
            # total loss for D
            D_loss = D_real_loss + D_fake_loss
            
            # backward
            D_optimz.zero_grad()
            D_loss.backward()
            D_optimz.step()
            
            #---------------
            # G net training
            #---------------
            z = torch.randn(batch_x.shape[0], 100, 1, 1).cuda()
            
            G_out = G(z)
            D_out = D(G_out).squeeze()
            G_loss = loss_func(D_out, y_real)
            
            # backward
            G_optimz.zero_grad()
            G_loss.backward()
            G_optimz.step()
            
            
            if iter_cnt % 100 == 0:
                print("Iter ", iter_cnt, " D_Loss ", D_loss.item(), " G_Loss ", G_loss.item())
            
            iter_cnt += 1

Iter  0  D_Loss  0.15480080246925354  G_Loss  3.2910728454589844
Iter  100  D_Loss  0.5700128674507141  G_Loss  2.5970005989074707
Iter  200  D_Loss  0.1091262698173523  G_Loss  3.0037841796875
Iter  300  D_Loss  0.23417961597442627  G_Loss  3.381990671157837
Iter  400  D_Loss  0.0933799296617508  G_Loss  3.5035383701324463
Iter  500  D_Loss  0.3483656644821167  G_Loss  2.208047866821289
Iter  600  D_Loss  0.07351014018058777  G_Loss  3.3296594619750977
Iter  700  D_Loss  0.1822715401649475  G_Loss  3.428154468536377
Iter  800  D_Loss  0.5098050236701965  G_Loss  2.201408863067627
Iter  900  D_Loss  0.09959743171930313  G_Loss  3.256833553314209
Iter  1000  D_Loss  0.15817564725875854  G_Loss  3.657238721847534
Iter  1100  D_Loss  0.06941404193639755  G_Loss  4.3300580978393555
Iter  1200  D_Loss  0.31714528799057007  G_Loss  3.5159683227539062
Iter  1300  D_Loss  0.04931776970624924  G_Loss  3.886507034301758
Iter  1400  D_Loss  0.7073432803153992  G_Loss  1.1119482517242432
Iter  150

In [25]:
test_net = G.cpu()

In [35]:
noise = torch.randn(10, 100, 1, 1)
images = test_net(noise)
plt.imshow(images[0].squeeze().detach().numpy())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1801b17f588>

In [24]:
plt.imshow(batch_x[0].cpu().squeeze().numpy())

<matplotlib.image.AxesImage at 0x18000a36cc0>