# 3. Multi-class Vnet on BV

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from dataset import *
from vnet import *
from training import *
from niiutility import show_image, show_batch_image

%matplotlib inline
%load_ext autoreload
%autoreload 2

## 3.1 Setup Torch Global Variable, load memory map 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 1
NUM_WORKERS = 6
NUM_TRAIN = 80
LEARNING_RATE = 1e-2

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print('using GPU for training')
else:
    device = torch.device('cpu')

using GPU for training


In [3]:
#-------------------------LOAD THE DATA SET-------------------------------------------#

data_index = np.arange(107)
data_index = np.delete(data_index, 46)
dataset_trans = niiDataset(data_index, 
                         transform=transforms.Compose([
                             downSample(2),
                             RandomFilp(0.5),
                             RandomAffine(15, 10)
                         ])
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_trans)
train_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

* Print first 4 batch of data

In [None]:
for i_batch, sample_batched in enumerate(train_loader):
    print(i_batch, sample_batched['image'].size(), \
          sample_batched['label'].size())
    # observe 4th batch and stop.
    if i_batch == 3:
        show_batch_image(sample_batched['image'],sample_batched['label'],BATCH_SIZE)
        break

In [4]:
from vnet import VNet

#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

LoadCKP = True
CKPPath = 'checkpoint2019-03-31 13:33:50.772063.pth'

model = VNet(classnum=3, slim=False)
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=30, verbose=True)

if LoadCKP:
    model, optimizer, scheduler = loadckp(model, optimizer, scheduler, CKPPath, device=device)

loading checkpoint 'checkpoint2019-03-31 13:33:50.772063.pth'
loaded checkpoint 'checkpoint2019-03-31 13:33:50.772063.pth' (epoch 151)


In [5]:
from loss import *

train(model, train_loader, validation_loader, optimizer, scheduler, \
      device=device, dtype=dtype,lossFun=dice_loss_2, epochs=500)

PATH = 'Vet_currculum_330.pth'
torch.save(model.state_dict(), PATH)


Epoch 0 finished ! Training Loss: 0.0854576003702381
     validation loss = 0.1801
Checkpoint 1 saved !
Epoch 1 finished ! Training Loss: 0.08587881960446321
     validation loss = 0.2020
Epoch 2 finished ! Training Loss: 0.08642978985098344
     validation loss = 0.1816
Epoch 3 finished ! Training Loss: 0.0859228742273548
     validation loss = 0.2073
Epoch 4 finished ! Training Loss: 0.08838204869741126
     validation loss = 0.2114
Epoch 5 finished ! Training Loss: 0.08479313608966296
     validation loss = 0.2070
Epoch 6 finished ! Training Loss: 0.08395553993273384
     validation loss = 0.1767
Epoch 7 finished ! Training Loss: 0.08295910494237006
     validation loss = 0.1944
Epoch 8 finished ! Training Loss: 0.08371647098396398
     validation loss = 0.1723
Epoch 9 finished ! Training Loss: 0.08301997561998005
     validation loss = 0.1814
Epoch 10 finished ! Training Loss: 0.0857431073732014
     validation loss = 0.1775
Epoch 11 finished ! Training Loss: 0.0834873235678371
   

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/xu/anaconda3/envs/cs231/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/home/xu/anaconda3/envs/cs231/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/xu/anaconda3/envs/cs231/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/xu/anaconda3/envs/cs231/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/xu/anaconda3/envs/cs231/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/home/xu/anaconda3/envs/cs231/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._s

KeyboardInterrupt: 

In [None]:
#-------------------------SAVE THE MODEL STATE DICT----------------------------------#
PATH = 'Vet_currculum_330.pth'
torch.save(model.state_dict(), PATH)

## Checking the result

* load model dict from *.pth state dict
* show low res image slice
* save image to file

In [None]:
from training import check_img
PATH = 'Vet_currculum_330.pth'

model = VNet(classnum=3, slim=True)
model.load_state_dict(torch.load(PATH))
model = model.to(device=device)
model.eval()

data_index = np.arange(4)
dataset_test = niiDataset(data_index, 
                         transform=transforms.Compose([
                             downSample(4),
                         ])
                     )

validation_loader = DataLoader(dataset_test, batch_size=1)

from loss import *

check_img(model, validation_loader, device, dtype, cirrculum=2, lossFun=dice_loss_2)