# Fully-Connected Network (PyTorch)

In [None]:
import os
import itertools

import numpy as np

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.optim as optim
import torch.utils.data as data_utils

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from verta import ModelDBClient


data_dir = os.path.join("..", "data", "mnist")
output_dir = os.path.join("..", "output", "pytorch")
os.makedirs(data_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

In [None]:
HOST = 
PORT = 

---

## Constants

In [None]:
TRAIN_DATA_PATH = os.path.join(data_dir, "train.npz")
TEST_DATA_PATH = os.path.join(data_dir, "test.npz")

LOSS_PLOT_PATH = os.path.join(output_dir, "{}.png")
MODEL_PATH = os.path.join(output_dir, "{}.pt")

In [None]:
GRID = {'hidden_size': [512, 1024],
        'dropout': [.2, .3, .5],
        'batch_size': [256, 512],
        'num_epochs': [4],
        'learning_rate': [0.001]}
grid = [dict(zip(GRID.keys(), values))
        for values
        in itertools.product(*GRID.values())]

## Client

In [None]:
client = ModelDBClient(HOST, PORT)
proj = client.set_project("MNIST Multiclassification")
expt = client.set_experiment("Pytorch FC-NN",
                             "one layer with dropout",
                             tags=["test", "neural-net"])

## Data

In [None]:
data = np.load(TRAIN_DATA_PATH)

x_train, x_val, y_train, y_val = train_test_split(data['x'], data['y'], test_size=.2)

x_train, x_val = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_val, dtype=torch.float)
y_train, y_val = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_val, dtype=torch.long)

# squeeze pixel values into from ints [0, 255] to reals [0, 1]
x_train, x_val = x_train/255, x_val/255

In [None]:
# create Dataset object to support batch training
class TrainingDataset(data_utils.Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return (self.features[idx], self.labels[idx])

## Model

In [None]:
class Net(nn.Module):
    def __init__(self, hidden_size=512, dropout=0.2):
        super().__init__()
        self.fc      = nn.Linear(28*28, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.output  = nn.Linear(hidden_size, 10)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)  # flatten non-batch dimensions
        x = func.relu(self.fc(x))
        x = self.dropout(x)
        x = func.softmax(self.output(x), dim=-1)
        return x
    
    def predict(self, x):
        with torch.no_grad():
            return self.forward(x).numpy().argmax(axis=1)
        
    def score(self, x, y):
        with torch.no_grad():
            return np.mean(self.predict(x) == y.numpy())

## Training

In [None]:
for hyperparams in grid:
    run = client.set_experiment_run(tags=["test"])
    
    for key, value in hyperparams.items():
        run.log_hyperparameter(key, value)
    
    model = Net(hyperparams['hidden_size'], hyperparams['dropout'])
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['learning_rate'])
    
    dataset = TrainingDataset(x_train, y_train)
    dataloader = data_utils.DataLoader(dataset, batch_size=hyperparams['batch_size'], shuffle=True)
    run.log_dataset("training_data", TRAIN_DATA_PATH)
    run.log_dataset("testing_data", TEST_DATA_PATH)
    
    losses = []
    for i_epoch in range(hyperparams['num_epochs']):
        print("{} | epoch {}/{}".format(hyperparams, i_epoch+1, hyperparams['num_epochs']), end='\r')
        batch_losses = []
        for i_batch, (x_batch, y_batch) in enumerate(dataloader):
            model.zero_grad()

            output = model(x_batch)

            loss = criterion(output, y_batch)
            run.log_observation("batch_loss", loss.item())
            batch_losses.append(loss.item())

            loss.backward()
            optimizer.step()
        run.log_observation("epoch_loss", sum(batch_losses)/len(dataloader))
        losses.extend(batch_losses)
    print()
    val_acc = model.score(x_val, y_val)
    run.log_metric("val_acc", val_acc)
    run.log_metric("final_loss", losses[-1])
    
#     plt.plot(losses)
#     plt.savefig(LOSS_PLOT_PATH.format(run.name), bbox_inches='tight')
#     plt.close()
    run.log_image("loss_plot", LOSS_PLOT_PATH.format(run.name))
    
#     torch.save(model.state_dict(), MODEL_PATH.format(run.name))
    run.log_model("validation_model", MODEL_PATH.format(run.name))