In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from model import *
from loss import *
from sync_batchnorm import convert_model

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 6
BATCH_SIZE = 3

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


* class down_unit
    * pass
* class up_unit
    * pass
* class input_unit
    * pass
* class output_unit
    * pass

In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
# loaders come with auto batch division and multi-thread acceleration

In [4]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Conv3d):
        init.kaiming_normal_(m.weight.data)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=0, std=1)

def downsample_label(label, scale_factor):
    return F.interpolate(label, scale_factor=scale_factor, mode='trilinear', align_corners=True)

In [5]:
test_dictionary = train_dataset[33]

image_4 = test_dictionary['image4_data'].view(1, 1, 64, 64, 64)
label_4 = test_dictionary['image4_label'].view(1, 3, 64, 64, 64)

image_2 = test_dictionary['image2_data'].view(1, 1, 128, 128, 128)
label_2 = test_dictionary['image2_label'].view(1, 3, 128, 128, 128)

image_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256)
label_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256)

label_1_resize_2 = downsample_label(label_1, 1/2) 
label_2_resize_2 = downsample_label(label_2, 1/2)
label_4_resize_2 = downsample_label(label_4, 1/2)

label_1_resize_4 = downsample_label(label_1, 1/4) 
label_2_resize_4 = downsample_label(label_2, 1/4)
label_4_resize_4 = downsample_label(label_4, 1/4)

image_4 = image_4.to(device=device, dtype=dtype)  # move to device, fix dtype
label_4 = label_4.to(device=device, dtype=dtype)
label_4_resize_4 = label_4_resize_4.to(device=device, dtype=dtype)
label_4_resize_2 = label_4_resize_2.to(device=device, dtype=dtype)

image_2 = image_2.to(device=device, dtype=dtype)
label_2 = label_2.to(device=device, dtype=dtype)
label_2_resize_4 = label_2_resize_4.to(device=device, dtype=dtype)
label_2_resize_2 = label_2_resize_2.to(device=device, dtype=dtype)

image_1 = image_1.to(device=device, dtype=dtype) 
label_1 = label_1.to(device=device, dtype=dtype)
label_1_resize_4 = label_1_resize_4.to(device=device, dtype=dtype)
label_1_resize_2 = label_1_resize_2.to(device=device, dtype=dtype)

print("Label 1:", label_1.shape, "Downsampled 1/4:", label_1_resize_4.shape, "Downsampled 1/2:", label_1_resize_2.shape)
print("Label 2:", label_2.shape, "Downsampled 1/4:", label_2_resize_4.shape, "Downsampled 1/2:", label_2_resize_2.shape)
print("Label 4:", label_4.shape, "Downsampled 1/4:", label_4_resize_4.shape, "Downsampled 1/2:", label_4_resize_2.shape)

from model import *

icnet1 = ModifiedICNet(num_classes=3)
icnet1.apply(init_weights)
icnet1 = icnet1.to(device=device, dtype=dtype)

icnet2 = OriginalICNet(num_classes=3)
icnet2.apply(init_weights)
icnet2 = icnet2.to(device=device, dtype=dtype)

full_res_icnet = FullResolutionICNet(num_classes=3)
full_res_icnet = nn.DataParallel(full_res_icnet)
full_res_icnet = convert_model(full_res_icnet)
full_res_icnet.apply(init_weights)
full_res_icnet = full_res_icnet.to(device=device, dtype=dtype)

import torch.optim as optim

optimizer1 = optim.Adam(icnet1.parameters(), lr=1e-2)
optimizer2 = optim.Adam(icnet2.parameters(), lr=1e-2)
optimizer3 = optim.Adam(full_res_icnet.parameters(), lr=1e-2)

Label 1: torch.Size([1, 3, 256, 256, 256]) Downsampled 1/4: torch.Size([1, 3, 64, 64, 64]) Downsampled 1/2: torch.Size([1, 3, 128, 128, 128])
Label 2: torch.Size([1, 3, 128, 128, 128]) Downsampled 1/4: torch.Size([1, 3, 32, 32, 32]) Downsampled 1/2: torch.Size([1, 3, 64, 64, 64])
Label 4: torch.Size([1, 3, 64, 64, 64]) Downsampled 1/4: torch.Size([1, 3, 16, 16, 16]) Downsampled 1/2: torch.Size([1, 3, 32, 32, 32])


In [None]:
from model import *

def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    scores = model(x)
    for i in scores:
        print(i.size())

m = FullResolutionICNet(num_classes=3)
m = nn.DataParallel(m)
m = convert_model(m)
m = m.to(device=device, dtype=dtype)
shape_test(m, True)

* network
    * test with GPU

In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_modified_model.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1, out_2, out_4 = icnet1(image_1)
        
    loss_4 = dice_loss_3(out_4, label_4_resize_2)
    loss_2 = dice_loss_3(out_2, label_2_resize_2)
    loss_1 = dice_loss_3(out_1, label_1_resize_2)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1 + loss_2 + loss_4
    
    outstr = 'in epoch {}, loss = {}, loss_1: {}'.format(e, loss.item(), loss_1.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer1.zero_grad()
    loss.backward()
    optimizer1.step()

record.close()

In [None]:
# overfit model on single embryo image (original ICNet Model)

from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_original_model.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1, out_2, out_4 = icnet2(image_1)
        
    loss_4 = dice_loss_3(out_4, label_4_resize_4)
    loss_2 = dice_loss_3(out_2, label_2_resize_4)
    loss_1 = dice_loss_3(out_1, label_1_resize_4)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_4 + loss_2 + loss_1
    
    outstr = 'in epoch {}, loss = {}, loss_1: {}, loss_2: {}, loss_4: {}'.format(e, loss.item(), loss_1.item(), loss_2.item(), loss_4.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer2.zero_grad()
    loss.backward()
    optimizer2.step()

record.close()

In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_full_res_model_check2.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1, out_2, out_4 = full_res_icnet(image_1)
        
    #loss_4 = dice_loss_3(out_4, label_4_resize_2)
    #loss_2 = dice_loss_3(out_2, label_2_resize_2)
    loss_4 = dice_loss_3(out_4, label_4)
    loss_2 = dice_loss_3(out_2, label_2)
    loss_1 = dice_loss_3(out_1, label_1)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1 + loss_2 + loss_4
    
    outstr = 'in epoch {}, loss = {}, loss_1: {}, loss_2: {}'.format(e, loss.item(), loss_1.item(), loss_2.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer3.zero_grad()
    loss.backward()
    optimizer3.step()

record.close()

  0%|          | 0/5000 [00:00<?, ?it/s]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 0, loss = 2.251490592956543, loss_1: 0.7215265035629272



In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_deeplab.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1 = deeplab(image_1)
        
    
    loss_1 = dice_loss_3(out_1, label_1)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1
    
    outstr = 'in epoch {}, loss = {}'.format(e, loss.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer1.zero_grad()
    loss.backward()
    optimizer1.step()

record.close()