In [None]:
# CODE CITATIONS: 
# https://github.com/pytorch/tutorials/blob/master/beginner_source/dcgan_faces_tutorial.py
# https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475

# imports
import math
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt

# hyperparameters
batch_size  = 64
n_channels  = 3
dataset = 'stl10'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
g_feature_map_size = 64
d_feature_map_size = 64

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

if dataset == 'stl10':
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(64),
            torchvision.transforms.CenterCrop(64),                                                                                                                                              
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])),
    shuffle=True, batch_size=batch_size, drop_last=True)
    train_iterator = iter(cycle(train_loader))
    class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'] # these are slightly different to CIFAR-10

In [None]:
# let's view some of the training data
plt.rcParams['figure.dpi'] = 175
x,t = next(train_iterator)
x,t = x.to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x, normalize=True).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# define the model
class Generator(nn.Module):
    def __init__(self, latent_size=100):
        super(Generator, self).__init__()
        self.layer = nn.Sequential(
            nn.ConvTranspose2d(latent_size, g_feature_map_size * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(g_feature_map_size * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(g_feature_map_size * 8, g_feature_map_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_feature_map_size * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(g_feature_map_size * 4, g_feature_map_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_feature_map_size * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(g_feature_map_size * 2, g_feature_map_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_feature_map_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(g_feature_map_size, n_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.layer(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(  
            nn.Conv2d(n_channels, d_feature_map_size, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),     
            nn.Conv2d(d_feature_map_size, d_feature_map_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature_map_size * 2),
            nn.LeakyReLU(0.2, inplace=True),  
            nn.Conv2d(d_feature_map_size * 2, d_feature_map_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature_map_size * 4),
            nn.LeakyReLU(0.2, inplace=True),  
            nn.Conv2d(d_feature_map_size * 4, d_feature_map_size * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature_map_size * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(d_feature_map_size * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layer(x)


G = Generator().to(device)
D = Discriminator().to(device)

G.apply(weights_init)
D.apply(weights_init)

print(f'> Number of generator parameters {len(torch.nn.utils.parameters_to_vector(G.parameters()))}')
print(f'> Number of discriminator parameters {len(torch.nn.utils.parameters_to_vector(D.parameters()))}')

# initialise the optimiser
optimiser_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimiser_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
bce_loss = nn.BCELoss()
epoch = 0

In [None]:
# training loop
while (epoch<300):

    # array(s) for the performance measures
    gen_loss_arr = np.zeros(0)
    dis_loss_arr = np.zeros(0)

   
    # iterate over some of the train dateset
    for i in range(821):
        
        # sample x from the dataset
        x,t = next(train_iterator)
        x,t = x.to(device), t.to(device)

        # train discriminator
        g = G(torch.randn(x.size(0), 100, 1, 1).to(device))
        D.zero_grad()
        l_r = bce_loss(D(x).view(-1), torch.full((x.size(0),), 1., dtype=torch.float).to(device))
        l_f = bce_loss(D(g.detach()).view(-1), torch.full((x.size(0),), 0., dtype=torch.float).to(device))
        l_r.backward()
        l_f.backward()
        loss_d = (l_r + l_f)
        optimiser_D.step()

        # train generator
        G.zero_grad()
        loss_g = bce_loss(D(g).view(-1), torch.full((x.size(0),), 1., dtype=torch.float).to(device))
        loss_g.backward()
        optimiser_G.step()

        # collect stats
        gen_loss_arr = np.append(gen_loss_arr, loss_g.item())
        dis_loss_arr = np.append(dis_loss_arr, loss_d.item())

    # plot some examples
    G.eval()
    rand_noise = torch.randn(x.size(0), 100, 1, 1).to(device)
    g = G(rand_noise)
    print('loss d: {:.3f}, loss g: {:.3f}'.format(gen_loss_arr.mean(), dis_loss_arr.mean()))
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(g, normalize=True).cpu().data.clamp(0,1).permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()
    plt.pause(0.0001)
    G.train()

    epoch = epoch+1

In [None]:
# Show a batch of data
plt.rcParams['figure.dpi'] = 175
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(g, normalize=True).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

In [None]:
# now show some interpolations
def slerp(low, high, val):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1)).view(batch_size,1,1,1)
    so = torch.sin(omega)
    a = (torch.sin((1-val)*omega)/so)
    a = a * low
    b = (torch.sin(val*omega)/so) * high
    res = a + b
    return res

col_size = int(np.sqrt(batch_size))
z0 = rand_noise[0:col_size].repeat(col_size,1,1,1) # z for top row
z1 = rand_noise[batch_size-col_size:].repeat(col_size,1,1,1) # z for bottom row
t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).view(batch_size,1,1,1).to(device)
z_point = slerp(z0, z1, t)

slerp_g = G(z_point) # sample the model at the resulting interpolated latents

plt.rcParams['figure.dpi'] = 175
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(slerp_g, normalize=True).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()