In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import get_dimensions
from loss import *
from train import *
from sync_batchnorm import convert_model
from cascade_fcn import UNet3D
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 1

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# loaders come with auto batch division and multi-thread acceleration

In [4]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Conv3d):
        init.kaiming_normal_(m.weight.data)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=0, std=1)

def downsample_label(label, scale_factor):
    return F.interpolate(label, scale_factor=scale_factor, mode='trilinear', align_corners=True)

In [None]:
def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    model = model.to(device=device, dtype=dtype) if cuda_bool else model
    scores = model(x)
    for i in scores:
        print(i.size())

In [None]:
#test_dictionary = train_dataset[33]

#image_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256).to(device=device, dtype=dtype)
#label_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256).to(device=device, dtype=dtype)

#image_2 = test_dictionary['image2_data'].view(1, 1, 128, 128, 128).to(device=device, dtype=dtype)
#label_2 = test_dictionary['image2_label'].view(1, 3, 128, 128, 128).to(device=device, dtype=dtype)

#label_2_resize_2 = downsample_label(label_2, 1/2).to(device=device, dtype=dtype)

m = UNet3D(1, 3, final_sigmoid = False)
m = nn.DataParallel(m)
m = convert_model(m)
m = m.to(device, dtype)
optimizer = optim.Adam(m.parameters(), lr=1e-2, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
epoch = 0

"""
checkpoint = torch.load('../3d_res_save/2019-07-31 20:01:12.105267.pth')
m.load_state_dict(checkpoint['state_dict_1'])
#m.apply(init_weights)
m = m.to(device=device, dtype=dtype)

optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
print(epoch)
"""
#shape_test(m, True)

"\ncheckpoint = torch.load('../3d_res_save/2019-07-31 20:01:12.105267.pth')\nm.load_state_dict(checkpoint['state_dict_1'])\n#m.apply(init_weights)\nm = m.to(device=device, dtype=dtype)\n\noptimizer.load_state_dict(checkpoint['optimizer'])\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)\nscheduler.load_state_dict(checkpoint['scheduler'])\nepoch = checkpoint['epoch']\nprint(epoch)\n"

In [None]:
epochs = 5000

record = open('train_3d_unet_dropout+weight_decay.txt','w+')

logger = {'train':[], 'validation_1':[]}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
        
        # iter over the train mini batches
    
        m.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        #image_2 = batch['image2_data'].to(device=device, dtype=dtype)
        #label_2 = batch['image2_label'].to(device=device, dtype=dtype)
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        
        # move data to device, convert dtype to desirable dtype
        
        # Downsample labels to coincide with icnet model outputs
        #label_1_resize_2 = downsample_label(label_1, 1/2)
        #label_2_resize_2 = downsample_label(label_2, 1/2)
        #label_4_resize_2 = downsample_label(label_4, 1/2)
        
        out_1 = m(image_1)
       
        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        
        optimizer.step()
        #scheduler.step(loss)
        # take a gradient descent step
        torch.cuda.empty_cache()
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        m.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_2 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
            
                #image_4_val = vbatch['image4_data'].to(device=device, dtype=dtype)
                #if get_dimensions(image_4_val) == 4:
                #    image_4_val.unsqueeze_(0)
                #label_4_val = vbatch['image4_label'].to(device=device, dtype=dtype)
                #if get_dimensions(label_4_val) == 4:
                #    label_4_val.unsqueeze_(0)
                
                #image_2_val = vbatch['image2_data'].to(device=device, dtype=dtype)
                #if get_dimensions(image_2_val) == 4:
                #    image_2_val.unsqueeze_(0)
                #label_2_val = vbatch['image2_label'].to(device=device, dtype=dtype)
                #if get_dimensions(label_2_val) == 4:
                #    label_2_val.unsqueeze_(0)
                
                image_2_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_2_val) == 4:
                    image_2_val.unsqueeze_(0)
                label_2_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_2_val) == 4:
                    label_2_val.unsqueeze_(0)
                
                #print("image_1_val:", image_1_val.shape)
                #print("label_1_val:", label_1_val.shape)
                # move data to device, convert dtype to desirable dtype
                # add dimension to labels if they are 4D tensors
                
                # Downsample labels to coincide with icnet model outputs
                #label_1_val_resize_2 = downsample_label(label_1_val, 1/2) 
                #label_2_val_resize_2 = downsample_label(label_2_val, 1/2) 
                #label_4_val_resize_2 = downsample_label(label_4_val, 1/2) 
                #print("label_1_val_resize:", label_1_val_resize_2.shape) 
                
                out_2_val = m(image_2_val)
                # do the inference
                
                #print("out_4:", out_4_val.shape)
                #print("label_4:", label_4_val_resize_2.shape)
                #loss_4 = dice_loss_3(out_4_val, label_4_val_resize_2)
                #loss_2 = dice_loss_3(out_2_val, label_2_val_resize_2)
                loss_2 = dice_loss_3(out_2_val, label_2_val)
                # calculate loss

                #valloss_4 += loss_4.item()
                #valloss_2 += loss_2.item()
                valloss_2 += loss_2.item()
                # record mini batch loss
            
            outstr = '------- 1st valloss={0:.4f}'\
                .format(valloss_2/(v+1)) + '\n'
            
            logger['validation_1'].append(valloss_2/(v+1))
            
            scheduler.step(valloss_2/(v+1))
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            #scheduler.step(valloss_2/(v+1))
            
            save_1('3d_unet_dropout_weight_decay_save', m, optimizer, logger, e, scheduler)
            torch.cuda.empty_cache()

record.close()

  0%|          | 1/4999 [08:23<699:37:33, 503.93s/it]

Epoch 1 finished ! Training Loss: 0.5331



  0%|          | 2/4999 [16:02<680:33:55, 490.30s/it]

Epoch 2 finished ! Training Loss: 0.4722



  0%|          | 3/4999 [23:40<667:07:45, 480.72s/it]

Epoch 3 finished ! Training Loss: 0.4695



  0%|          | 4/4999 [31:16<656:35:47, 473.22s/it]

Epoch 4 finished ! Training Loss: 0.4641

Epoch 5 finished ! Training Loss: 0.4737

------- 1st valloss=0.4818

Checkpoint 5 saved !


  0%|          | 6/4999 [47:21<659:54:31, 475.80s/it]

Epoch 6 finished ! Training Loss: 0.4680



  0%|          | 7/4999 [54:59<652:15:03, 470.37s/it]

Epoch 7 finished ! Training Loss: 0.4636



  0%|          | 8/4999 [1:02:36<646:34:07, 466.37s/it]

Epoch 8 finished ! Training Loss: 0.4648



  0%|          | 9/4999 [1:10:15<643:14:12, 464.06s/it]

Epoch 9 finished ! Training Loss: 0.4647

Epoch 10 finished ! Training Loss: 0.4648

------- 1st valloss=0.5138

Checkpoint 10 saved !


  0%|          | 11/4999 [1:26:10<649:17:05, 468.61s/it]

Epoch 11 finished ! Training Loss: 0.4614



  0%|          | 12/4999 [1:33:50<645:30:23, 465.98s/it]

Epoch 12 finished ! Training Loss: 0.4762



  0%|          | 13/4999 [1:41:24<640:40:47, 462.58s/it]

Epoch 13 finished ! Training Loss: 0.4684



  0%|          | 14/4999 [1:49:00<637:32:22, 460.41s/it]

Epoch 14 finished ! Training Loss: 0.4657

Epoch 15 finished ! Training Loss: 0.4698

------- 1st valloss=0.6084

Checkpoint 15 saved !


  0%|          | 16/4999 [2:04:54<646:40:51, 467.20s/it]

Epoch 16 finished ! Training Loss: 0.4632



  0%|          | 17/4999 [2:12:32<642:47:40, 464.48s/it]

Epoch 17 finished ! Training Loss: 0.4689



  0%|          | 18/4999 [2:20:06<638:16:40, 461.31s/it]

Epoch 18 finished ! Training Loss: 0.4657



  0%|          | 19/4999 [2:27:47<637:54:08, 461.13s/it]

Epoch 19 finished ! Training Loss: 0.4647

Epoch 20 finished ! Training Loss: 0.4704

------- 1st valloss=0.4684

Checkpoint 20 saved !


  0%|          | 21/4999 [2:43:45<647:37:48, 468.35s/it]

Epoch 21 finished ! Training Loss: 0.4685



  0%|          | 22/4999 [2:51:22<642:37:39, 464.83s/it]

Epoch 22 finished ! Training Loss: 0.4712



  0%|          | 23/4999 [2:59:02<640:34:38, 463.44s/it]

Epoch 23 finished ! Training Loss: 0.4649



  0%|          | 24/4999 [3:06:37<637:11:07, 461.08s/it]

Epoch 24 finished ! Training Loss: 0.4617

Epoch 25 finished ! Training Loss: 0.4787

------- 1st valloss=0.5668

Checkpoint 25 saved !


  1%|          | 26/4999 [3:22:40<648:29:02, 469.44s/it]

Epoch 26 finished ! Training Loss: 0.4692



In [None]:
m.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = m(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)

In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_multi_model.txt','w+')

for e in tqdm(range(epochs)):
    #torch.cuda.empty_cache()
    
    out_1 = m(image_1)
    #out_1 = m(image_1)
    #loss_1 = dice_loss_3(out_1, label_1)
    loss_1 = dice_loss_3(out_1, label_1)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1
    
    outstr = 'in epoch {}, loss = {}'.format(e, loss.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    #torch.cuda.empty_cache()

record.close()