In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
image_size = 64
num_epochs = 20
batch_size = 128

In [3]:
transform = transforms.Compose([transforms.Resize((image_size, image_size)), 
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), 
                                                     (0.5, 0.5, 0.5)),])
dataset = dset.ImageFolder("/home/tyler/data/image/celebs/", 
                               transform)
#dataset = dset.CIFAR10(root = '/home/tyler/data/image', 
#                       download = True, 
#                       transform = transform)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size = batch_size, 
                                         shuffle = True, 
                                         num_workers = 2)

In [4]:
## weight initalization for network

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

### Generator

Produces an image

Q: why the convtranspose2d numbers? A: they are the inverse of discriminator

Q: Why first convtranspose2d in channels is 100? This is just a hyper-parameter. You can choose how many channels you want to randomly generate for the input to your generator.

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            ## in channels, out channels, kernel size, stride, padding 
            nn.ConvTranspose2d(100, 1024, 4, 1, 0),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh())
    def forward(self, input):
        return self.main(input)

## Discriminator

Just a normal conv net to tell if an image is fake or not.

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(512, 1024, 4, 2, 1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(1024, 1, 4, 1, 0),
            nn.Sigmoid())
    def forward(self, input):
        return self.main(input).view(-1)

In [11]:
netG = Generator().to(device)
netG.apply(weights_init)
netD = Discriminator().to(device)
netD.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)

In [12]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [None]:
total_step = len(dataloader)
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        
        ## train discriminator
        
        netD.zero_grad()
        
        ## calculate error using real image 
        real, _ = data
        input = Variable(real).to(device)
        # target is 1 b/c real image
        target = Variable(torch.ones(input.size()[0])).to(device)
        output = netD(input)
        real_score = output
        errD_real = criterion(output, target)
        
        ## calculate error using fake image
        ## first generate an image using generator then discriminate
        ## this is 100 channels, 1x1 random noise that the generate will use
        noise = Variable(torch.randn(input.size()[0], 100, 1, 1)).to(device)
        fake = netG(noise)
        target = Variable(torch.zeros(input.size()[0])).to(device)
        output = netD(fake.detach())
        fake_score = output
        errD_fake = criterion(output, target)
        
        
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()
        
        
        ## train generator
        
        
        ## we want the generator to learn to create realistic images
        ## and thus to produce a 1 from the discriminator
        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0])).to(device)
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, errD.item(), errG.item(), 
                    real_score.mean().item(), fake_score.mean().item()))
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)

Epoch [0/20], Step [200/1583], d_loss: 0.0000, g_loss: 27.6310, D(x): 1.00, D(G(z)): 0.00
