In [0]:
import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable


In [0]:
class Generator(nn.Module):
  def __init__(self, d=128):
    super(Generator, self).__init__()
    self.deconv1 = nn.ConvTranspose2d(100,d*8,4,1,0)
    self.bn1 = nn.BatchNorm2d(d*8)
    self.deconv2 = nn.ConvTranspose2d(d*8,d*4,4,2,1)
    self.bn2 = nn.BatchNorm2d(d*4)
    self.deconv3 = nn.ConvTranspose2d(d*4,d*2,4,2,1)
    self.bn3 = nn.BatchNorm2d(d*2)
    self.deconv4 = nn.ConvTranspose2d(d*2,d,4,2,1)
    self.bn4 = nn.BatchNorm2d(d)
    self.deconv5 = nn.ConvTranspose2d(d,1,4,2,1)
    
  def forward(self, x):
    x = F.relu(self.bn1(self.deconv1(x)))
    x = F.relu(self.bn2(self.deconv2(x)))
    x = F.relu(self.bn3(self.deconv3(x)))
    x = F.relu(self.bn4(self.deconv4(x)))
    x = F.tanh(self.deconv5(x)) # GAN_hacks as it works better 
    return x

In [0]:
class Discriminator(nn.Module):
  def __init__(self, d=128):
    super(Discriminator, self).__init__()
    self.conv1 = nn.Conv2d(1,d,4,2,1)
    self.conv2 = nn.Conv2d(d,d*2,4,2,1)
    self.bn2 = nn.BatchNorm2d(d*2)
    self.conv3 = nn.Conv2d(d*2,d*4,4,2,1)
    self.bn3 = nn.BatchNorm2d(d*4)
    self.conv4 = nn.Conv2d(d*4,d*8,4,2,1)
    self.bn4 =nn.BatchNorm2d(d*8)
    self.conv5 = nn.Conv2d(d*8,1,4,1,0)
    
  def forward(self, x):
    x = F.leaky_relu(self.conv1(x),0.2)
    x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
    x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
    x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2)
    x = F.sigmoid(self.conv5(x))
    return x

In [0]:
fixed_z_ = torch.randn((5*5,100)).view(-1,100,1,1)
fixed_z_ = Variable(fixed_z_.cuda(), volatile = True)

  


In [0]:
def show_result(num_epoch, show = False, save = True, path = 'result.png', isFix=False):
    z_ = torch.randn((5*5, 100)).view(-1, 100, 1, 1)
    z_ = Variable(z_.cuda(), volatile=True)

    G.eval()
    if isFix:
        test_images = G(fixed_z_)
    else:
        test_images = G(z_)
    G.train()

    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(5*5):
        i = k // 5
        j = k % 5
        ax[i, j].cla()
        ax[i, j].imshow(test_images[k, 0].cpu().data.numpy(), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

In [0]:
def show_train_hist(hist, show = False, save = True, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Iter')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()


In [0]:
#Training Parameters
batch_size = 128
lr = 0.0002
epochs = 2

In [0]:
img_size = 64
transform = transforms.Compose([transforms.Scale(img_size),transforms.ToTensor(),
                               transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])

train_loader = torch.utils.data.DataLoader(datasets.MNIST('data',train = 'True',
                                                         download=True,transform=transform),
                                          batch_size=batch_size,shuffle = True)

  "please use transforms.Resize instead.")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [0]:
# Defining Model
G = Generator(128)
D = Discriminator(128)
G.cuda()
D.cuda()


Discriminator(
  (conv1): Conv2d(1, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1))
)

In [0]:
# Define Loss Function
loss = nn.BCELoss()

# Define Adam optimizers
G_optimizer = optim.Adam(G.parameters(),lr = lr, betas=(0.5,0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas = (0.5,0.999))

In [0]:
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []

In [0]:
num_iters = 0
print("Training Starts!!")
start_time = time.time()
for epoch in range(epochs):
  D_losses = []
  G_losses = []
  epoch_start_time = time.time()
  for x_,_ in train_loader:
    #  Train Discriminator   D:G == 1:1
    D.zero_grad()
    mini_batch = x_.size()[0]
    y_real_ = torch.ones(mini_batch)
    y_fake_ = torch.zeros(mini_batch)
    x_, y_real_, y_fake_ = Variable(x_.cuda()),Variable(y_real_.cuda()),Variable(y_fake_.cuda())
    D_result = D(x_).squeeze()
    D_real_loss = loss(D_result, y_real_)
    
    z_sampled = torch.randn((mini_batch,100)).view(-1,100,1,1)
    z_sampled = Variable(z_sampled.cuda())
    G_result = G(z_sampled)
    D_result = D(G_result).squeeze()
    D_fake_loss = loss(D_result,y_fake_)
    D_fake_score = D_result.data.mean()
    
    D_total_loss = D_real_loss + D_fake_loss
    D_total_loss.backward()
    D_optimizer.step()
    
    D_losses.append(D_total_loss.item())
    
    #   Train Generator
    G.zero_grad()
    z_sampled = torch.randn((mini_batch,100)).view(-1,100,1,1)
    z_sampled = Variable(z_sampled.cuda())
    G_result = G(z_sampled)
    D_result = D(G_result).squeeze()
    G_loss = loss(D_result, y_real_)
    G_loss.backward()
    G_optimizer.step()
    
    G_losses.append(G_loss.item())
    
    num_iters += 1
    if num_iters%20 == 0:
      print("Iterations occured :",num_iters," D_loss :",D_total_loss.item()," G_loss :",G_loss.item())
  
  epoch_end_time = time.time()
  
  per_epoch_time = epoch_end_time - epoch_start_time
  
  print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), epochs, per_epoch_time, torch.mean(torch.FloatTensor(D_losses)),
                                                              torch.mean(torch.FloatTensor(G_losses))))
  p = str(epoch + 1) + '.png'
  fixed_p = str(epoch + 1) + '.png'
  show_result((epoch+1), save=True, path=p, isFix=False)
  show_result((epoch+1), save=True, path=fixed_p, isFix=True)
  train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
  train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
  train_hist['per_epoch_ptimes'].append(per_epoch_time)
  




Training Starts!!




Iterations occured : 20  D_loss : 0.015724992379546165  G_loss : 6.936120510101318
Iterations occured : 40  D_loss : 0.005871393717825413  G_loss : 6.826488494873047
Iterations occured : 60  D_loss : 0.00560090783983469  G_loss : 7.0915045738220215
Iterations occured : 80  D_loss : 1.317993402481079  G_loss : 2.5987672805786133
Iterations occured : 100  D_loss : 0.41862428188323975  G_loss : 2.426618814468384
Iterations occured : 120  D_loss : 0.6542184352874756  G_loss : 2.273446559906006
Iterations occured : 140  D_loss : 1.296753168106079  G_loss : 1.817238688468933
Iterations occured : 160  D_loss : 0.8057482838630676  G_loss : 4.284806728363037
Iterations occured : 180  D_loss : 1.4548671245574951  G_loss : 1.9282777309417725
Iterations occured : 200  D_loss : 1.162239670753479  G_loss : 1.8851103782653809
Iterations occured : 220  D_loss : 0.8248867988586426  G_loss : 1.6100919246673584
Iterations occured : 240  D_loss : 0.8248025178909302  G_loss : 0.45100805163383484
Iterations

  This is separate from the ipykernel package so we can avoid doing imports until


Iterations occured : 480  D_loss : 1.0464916229248047  G_loss : 2.5303244590759277
Iterations occured : 500  D_loss : 0.6727755069732666  G_loss : 2.981663703918457
Iterations occured : 520  D_loss : 0.8999950885772705  G_loss : 1.108676552772522
Iterations occured : 540  D_loss : 1.2794677019119263  G_loss : 3.2648234367370605
Iterations occured : 560  D_loss : 0.9696221947669983  G_loss : 3.142409563064575
Iterations occured : 580  D_loss : 1.0331957340240479  G_loss : 2.346543550491333
Iterations occured : 600  D_loss : 0.9195282459259033  G_loss : 0.8211212754249573
Iterations occured : 620  D_loss : 0.7833235263824463  G_loss : 2.170616388320923
Iterations occured : 640  D_loss : 0.7825809717178345  G_loss : 1.0260231494903564
Iterations occured : 660  D_loss : 1.1929975748062134  G_loss : 3.477729320526123
Iterations occured : 680  D_loss : 0.8346942663192749  G_loss : 2.6266255378723145
Iterations occured : 700  D_loss : 0.8791630864143372  G_loss : 2.3139796257019043
Iterations

In [0]:
end_time = time.time()
total_time = end_time - start_time
train_hist['total_ptime'].append(total_time)
print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), epochs, total_time))
print("Training finish!... save training results")
from google.colab import files
torch.save(G.state_dict(), "generator_param.pkl")
files.download("./generator_param.pkl")
torch.save(D.state_dict(), "discriminator_param.pkl")
files.download("./discriminator_param.pkl")
with open('train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)
files.download('./train_hist.pkl')
show_train_hist(train_hist, save=True, path='MNIST_DCGAN_train_hist.png')

    

Avg per epoch ptime: 821.55, total 7 epochs ptime: 5768.17
Training finish!... save training results


KeyboardInterrupt: ignored

In [0]:
ls

In [0]:

images = []
for e in range(epochs):
    img_name =str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('generation_animation.gif', images, fps=5)
files.download('./generation_animation.gif')    