## Testing binary classification with fully-connected neural network

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src import constants as c
from src.model import VAE
from src import visualization as v

In [None]:
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pandas as pd
from sklearn import decomposition, manifold

from tqdm import tqdm, tnrange, tqdm_notebook

import torch
from torchvision import datasets, transforms
import torch.nn as nn

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(c.image_size),
    transforms.CenterCrop(c.image_size),
    transforms.ToTensor()
])

image_datasets = {x: datasets.ImageFolder(os.path.join(c.data_home, 'surgical_data/',x),
                                          data_transforms)
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=c.batch_size,
                                              shuffle=True)
               for x in ['train', 'val']}

## Import VAE to generate encodings

In [None]:
models = {zdim: VAE(image_channels=c.image_channels,
                    image_size=c.image_size, 
                    h_dim1=1024,
                    h_dim2=128,
                    zdim=zdim).to(c.device) for zdim in [10]}

for zdim, model in models.items():
    model.load_state_dict(torch.load(c.data_home + "weights/tools_vae_{}_epoch_50_zdim_{}.torch".format(c.image_size,
                                                                                                        zdim)))

In [None]:
labels = pd.read_csv(os.path.join(c.data_home, 'surgical_data/', 'surgical_labels.csv'))

In [None]:
encoded_inputs = {zdim: [] for zdim in [10]}

with torch.no_grad():
    for zdim in tqdm_notebook(encoded_inputs):
        for index in tnrange(len(image_datasets['train'])):
            data = image_datasets['train'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])
            
        for index in tnrange(len(image_datasets['val'])):
            data = image_datasets['val'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])

In [None]:
dataframes = {zdim: pd.concat([pd.DataFrame(encoded_inputs[zdim]), labels], axis=1) for zdim in [10]}
latent_space = pd.concat(dataframes)

In [None]:
class LatentSpaceClassifier(nn.Module):
    def __init__(self, zdim, hdim1, hdim2):
        super(LatentSpaceClassifier, self).__init__()
        self.fc1 = nn.Linear(zdim, hdim1)
        self.fc2 = nn.Linear(hdim1, hdim2)
        self.fc3 = nn.Linear(hdim2, 2)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x)

In [None]:
clf = LatentSpaceClassifier(10, 100, 50).to(c.device)

In [None]:
optimizer = torch.optim.SGD(clf.parameters(), lr=1e-3, momentum=0.9)
# create a loss function
criterion = nn.NLLLoss()

In [None]:
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        # resize data from (batch_size, 1, 28, 28) to (batch_size, 28*28)
        data = data.view(-1, 28*28)
        optimizer.zero_grad()
        net_out = net(data)
        loss = criterion(net_out, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.data[0]))