## Things to Try
- Apply dropout to generator
- Use labels
- Add noise to inputs, decay over time
- Make labels noisy by sometimes flipping them

In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
image_size = 64
num_epochs = 250
batch_size = 128

In [3]:
transform = transforms.Compose([transforms.Resize((image_size, image_size)), 
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), 
                                                     (0.5, 0.5, 0.5)),])
dataset = dset.ImageFolder("/home/tyler/data/image/small_images", 
                               transform)
# dataset = dset.CIFAR10(root = '/home/tyler/data/image', 
#                        download = True, transform = transform)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size = batch_size, 
                                         shuffle = True, 
                                         num_workers = 2)

In [4]:
## weight initalization for network
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

### Generator

Produces an image

Q: why the convtranspose2d numbers? A: they are the inverse of discriminator

Q: Why first convtranspose2d in channels is 100? This is just a hyper-parameter. You can choose how many channels you want to randomly generate for the input to your generator.

In [5]:
class Generator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(Generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(100, d*8, 4, 1, 0, bias=False)
        self.deconv1_bn = nn.BatchNorm2d(d*8)
        self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1, bias=False)
        self.deconv2_bn = nn.BatchNorm2d(d*4)
        self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1, bias=False)
        self.deconv3_bn = nn.BatchNorm2d(d*2)
        self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1, bias=False)
        self.deconv4_bn = nn.BatchNorm2d(d)
        self.deconv5 = nn.ConvTranspose2d(d, 3, 4, 2, 1, bias=False)

    # forward method
    def forward(self, input):
        # x = F.relu(self.deconv1(input))
        x = F.leaky_relu(self.deconv1_bn(self.deconv1(input)), 0.2)
        x = F.leaky_relu(self.deconv2_bn(self.deconv2(x)), 0.2)
        x = F.leaky_relu(self.deconv3_bn(self.deconv3(x)), 0.2)
        x = F.leaky_relu(self.deconv4_bn(self.deconv4(x)), 0.2)
        x = F.tanh(self.deconv5(x))

        return x

## Discriminator

Just a normal conv net to tell if an image is fake or not.

In [6]:
class Discriminator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, d, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1, bias=False)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1, bias=False)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1, bias=False)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0, bias=False)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.conv1(input), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = F.sigmoid(self.conv5(x))

        return x.view(-1)

In [7]:
netG = Generator(128).to(device)
netG.apply(weights_init)
netD = Discriminator(128).to(device)
netD.apply(weights_init)

Discriminator(
  (conv1): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
)

In [8]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.00005, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [9]:
total_step = len(dataloader)
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        
        ## train discriminator
        
        netD.zero_grad()
        
        ## calculate error using real image 
        real, _ = data
        input = Variable(real).to(device)
        # target is 1 b/c real image
        target = Variable(torch.empty(input.size()[0]).uniform_(0.7, 1.2) ).to(device)
        output = netD(input)
        real_score = output
        errD_real = criterion(output, target)
        
        ## calculate error using fake image
        ## first generate an image using generator then discriminate
        ## this is 100 channels, 1x1 random noise that the generate will use
        noise = Variable(torch.randn(input.size()[0], 100, 1, 1)).to(device)
        fake = netG(noise)
        target = Variable(torch.empty(input.size()[0]).uniform_(0.0, 0.3) ).to(device)
        output = netD(fake.detach())
        fake_score = output
        errD_fake = criterion(output, target)
        
        
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()
        
        
        ## train generator
        
        
        ## we want the generator to learn to create realistic images
        ## and thus to produce a 1 from the discriminator
        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0])).to(device)
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()
        
        if (i+1) % 50 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, errD.item(), errG.item(), 
                    real_score.mean().item(), fake_score.mean().item()))
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)

Epoch [0/250], Step [50/103], d_loss: 2.2717, g_loss: 10.5983, D(x): 0.53, D(G(z)): 0.67
Epoch [0/250], Step [100/103], d_loss: 1.8072, g_loss: 7.3007, D(x): 0.64, D(G(z)): 0.60
Epoch [1/250], Step [50/103], d_loss: 1.8046, g_loss: 1.1761, D(x): 0.47, D(G(z)): 0.55
Epoch [1/250], Step [100/103], d_loss: 1.6038, g_loss: 1.1792, D(x): 0.55, D(G(z)): 0.57
Epoch [2/250], Step [50/103], d_loss: 1.3698, g_loss: 1.0881, D(x): 0.58, D(G(z)): 0.50
Epoch [2/250], Step [100/103], d_loss: 1.5638, g_loss: 1.1816, D(x): 0.58, D(G(z)): 0.59
Epoch [3/250], Step [50/103], d_loss: 1.3466, g_loss: 0.9192, D(x): 0.64, D(G(z)): 0.57
Epoch [3/250], Step [100/103], d_loss: 1.0820, g_loss: 1.2337, D(x): 0.64, D(G(z)): 0.33
Epoch [4/250], Step [50/103], d_loss: 1.6869, g_loss: 0.7402, D(x): 0.47, D(G(z)): 0.53
Epoch [4/250], Step [100/103], d_loss: 1.3235, g_loss: 0.8120, D(x): 0.57, D(G(z)): 0.49
Epoch [5/250], Step [50/103], d_loss: 1.4563, g_loss: 0.8139, D(x): 0.60, D(G(z)): 0.59
Epoch [5/250], Step [100/1

Epoch [46/250], Step [50/103], d_loss: 1.3315, g_loss: 0.6828, D(x): 0.54, D(G(z)): 0.48
Epoch [46/250], Step [100/103], d_loss: 1.3707, g_loss: 0.7663, D(x): 0.63, D(G(z)): 0.58
Epoch [47/250], Step [50/103], d_loss: 1.3920, g_loss: 0.8246, D(x): 0.59, D(G(z)): 0.55
Epoch [47/250], Step [100/103], d_loss: 1.3266, g_loss: 0.7143, D(x): 0.54, D(G(z)): 0.47
Epoch [48/250], Step [50/103], d_loss: 1.3459, g_loss: 0.9104, D(x): 0.66, D(G(z)): 0.58
Epoch [48/250], Step [100/103], d_loss: 1.3990, g_loss: 0.6400, D(x): 0.50, D(G(z)): 0.49
Epoch [49/250], Step [50/103], d_loss: 1.3997, g_loss: 0.7791, D(x): 0.54, D(G(z)): 0.52
Epoch [49/250], Step [100/103], d_loss: 1.2957, g_loss: 0.9525, D(x): 0.64, D(G(z)): 0.53
Epoch [50/250], Step [50/103], d_loss: 1.3006, g_loss: 0.8646, D(x): 0.61, D(G(z)): 0.54
Epoch [50/250], Step [100/103], d_loss: 1.4118, g_loss: 0.7364, D(x): 0.59, D(G(z)): 0.57
Epoch [51/250], Step [50/103], d_loss: 1.3412, g_loss: 1.1962, D(x): 0.65, D(G(z)): 0.59
Epoch [51/250], 

Epoch [92/250], Step [50/103], d_loss: 1.1835, g_loss: 1.5030, D(x): 0.69, D(G(z)): 0.44
Epoch [92/250], Step [100/103], d_loss: 1.1903, g_loss: 1.1642, D(x): 0.59, D(G(z)): 0.33
Epoch [93/250], Step [50/103], d_loss: 0.9677, g_loss: 1.7912, D(x): 0.82, D(G(z)): 0.47
Epoch [93/250], Step [100/103], d_loss: 1.0473, g_loss: 0.9233, D(x): 0.58, D(G(z)): 0.25
Epoch [94/250], Step [50/103], d_loss: 0.8280, g_loss: 2.1817, D(x): 0.86, D(G(z)): 0.37
Epoch [94/250], Step [100/103], d_loss: 0.8598, g_loss: 1.1367, D(x): 0.73, D(G(z)): 0.25
Epoch [95/250], Step [50/103], d_loss: 1.3284, g_loss: 2.3993, D(x): 0.92, D(G(z)): 0.66
Epoch [95/250], Step [100/103], d_loss: 1.1138, g_loss: 0.9476, D(x): 0.54, D(G(z)): 0.18
Epoch [96/250], Step [50/103], d_loss: 0.9065, g_loss: 1.6704, D(x): 0.77, D(G(z)): 0.36
Epoch [96/250], Step [100/103], d_loss: 0.8321, g_loss: 1.5341, D(x): 0.79, D(G(z)): 0.33
Epoch [97/250], Step [50/103], d_loss: 1.0027, g_loss: 0.9203, D(x): 0.74, D(G(z)): 0.41
Epoch [97/250], 

Epoch [137/250], Step [100/103], d_loss: 0.7341, g_loss: 1.8399, D(x): 0.84, D(G(z)): 0.19
Epoch [138/250], Step [50/103], d_loss: 1.0374, g_loss: 3.5274, D(x): 0.88, D(G(z)): 0.52
Epoch [138/250], Step [100/103], d_loss: 0.7052, g_loss: 2.0994, D(x): 0.91, D(G(z)): 0.17
Epoch [139/250], Step [50/103], d_loss: 0.8347, g_loss: 1.4817, D(x): 0.76, D(G(z)): 0.24
Epoch [139/250], Step [100/103], d_loss: 1.7221, g_loss: 5.3928, D(x): 0.91, D(G(z)): 0.77
Epoch [140/250], Step [50/103], d_loss: 0.7581, g_loss: 2.0598, D(x): 0.95, D(G(z)): 0.37
Epoch [140/250], Step [100/103], d_loss: 0.9801, g_loss: 2.9377, D(x): 0.95, D(G(z)): 0.51
Epoch [141/250], Step [50/103], d_loss: 0.8141, g_loss: 2.1636, D(x): 0.93, D(G(z)): 0.38
Epoch [141/250], Step [100/103], d_loss: 1.8282, g_loss: 4.3393, D(x): 0.95, D(G(z)): 0.80
Epoch [142/250], Step [50/103], d_loss: 0.7292, g_loss: 2.2139, D(x): 0.92, D(G(z)): 0.32
Epoch [142/250], Step [100/103], d_loss: 0.9902, g_loss: 3.0051, D(x): 0.89, D(G(z)): 0.50
Epoc

Epoch [183/250], Step [50/103], d_loss: 0.7333, g_loss: 1.8212, D(x): 0.88, D(G(z)): 0.16
Epoch [183/250], Step [100/103], d_loss: 0.6767, g_loss: 2.4431, D(x): 0.91, D(G(z)): 0.15
Epoch [184/250], Step [50/103], d_loss: 0.6862, g_loss: 2.6230, D(x): 0.97, D(G(z)): 0.20
Epoch [184/250], Step [100/103], d_loss: 0.7598, g_loss: 2.5133, D(x): 0.94, D(G(z)): 0.24
Epoch [185/250], Step [50/103], d_loss: 0.5870, g_loss: 2.1998, D(x): 0.93, D(G(z)): 0.18
Epoch [185/250], Step [100/103], d_loss: 0.8082, g_loss: 2.4593, D(x): 0.95, D(G(z)): 0.38
Epoch [186/250], Step [50/103], d_loss: 0.7490, g_loss: 1.2687, D(x): 0.80, D(G(z)): 0.17
Epoch [186/250], Step [100/103], d_loss: 0.7425, g_loss: 1.4140, D(x): 0.84, D(G(z)): 0.13
Epoch [187/250], Step [50/103], d_loss: 0.8225, g_loss: 2.2897, D(x): 0.85, D(G(z)): 0.25
Epoch [187/250], Step [100/103], d_loss: 0.6962, g_loss: 2.0851, D(x): 0.94, D(G(z)): 0.21
Epoch [188/250], Step [50/103], d_loss: 0.6250, g_loss: 2.2802, D(x): 0.96, D(G(z)): 0.20
Epoch

Epoch [228/250], Step [100/103], d_loss: 0.7041, g_loss: 1.8802, D(x): 0.86, D(G(z)): 0.21
Epoch [229/250], Step [50/103], d_loss: 0.6908, g_loss: 1.6021, D(x): 0.88, D(G(z)): 0.17
Epoch [229/250], Step [100/103], d_loss: 0.8375, g_loss: 2.7461, D(x): 0.96, D(G(z)): 0.34
Epoch [230/250], Step [50/103], d_loss: 0.7279, g_loss: 1.6428, D(x): 0.85, D(G(z)): 0.18
Epoch [230/250], Step [100/103], d_loss: 0.7189, g_loss: 2.0260, D(x): 0.87, D(G(z)): 0.10
Epoch [231/250], Step [50/103], d_loss: 0.9417, g_loss: 3.3204, D(x): 0.98, D(G(z)): 0.50
Epoch [231/250], Step [100/103], d_loss: 0.6962, g_loss: 2.0346, D(x): 0.95, D(G(z)): 0.30
Epoch [232/250], Step [50/103], d_loss: 0.6306, g_loss: 1.3847, D(x): 0.91, D(G(z)): 0.13
Epoch [232/250], Step [100/103], d_loss: 0.6695, g_loss: 1.6490, D(x): 0.90, D(G(z)): 0.20
Epoch [233/250], Step [50/103], d_loss: 0.6862, g_loss: 2.7590, D(x): 0.95, D(G(z)): 0.10
Epoch [233/250], Step [100/103], d_loss: 0.6930, g_loss: 2.2079, D(x): 0.97, D(G(z)): 0.16
Epoc