# Load code libraries

In [3]:
# We use as few libraries as possible to make the code more portable, and to make it easier to understand.
# I.e., there are fewer libraries to learn.

import torch
import pandas as pd

import pathlib

In [4]:
# A few convenience functions for manipulating data and plotting

from generatedata.df_to_tensor import df_to_tensor
from generatedata.StartTargetData import StartTargetData
from generatedata.load_data import load_data

# Load data

In [5]:
name = 'MNIST1D'
z_size = 50
x_idx = range(40)
y_idx = range(40, 50)

# name = 'MNIST'
# z_size = 794
# x_idx = range(784)
# y_idx = range(784, 794)

# Read the data
data_dict = load_data(name)

z_start = data_dict['start']
x_target = data_dict['target']

# Dynamical system

In [6]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
# Make two pytorch tensor datasets from the start and target data
z_start_tensor = df_to_tensor(z_start).to(device)
x_target_tensor = df_to_tensor(x_target).to(device)
    
train_data = StartTargetData(z_start_tensor, x_target_tensor)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)

Where to start?  Lets look at an MLP.  We will use the following notation:

$$
\begin{bmatrix}
I   && 0   && 0 \\
W_1 && 0   && 0 \\
0   && W_2 && 0 \\
\end{bmatrix}
$$

In [8]:
class MLP(torch.nn.Module):
    def __init__(self, x_idx, y_idx, hidden_size = 100):
        super(MLP, self).__init__()
        self.x_idx = x_idx
        self.y_idx = y_idx
        self.hidden_size = hidden_size

        self.sigma = torch.nn.ReLU()
        self.reset()

    def reset(self):
        with torch.no_grad():
            # Let's make all of the choices as explicit as possible! 
            # So we start with raw tensors (not even Linear layers!)
            self.W1_raw = torch.zeros(size=(self.hidden_size, len(self.x_idx)), requires_grad=True, dtype=torch.float32)
            self.W2_raw = torch.zeros(size=(self.hidden_size, self.hidden_size), requires_grad=True, dtype=torch.float32)
            self.W3_raw = torch.zeros(size=(len(self.y_idx),  self.hidden_size), requires_grad=True, dtype=torch.float32)

            self.b1_raw = torch.zeros(size=(self.hidden_size,))
            self.b2_raw = torch.zeros(size=(self.hidden_size,))
            self.b3_raw = torch.zeros(size=(len(self.y_idx),))

            # Which we need to initialize, we use the defaul from the Linear layer
            torch.nn.init.kaiming_uniform_(self.W1_raw, a=5**0.5)
            torch.nn.init.kaiming_uniform_(self.W2_raw, a=5**0.5)
            torch.nn.init.kaiming_uniform_(self.W3_raw, a=5**0.5)
            torch.nn.init.zeros_(self.b1_raw)
            torch.nn.init.zeros_(self.b2_raw)
            torch.nn.init.zeros_(self.b3_raw)

            # Now we make them into parameters
            self.W1 = torch.nn.Parameter(self.W1_raw)
            self.W2 = torch.nn.Parameter(self.W2_raw)
            self.W3 = torch.nn.Parameter(self.W3_raw)
            self.b1 = torch.nn.Parameter(self.b1_raw)
            self.b2 = torch.nn.Parameter(self.b2_raw)
            self.b3 = torch.nn.Parameter(self.b3_raw)
    
    def forward(self, z):
        # Extract the x part of the overall z
        x = z[:, self.x_idx]

        # The first layer
        # NOTE: the transpose and the left multiplication.  This is to be consistent with how Pytorch does it.
        h_1 = x @ self.W1.T + self.b1
        h_1 = self.sigma(h_1)

        # The second layer
        # NOTE: the skip connection here! This would normally be 
        #    h_2 = h_1 @ self.W2.T + self.b2
        # but https://github.com/greydanus/mnist1d puts in the skip connection, so we do the same!
        h_2 = h_1 + h_1 @ self.W2.T + self.b2
        h_2 = self.sigma(h_2)

        # The third layer
        y = h_2 @ self.W3.T + self.b3
        # The log_softmax is the output of the map and is paired with the NLLLoss
        y = torch.nn.functional.log_softmax(y, dim=1)

        # This is the identity part of the map, so that we can iterate the map if we want. 
        # NOTE:  you have to be careful with this so that the gradients can flow through it.  
        #        There are many ways to do this, but this is one has the desired effect.
        return torch.cat([z[:, self.x_idx], y], dim=1)

    @staticmethod
    def train(model, train_loader, epochs = 100, lr = 0.01, verbose = False):
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        # If you use the CrossEntropyLosee, the you *do not* need to use the log_softmax in the forward function.
        # loss_fn = torch.nn.CrossEntropyLoss()
        # If you use the NLLLoss, then you need to use the log_softmax in the forward function.
        loss_fn = torch.nn.NLLLoss()
        steps = 0

        for epoch in range(epochs):
            # Apply the map to the start data
            for i, (z_start, z_target) in enumerate(train_loader):
                # Note, not all methods will do the gradient update in the same way.
                optimizer.zero_grad()
                # Also, the forward function, which is inference focused, may not always 
                # be the function you want to use for training.
                z_pred = model.forward(z_start)

                # z_target is a one-hot encoded vector, so we need to extract the index of the 1
                z_target = torch.argmax(z_target[:, y_idx], dim=1)
                loss = loss_fn(z_pred[:, y_idx], z_target)
                accuracy = (torch.argmax(z_pred[:, y_idx], dim=1) == z_target).float().mean()   

                # Note, not all methods will do the gradient update in the same way.
                # accelerator.backward(loss)
                loss.backward()
                optimizer.step()
                steps += 1

            if verbose and epoch % 10 == 0:
                print(f'Epoch {epoch}, Steps {steps}, Loss {loss.item()}, Accuracy {accuracy}')

model = MLP(x_idx, y_idx)
model.to(device)
MLP.train(model, train_loader, epochs = 150, lr = 0.01, verbose = True)

Epoch 0, Steps 10, Loss 1.9109129905700684, Accuracy 0.17999999225139618
Epoch 10, Steps 110, Loss 0.5957124829292297, Accuracy 0.7799999713897705
Epoch 20, Steps 210, Loss 0.016593709588050842, Accuracy 1.0
Epoch 30, Steps 310, Loss 0.004736598581075668, Accuracy 1.0
Epoch 40, Steps 410, Loss 0.0019781608134508133, Accuracy 1.0
Epoch 50, Steps 510, Loss 0.0013258801773190498, Accuracy 1.0
Epoch 60, Steps 610, Loss 0.0011669707018882036, Accuracy 1.0
Epoch 70, Steps 710, Loss 0.0009910791413858533, Accuracy 1.0
Epoch 80, Steps 810, Loss 0.0006620872300118208, Accuracy 1.0
Epoch 90, Steps 910, Loss 0.0005466172005981207, Accuracy 1.0
Epoch 100, Steps 1010, Loss 0.0003733691992238164, Accuracy 1.0
Epoch 110, Steps 1110, Loss 0.0003196106117684394, Accuracy 1.0
Epoch 120, Steps 1210, Loss 0.0002991199435200542, Accuracy 1.0
Epoch 130, Steps 1310, Loss 0.0002387588901910931, Accuracy 1.0
Epoch 140, Steps 1410, Loss 0.0002096103853546083, Accuracy 1.0
