In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
!pip install Pillow
from PIL import Image
import matplotlib.animation as animation
from IPython.display import HTML

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# root directory for dataset
data_root = './data'

# directory for generated images
fake_dir = "./gan-03-images"

# make fake_dir
if not os.path.exists(fake_dir):
    os.makedirs(fake_dir)

num_workers = 2
ngpu = 1
batch_size = 256

image_size = 32
nc = 3
nz = 100
ngf = 64
ndf = 64

num_epochs = 200
learning_rate = 0.0002
beta1 = 0.5

In [None]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dataset = torchvision.datasets.CIFAR10(root=data_root,
                                       train=True,
                                       download=True,
                                       transform=transform)

data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=num_workers)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        out = self.main(x)
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.BatchNorm2d(ndf*2),

            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.BatchNorm2d(ndf*4),

            nn.Conv2d(ndf*4, 1, 4, 1, 0, bias=False),
            nn.Flatten(1, -1),
            nn.Sigmoid()
        )
    def forward(self, x):
        out = self.main(x)
        return out

In [None]:
G = Generator(ngpu).to(device)
D = Discriminator(ngpu).to(device)

G.apply(weights_init)
D.apply(weights_init)

if device.type == "cuda" and ngpu > 1:
    G = nn.DataParallel(G, list(range(ngpu)))
    D = nn.DataParallel(D, list(range(ngpu)))

print(G)
print(D)

In [None]:
summary(G, (100, 1, 1))
summary(D, (3, 32, 32))

In [None]:
def denorm(x):
    out = (x+1)/2
    return out

def reset_grad():
    g_optimizer.zero_grad()
    d_optimizer.zero_grad()

In [None]:
criterion = nn.BCELoss()
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=(beta1, 0.999))
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
img_list = []
g_losses = []
d_losses = []
total_step = len(data_loader)

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(data_loader):
        #==========train D==========#
        # train with real images
        real_images = real_images.to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        
        outputs = D(real_images)
        d_loss_real = criterion(outputs, real_labels[:len(outputs)])
        real_score = outputs
        
        # train with fake images
        z = torch.randn(batch_size, nz, 1, 1).to(device)
        fake_images = G(z)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels[:len(outputs)])
        fake_score = outputs
        
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #==========train G==========#
        z = torch.randn(batch_size, nz, 1, 1).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels[:len(outputs)])
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
    
    print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch+1, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    if (epoch+1) == 1:
        torchvision.utils.save_image(denorm(real_images), os.path.join(fake_dir, 'real_images.png'))
    
    torchvision.utils.save_image(denorm(fake_images), os.path.join(fake_dir, 'fake_images-{}.png').format(epoch+1))
    img_list.append(torchvision.utils.make_grid(fake_images, padding=2, normalize=True))

torch.save(G.parameters(), './G.ckpt')
torch.save(D.parameters(), './D.ckpt')

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(g_losses,label="G")
plt.plot(d_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(data_loader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(torchvision.utils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1].cpu(),(1,2,0)))
plt.show()