# A06 Q2: Batch Normalization

In [None]:
import numpy as np
import torch
import torch.nn as nn
import utils
import matplotlib.pyplot as plt
from tqdm import tqdm

# Dataset

In [None]:
# Make a PyTorch wrapper for our UClasses dataset
class UClasses(torch.utils.data.Dataset):
    def __init__(self, n=300):
        super().__init__()
        np_ds = utils.UClasses(n=n, binary=False)  # heavy lifting done by NumPy code
        self.x = torch.tensor(np_ds.inputs(), dtype=torch.float32)
        self.t = torch.argmax(torch.tensor(np_ds.targets()), axis=1, keepdim=False)
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.t[idx]
    
    def inputs(self):
        return self.x
    
    def targets(self):
        return self.t
            
    def plot(self, labels=None, *args, **kwargs): 
        if labels is None:
            labels = self.t
        colour_options = ['y', 'r', 'g', 'b', 'k']
        if len(labels.size())>1 and len(labels[0])>1:
            # one-hot labels
            cidx = torch.argmax(labels, axis=1)
        else:
            # binary labels
            cidx = (labels>0.5).type(torch.int)
        colours = [colour_options[k] for k in cidx]
        plt.scatter(self.x[:,0].detach(), self.x[:,1].detach(), color=colours, marker='.')
        plt.axis('equal');

In [None]:
train = UClasses(n=1000)
train.plot()

In [None]:
train.targets()[0]

In [None]:
train_dl = torch.utils.data.DataLoader(train, batch_size=200, shuffle=True)

# A. `BatchNorm` Class

In [None]:
class BatchNorm(nn.Module):
    '''
     lyr = BatchNorm(eps=0.001)
     
     Creates a PyTorch layer (Module) that implements batch normalization, so
     that its outputs are remapped. For each node in the layer, the output for the
     batch is normalized so that it is zero-mean and approximately unit-variance.
     
     Inputs:
      eps     stability parameter, to avoid division by zero
             
     Usage:
      lyr = BatchNorm(eps=0.001)
      y = lyr(x)    # x is a batch (tensor), with one sample in each row
                    # y is the same shape as x
    '''
    def __init__(self, eps=0.001):
        super().__init__()
        

    def forward(self, x):
        y = x   # replace this line
        
        return y

# B. Demonstrate `BatchNorm`

# Network base class, `NNBase`

In [None]:
# This network base class saves us from having to duplicate the learn function.

class NNBase(nn.Module):
    '''
     You should not instantiate this class directly. 
     Base class for other simple neural network classes.
     eg.
       class MyDerivedNN(NNBase):
          ...
    '''
    def __init__(self):
        super().__init__()
        self.losses = []
        self.loss_fcn = None  # Should be overridden in derived class
        
    def forward(self, x):
        return x
        
    def learn(self, dl, epochs=10, lr=0.1, plot=True):
        '''
         net.learn(dl, epochs=10, lr=0.1, plot=True)
         
         Performs SGD on the neural network.
         Inputs:
          dl      DataLoader object (PyTorch)
          epochs  number of epochs to perform
          lr      learning rate
          plot    whether or not to plot the learning curve
        '''
        optim = torch.optim.SGD(self.parameters(), lr=lr)
        for epoch in tqdm(range(epochs)):
            total_loss = 0.
            for x,t in dl:
                y = self(x)
                loss = self.loss_fcn(y, t.float())
                optim.zero_grad()
                loss.backward()
                optim.step()
                total_loss += loss.item()
            self.losses.append(total_loss/len(dl))
        if plot:
            plt.figure(figsize=(4,4))
            plt.plot(self.losses);

# C. Compare learning performance

## Create two network classes

In [None]:
# NormalNet
# Create a simple neural network that does NOT use batchnorm.
# You should use NNBase as the base class.


In [None]:
# BNNet
# Create a simple neural network that DOES use batchnorm after each
# layer (except the output layer).
# You should use NNBase as the base class.
# Apply batchnorm between the activation function and the connections
# to the next layer.


## Experiments
Let's compare the learning curves for the following cases:
1. Normal NN (no batchnorm)
3. Batchnorm

In [None]:
# Code to run experiments



## Results

In [None]:
# Code to plot the results of the trials

