In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
image_size = 64
num_epochs = 500
batch_size = 100

In [3]:
transform = transforms.Compose([transforms.Resize(image_size), 
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), 
                                                     (0.5, 0.5, 0.5)),])

dataset = dset.CIFAR10(root = '/home/tyler/data/image', 
                       download = True, 
                       transform = transform)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size = batch_size, 
                                         shuffle = True, 
                                         num_workers = 2)

Files already downloaded and verified


In [4]:
## weight initalization for network

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

### Generator

Produces an image

Q: why the convtranspose2d numbers? A: they are the inverse of discriminator

Q: Why first convtranspose2d in channels is 100? This is just a hyper-parameter. You can choose how many channels you want to randomly generate for the input to your generator.

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            ## in channels, out channels, kernel size, stride, padding 
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
            nn.Tanh())
    def forward(self, input):
        return self.main(input)

## Discriminator

Just a normal conv net to tell if an image is fake or not.

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(256, 512, 4, 2, 1, bias = False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(512, 1, 4, 1, 0, bias = False),
            nn.Sigmoid())
    def forward(self, input):
        return self.main(input).view(-1)

In [7]:
netG = Generator().to(device)
netG.apply(weights_init)
netD = Discriminator().to(device)
netD.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

In [8]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [9]:
total_step = len(dataloader)
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        
        ## train discriminator
        
        netD.zero_grad()
        
        ## calculate error using real image 
        real, _ = data
        input = Variable(real).to(device)
        # target is 1 b/c real image
        target = Variable(torch.ones(input.size()[0])).to(device)
        output = netD(input)
        real_score = output
        errD_real = criterion(output, target)
        
        ## calculate error using fake image
        ## first generate an image using generator then discriminate
        ## this is 100 channels, 1x1 random noise that the generate will use
        noise = Variable(torch.randn(input.size()[0], 100, 1, 1)).to(device)
        fake = netG(noise)
        target = Variable(torch.zeros(input.size()[0])).to(device)
        output = netD(fake.detach())
        fake_score = output
        errD_fake = criterion(output, target)
        
        
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()
        
        
        ## train generator
        
        
        ## we want the generator to learn to create realistic images
        ## and thus to produce a 1 from the discriminator
        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0])).to(device)
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, errD.item(), errG.item(), 
                    real_score.mean().item(), fake_score.mean().item()))
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)

Epoch [0/500], Step [200/500], d_loss: 0.2599, g_loss: 8.9420, D(x): 0.81, D(G(z)): 0.00
Epoch [0/500], Step [400/500], d_loss: 0.4362, g_loss: 4.9172, D(x): 0.71, D(G(z)): 0.00
Epoch [1/500], Step [200/500], d_loss: 0.3738, g_loss: 3.3402, D(x): 0.86, D(G(z)): 0.18
Epoch [1/500], Step [400/500], d_loss: 0.3290, g_loss: 3.1749, D(x): 0.79, D(G(z)): 0.06
Epoch [2/500], Step [200/500], d_loss: 1.9517, g_loss: 8.1983, D(x): 0.99, D(G(z)): 0.79
Epoch [2/500], Step [400/500], d_loss: 0.8382, g_loss: 4.5285, D(x): 0.94, D(G(z)): 0.46
Epoch [3/500], Step [200/500], d_loss: 0.9334, g_loss: 1.4541, D(x): 0.56, D(G(z)): 0.16
Epoch [3/500], Step [400/500], d_loss: 0.4521, g_loss: 5.1011, D(x): 0.90, D(G(z)): 0.27
Epoch [4/500], Step [200/500], d_loss: 0.3517, g_loss: 3.4442, D(x): 0.90, D(G(z)): 0.21
Epoch [4/500], Step [400/500], d_loss: 0.7710, g_loss: 2.4556, D(x): 0.75, D(G(z)): 0.32
Epoch [5/500], Step [200/500], d_loss: 0.7425, g_loss: 1.7695, D(x): 0.63, D(G(z)): 0.19
Epoch [5/500], Step [

Epoch [46/500], Step [200/500], d_loss: 0.1819, g_loss: 3.8989, D(x): 0.97, D(G(z)): 0.13
Epoch [46/500], Step [400/500], d_loss: 0.1324, g_loss: 3.8826, D(x): 0.92, D(G(z)): 0.04
Epoch [47/500], Step [200/500], d_loss: 0.0383, g_loss: 4.5517, D(x): 0.98, D(G(z)): 0.02
Epoch [47/500], Step [400/500], d_loss: 0.1437, g_loss: 3.9682, D(x): 0.93, D(G(z)): 0.06
Epoch [48/500], Step [200/500], d_loss: 0.9022, g_loss: 2.6150, D(x): 0.48, D(G(z)): 0.02
Epoch [48/500], Step [400/500], d_loss: 0.5232, g_loss: 2.2944, D(x): 0.73, D(G(z)): 0.12
Epoch [49/500], Step [200/500], d_loss: 0.4064, g_loss: 2.3282, D(x): 0.77, D(G(z)): 0.09
Epoch [49/500], Step [400/500], d_loss: 0.0761, g_loss: 5.9181, D(x): 0.99, D(G(z)): 0.06
Epoch [50/500], Step [200/500], d_loss: 0.0156, g_loss: 5.6157, D(x): 1.00, D(G(z)): 0.01
Epoch [50/500], Step [400/500], d_loss: 0.0083, g_loss: 6.5039, D(x): 1.00, D(G(z)): 0.01
Epoch [51/500], Step [200/500], d_loss: 0.0030, g_loss: 7.4506, D(x): 1.00, D(G(z)): 0.00
Epoch [51/

Process Process-130:
Process Process-129:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/tyler/anaconda3/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/tyler/anaconda3/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/tyler/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tyler/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tyler/anaconda3/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/transforms/transforms.py", line 61, in __call__
    return F.to_tensor(pic)
  File "/home/tyler/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/home/tyler/anaconda3/lib/python3

KeyboardInterrupt: 