<a href="https://colab.research.google.com/github/wylhtydtm/Nematode-project/blob/master/vgg16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import tables
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [None]:
!pip install livelossplot --quiet
from livelossplot import PlotLosses

In [None]:
def shift_and_normalize(data):  #Preprocessing step 
    data_m = data.view(np.ma.MaskedArray)
    data_m.mask = data==0
    if data.ndim == 3:
        sub_d = np.percentile(data, 95, axis=(1,2)) #let's use the 95th as the value of the background
        data_m -= sub_d[:, None, None]
    else:
        sub_d = np.percentile(data, 95)
        data_m -= sub_d
        
    data /= 255
    return data

In [None]:
def img_rescale(img, for_pillow=False):
    """
    Rescale the image between 0 and 1, make it 3D if it was just 2D.
    Unlike prep_for_pytorch, no need to make it 4d (batches) because
    the images will be loaded through the dataloader, and that will already
    create the 4d batches.
    In tierpsy, I manually make a N_images x channels x width x height batches,
    and the input to prep_for_pytorch is n_images x w x h (because grayscale)
    While here the ndim==3 refers to channels...
    If you don't use the dataloader, you'll still need to add one dimension
    in the appropriate place"""
    assert img.ndim==2, 'img_rescale only works with 2d array for now'
    img = img - img.min()
    img = img / img.max()
    if for_pillow:
        img *= 255
        img = img.astype(np.uint8)
    else:
        img = img.astype(np.float32)[None, :, :] # c,w,h
    return img

In [None]:
class new_dataset(Dataset):

    def __init__(self, hdf5_filename, which_set='train', transform=None):

        self.fname = hdf5_filename
        self.set_name = which_set
        # get labels info
        with tables.File(self.fname, 'r') as fid:
            tmp = pd.DataFrame.from_records(
                fid.get_node('/'+self.set_name)['sample_data'].read())
        self.label_info = tmp[['img_row_id', 'is_worm', 'is_avelinos']]
        # size in hdf5 file is 160x160 (in theory), but we train on 80x80
        self.roi_size = 80  # size we want to train on
        with tables.File(self.fname, 'r') as fid:
            dataset_size = fid.get_node('/train/mask').shape[1]
        pad = (dataset_size - self.roi_size)/2
        self.dd = [pad, dataset_size-pad]
        # any transform?
        self.transform = transform

    def __len__(self):
        return len(self.label_info)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        # I could just use index because img_row_id is the same as the index of
        # label_info, but just in case we ever want to shuffle...
        label_info = self.label_info.iloc[index]
        img_row_id = label_info['img_row_id']
        # read images from disk
        # I could just use index because img_row_id is the same as the index of
        # label_info, but just in case we ever want to shuffle...
        label_info = self.label_info.iloc[index]
        img_row_id = label_info['img_row_id']
        # read images from disk
        with tables.File(self.fname, 'r') as fid:
          roi_data = fid.get_node(
                '/' + self.set_name + '/mask')[img_row_id,
                                               self.dd[0]:self.dd[1],
                                               self.dd[0]:self.dd[1]].copy()

        # shift_and_normalize wants a float, and pytorch a single, use single
        img = roi_data.astype(np.float32)
        img = shift_and_normalize(img)

        # as of now, the model works even without PIL
        # but transform only works with pil, so:
        if self.transform:  # if any transforms were given to initialiser
            img = img_rescale(img, for_pillow=True)
            img = Image.fromarray(img)
            img = img.convert(mode='RGB')
            img = self.transform(img)
      
        else:
            img = img_rescale(img, for_pillow=False)

        # read labels too
        labels = label_info['is_worm']
        labels = np.array(labels, dtype=np.float32).reshape(-1, 1)
        labels = torch.from_numpy(labels)

        return img, labels

In [None]:
# where are things?
hd = Path('/content/drive/My Drive')
fname = hd / 'Hydra_Phenix_dataset.hdf5'

# parameters
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 128

# define transforms
# do we need vertical/hor flip?
training_transform = transforms.Compose([transforms.RandomVerticalFlip(p=0.4),
                                         transforms.RandomHorizontalFlip(p=0.4),
                                         transforms.ColorJitter(contrast=0.2, hue=0.2),
                                         transforms.ToTensor()])

validation_transform = transforms.Compose([transforms.RandomVerticalFlip(p=0.4),
                                           transforms.RandomHorizontalFlip(p=0.4),
                                           transforms.ColorJitter(contrast=0.2, hue=0.2),
                                           transforms.ToTensor()])
    
test_transform = transforms.ToTensor()


    # create datasets
train_data = new_dataset(fname, which_set='train',transform=training_transform)
val_data = new_dataset(fname, which_set='val',transform=validation_transform)
test_data = new_dataset(fname, which_set='test',transform=test_transform)


    # create dataloaders
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, num_workers=4)
val_loader = DataLoader(val_data, shuffle=True, batch_size=batch_size, num_workers=4)
test_loader= DataLoader(test_data,shuffle=True, batch_size=batch_size, num_workers=4)

In [None]:
dataloaders = {
    "train": train_loader,
    "validation": val_loader
}
dataset_sizes = {'train':len(train_loader.dataset), 'validation':len(val_loader.dataset)}

In [None]:
# Checking whether the input image has the right channel
img = train_data[0][0]
img =img.unsqueeze(0)
print(img.size())

torch.Size([1, 3, 80, 80])


In [None]:
model_vgg = torchvision.models.vgg16(pretrained=True)
for param in model_vgg.parameters():
    param.requires_grad = False  # freeze all network except the final layer in model_conv ResNeT18 for training
        
num_ftrs = model_vgg.classifier[6].in_features # to discover parameters of newly constructed modules 
model_vgg.classifier[6] = nn.Linear(num_ftrs, 2) #the final layer in our model, 2 classes only
    
model_vgg = model_vgg.to(device)    
learning_rate = 0.0001
num_epoch = 50  # gradient descent that controls no of complete passes through the training dataset
    
criterion = torch.nn.CrossEntropyLoss()
optimiser_vgg = torch.optim.Adam(model_vgg.parameters(),lr=learning_rate)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))




In [None]:
import time
import copy

def train_model(model, criterion, optimiser, num_epoch):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    Liveloss= PlotLosses()   
    #Iterate through epochs
    for epoch in range(num_epoch):
        logs = {}
        print('Epoch{}/{}'.format(epoch, num_epoch -1))
        print('-' * 15)
  
        #Each epoch has a training and validation phase        
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
              
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
        
            for ii, (inputs, labels) in enumerate (dataloaders[phase]):
                inputs,labels = inputs.to(device), labels.to(device)
                labels = labels.view(-1)
                labels = labels.long()
                optimiser.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, pred = torch.max(outputs, dim= 1)      
                    loss = criterion(outputs, labels)
                                   
                  #backward and optimze only in the training pahse
                    if phase == 'train':
                        loss.backward()  #Loss and backpropagation
                        optimiser.step()

                _, pred = torch.max(outputs, dim= 1)              
                running_loss += loss.detach()  * inputs.size(0) 
                running_corrects += torch.sum(pred == labels.data)


             # calculate average losses fo the entire epoch
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.float() / dataset_sizes[phase]

            prefix = ''
            if phase == 'validation':
                prefix = 'val_'
                
            logs[prefix + ' loss'] = epoch_loss
            logs[prefix + 'accuracy'] = epoch_acc
                                               
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
                 
            #Ddeep copy the model
            if phase == 'validation' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        Liveloss.update(logs)
        Liveloss.send()

        print()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s. Saving model...'.format(time_elapsed //60, time_elapsed % 60))
    print('Best Val Acc: {.4f}'.format(best_acc)) 
    model.load_state_dic(best_model_wts)
    return model  

In [None]:
model_vgg=train_model(model_vgg, criterion, optimiser_vgg, num_epoch)

In [None]:
PATH= '/content/drive/My Drive/vgg16_epoch20_sgd.pth'
torch.save(model_vgg.state_dict(), PATH)

In [None]:
device = torch.device('cpu')
model_vgg = model_vgg.to(device)   
model_vgg.load_state_dict(torch.load('drive/My Drive/vgg16_epoch20_adam.pth',map_location=device))

<All keys matched successfully>

In [None]:
model_vgg.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [None]:
from torchsummary import summary
summary(model_vgg,(3, 80, 80))

In [None]:
def measure_performance(predictions, labels):
    """
    I think there's scikit learn functions for this
    but found out after writing the function
    """
    # go logical for ease
    predictions = predictions.astype(bool)
    labels = labels.astype(bool)
    # true positives
    tp = np.logical_and(predictions, labels).sum()
    # true negatives
    tn = np.logical_and(~predictions, ~labels).sum()
    # false positives
    fp = np.logical_and(predictions, ~labels).sum()
    # false negatives
    fn = np.logical_and(~predictions, labels).sum()
    # accuracy
    accuracy = (tp + tn) / len(predictions)
    print(f"accuracy = {accuracy}")
    # precision
    precision = tp / (tp + fp)
    print(f"precision = {precision}")
    # recall
    recall = tp / (tp + fn)
    print(f"recall = {recall}")
    # F1
    f1 = 2*tp / (2*tp + fp + fn)
    print(f"F1 score = {f1}")
    return

In [None]:
labels = []
predictions = []
with torch.no_grad():
  for images, labs in test_loader:
    images = images.to(device)
    preds = model_vgg(images)
    preds = torch.argmax(preds, axis=1)
    predictions.append(preds)
    labels.append(labs)

  predictions = np.concatenate(predictions, axis=0)
  labels = np.concatenate(labels, axis=0).squeeze()

print("\nPerformance ")
measure_performance(predictions, labels)

In [None]:
model_vgg=train_model(model_vgg, criterion, optimiser_vgg, exp_lr_scheduler,num_epoch)