## Testing binary classification with fully-connected neural network

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src import constants as c
from src.model import VAE
from src import visualization as v

In [None]:
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pandas as pd
from sklearn import decomposition, manifold

from tqdm import tqdm, tnrange, tqdm_notebook

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(c.image_size),
    transforms.CenterCrop(c.image_size),
    transforms.ToTensor()
])

image_datasets = {x: datasets.ImageFolder(os.path.join(c.data_home, 'surgical_data/',x),
                                          data_transforms)
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=c.batch_size,
                                              shuffle=True)
               for x in ['train', 'val']}

## Import VAE to generate encodings

In [None]:
models = {zdim: VAE(image_channels=c.image_channels,
                    image_size=c.image_size, 
                    h_dim1=1024,
                    h_dim2=128,
                    zdim=zdim).to(c.device) for zdim in [10]}

for zdim, model in models.items():
    model.load_state_dict(torch.load(c.data_home + "weights/tools_vae_{}_epoch_50_zdim_{}.torch".format(c.image_size,
                                                                                                        zdim)))

In [None]:
train_labels = pd.read_csv(os.path.join(c.data_home, 'surgical_data/train/', 'labels.csv'), names=['Frame', 'Tool'])
val_labels = pd.read_csv(os.path.join(c.data_home, 'surgical_data/val/', 'labels.csv'),  names=['Frame', 'Tool'])

In [None]:
encoded_inputs = {zdim: {'train':[], 'val':[]} for zdim in [10]}

with torch.no_grad():
    for zdim in tqdm_notebook(encoded_inputs):
        for index in tnrange(len(image_datasets['train'])):
            data = image_datasets['train'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim]['train'].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])
            
        for index in tnrange(len(image_datasets['val'])):
            data = image_datasets['val'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim]['val'].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])

In [None]:
train_df = pd.concat([pd.DataFrame(encoded_inputs[10]['train']), train_labels],axis=1)
train_df.drop(columns=['Frame'], inplace=True)
train_df = train_df.dropna()
train_df

In [None]:
val_df = pd.concat([pd.DataFrame(encoded_inputs[10]['val']), val_labels],axis=1)
val_df.drop(columns=['Frame'], inplace=True)
val_df = val_df.dropna()
val_df

In [None]:
train_dataset = torch.utils.data.TensorDataset(torch.Tensor(np.array(train_df, dtype=np.float32)), 
                                               torch.Tensor(np.array(train_target)))
val_dataset = torch.utils.data.TensorDataset(torch.Tensor(np.array(val_df, dtype=np.float32)), 
                                               torch.Tensor(np.array(val_target)))

train_loader = torch.utils.data.DataLoader(train_dataset, 
                                          batch_size=64, 
                                          shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, 
                                        batch_size=64, 
                                        shuffle=True)

In [None]:
class LatentSpaceClassifier(nn.Module):
    def __init__(self, zdim, hdim1, hdim2):
        super(LatentSpaceClassifier, self).__init__()
        self.fc1 = nn.Linear(zdim, hdim1)
        self.fc2 = nn.Linear(hdim1, hdim2)
        self.fc3 = nn.Linear(hdim2, 2)
        
    def forward(self, x):
        # Using dropout to counter possible overfitting
        x = F.relu(self.fc1(x))
        x = F.relu(F.dropout(self.fc2(x)))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [None]:
clf = LatentSpaceClassifier(10, 20, 5).to(c.device)

optimizer = torch.optim.SGD(clf.parameters(), lr=1e-3, momentum=0.9)
# create a loss function
criterion = nn.NLLLoss()




def train():
    train_correct = 0
    train_total = 0
    for (data, target) in train_loader:
        data, target = torch.autograd.Variable(data).to(c.device), torch.autograd.Variable(target).to(c.device)
        target = target.long()
        optimizer.zero_grad()
        
        output = clf(data)
        
        _, predicted = torch.max(output.data, 1)
        train_total += target.size(0)
        train_correct += (predicted == target).sum().item()
            
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
    tqdm.write('Train Epoch: {} \tLoss: {:.6f}\tTraining Accuracy: {:.3f}%'.format(epoch, loss.item(), 100*train_correct/train_total))

    return (loss.item(), train_correct/train_total)

    
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for (data, target) in val_loader:
            output = clf(data.to(c.device))
            loss = criterion(output, target.to(c.device).long())
            
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target.to(c.device).long()).sum().item()
    
    tqdm.write('Epoch: {}\tValidation Accuracy: {:.3f}%'.format(epoch, 100 * correct / total))

    return (loss.item(), correct/total)


train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []   
for epoch in tnrange(150):
    train_loss, train_acc = train()
    val_loss, val_acc = test()
    
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
ax[0,0].plot(train_losses)
ax[0,0].set_title("Training Loss")

ax[0,1].plot(train_accuracies)
ax[0,1].set_title("Training Accuracy")

ax[1,0].plot(val_losses)
ax[1,0].set_title("Validation Loss")

ax[1,1].plot(val_accuracies)
ax[1,1].set_title("Validation Accuracy")