# 3. Multi-class Vnet on BV

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from dataset import *
from vnet import *
from training import *
from niiutility import show_image, show_batch_image

%matplotlib inline
%load_ext autoreload
%autoreload 2

## 3.1 Setup Torch Global Variable, load memory map 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 12
NUM_WORKERS = 6
NUM_TRAIN = 84
LEARNING_RATE = 1e-2

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print('using GPU for training')
else:
    device = torch.device('cpu')

using GPU for training


In [3]:
#-------------------------LOAD THE DATA SET-------------------------------------------#

data_index = np.arange(107)
data_index = np.delete(data_index, 46)
dataset_trans = niiDataset(data_index, 
                         transform=transforms.Compose([
                             downSample(4),
                             RandomFilp(0.5),
                             RandomAffine(15, 5)
                         ])
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_trans)
train_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

* Print first 4 batch of data

In [None]:
for i_batch, sample_batched in enumerate(train_loader):
    print(i_batch, sample_batched['image'].size(), \
          sample_batched['label'].size())
    # observe 4th batch and stop.
    if i_batch == 3:
        show_batch_image(sample_batched['image'],sample_batched['label'],BATCH_SIZE)
        break

In [4]:
from vnet_mask import VNetMask
#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

model = VNetMask()
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
train2(model, train_loader, validation_loader, optimizer,\
      device=device, dtype=dtype, epochs=1500, print_every=100)

epoch 0 begins: 
     Iteration 0, loss = 0.8731
     validation loss = 0.6749
epoch 1 begins: 
     Iteration 0, loss = 0.7681
     validation loss = 0.6351
epoch 2 begins: 
     Iteration 0, loss = 0.7760
     validation loss = 0.5999
epoch 3 begins: 
     Iteration 0, loss = 0.7766
     validation loss = 0.6011
epoch 4 begins: 
     Iteration 0, loss = 0.7534
     validation loss = 0.5788
epoch 5 begins: 
     Iteration 0, loss = 0.7171
     validation loss = 0.5910
epoch 6 begins: 
     Iteration 0, loss = 0.6713
     validation loss = 0.5632
epoch 7 begins: 
     Iteration 0, loss = 0.6836
     validation loss = 0.5839
epoch 8 begins: 
     Iteration 0, loss = 0.6595
     validation loss = 0.7085
epoch 9 begins: 
     Iteration 0, loss = 0.6922
     validation loss = 0.5328
epoch 10 begins: 
     Iteration 0, loss = 0.6939
     validation loss = 0.5263
epoch 11 begins: 
     Iteration 0, loss = 0.5734
     validation loss = 0.4728
epoch 12 begins: 
     Iteration 0, loss = 0.6361


In [None]:
#-------------------------SAVE THE MODEL STATE DICT----------------------------------#
PATH = 'Vet_Mask_class_1'
torch.save(model.state_dict(), PATH)