## Things to Try
- Apply dropout to generator
- Use labels
- Add noise to inputs, decay over time
- Make labels noisy by sometimes flipping them

In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
image_size = 64
num_epochs = 250
batch_size = 128

In [3]:
transform = transforms.Compose([
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), 
                                                     (0.5, 0.5, 0.5)),])
dataset = dset.ImageFolder("/home/tyler/data/image/small_art/", 
                               transform)
# dataset = dset.CIFAR10(root = '/home/tyler/data/image', 
#                        download = True, transform = transform)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size = batch_size, 
                                         shuffle = True, 
                                         num_workers = 2)

In [4]:
## weight initalization for network
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

### Generator

Produces an image

Q: why the convtranspose2d numbers? A: they are the inverse of discriminator

Q: Why first convtranspose2d in channels is 100? This is just a hyper-parameter. You can choose how many channels you want to randomly generate for the input to your generator.

In [5]:
class Generator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(Generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(100, d*8, 4, 1, 0, bias=False)
        self.deconv1_bn = nn.BatchNorm2d(d*8)
        self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1, bias=False)
        self.deconv2_bn = nn.BatchNorm2d(d*4)
        self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1, bias=False)
        self.deconv3_bn = nn.BatchNorm2d(d*2)
        self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1, bias=False)
        self.deconv4_bn = nn.BatchNorm2d(d)
        self.deconv5 = nn.ConvTranspose2d(d, 3, 4, 2, 1, bias=False)

    # forward method
    def forward(self, input):
        # x = F.relu(self.deconv1(input))
        x = F.leaky_relu(self.deconv1_bn(self.deconv1(input)), 0.2)
        x = F.leaky_relu(self.deconv2_bn(self.deconv2(x)), 0.2)
        x = F.leaky_relu(self.deconv3_bn(self.deconv3(x)), 0.2)
        x = F.leaky_relu(self.deconv4_bn(self.deconv4(x)), 0.2)
        x = F.tanh(self.deconv5(x))

        return x

## Discriminator

Just a normal conv net to tell if an image is fake or not.

In [6]:
class Discriminator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, d, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1, bias=False)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1, bias=False)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1, bias=False)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0, bias=False)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.conv1(input), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = F.sigmoid(self.conv5(x))

        return x.view(-1)

In [7]:
netG = Generator(128).to(device)
netG.apply(weights_init)
netD = Discriminator(128).to(device)
netD.apply(weights_init)

Discriminator(
  (conv1): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
)

In [8]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.00005, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [9]:
total_step = len(dataloader)
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        
        ## train discriminator
        
        netD.zero_grad()
        
        ## calculate error using real image 
        real, _ = data
        input = Variable(real).to(device)
        # target is 1 b/c real image
        target = Variable(torch.empty(input.size()[0]).uniform_(0.7, 1.2) ).to(device)
        output = netD(input)
        real_score = output
        errD_real = criterion(output, target)
        
        ## calculate error using fake image
        ## first generate an image using generator then discriminate
        ## this is 100 channels, 1x1 random noise that the generate will use
        noise = Variable(torch.randn(input.size()[0], 100, 1, 1)).to(device)
        fake = netG(noise)
        target = Variable(torch.empty(input.size()[0]).uniform_(0.0, 0.3) ).to(device)
        output = netD(fake.detach())
        fake_score = output
        errD_fake = criterion(output, target)
        
        
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()
        
        
        ## train generator
        
        
        ## we want the generator to learn to create realistic images
        ## and thus to produce a 1 from the discriminator
        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0])).to(device)
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()
        
        if (i+1) % 50 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, errD.item(), errG.item(), 
                    real_score.mean().item(), fake_score.mean().item()))
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)

Epoch [0/250], Step [50/118], d_loss: 2.5547, g_loss: 11.1690, D(x): 0.51, D(G(z)): 0.59
Epoch [0/250], Step [100/118], d_loss: 0.7779, g_loss: 3.0336, D(x): 0.87, D(G(z)): 0.06
Epoch [1/250], Step [50/118], d_loss: 1.7790, g_loss: 0.9545, D(x): 0.46, D(G(z)): 0.49
Epoch [1/250], Step [100/118], d_loss: 1.5142, g_loss: 0.9826, D(x): 0.56, D(G(z)): 0.54
Epoch [2/250], Step [50/118], d_loss: 1.3838, g_loss: 0.8287, D(x): 0.58, D(G(z)): 0.54
Epoch [2/250], Step [100/118], d_loss: 1.4727, g_loss: 0.9176, D(x): 0.49, D(G(z)): 0.49
Epoch [3/250], Step [50/118], d_loss: 1.2830, g_loss: 1.0329, D(x): 0.60, D(G(z)): 0.48
Epoch [3/250], Step [100/118], d_loss: 1.4512, g_loss: 1.1504, D(x): 0.60, D(G(z)): 0.57
Epoch [4/250], Step [50/118], d_loss: 1.5745, g_loss: 0.7480, D(x): 0.60, D(G(z)): 0.63
Epoch [4/250], Step [100/118], d_loss: 1.4833, g_loss: 0.6142, D(x): 0.49, D(G(z)): 0.52
Epoch [5/250], Step [50/118], d_loss: 1.4586, g_loss: 0.7014, D(x): 0.53, D(G(z)): 0.54
Epoch [5/250], Step [100/1

Epoch [46/250], Step [50/118], d_loss: 1.4071, g_loss: 0.6052, D(x): 0.54, D(G(z)): 0.53
Epoch [46/250], Step [100/118], d_loss: 1.3634, g_loss: 0.7044, D(x): 0.59, D(G(z)): 0.56
Epoch [47/250], Step [50/118], d_loss: 1.3985, g_loss: 0.5864, D(x): 0.54, D(G(z)): 0.54
Epoch [47/250], Step [100/118], d_loss: 1.4418, g_loss: 0.7136, D(x): 0.56, D(G(z)): 0.58
Epoch [48/250], Step [50/118], d_loss: 1.4124, g_loss: 0.6427, D(x): 0.53, D(G(z)): 0.53
Epoch [48/250], Step [100/118], d_loss: 1.3874, g_loss: 0.6042, D(x): 0.56, D(G(z)): 0.55
Epoch [49/250], Step [50/118], d_loss: 1.3927, g_loss: 0.6192, D(x): 0.57, D(G(z)): 0.57
Epoch [49/250], Step [100/118], d_loss: 1.4000, g_loss: 0.6429, D(x): 0.53, D(G(z)): 0.53
Epoch [50/250], Step [50/118], d_loss: 1.3858, g_loss: 0.6909, D(x): 0.54, D(G(z)): 0.53
Epoch [50/250], Step [100/118], d_loss: 1.3836, g_loss: 0.6311, D(x): 0.55, D(G(z)): 0.55
Epoch [51/250], Step [50/118], d_loss: 1.3964, g_loss: 0.6542, D(x): 0.57, D(G(z)): 0.57
Epoch [51/250], 

Epoch [92/250], Step [50/118], d_loss: 1.3705, g_loss: 0.6670, D(x): 0.56, D(G(z)): 0.54
Epoch [92/250], Step [100/118], d_loss: 1.3730, g_loss: 0.6535, D(x): 0.57, D(G(z)): 0.56
Epoch [93/250], Step [50/118], d_loss: 1.4049, g_loss: 0.5752, D(x): 0.57, D(G(z)): 0.57
Epoch [93/250], Step [100/118], d_loss: 1.3875, g_loss: 0.5970, D(x): 0.59, D(G(z)): 0.58
Epoch [94/250], Step [50/118], d_loss: 1.3619, g_loss: 0.7272, D(x): 0.60, D(G(z)): 0.58
Epoch [94/250], Step [100/118], d_loss: 1.3662, g_loss: 0.6485, D(x): 0.54, D(G(z)): 0.51
Epoch [95/250], Step [50/118], d_loss: 1.3386, g_loss: 0.6564, D(x): 0.59, D(G(z)): 0.56
Epoch [95/250], Step [100/118], d_loss: 1.3352, g_loss: 0.7258, D(x): 0.57, D(G(z)): 0.53
Epoch [96/250], Step [50/118], d_loss: 1.4018, g_loss: 0.6507, D(x): 0.58, D(G(z)): 0.58
Epoch [96/250], Step [100/118], d_loss: 1.3760, g_loss: 0.6718, D(x): 0.56, D(G(z)): 0.54
Epoch [97/250], Step [50/118], d_loss: 1.3836, g_loss: 0.5813, D(x): 0.57, D(G(z)): 0.56
Epoch [97/250], 

Epoch [137/250], Step [100/118], d_loss: 1.3442, g_loss: 0.7990, D(x): 0.57, D(G(z)): 0.53
Epoch [138/250], Step [50/118], d_loss: 1.3037, g_loss: 0.7719, D(x): 0.56, D(G(z)): 0.49
Epoch [138/250], Step [100/118], d_loss: 1.3609, g_loss: 0.8400, D(x): 0.59, D(G(z)): 0.55
Epoch [139/250], Step [50/118], d_loss: 1.2885, g_loss: 0.7651, D(x): 0.56, D(G(z)): 0.48
Epoch [139/250], Step [100/118], d_loss: 1.2790, g_loss: 0.9197, D(x): 0.61, D(G(z)): 0.52
Epoch [140/250], Step [50/118], d_loss: 1.3406, g_loss: 0.9179, D(x): 0.61, D(G(z)): 0.56
Epoch [140/250], Step [100/118], d_loss: 1.3022, g_loss: 0.7431, D(x): 0.56, D(G(z)): 0.48
Epoch [141/250], Step [50/118], d_loss: 1.3078, g_loss: 0.8196, D(x): 0.60, D(G(z)): 0.54
Epoch [141/250], Step [100/118], d_loss: 1.3037, g_loss: 0.7561, D(x): 0.55, D(G(z)): 0.47
Epoch [142/250], Step [50/118], d_loss: 1.2971, g_loss: 0.8145, D(x): 0.54, D(G(z)): 0.45
Epoch [142/250], Step [100/118], d_loss: 1.3255, g_loss: 0.7118, D(x): 0.54, D(G(z)): 0.48
Epoc

Epoch [183/250], Step [50/118], d_loss: 1.0964, g_loss: 0.9808, D(x): 0.62, D(G(z)): 0.35
Epoch [183/250], Step [100/118], d_loss: 1.1940, g_loss: 0.9696, D(x): 0.57, D(G(z)): 0.38
Epoch [184/250], Step [50/118], d_loss: 1.1331, g_loss: 0.7471, D(x): 0.60, D(G(z)): 0.38
Epoch [184/250], Step [100/118], d_loss: 1.2060, g_loss: 1.6716, D(x): 0.86, D(G(z)): 0.62
Epoch [185/250], Step [50/118], d_loss: 1.2288, g_loss: 0.7903, D(x): 0.49, D(G(z)): 0.25
Epoch [185/250], Step [100/118], d_loss: 1.0157, g_loss: 1.0765, D(x): 0.70, D(G(z)): 0.41
Epoch [186/250], Step [50/118], d_loss: 1.0782, g_loss: 1.3733, D(x): 0.75, D(G(z)): 0.48
Epoch [186/250], Step [100/118], d_loss: 1.0382, g_loss: 1.0750, D(x): 0.63, D(G(z)): 0.31
Epoch [187/250], Step [50/118], d_loss: 1.1537, g_loss: 0.7817, D(x): 0.56, D(G(z)): 0.33
Epoch [187/250], Step [100/118], d_loss: 1.2098, g_loss: 1.0656, D(x): 0.72, D(G(z)): 0.56
Epoch [188/250], Step [50/118], d_loss: 1.0667, g_loss: 1.3126, D(x): 0.73, D(G(z)): 0.44
Epoch

Epoch [228/250], Step [100/118], d_loss: 0.7554, g_loss: 1.6018, D(x): 0.86, D(G(z)): 0.33
Epoch [229/250], Step [50/118], d_loss: 0.8692, g_loss: 1.2933, D(x): 0.72, D(G(z)): 0.27
Epoch [229/250], Step [100/118], d_loss: 0.9026, g_loss: 1.4732, D(x): 0.70, D(G(z)): 0.22
Epoch [230/250], Step [50/118], d_loss: 0.8287, g_loss: 1.4666, D(x): 0.76, D(G(z)): 0.29
Epoch [230/250], Step [100/118], d_loss: 0.8941, g_loss: 1.4683, D(x): 0.76, D(G(z)): 0.32
Epoch [231/250], Step [50/118], d_loss: 1.0285, g_loss: 1.9272, D(x): 0.89, D(G(z)): 0.51
Epoch [231/250], Step [100/118], d_loss: 0.9080, g_loss: 1.0737, D(x): 0.71, D(G(z)): 0.26
Epoch [232/250], Step [50/118], d_loss: 0.9766, g_loss: 0.9084, D(x): 0.73, D(G(z)): 0.35
Epoch [232/250], Step [100/118], d_loss: 0.9156, g_loss: 1.4559, D(x): 0.74, D(G(z)): 0.30
Epoch [233/250], Step [50/118], d_loss: 0.8880, g_loss: 1.7807, D(x): 0.79, D(G(z)): 0.35
Epoch [233/250], Step [100/118], d_loss: 0.9838, g_loss: 1.9519, D(x): 0.87, D(G(z)): 0.49
Epoc