In [None]:
#This is adapted from the pytorch blog/tutorial here: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# a bunch of pictures of celebrities

image_size = 64
celebs = dset.CelebA("./celeba", download=True, 
                            transform=transforms.Compose([ #many transforms
                               transforms.Resize(image_size),  #resize 
                               transforms.CenterCrop(image_size), #crop to the middle
                               transforms.ToTensor(), #convert from PIL format to a tensor 
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), #scale things nicely
                           ]))
print(celebs)

In [None]:
dataloader = torch.utils.data.DataLoader(celebs, batch_size=64,shuffle=True, num_workers=0)

In [None]:
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:64], padding=2, normalize=True),(1,2,0)))

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Generator Code
# Number of channels in the training images. For color images this is 3
nc = 3 # RGB
# Size of z latent vector (i.e. size of generator input)
nz = 100 
# Size of feature maps in generator
ngf = 64

#this network takes in a 100-dimensional vector and turns it into a 64x64 color image!!
# Each layer basically produces an "image" with more "pixels", but fewer data channels per pixel
# we start with a "1x1 image" with "100 channels"
# and eventually end up with an 8x8 image with 3 channels
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential( #another way of specifying the layers in order
            # input is Z, going into a convolution
            #takes a 100 channel input, produces a 64*8-channel 4x4 image
            # kernel size is 4x4
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), #basically an upsampling convolution layer
            nn.BatchNorm2d(ngf * 8), # this makes this easier to train by keeping the output values of this layer consistent
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), #same thing, again
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False), # and 1 more time
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False), #and AGAIN
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), # final upscale, now we're at 64x64x3 (64x64 RGB image)
            nn.Tanh() #normlaize the output with the tanh activation function, like logistic function, but goes from -1 -> 1 instead of 0-> 1
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Create the generator
netG = Generator()

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)

![network](https://pytorch.org/tutorials/_images/dcgan_generator.png)

In [None]:
# the discriminator is a binary classifier.  It takes in an image and outputs whether looks like a training image or not
# The generator is trying to fool the discriminator, so that it can't tell generated images from training images
# They both start out as bad, but the generator gets better at making images as the discriminator gets better at teling
# generated/training images apart

# Size of feature maps in discriminator
ndf = 64

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 
            nn.LeakyReLU(0.2, inplace=True), #ReLU, but instead of setting negative inputs to 0, scales them down instead
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid() # 0.0 to 1.0 because this is a classifier... 0 == fake, 1 == real
        )

    def forward(self, input):
        return self.main(input)

In [None]:
netD = Discriminator()
# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

In [None]:
# setup optimization stuff

# Initialize BCELoss function
criterion = nn.BCELoss() # binary cross entropy... Was the discriminator right or not, considering its confidence too

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
# a "latent vector" is a random input that will generate an image if we run it through the generator
fixed_noise = torch.randn(64, nz, 1, 1)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
# Learning rate for optimizers
lr = 0.0002 # how aggressive will we be when updating weights?
beta1 = 0.5 # Beta1 hyperparam for Adam optimizers ... honestly not sure exactly what this does

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

NOTE, THIS IS SUPER SLOW (took 2 hours on my machine), feel free to load my model, as shown below

In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

num_epochs = 1 #start small

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        
        ## Since these are training images, the correct label for all of them is "real"
        netD.zero_grad()
        # Format batch
        real_cpu = data[0] # real images
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)  # probability of real for each image
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item() # Want the mean to be close to 1... which would mean that we are 100% confident that all images in this batch were real.

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()

        # D_G_z1: Output on fake images before we update the weights
        D_G_z1 = output.mean().item()  # If 0, then D network was sure that all fake images were fake
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z))) In english, get D to classify all of these as real
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item() # Output on fake images after we update the weights
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

# D(x): How confident the classifier is that the real images are real
# D(G(z)): Confidence that the fake images are fake

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
import matplotlib.animation as animation
from IPython.display import HTML

In [None]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
torch.save(netG.state_dict(), "GNet.pt")
torch.save(netD.state_dict(), "DNet.pt")


In [None]:
#should be able to load with:
loadedG = Generator()
loadedG.load_state_dict(torch.load("GNet.pt"))
loadedG.eval()

loadedD = Discriminator()
loadedD.load_state_dict(torch.load("DNet.pt"))
loadedD.eval()
