# 3. Multi-class Vnet on BV

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from dataset import *
from vnet import *
from training import *
from niiutility import show_image, show_batch_image

%matplotlib inline
%load_ext autoreload
%autoreload 2

## 3.1 Setup Torch Global Variable, load memory map 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 8
NUM_WORKERS = 6
NUM_TRAIN = 80
LEARNING_RATE = 1e-2

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print('using GPU for training')
else:
    device = torch.device('cpu')

using GPU for training


In [3]:
#-------------------------LOAD THE DATA SET-------------------------------------------#
regen = False

if regen:

    data_index = np.arange(107)
    data_index = np.delete(data_index, 46)
    data_idnex = np.random.shuffle(data_index)
else:
    data_index = np.array ([50,17,81,39,36,88,33,77,7,1,52,43,34,40,41,18,72,58,51,
                  63,78,35,16,79,0,89,70,67,60,13,76,8,2,47,4,97,29,85,32,
                  55,30,49,44,11,101,22,37,10,92,68,5,64,105,95,20,38,99,
                  84,86,91,96,71,98,104,45,69,103,27,19,59,73,106,93,24,80,
                  66,28,90,3,102,31,26,94,62,54,48,12,61,87,42,65,74,53,57,
                  14,56,83,100,25,6,75,82,23,9,21,15])
    
dataset_trans = BvMaskDataset(data_index, 
                         transform=transforms.Compose([
                             downSample(2),
                             RandomFilp(0.5),
                             RandomAffine(180, 15)
                         ])
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_trans)
train_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

* Print first 4 batch of data

In [4]:
from vnet import LNet

LoadCKP = True

CKPPath = 'checkpoint2019-04-05 19:46:58.793496.pth'

model = LNet(img_size=(96, 128, 128), out_size=6)
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=50, verbose=True)

if LoadCKP:
    model, optimizer, scheduler = loadckp(model, optimizer, scheduler, CKPPath, device=device)

loading checkpoint 'checkpoint2019-04-05 19:46:58.793496.pth'
loaded checkpoint 'checkpoint2019-04-05 19:46:58.793496.pth' (epoch 701)


In [None]:
#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

from loss import *

train(model, train_loader, validation_loader, optimizer, scheduler,\
      device=device, dtype=dtype, lossFun=MSE, epochs=1500, startepoch=701)

Epoch 701 finished ! Training Loss: 207.0958251953125
     validation loss = 222.0139
Epoch 702 finished ! Training Loss: 200.1828621758355
     validation loss = 93.8115
Epoch 703 finished ! Training Loss: 171.88488939073352
     validation loss = 102.4838
Epoch 704 finished ! Training Loss: 171.37626139322916
     validation loss = 105.9510
Epoch 705 finished ! Training Loss: 189.24322424994574
     validation loss = 118.0190
Epoch 706 finished ! Training Loss: 174.1217778523763
     validation loss = 88.8087
Epoch 707 finished ! Training Loss: 174.5380342271593
     validation loss = 102.0236
Epoch 708 finished ! Training Loss: 184.65301767985025
     validation loss = 190.2569
Epoch 709 finished ! Training Loss: 169.331295437283
     validation loss = 145.6956
Epoch 710 finished ! Training Loss: 197.8787138197157
     validation loss = 120.7738
Epoch 711 finished ! Training Loss: 161.75042639838324
     validation loss = 81.5241
Epoch 712 finished ! Training Loss: 160.7870763142903

     validation loss = 76.2876
Epoch 797 finished ! Training Loss: 143.56932152642145
     validation loss = 51.6047
Epoch 798 finished ! Training Loss: 136.53563944498697
     validation loss = 58.4983
Epoch 799 finished ! Training Loss: 125.8340326944987
     validation loss = 55.7549
Epoch 800 finished ! Training Loss: 133.67003970675998
     validation loss = 93.9013
Epoch 801 finished ! Training Loss: 148.07248009575738
     validation loss = 128.8094
Epoch 802 finished ! Training Loss: 152.47383202446832
     validation loss = 65.3616
Epoch 803 finished ! Training Loss: 122.12136925591363
     validation loss = 79.7180
Epoch 804 finished ! Training Loss: 143.43042161729602
     validation loss = 73.7016
Epoch 805 finished ! Training Loss: 137.01218583848743
     validation loss = 68.7487
Epoch 806 finished ! Training Loss: 122.6809344821506
     validation loss = 74.7633
Epoch 807 finished ! Training Loss: 110.83785883585612
     validation loss = 41.9685
Epoch 808 finished ! Tra

Epoch 892 finished ! Training Loss: 126.1908425225152
     validation loss = 57.0550
Epoch 893 finished ! Training Loss: 126.65455500284831
     validation loss = 57.6058
Epoch 894 finished ! Training Loss: 134.50270165337457
     validation loss = 60.6133
Epoch 895 finished ! Training Loss: 121.28860855102539
     validation loss = 57.0276
Epoch 896 finished ! Training Loss: 133.06022135416666
     validation loss = 58.6913
Epoch 897 finished ! Training Loss: 107.89890882703993
     validation loss = 84.3360
Epoch 898 finished ! Training Loss: 113.38502036200629
     validation loss = 63.5748
Epoch 899 finished ! Training Loss: 115.74192640516493
     validation loss = 63.2825
Epoch 900 finished ! Training Loss: 111.2618637084961
     validation loss = 72.0284
Checkpoint 901 saved !
Epoch 901 finished ! Training Loss: 90.96461232503255
     validation loss = 51.0269
Epoch 902 finished ! Training Loss: 131.134157816569
     validation loss = 74.0868
Epoch 903 finished ! Training Loss: 

In [None]:
#-------------------------SAVE THE MODEL STATE DICT----------------------------------#
PATH = 'LNET-404.pth'
torch.save(model.state_dict(), PATH)