# Pegasus

## Google COLAB Settings
In this section are certain processes that should be run when running the code through Google COLAB so to have access to a GPU. If such is the case, uncomment the sections and run them sequentially. otherwise, feel free to skip directly to [Imports](#Imports).

### Installs

In [None]:
# %%capture
# from os.path import exists
# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
# cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
# accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
# !pip install livelossplot

### Google Drive
This portion is exclusively for development on _my_ end. I use Google Drive to access the training/testing data without having to redownload it each time the Google COLAB runtime is reset. 

Of course anyone who does not have access to my Google credentials will not be able to access my Drive. As such, these users should skip directly to [Imports](#Imports). The result will be that torchvision will personally download the CIFAR data from the web each time the COLAB runtime is reset.

#### Mounting Drive
This mounts Google Drive to the local runtime. If Drive is already mounted, then of course, it will not try to mount it again. It will of course ask for authentication.

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive')

#### Importing data from Drive
Here I import the CIFAR data which I have previously downloaded and stored in my Google Drive. To do so I copy the corresponding directory from my Google Drive into the COLAB Runtime to avoid having to redownload it each time my COLAB runtime is reset. The ```-n``` flag is set to avoid overwriting.

In [None]:
# !cp -r -n /content/gdrive/My\ Drive/Education/Undergraduate/Year_3/Computer_Science/SSA/Machine_Learning/Coursework/ML_Classifier-Pegasus-Generator/data/ /content/

## Imports

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from livelossplot import PlotLosses

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Functions

In [None]:
def cycle(iterable):
    """helper function to make getting another batch of data easier"""
    while True:
        for x in iterable:
            yield x

## Classes

In [None]:
class MyNetwork(nn.Module):
    """define the model (a simple autoencoder)"""
    def __init__(self):
        super(MyNetwork, self).__init__()
        layers = nn.ModuleList()
        layers.append(nn.Linear(in_features=3*32*32, out_features=512))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(in_features=512, out_features=32))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(in_features=32, out_features=512))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(in_features=512, out_features=3*32*32))
        layers.append(nn.Sigmoid())
        self.layers = layers

    def forward(self, x):
        z = self.encode(x)
        x = self.decode(z)
        return x

    def encode(self, x):
        """encode (flatten as linear, then run first half of network)"""
        x = x.view(x.size(0), -1)
        for i in range(4):
            x = self.layers[i](x)
        return x

    def decode(self, x):
        """decode (run second half of network then unflatten)"""
        for i in range(4,8):
            x = self.layers[i](x)
        x = x.view(x.size(0), 3, 32, 32)
        return x

## Dataset Setup

In [None]:
# define class names for CIFAR 10
class_names = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

### Transforms

In [None]:
# define transforms to be applied to training data
train_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])

# define transforms to be applied to testing data
test_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])

### Getting the Data
If not already present, this cell will download the [CIFAR 10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset from the web. Otherwise it will simply read it from the existing directory. The transforms defined [above](#Transforms) will be applied.

In [None]:
# download training set
train_set = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=train_transforms)
# download test set
test_set = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=test_transforms)

### Loading the Data
Having obtained the data, it needs to be loaded into an iterable format for pytorch to use.

In [None]:
# define batch size
BATCH_SIZE = 16

# load training set into a data torch data object, shuffled
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)

# load test set into a data torch data object, unshuffled
test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=BATCH_SIZE, drop_last=True)

# create iterators for later use
train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

# diagnostic prints
print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')

### Viewing (some of) the Data

In [None]:
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(test_loader.dataset[i][0].permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.xlabel(class_names[test_loader.dataset[i][1]])

## Training

### Set up

In [None]:
# initiliaze the network instance
N = MyNetwork().to(device)

print(f'> Number of network parameters {len(torch.nn.utils.parameters_to_vector(N.parameters()))}')

# initialise the optimiser
optimiser = torch.optim.Adam(N.parameters(), lr=0.001)
# initialize the epochs
epoch = 0
# initialize livelossplot instance
liveplot = PlotLosses()

### Train

In [None]:
# training loop, feel free to also train on the test dataset if you like for generating the pegasus
while (epoch<10):
    
    # arrays for metrics
    train_loss_arr = np.zeros(0)

    # iterate over some of the train dateset
    for i in range(1000):
        # get data and respective target batch samples
        x,t = next(train_iterator)
        # place them onto the GPU
        x,t = x.to(device), t.to(device)
        
        # set the gradient to zero
        optimiser.zero_grad()
        # calculate a prediction
        p = N(x)
        # calculate the loss
        loss = ((p-x)**2).mean() # simple l2 loss
        # backpropagate the loss 
        loss.backward()
        # train
        optimiser.step()
        # record the losses for each data/target pair
        train_loss_arr = np.append(train_loss_arr, loss.cpu().data)

    # plot the training loss
    liveplot.update({
        'loss': train_loss_arr.mean()
    })
    liveplot.draw()
    
    # move on to the next epoch
    epoch = epoch+1

## Results

In [None]:
# get the Tensors for a horse and for a bird
example_1 = (test_loader.dataset[13][0]).to(device)  # horse
example_2 = (test_loader.dataset[160][0]).to(device) # bird

# run them through the encoder
example_1_code = N.encode(example_1.unsqueeze(0))
example_2_code = N.encode(example_2.unsqueeze(0))

# decode an interpolation of the two
bad_pegasus = N.decode(0.9*example_1_code + 0.1*example_2_code).squeeze(0)

# plot the result of the decoding
plt.grid(False)
plt.imshow(bad_pegasus.cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)