In [None]:
# reference: https://github.com/pytorch/examples/blob/master/dcgan/main.py
#
import os
import itertools
import argparse
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision 
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [15]:
nz = 100
nf = 128
nc = 1
data_root = '../data'
figure_root = './figures'
model_root = './models'
image_size = 64
batch_size = 64
lr = 0.0002
beta1 = 0.5
n_epochs = 10
n_batches_print = 100

load_weights_generator = ''
load_weights_discriminator = ''

In [None]:
class Generator(nn.Module):

    def __init__(self, nz, nf, nc):
        """
            nz      dimension of noise 
            nf      dimension of features in last conv layer
            nc      number of channels in the image

            In DCGAN paper for LSUN dataset, nz=100, nf=128, nc=3
        """
        super(Generator, self).__init__()

        def block(in_channels, out_channels, stride=2, padding=1, batch_norm=True, nonlinearity=nn.ReLU(True)):
            """ stride=1, padding=0: H_out = H_in + 3       # 1 -> 4
                stride=2, padding=1: H_out = 2 * H_in       # doubles
            """
            return [
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=padding, bias=False),
                *( [nn.BatchNorm2d(out_channels)] if batch_norm else [] ),
                nonlinearity,
            ]

        self.model = nn.Sequential(
            # (nz)   x 1 x 1
            *block(nz,   8*nf, stride=1, padding=0),
            # (8*nf) x 4 x 4
            *block(8*nf, 4*nf),
            # (4*nf) x 8 x 8
            *block(4*nf, 2*nf),
            # (2*nf) x 16 x 16
            *block(2*nf,   nf),
            # (nf) x 32 x 32
            *block(nf,     nc, batch_norm=False, nonlinearity=nn.Tanh()),
            # (nc) x 64 x 64
        )

    def forward(self, z):
        """
            z       (N, nz, 1, 1)
                noise vector
            Returns (N, nc, h, w)
                image generated from model distribution
                
        """
        return self.model(z)
    
    
class Discriminator(nn.Module):
    
    def __init__(self, nc, nf):
        """
            nc      number of channels in the image
            nf      dimension of features of first conv layer

            In DCGAN paper for LSUN dataset, nc=3
        """
        super(Discriminator, self).__init__()
        
        def block(in_channels, out_channels,
                  stride=2, padding=1,
                  batch_norm=True,
                  nonlinearity=nn.LeakyReLU(0.2, inplace=True)):
            """ stride=1, padding=0: H_out = H_in - 3              # 4 -> 1
                stride=2, padding=1: H_out = floor((H_in-1)/2 +1)  # roughly halves
            """
            return [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=padding, bias=False),
                *( [nn.BatchNorm2d(out_channels)] if batch_norm else [] ),
                nonlinearity,
            ]
        
        self.model = nn.Sequential(
            # (nc) x 64 x 64
            *block(nc,     nf, batch_norm=False),
            # (nf) x 32 x 32
            *block(nf,   2*nf),
            # (2*nf) x 16 x 16
            *block(2*nf, 4*nf),
            # (4*nf) x 8 x 8
            *block(4*nf, 8*nf),
            # (8*nf) x 4 x 4
            *block(8*nf, 1, stride=1, padding=0, batch_norm=False, nonlinearity=nn.Sigmoid()),
            # 1 x 1 x 1
        )
        
        
    def forward(self, x):
        """
            x        (N, nc, h, w)
            Returns  (N,)
                classification probability that x comes from data distribution
        """
        x = self.model(x)
        return  x.view(-1, 1).squeeze(1)
        
        


# custom weights initialization called on G/D
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [None]:
trainset = datasets.MNIST(root=data_root, download=True,
                   transform=transforms.Compose([
                       transforms.Resize(image_size),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,)),
                   ]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
x = next(iter(trainloader))

im = x[0][0]
plt.imshow(np.transpose(im,(1,2,0)).squeeze())

torch.mean(x[0]), torch.std(x[0])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = Generator(nz, nf, nc).to(device)
G.apply(weights_init)
if load_weights_generator != '':
    G.load_state_dict(torch.load(load_weights_generator))
    
    
D = Discriminator(nc, nf).to(device)
D.apply(weights_init)
if load_weights_discriminator != '':
    D.load_state_dict(torch.load(load_weights_discriminator))

In [None]:
print(G,D)

In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))

os.makedirs(data_root, exist_ok=True)
os.makedirs(model_root, exist_ok=True)
os.makedirs(figure_root, exist_ok=True)

In [None]:
for epoch in range(n_epochs):
    
    for it, (x_real, _) in enumerate(trainloader):
        
        # batch_size for last batch might be different ...
        batch_size = x_real.size(0)
        real_labels = torch.full((batch_size,), real_label, device=device)
        fake_labels = torch.full((batch_size,), fake_label, device=device)
        
        ##############################################################
        # Update Discriminator: Maximize E[log(D(x))] + E[log(1 - D(G(z)))]
        ##############################################################
        
        D.zero_grad()
        
        # a minibatch of samples from data distribution
        x_real = x_real.to(device)
        
        y = D(x_real)
        loss_D_real = criterion(y, real_labels)
        loss_D_real.backward()
        
        D_x = y.mean().item()
        
        
        # a minibatch of samples from the model distribution
        z = torch.randn(batch_size, nz, 1, 1, device=device)
        
        x_fake = G(z)
        # https://github.com/pytorch/examples/issues/116
        # If we do not detach, then, although x_fake is not needed for gradient update of D,
        #   as a consequence of backward pass which clears all the variables in the graph
        #   graph for G will not be available for gradient update of G
        # Also for performance considerations, detaching x_fake will prevent computing 
        #   gradients for parameters in G
        y = D(x_fake.detach())
        loss_D_fake = criterion(y, fake_labels)
        loss_D_fake.backward()
        
        D_G_z1 = y.mean().item()
        loss_D = loss_D_real + loss_D_fake
        
        optimizerD.step()
         
        ##############################################################
        # Update Generator: Minimize E[log(1 - D(G(z)))] => Maximize E[log(D(G(z))))]
        ##############################################################
        
        G.zero_grad()
        
        y = D(x_fake)
        loss_G = criterion(y, real_labels)
        loss_G.backward()
        
        D_G_z2 = y.mean().item()
        
        optimizerG.step()
        
        
        ##############################################################
        # print
        ##############################################################
        
        if it % n_batches_print == n_batches_print-1:
            print(f"[{epoch+1}/{n_epochs}][{it+1}/{len(trainloader)}] loss: {loss_D.item()+loss_G.item():.8} loss_D: {loss_D.item():.8}  loss_G: {loss_G.item():.8} D_x: {D_x:.8} D(G(z1)): {D_G_z1:.8} D(G(z2)): {D_G_z2:.8}" ) 


            x_fake = G(fixed_noise)
            vutils.save_image(x_fake.detach(), os.path.join(figure_root, f'dcgan_fake_samples_epoch={epoch}_it={it}.png'))
    
    # checkpointing
    torch.save(G.state_dict(), os.path.join(model_root, f'G_epoch_{epoch}.pt'))
    torch.save(D.state_dict(), os.path.join(model_root, f'D_epoch_{epoch}.pt'))

In [20]:
! git add -A

On branch master
Your branch is up to date with 'origin/master'.

Changes to be committed:
  (use "git reset HEAD <file>..." to unstage)

	[33mnew file:   ../beta-vae/beta-vae.py[m
	[33mrenamed:    ../vae/README.md -> ../cvae/README.md[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=0.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=1.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=10.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=100.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=101.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=102.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=103.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=104.png[m
	[33mnew file:   ../cvae/figures/decode_along_a_lattice_cvae_c=0_epochs=10

	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=144.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=145.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=146.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=147.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=148.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=149.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=15.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=150.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=151.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=152.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=153.png[m
	[33mnew file:   ../cvae/figures/latent_sample_decoded_cvae_epochs=154.png[m
	[33mnew file:   ../cvae/figures/latent_

In [None]:
from IPython.display import Image
Image(filename='../figures/dcgan_fake_samples_epoch=6_it=599.png') 