In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
import datetime

from sync_batchnorm import convert_model

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
def init_weights(m):
    if isinstance(m, nn.Conv3d):
        torch.nn.init.kaiming_normal_(m.weight.data)
    elif isinstance(m, nn.BatchNorm3d):
        torch.nn.init.normal_(m.weight.data, mean=0, std=1)

def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    scores = model(x)
    for i in scores:
        print(i.size())

In [4]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
# loaders come with auto batch division and multi-thread acceleration

In [None]:
"""
test_dictionary = train_dataset[33]

image_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256)
label_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256)

image_1 = image_1.to(device=device, dtype=dtype) 
label_1 = label_1.to(device=device, dtype=dtype)
"""

from ssn3d import Model

m = Model()
m = nn.DataParallel(m)
m.apply(init_weights)
m = m.to(device, dtype)

import torch.optim as optim

optimizer = optim.Adam(m.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [None]:
epochs = 5000

record = open('train_ssn_dropout.txt','a+')

logger = {'train':[], 'validation_1':[]}

for e in tqdm(range(1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        m.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        
        out_1 = m(image_1)
        # do the inference

        loss = dice_loss_3(out_1, label_1)
        # add loss
        
        epoch_loss += loss.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss.backward()
        # calculate the gradient wrt loss
        
        optimizer.step()
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        m.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
            
                #image_4_val = vbatch['image4_data'].to(device=device, dtype=dtype)
                #if get_dimensions(image_4_val) == 4:
                #    image_4_val.unsqueeze_(0)
                #label_4_val = vbatch['image4_label'].to(device=device, dtype=dtype)
                #if get_dimensions(label_4_val) == 4:
                #    label_4_val.unsqueeze_(0)
                
                #image_2_val = vbatch['image2_data'].to(device=device, dtype=dtype)
                #if get_dimensions(image_2_val) == 4:
                #    image_2_val.unsqueeze_(0)
                #label_2_val = vbatch['image2_label'].to(device=device, dtype=dtype)
                #if get_dimensions(label_2_val) == 4:
                #    label_2_val.unsqueeze_(0)
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                
                #print("image_1_val:", image_1_val.shape)
                #print("label_1_val:", label_1_val.shape)
                # move data to device, convert dtype to desirable dtype
                # add dimension to labels if they are 4D tensors
                
                # Downsample labels to coincide with icnet model outputs
                #label_1_val_resize_2 = downsample_label(label_1_val, 1/2) 
                #label_2_val_resize_2 = downsample_label(label_2_val, 1/2) 
                #label_4_val_resize_2 = downsample_label(label_4_val, 1/2) 
                #print("label_1_val_resize:", label_1_val_resize_2.shape) 
                
                out_1_val = m(image_1_val)
                # do the inference
                
                #print("out_4:", out_4_val.shape)
                #print("label_4:", label_4_val_resize_2.shape)
                #loss_4 = dice_loss_3(out_4_val, label_4_val_resize_2)
                #loss_2 = dice_loss_3(out_2_val, label_2_val_resize_2)
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                #valloss_4 += loss_4.item()
                #valloss_2 += loss_2.item()
                valloss_1 += loss_1.item()
                # record mini batch loss
            
            outstr = '------- 1st valloss={0:.4f}'\
                .format(valloss_1/(v+1)) + '\n'
            
            logger['validation_1'].append(valloss_1/(v+1))
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            save_1('ssn_dropout_save', m, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [07:30<624:52:38, 450.09s/it]

Epoch 1 finished ! Training Loss: 0.5395



  0%|          | 2/4999 [13:22<584:00:37, 420.74s/it]

Epoch 2 finished ! Training Loss: 0.4966



  0%|          | 3/4999 [19:11<553:54:35, 399.13s/it]

Epoch 3 finished ! Training Loss: 0.4658



  0%|          | 4/4999 [25:03<534:13:20, 385.03s/it]

Epoch 4 finished ! Training Loss: 0.4573

Epoch 5 finished ! Training Loss: 0.4321



  0%|          | 5/4999 [31:37<537:52:29, 387.74s/it]

------- 1st valloss=0.5283

Checkpoint 5 saved !


  0%|          | 6/4999 [37:25<521:19:45, 375.88s/it]

Epoch 6 finished ! Training Loss: 0.4193



  0%|          | 7/4999 [43:13<509:45:11, 367.61s/it]

Epoch 7 finished ! Training Loss: 0.4128



  0%|          | 8/4999 [49:08<504:21:36, 363.79s/it]

Epoch 8 finished ! Training Loss: 0.3987



  0%|          | 9/4999 [54:58<498:34:28, 359.69s/it]

Epoch 9 finished ! Training Loss: 0.3869

Epoch 10 finished ! Training Loss: 0.3671



  0%|          | 10/4999 [1:01:37<514:49:07, 371.49s/it]

------- 1st valloss=0.5516

Checkpoint 10 saved !


  0%|          | 11/4999 [1:07:29<506:22:42, 365.47s/it]

Epoch 11 finished ! Training Loss: 0.3636



  0%|          | 12/4999 [1:13:20<500:27:39, 361.27s/it]

Epoch 12 finished ! Training Loss: 0.3537



  0%|          | 13/4999 [1:19:11<496:00:25, 358.13s/it]

Epoch 13 finished ! Training Loss: 0.3378



  0%|          | 14/4999 [1:25:06<494:49:10, 357.34s/it]

Epoch 14 finished ! Training Loss: 0.3275

Epoch 15 finished ! Training Loss: 0.3235



  0%|          | 15/4999 [1:31:46<512:11:47, 369.97s/it]

------- 1st valloss=0.6755

Checkpoint 15 saved !


  0%|          | 16/4999 [1:37:36<503:59:31, 364.11s/it]

Epoch 16 finished ! Training Loss: 0.3031



  0%|          | 17/4999 [1:43:23<496:49:20, 359.00s/it]

Epoch 17 finished ! Training Loss: 0.3039



  0%|          | 18/4999 [1:49:20<495:46:35, 358.32s/it]

Epoch 18 finished ! Training Loss: 0.2717



  0%|          | 19/4999 [1:55:11<492:38:29, 356.13s/it]

Epoch 19 finished ! Training Loss: 0.2696

Epoch 20 finished ! Training Loss: 0.2519



  0%|          | 20/4999 [2:01:50<510:13:39, 368.91s/it]

------- 1st valloss=0.4225

Checkpoint 20 saved !


  0%|          | 21/4999 [2:07:37<501:07:04, 362.40s/it]

Epoch 21 finished ! Training Loss: 0.2439



  0%|          | 22/4999 [2:13:27<495:55:17, 358.71s/it]

Epoch 22 finished ! Training Loss: 0.2345



  0%|          | 23/4999 [2:19:23<494:40:13, 357.88s/it]

Epoch 23 finished ! Training Loss: 0.2223



  0%|          | 24/4999 [2:25:16<492:26:41, 356.34s/it]

Epoch 24 finished ! Training Loss: 0.2267

Epoch 25 finished ! Training Loss: 0.2228



  1%|          | 25/4999 [2:31:56<510:23:18, 369.40s/it]

------- 1st valloss=0.3183

Checkpoint 25 saved !


  1%|          | 26/4999 [2:37:54<505:38:09, 366.03s/it]

Epoch 26 finished ! Training Loss: 0.2132



  1%|          | 27/4999 [2:43:53<502:34:37, 363.89s/it]

Epoch 27 finished ! Training Loss: 0.2164



  1%|          | 28/4999 [2:49:47<498:20:09, 360.90s/it]

Epoch 28 finished ! Training Loss: 0.2069



  1%|          | 29/4999 [2:55:41<495:36:49, 359.00s/it]

Epoch 29 finished ! Training Loss: 0.1993

Epoch 30 finished ! Training Loss: 0.1985



  1%|          | 30/4999 [3:02:14<509:23:42, 369.05s/it]

------- 1st valloss=0.3221

Checkpoint 30 saved !


  1%|          | 31/4999 [3:08:03<500:59:15, 363.03s/it]

Epoch 31 finished ! Training Loss: 0.2005



  1%|          | 32/4999 [3:13:56<496:46:28, 360.05s/it]

Epoch 32 finished ! Training Loss: 0.1971



  1%|          | 33/4999 [3:19:48<493:16:41, 357.59s/it]

Epoch 33 finished ! Training Loss: 0.1936



  1%|          | 34/4999 [3:25:43<492:14:56, 356.92s/it]

Epoch 34 finished ! Training Loss: 0.1858

