In [None]:
import os
import numpy as np
import math
import sys
from time import time
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import matplotlib.pyplot as plt

In [None]:
dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "data/MNIST",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, ), std=(0.5,))])
    ),
    batch_size=128,
    shuffle=True,
)

In [None]:
class Opt(object):
    n_epochs = 40
    Diters = 5
    batchSize = 128
    lr = 0.00005
    n_cpu = 1
    latent_dim = 100
    imageSize = 32
    channels = 1
    n_critic = 5
    experiment = "experiments"
    sample_interval = 400

In [None]:
#Weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channel = 100,
             out_channel = 1,
             feature_map = 128,
             kernal_size = 3,
             image_size = 28):
        super(Generator, self).__init__()
        
        #Upsample random variable
        layers = nn.ModuleList()
        layers += [nn.ConvTranspose2d(in_channel, feature_map, 4, 1, 0)]
        layers += [nn.BatchNorm2d(feature_map)]
        layers += [nn.ReLU(True)]
        
        size = 4
        
        layers += [nn.ConvTranspose2d(feature_map, feature_map // 2, kernal_size, 2, 1, bias = False)]
        layers += [nn.BatchNorm2d(feature_map // 2)]
        layers += [nn.ReLU(True)]
        feature_map = feature_map // 2
        size = size * 2
        
        #Main G structure
        while size < image_size // 2:
            layers += [nn.ConvTranspose2d(feature_map, feature_map // 2, 4, 2, 1, bias = False)]
            layers += [nn.BatchNorm2d(feature_map // 2)]
            layers += [nn.ReLU(True)]
            feature_map = feature_map // 2
            size = size * 2
        
        #Final layer
        layers += [nn.ConvTranspose2d(feature_map, out_channel, 4, 2, 1, bias = False)]
        layers += [nn.Tanh()]
        
        self.g = nn.Sequential(*layers)
        
    def forward(self, z):
        return self.g(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channel = 1,
             out_channel = 1,
             feature_map = 32,
             kernal_size = 3,
             image_size = 28):
        super(Discriminator, self).__init__()
        
        #First layer
        layers = nn.ModuleList()
        layers += [nn.Conv2d(in_channel, feature_map, 4, 2, 1)]
        layers += [nn.BatchNorm2d(feature_map)]
        layers += [nn.LeakyReLU(0.2, inplace=True)]
        
        size = image_size / 2
        
        #Main D structure
        while size > 8:
            layers += [nn.Conv2d(feature_map, feature_map * 2, 4, 2, 1, bias = False)]
            layers += [nn.BatchNorm2d(feature_map * 2)]
            layers += [nn.LeakyReLU(0.2, inplace=True)]
            feature_map = feature_map * 2
            size = size / 2
            
        layers += [nn.Conv2d(feature_map, feature_map * 2, kernal_size, 2, 1, bias = False)]
        layers += [nn.BatchNorm2d(feature_map * 2)]
        layers += [nn.LeakyReLU(0.2, inplace=True)]
        feature_map = feature_map * 2
        size = size / 2
        
        #Final layer
        layers += [nn.Conv2d(feature_map, out_channel, 4, 1, 0, bias = False)]
        
        self.d = nn.Sequential(*layers)
        
    def forward(self, image):
        return self.d(image)  

In [None]:
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
opt = Opt()
cuda = True if torch.cuda.is_available() else False

generator_optimizer = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
discriminator_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
batches_done = 0
saved_imgs = []
input = torch.FloatTensor(opt.batchSize, 1, opt.imageSize, opt.imageSize)
noise = torch.FloatTensor(opt.batchSize, opt.latent_dim)
fixed_noise = torch.FloatTensor(opt.batchSize, opt.latent_dim).normal_(0, 1)
one = torch.FloatTensor([1])
mone = one * -1

gen_iterations = 0
for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = Variable(imgs.type(Tensor))
        real_imgs = real_imgs.to(device)
        for _ in range(opt.Diters):
            noise = torch.randn(opt.batchSize,opt.latent_dim, 1,1,device=device)
            fake = generator(noise).detach()
            errD_fake = torch.mean(discriminator(fake))
            errD_real = torch.mean(discriminator(real_imgs))
            errD = errD_fake - errD_real
            discriminator.zero_grad()
            errD.backward()
            discriminator_optimizer.step()
            for p in discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)
                
        gen_fake = generator(noise)
        errG = -torch.mean(discriminator(gen_fake))        
        generator.zero_grad()
        errG.backward()
        generator_optimizer.step()
        gen_iterations += 1

        print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
            % (epoch, opt.n_epochs, i, len(dataloader), gen_iterations,
            errD.item(), errG.item(), errD_real.item(), errD_fake.item()))
        if gen_iterations % 300 == 0:
            real_imgs = real_imgs.mul(0.5).add(0.5)
            save_image(real_imgs, '{0}/real_samples.png'.format(opt.experiment))
            with torch.no_grad():
                fake = generator(noise)
            fake.data = fake.data.mul(0.5).add(0.5)
            save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))