Huge thanks to @devnag on github for the pytorch sample code. I've never done GANs in pytorch before (only tensorflow).

refs:

https://github.com/pytorch/examples/blob/master/dcgan/main.py
https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py

https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/deep_convolutional_gan/model.py

In [1]:
import os
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets
from torch.utils import data

In [3]:
# params 
batch_size = 24
lr = 0.00001
image_path = os.path.join('imgs', 'test')
# params generator
noise_dim = 100
g_filter_depth = 64
g_kernel_size = 4
g_stride = 2
g_padding = 1
# params discriminator
d_filter_depth_in = 1
d_filter_depth = 64
d_kernel_size = 4
d_stride = 2
d_padding = 1

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(noise_dim, 
                              g_filter_depth*8, 
                              kernel_size=4, 
                              stride=1, 
                              padding=0,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth*8),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth*8, 
                              g_filter_depth*4, 
                              kernel_size=g_kernel_size, 
                              stride=2,
                              padding=1,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth*4),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth*4, 
                              g_filter_depth*2, 
                              kernel_size=g_kernel_size, 
                              stride=2, 
                              padding=1,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth*2),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth*2, 
                              g_filter_depth, 
                              kernel_size=g_kernel_size, 
                              stride=2, 
                              padding=1,
                              bias=False),
            nn.BatchNorm2d(g_filter_depth),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(g_filter_depth, 
                               d_filter_depth_in,
                               kernel_size=g_kernel_size,
                               stride=2,
                               padding=1,
                               bias=False),
            nn.Tanh()
        )
        
    def forward(self, inputs):
        inputs = inputs.view(inputs.size(0), inputs.size(1), 1, 1)
        output = self.main(inputs)
        return output
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=d_filter_depth_in, 
                      out_channels=d_filter_depth, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth, 
                      out_channels=d_filter_depth*2, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(d_filter_depth*2),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth*2, 
                      out_channels=d_filter_depth*4, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(d_filter_depth*4),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth*4, 
                      out_channels=d_filter_depth*8, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(d_filter_depth*8),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(in_channels=d_filter_depth*8, 
                      out_channels=1, 
                      kernel_size=4, 
                      stride=2,
                      padding=1,
                      bias=False),
            nn.Sigmoid()
        )
    def forward(self, inputs):
        output = self.main(inputs)
        return output

In [5]:
def weights_init(m):
    '''
    initialize weights for a layer with
    the right initialization
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0,0.02)
        m.bias.data.fill_(0)
        
def to_variable(x):
    '''
    convert a tensor to a variable
    with gradient tracking
    '''
    if torch.cuda.is_available():
        x = x .cuda()
    return Variable(x)

def denorm(x):
    '''
    convert images to be (0, 1)
    '''
    renorm = (x+1)/2
    return renorm.clamp(0,1)

In [10]:
n_epochs = 1

monster_transform = transforms.Compose([
         transforms.Scale(64), 
         transforms.ToTensor(), 
         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        ])
mnist_transform = transforms.Compose([
        transforms.Scale(64),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

# load images
monster_dataset = ImageFolder(image_path, monster_transform)

mnist_dataset = datasets.MNIST('imgs/mnist', download=True, transform=mnist_transform)

# load the data
monster_loader = data.DataLoader(dataset=monster_dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)

mnist_loader = data.DataLoader(dataset=mnist_dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)

In [55]:
single_mnist_batch = []
for x in mnist_loader:
    print(x[0].shape)
    single_mnist_batch.append(x[0])
    break

torch.Size([24, 1, 64, 64])


In [56]:
single_mnist_batch = torch.stack(single_mnist_batch)
single_mnist_batch.shape

torch.Size([1, 24, 1, 64, 64])

In [57]:
generator = Generator()
generator = generator.apply(weights_init)
discriminator = Discriminator()
discriminator = discriminator.apply(weights_init)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = optim.Adam(generator.parameters(), lr=lr)

if torch.cuda.is_available():
    generator.cuda()
    discriminator.cuda()
    
data_loader = mnist_loader

In [62]:
for epoch in range(n_epochs):
    for i, images in enumerate(mnist_loader):
        images = images[0]
        # --- train discriminator ---
        real_images = to_variable(images)
        batch_size = real_images.size(0)
        outputs = discriminator(real_images)
        real_loss = torch.mean((outputs - 1)**2)
        
        # make fake images from generator and
        # see how much probability in excess of
        # 0 the discriminator gives each one to being
        # real
        noise = to_variable(torch.randn(batch_size, noise_dim))
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        fake_loss = torch.mean(outputs**2)
        
        # backpropagate the loss from both
        # real and fake images for discriminator
        total_loss = real_loss + fake_loss
        discriminator.zero_grad()
        generator.zero_grad()
        total_loss.backward()
        d_optimizer.step()
        
        # --- train generator ---
        noise = to_variable(torch.randn(batch_size, noise_dim))
        
        # get the generator loss by seeing
        # how close to being 1.0 (positive label)
        # each fake image was throughthe discriminator
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        gen_loss = torch.mean((outputs-1)**2)
        
        # backpropagate and update 
        # generator
        
        generator.zero_grad()
        discriminator.zero_grad()
        gen_loss.backward()
        g_optimizer.step()   
        
        if not i % 20:
            print('image {}'.format(i))
            
        if i == 100:
            break
        
    if not epoch % 5:
        print('epoch {}'.format(epoch))
        print('generator loss: {}, discriminator loss: {}'.format(gen_loss.data[0], total_loss.data[0]))

image 0
image 20


Process Process-16:
Traceback (most recent call last):
  File "/Users/yvanscher/anaconda/envs/gen_monsters/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/yvanscher/anaconda/envs/gen_monsters/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/yvanscher/anaconda/envs/gen_monsters/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 34, in _worker_loop
    r = index_queue.get()
  File "/Users/yvanscher/anaconda/envs/gen_monsters/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/Users/yvanscher/anaconda/envs/gen_monsters/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/Users/yvanscher/anaconda/envs/gen_monsters/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/Users/yvanscher/anaconda/en

KeyboardInterrupt: 

In [63]:
noise = to_variable(torch.randn(batch_size, noise_dim))
fake_img = generator(noise)

In [64]:
# img = denorm(fake_img)
result = transforms.Compose([
    transforms.ToPILImage()])(fake_img.data[0])

In [65]:
result.show()