### Initialize notebook

In [None]:
# Torch imports
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
import PIL.Image as IMG
from torch.utils.data.sampler import WeightedRandomSampler

# Other imports
import matplotlib.pyplot as plt
import numpy as np
import os
os.chdir('/home/ak/Spring2018/ature')

from neuralnet.utils.datasets import PatchesGenerator

from utils import img_utils as imgutil
from commons.IMAGE import Image
from neuralnet.trainer import NNTrainer
import neuralnet.utils.measurements as mnt
import neuralnet.utils.data_utils as nndutils

sep = os.sep

%load_ext autoreload
%autoreload 2

# Define folders (create them if needed)
Dirs = {}

Dirs['train_data']= 'data'+sep+'DRIVE'+sep+'training'+sep +'patches'
Dirs['test_data'] = 'data'+sep+'DRIVE'+sep+'test' +sep+ 'images1'
Dirs['checkpoint']   = 'assests' +sep+ 'nnet_models'

Dirs['data']      = 'data'+sep+'DRIVE'+sep+'test'
Dirs['images']    = Dirs['data'] +sep+ 'images2'
Dirs['mask']      = Dirs['data'] +sep+ 'mask'
Dirs['truth']     = Dirs['data'] +sep+ '1st_manual'
Dirs['segmented'] = Dirs['data'] +sep+ 'drive_segmented'

for k, folder in Dirs.items():
    os.makedirs(folder, exist_ok=True)

# Set up execution flags
Flags = {}
Flags['useGPU'] = False


num_classes = 4
batch_size = 100
epochs =1
patch_size = 31

classes = ('white', 'green', 'black', 'red')


def get_mask_file(file_name): 
    return file_name.split('_')[0] + '_test_mask.gif'

def get_ground_truth_file(file_name): 
    return file_name.split('_')[0] + '_manual1.gif'

def get_segmented_file(file_name):
    return file_name + '_SEG.PNG'

### Define the network

In [None]:
class Net(nn.Module):
    def __init__(self, width, channels):
        super(Net, self).__init__()
        
        self.channels = channels
        self.width = width
        
    
        self.kern_size = 5
        self.kern_stride = 2      
        self.kern_padding = 2
        self.mxp_kern_size = 2
        self.mxp_stride = 2 
        self.pool1 = nn.MaxPool2d(kernel_size=self.mxp_kern_size, stride=self.mxp_stride)
        self.conv1 = nn.Conv2d(self.channels, 20, self.kern_size, 
                               stride=self.kern_stride, padding=self.kern_padding)
        self._update_output_size()
        
        
        self.kern_size = 5
        self.kern_stride = 1      
        self.kern_padding = 2
        self.mxp_kern_size = 1
        self.mxp_stride = 1 
        self.pool2 = nn.MaxPool2d(kernel_size=self.mxp_kern_size, stride=self.mxp_stride)
        self.conv2 = nn.Conv2d(20, 50, self.kern_size, 
                               stride=self.kern_stride, padding=self.kern_padding)
        self._update_output_size()
        
        
        self.kern_size = 5
        self.kern_stride = 1      
        self.kern_padding = 1
        self.mxp_kern_size = 1
        self.mxp_stride = 1 
        self.pool3 = nn.MaxPool2d(kernel_size=self.mxp_kern_size, stride=self.mxp_stride)
        self.conv3 = nn.Conv2d(50, 50, self.kern_size, 
                               stride=self.kern_stride, padding=self.kern_padding)
        self._update_output_size()
        
        
        
        self.linearWidth = 50*int(self.width)*int(self.width)
        self.fc1 = nn.Linear(self.linearWidth, 100)
        self.fc2 = nn.Linear(100, 20)
        self.fc3 = nn.Linear(20, 4)
        self.sm = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = x.view(-1, self.linearWidth)
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def _update_output_size(self):       
        temp = self.width
        self.width = ((self.width - self.kern_size + 2 * self.kern_padding) / self.kern_stride) + 1
        temp1 = self.width
        self.width = ((self.width - self.mxp_kern_size)/self.mxp_stride) + 1
        print('output width { ' + str(temp) + ' -conv-> ' + str(temp1) + ' -maxpool-> ' + str(self.width) + ' }')

width = 31
channels = 1
net = Net(width, channels)
optimizer = optim.Adam(net.parameters(), lr=0.0005)

### Load train data

In [None]:
transform = transforms.Compose([
        imgutil.whiten_image2d,
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(40),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor()
    ])

trainset = PatchesGenerator(Dirs=Dirs, patch_size=patch_size, num_classes=num_classes, transform=transform,
                 fget_mask=get_mask_file, fget_truth=get_ground_truth_file, fget_segmented=get_segmented_file) 

### Fix skewed classes by sampling based on class weights
_, ccounts_train = np.unique(trainset.labels, return_counts=True)
cweights_train = 1.0/ccounts_train
dweights_train = np.array([ccounts_train[t] for t in trainset.labels])
dweights_train = np.ones_like(trainset.labels)
dmin =  np.partition(ccounts_train, 1)[1]

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 
                                          shuffle=False, num_workers=3, 
                                          sampler=WeightedRandomSampler(dweights_train, int(2*dmin)))

### Load test dataset

In [None]:
transform = transforms.Compose([
        imgutil.whiten_image2d,
        transforms.ToPILImage(),
        transforms.ToTensor()
    ])
Dirs['images'] = Dirs['test_data']
testset = PatchesGenerator(Dirs=Dirs, patch_size=patch_size, num_classes=num_classes, transform=transform,
                 fget_mask=get_mask_file, fget_truth=get_ground_truth_file, fget_segmented=get_segmented_file) 

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 
                                          shuffle=False, num_workers=3)

### Train and evaluate the Network

In [None]:
trainer = NNTrainer(model=net, checkpoint_dir=Dirs['checkpoint'], checkpoint_file='checkpoint4Way.nn.tar')
# trainer.resume_from_checkpoint()
trainer.train(optimizer=optimizer, dataloader=trainloader, epochs=epochs, use_gpu=Flags['useGPU'])
# trainer.resume_from_checkpoint()
acc, y_pred, y_true = trainer.test(dataloader=testloader, use_gpu=Flags['useGPU'], force_checkpoint=True)

In [None]:
acc, y_pred, y_true = trainer.test(dataloader=testloader, use_gpu=Flags['useGPU'], force_checkpoint=True)
mnt.plot_confusion_matrix(y_pred=y_pred, y_true=y_true, classes=classes)

### Check per-class performance

In [None]:
class_correct = list(0. for i in range(num_classes))
class_total = list(0. for i in range(num_classes))
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    c = (predicted == labels).squeeze()
    for i in range(4):
        label = labels[i]
        class_correct[label] += c[i]
        class_total[label] += 1
for i in range(num_classes):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

### Convolve throughout the image to generate segmented image based on trained Network

In [None]:
transform = transforms.Compose([
        imgutil.whiten_image2d,
        transforms.ToPILImage(),
        transforms.ToTensor()
    ])

segset = PatchesGenerator(Dirs=Dirs, patch_size=patch_size, 
                             num_classes=num_classes, transform=transform,
                             fget_mask=get_mask_file, fget_truth=get_ground_truth_file, 
                             fget_segmented=get_segmented_file, segment_mode=True) 

segloader = torch.utils.data.DataLoader(segset, batch_size=1, 
                                          shuffle=False, num_workers=1)

In [None]:
seg = np.zeros_like(segset.images[0].working_arr)
for ID, image, label in segloader:
#     IMG.fromarray(np.array(image*255, np.uint8).reshape(31,31)).show()
#     break
    print(ID)
    break
    outputs = trainer.model(Variable(image))
    _, predicted = torch.max(outputs.data, 1)
    if predicted[0]==3 or predicted[0] == 0:
        seg[i, j] = 255

In [None]:
IMG.fromarray(seg)