# Pegasus

## Google COLAB Settings
In this section are certain processes that should be run when running the code through Google COLAB so to have access to a GPU. If such is the case, uncomment the sections and run them sequentially. otherwise, feel free to skip directly to [Imports](#Imports).

### Installs

In [None]:
# %%capture
# from os.path import exists
# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
# cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
# accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
# !pip install livelossplot

### Google Drive
This portion is exclusively for development on _my_ end. I use Google Drive to access the training/testing data without having to redownload it each time the Google COLAB runtime is reset. 

Of course anyone who does not have access to my Google credentials will not be able to access my Drive. As such, these users should skip directly to [Imports](#Imports). The result will be that torchvision will personally download the CIFAR data from the web each time the COLAB runtime is reset.

#### Mounting Drive
This mounts Google Drive to the local runtime. If Drive is already mounted, then of course, it will not try to mount it again. It will of course ask for authentication.

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive')

#### Importing data from Drive
Here I import the CIFAR data which I have previously downloaded and stored in my Google Drive. To do so I copy the corresponding directory from my Google Drive into the COLAB Runtime to avoid having to redownload it each time my COLAB runtime is reset. The ```-n``` flag is set to avoid overwriting.

In [None]:
# !cp -r -n /content/gdrive/My\ Drive/Education/Undergraduate/Year_3/Computer_Science/SSA/Machine_Learning/Coursework/ML_Classifier-Pegasus-Generator/data/ /content/

## Imports

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from livelossplot import PlotLosses

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Dataset Setup

In [None]:
# define class names for CIFAR 10
class_names = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

### Transforms

In [None]:
# define transforms to be applied to training and testing data
transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

inverse_normalize = torchvision.transforms.Normalize(
    mean=[-1, -1, -1],
    std=[2, 2, 2]
)

### Getting the Data
If not already present, this cell will download the [CIFAR 10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset from the web. Otherwise it will simply read it from the existing directory. The transforms defined [above](#Transforms) will be applied.

#### Retaining only Relevant Data
It should be noted that the eventual goal of this script is to generate a pegasus -- that is, a horse with wings. As such, there is no need for training the network on data that is not relevant. The network might as well only focus on creating inner representations of horses and wings. Therefore it makes more sense to only train the network on images of horses, birds and planes (after all no one said this _couldn't_ be a cyborg-pegasus). It is hypothesized that this will make training more efficient and will render [mode collapse](https://arxiv.org/pdf/1611.02163.pdf) less of an issue, as there are less modes to collapse to.

To achieve this effect, we make use of the [SubsetRandomSampler class provided by torch.utils.data](https://pytorch.org/docs/stable/data.html#torch.utils.data.SubsetRandomSampler) passing to it a list of the indices that correspond exclusively to these classes.

In [None]:
# download training set
train_set = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=transforms)
# download test set
test_set = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=transforms)

# get the labels (numbers) corresponding to airplane, horse, bird class names
accepted_labels = [i for i in range(len(class_names)) if class_names[i] in ['airplane', 'horse', 'bird']]

# get the indices in the downloaded data sets corresponding to airplanes, horses and birds.
relevant_train_indices = [i for i in range(len(train_set.train_labels)) if train_set.train_labels[i] in accepted_labels]
relevant_test_indices = [i for i in range(len(test_set.test_labels)) if test_set.test_labels[i] in accepted_labels]

# insatiating the samplers to feed into the DataLoader so that only airplanes, horses and birds are loaded
train_sampler = torch.utils.data.SubsetRandomSampler(relevant_train_indices)
# we can use random sampling on the test data too since we are using it to train anyway (it's not being used for testing)
test_sampler = torch.utils.data.SubsetRandomSampler(relevant_test_indices)

### Loading the Data
Having obtained the data, it needs to be loaded into an iterable format for pytorch to use.

In [None]:
# define batch size
BATCH_SIZE = 128

# load training set into a data torch data object, shuffled
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, drop_last=True, sampler = train_sampler)

# load test set into a data torch data object, unshuffled
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, drop_last=True, sampler = test_sampler)

# diagnostic print size of training dataset
print(f'> Size of training dataset: {len(relevant_train_indices)} + {len(relevant_test_indices)} (train + test)')

### Viewing (some of) the Data
...It almost seems as if birds are not worth training on since their wings are often closed...

In [None]:
# convert test loader into a list so that we may index it.
test_loader_list = list(test_loader)

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    img = inverse_normalize(test_loader_list[0][0][i]).permute(0,2,1).contiguous().permute(2,1,0)
    plt.imshow(img, cmap=plt.cm.binary)
    plt.xlabel(class_names[test_loader_list[0][1][i]])
    plt.tight_layout()

## Networks

In [None]:
def initialize_params(module, mean, stdDev):
    """Helper function that initializes the biases and weights of a module given mean and standard deviations
    given a module that is either a 2D Convolution or a 2D Transpose Convolution  to 
    """
    if isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.Conv2d):
        m.weight.data.normal_(mean, stdDev)
        m.bias.data.zero_()



class Generator(nn.Module):
    """A deep convolutional Generator network, 
    designed to generate 32x32 colour images from a random 100D noise vector.
    Based on DCGAN but with one less layer, 
    mostly due to the fact that the images here are half the size of those in the original paper.
    """
    def __init__(self, b=128):
        # inheriting initialization parameters from nn.Module
        super(Generator, self).__init__()
        # transpose convolve input into 512x4x4 tensor
        self.tconv4 = nn.ConvTranspose2d(100, b*4, kernel_size = 4, stride = 1, padding = 0, bias = False)
        self.bnorm4 = nn.BatchNorm2d(b*4)
        # transpose convolve into 256x8x8 tensor
        self.tconv8 = nn.ConvTranspose2d(b*4, b*2, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.bnorm8 = nn.BatchNorm2d(b*2)
        # transpose convolve into 128x16x16 tensor
        self.tconv16 = nn.ConvTranspose2d(b*2, b, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.bnorm16 = nn.BatchNorm2d(b)
        # transpose convolve into 3x32x32 tensor (an image)
        self.tconv32 = nn.ConvTranspose2d(b, 3, kernel_size = 4, stride = 2, padding = 1)
        self.tanh = nn.Tanh()
        
        # activation function 
        self.relu = nn.ReLU()
    
    def initialize(self, mean, stdDev):
        """Instance method for initializing the weights to specific values"""
        for module in self._modules:
            initialize_params(self._modules[module], mean, stdDev)
    
    def forward(self, z):
        """Defines the order of operations to follow when the generator is called"""
        z = self.relu(self.bnorm4(self.tconv4(z)))
        z = self.relu(self.bnorm8(self.tconv8(z)))
        z = self.relu(self.bnorm16(self.tconv16(z)))
        z = self.tanh(self.tconv32(z))
        return z

class Discriminator(nn.Module):
    """A deep convolutional Discriminator network, 
    designed determine whether an input image is real (output 1) or fake (output 0).
    Based on DCGAN but with one less layer, 
    mostly due to the fact that the images here are half the size of those in the original paper.
    """
    
    def __init__(self, b=128):
        # inheriting initialization parameters from nn.Module
        super(Discriminator, self).__init__()
        # convolve image into 128x16x16 tensor
        self.conv16 = nn.Conv2d(3, b, kernel_size = 4, stride = 2, padding = 1)
        # convolve into 256x8x8 tensor
        self.conv8 = nn.Conv2d(b, b*2, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.bnorm8 = nn.BatchNorm2d(b*2)
        # convolve into 512x4x4 tensor
        self.conv4 = nn.Conv2d(b*2, b*4, kernel_size = 4, stride =2, padding = 1, bias = False)
        self.bnorm4 = nn.BatchNorm2d(b*4)
        # compute sigmoid
        self.sigmoid = nn.Sigmoid()
        # activation function
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        
    def initialize(self, mean, stdDev):
        """Instance method for initializing the weights to specific values"""
        for module in self._modules:
            initialize_params(self._modules[module], mean, stdDev)
        
    def forward(self, x):
        """Defines the order of operations to follow when the discriminator is called"""
        x = self.leaky_relu(self.conv16(x))
        x = self.leaky_relu(self.bnorm8(self.conv8(x)))
        x = self.leaky_relu(self.bnorm4(self.conv4(x)))
        x = self.sigmoid(x)
        return x

## Training

### Set up

In [None]:
# initiliaze the neural netowork instances
G = Generator().to(device)
D = Discriminator().to(device)

print(f'> Number of Generator parameters {len(torch.nn.utils.parameters_to_vector(G.parameters()))}')
print(f'> Number of Discriminator parameters {len(torch.nn.utils.parameters_to_vector(D.parameters()))}')

# define optimizer parameters
learning_rate = 0.0002
betas = (0.5, 0.999)

# initialise the optimisers
opt_G = torch.optim.Adam(G.parameters(), lr = learning_rate, betas = betas)
opt_D = torch.optim.Adam(D.parameters(), lr = learning_rate, betas = betas )

# initialize the epochs
epoch = 0

# initialize livelossplot instance
liveplot = PlotLosses()

### Train

## Results

In [None]:
# get the Tensors for a horse and for a bird
example_1 = (test_loader.dataset[13][0]).to(device)  # horse
example_2 = (test_loader.dataset[160][0]).to(device) # bird

# run them through the encoder
example_1_code = N.encode(example_1.unsqueeze(0))
example_2_code = N.encode(example_2.unsqueeze(0))

# decode an interpolation of the two
bad_pegasus = N.decode(0.9*example_1_code + 0.1*example_2_code).squeeze(0)

# plot the result of the decoding
plt.grid(False)
plt.imshow(bad_pegasus.cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)