Huge thanks to @devnag on github for the pytorch sample code. I've never done GANs in pytorch before (only tensorflow).

refs:

https://github.com/pytorch/examples/blob/master/dcgan/main.py
https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py

https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/deep_convolutional_gan/model.py

In [1]:
import os
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils import data

In [3]:
# params 
batch_size = 6
lr = 0.00001
image_path = os.path.join('imgs', 'test')
# params generator
noise_dim = 100
g_filter_depth = 64
g_kernel_size = 4
g_stride = 2
g_padding = 1
# params discriminator
d_filter_depth_in = 3
d_filter_depth = 64
d_kernel_size = 4
d_stride = 2
d_padding = 1

In [14]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(noise_dim, 
                              g_filter_depth*8, 
                              kernel_size=4, 
                              stride=1, 
                              padding=0,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth*8),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth*8, 
                              g_filter_depth*4, 
                              kernel_size=g_kernel_size, 
                              stride=2,
                              padding=1,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth*4),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth*4, 
                              g_filter_depth*2, 
                              kernel_size=g_kernel_size, 
                              stride=2, 
                              padding=1,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth*2),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth*2, 
                              g_filter_depth, 
                              kernel_size=g_kernel_size, 
                              stride=2, 
                              padding=1,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth, 
                               3,
                               kernel_size=g_kernel_size,
                               stride=2,
                               padding=1,
                               bias=False),
            nn.Tanh()
        )
        
    def forward(self, inputs):
        inputs = inputs.view(inputs.size(0), inputs.size(1), 1, 1)
        output = self.main(inputs)
        return output
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=d_filter_depth_in, 
                      out_channels=d_filter_depth, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth, 
                      out_channels=d_filter_depth*2, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(d_filter_depth*2),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth*2, 
                      out_channels=d_filter_depth*4, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(d_filter_depth*4),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth*4, 
                      out_channels=d_filter_depth*8, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(d_filter_depth*8),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth*8, 
                      out_channels=1, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.Sigmoid()
        )
    def forward(self, inputs):
        output = self.main(inputs)
        return output

In [15]:
def weights_init(m):
    '''
    initialize weights for a layer with
    the right initialization
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0,0.02)
        m.bias.data.fill_(0)
        
def to_variable(x):
    '''
    convert a tensor to a variable
    with gradient tracking
    '''
    if torch.cuda.is_available():
        x = x .cuda()
    return Variable(x)

def denorm(x):
    '''
    convert images to be (0, 1)
    '''
    renorm = (x+1)/2
    return renorm.clamp(0,1)

In [23]:
n_epochs = 50

# declare transformation to
# apply to image data for GAN
transform = transforms.Compose([
         transforms.Scale(64), 
         transforms.ToTensor(), 
         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        ])

# load images
dataset = ImageFolder(image_path, transform)

# load the data
data_loader = data.DataLoader(dataset=dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)

In [24]:
generator = Generator()
generator = generator.apply(weights_init)
discriminator = Discriminator()
discriminator = discriminator.apply(weights_init)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = optim.Adam(generator.parameters(), lr=lr)

if torch.cuda.is_available():
    generator.cuda()
    discriminator.cuda()

In [25]:
for epoch in range(n_epochs):
    for i, images in enumerate(data_loader):
        images = images[0]
        # --- train discriminator ---
        real_images = to_variable(images)
        batch_size = real_images.size(0)
        outputs = discriminator(real_images)
        real_loss = torch.mean((outputs - 1)**2)

        # make fake images from generator and
        # see how much probability in excess of
        # 0 the discriminator gives each one to being
        # real
        noise = to_variable(torch.randn(batch_size, noise_dim))
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        fake_loss = torch.mean(outputs**2)
        
        # backpropagate the loss from both
        # real and fake images for discriminator
        total_loss = real_loss + fake_loss
        discriminator.zero_grad()
        generator.zero_grad()
        total_loss.backward()
        d_optimizer.step()
        
        # --- train generator ---
        noise = to_variable(torch.randn(batch_size, noise_dim))
        
        # get the generator loss by seeing
        # how close to being 1.0 (positive label)
        # each fake image was throughthe discriminator
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        gen_loss = torch.mean((outputs-1)**2)
        
        # backpropagate and update 
        # generator
        
        generator.zero_grad()
        discriminator.zero_grad()
        gen_loss.backward()
        g_optimizer.step()   
        
    if not epoch % 5:
        print('epoch {}'.format(epoch))
        print('generator loss: {}, discriminator loss: {}'.format(gen_loss.data[0], total_loss.data[0]))

epoch 0
generator loss: 0.3930845260620117, discriminator loss: 0.4526112675666809
epoch 5
generator loss: 0.4451755881309509, discriminator loss: 0.33018141984939575
epoch 10
generator loss: 0.4393256604671478, discriminator loss: 0.1418849676847458
epoch 15
generator loss: 0.6556639671325684, discriminator loss: 0.06843001395463943
epoch 20
generator loss: 0.7342545390129089, discriminator loss: 0.08750296384096146
epoch 25
generator loss: 0.7211910486221313, discriminator loss: 0.030159831047058105
epoch 30
generator loss: 0.6499167680740356, discriminator loss: 0.05949952080845833
epoch 35
generator loss: 0.670330822467804, discriminator loss: 0.05165122449398041
epoch 40
generator loss: 0.737967848777771, discriminator loss: 0.030380435287952423
epoch 45
generator loss: 0.6957354545593262, discriminator loss: 0.03537507727742195


In [30]:
noise = to_variable(torch.randn(batch_size, noise_dim))
fake_img = generator(noise)

In [31]:
# img = denorm(fake_img)
result = transforms.Compose([
    transforms.ToPILImage()])(fake_img.data[0])

In [32]:
result.show()