In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import get_dimensions
from loss import *
from train import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 1

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5), 
                transforms.RandomApply([ElasticTransformation(256*2, 256*0.08)])
                ]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# loaders come with auto batch division and multi-thread acceleration

In [4]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Conv3d):
        init.kaiming_normal_(m.weight.data)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=0, std=1)

def downsample_label(label, scale_factor):
    return F.interpolate(label, scale_factor=scale_factor, mode='trilinear', align_corners=True)

In [5]:
def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    model = model.to(device=device, dtype=dtype) if cuda_bool else model
    scores = model(x)
    for i in scores:
        print(i.size())

In [6]:
"""
test_dictionary = train_dataset[33]

image_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256).to(device=device, dtype=dtype)
label_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256).to(device=device, dtype=dtype)

#image_2 = test_dictionary['image2_data'].view(1, 1, 128, 128, 128).to(device=device, dtype=dtype)
#label_2 = test_dictionary['image2_label'].view(1, 3, 128, 128, 128).to(device=device, dtype=dtype)

#label_2_resize_2 = downsample_label(label_2, 1/2).to(device=device, dtype=dtype)

from dense_unet_3d import Model

m = Model()
#m = nn.DataParallel(m)
#m = convert_model(m)
m.apply(init_weights)
m = m.to(device=device, dtype=dtype)
#shape_test(m, True)

optimizer = optim.Adam(m.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
epoch=0
#shape_test(m, True)
"""

"\ntest_dictionary = train_dataset[33]\n\nimage_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256).to(device=device, dtype=dtype)\nlabel_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256).to(device=device, dtype=dtype)\n\n#image_2 = test_dictionary['image2_data'].view(1, 1, 128, 128, 128).to(device=device, dtype=dtype)\n#label_2 = test_dictionary['image2_label'].view(1, 3, 128, 128, 128).to(device=device, dtype=dtype)\n\n#label_2_resize_2 = downsample_label(label_2, 1/2).to(device=device, dtype=dtype)\n\nfrom dense_unet_3d import Model\n\nm = Model()\n#m = nn.DataParallel(m)\n#m = convert_model(m)\nm.apply(init_weights)\nm = m.to(device=device, dtype=dtype)\n#shape_test(m, True)\n\noptimizer = optim.Adam(m.parameters(), lr=1e-3)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)\nepoch=0\n#shape_test(m, True)\n"

In [7]:
from dense_unet_3d import Model
m = Model()
optimizer = optim.Adam(m.parameters())
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

checkpoint = torch.load('../3d_dense_unet_elastic_save/2019-08-05 12:11:38.146979 epoch: 75.pth')
m.load_state_dict(checkpoint['state_dict_1'])
m = m.to(device, dtype)

optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)

75


In [None]:
epochs = 5000

record = open('train_3d_dense_unet_elastic.txt','a')

logger = {'train':[], 'validation_1':[]}

for e in tqdm(range(epoch+1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        m.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        #image_2 = batch['image2_data'].to(device=device, dtype=dtype)
        #label_2 = batch['image2_label'].to(device=device, dtype=dtype)
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        
        # move data to device, convert dtype to desirable dtype
        
        # Downsample labels to coincide with icnet model outputs
        #label_1_resize_2 = downsample_label(label_1, 1/2)
        #label_2_resize_2 = downsample_label(label_2, 1/2)
        #label_4_resize_2 = downsample_label(label_4, 1/2)
        
        out_1 = m(image_1)
       
        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        
        optimizer.step()
        
        torch.cuda.empty_cache()
        #scheduler.step(loss)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        m.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_2 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
            
                #image_4_val = vbatch['image4_data'].to(device=device, dtype=dtype)
                #if get_dimensions(image_4_val) == 4:
                #    image_4_val.unsqueeze_(0)
                #label_4_val = vbatch['image4_label'].to(device=device, dtype=dtype)
                #if get_dimensions(label_4_val) == 4:
                #    label_4_val.unsqueeze_(0)
                
                #image_2_val = vbatch['image2_data'].to(device=device, dtype=dtype)
                #if get_dimensions(image_2_val) == 4:
                #    image_2_val.unsqueeze_(0)
                #label_2_val = vbatch['image2_label'].to(device=device, dtype=dtype)
                #if get_dimensions(label_2_val) == 4:
                #    label_2_val.unsqueeze_(0)
                
                image_2_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_2_val) == 4:
                    image_2_val.unsqueeze_(0)
                label_2_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_2_val) == 4:
                    label_2_val.unsqueeze_(0)
                
                #print("image_1_val:", image_1_val.shape)
                #print("label_1_val:", label_1_val.shape)
                # move data to device, convert dtype to desirable dtype
                # add dimension to labels if they are 4D tensors
                
                # Downsample labels to coincide with icnet model outputs
                #label_1_val_resize_2 = downsample_label(label_1_val, 1/2) 
                #label_2_val_resize_2 = downsample_label(label_2_val, 1/2) 
                #label_4_val_resize_2 = downsample_label(label_4_val, 1/2) 
                #print("label_1_val_resize:", label_1_val_resize_2.shape) 
                
                out_2_val = m(image_2_val)
                # do the inference
                
                #print("out_4:", out_4_val.shape)
                #print("label_4:", label_4_val_resize_2.shape)
                #loss_4 = dice_loss_3(out_4_val, label_4_val_resize_2)
                #loss_2 = dice_loss_3(out_2_val, label_2_val_resize_2)
                loss_2 = dice_loss_3(out_2_val, label_2_val)
                # calculate loss

                #valloss_4 += loss_4.item()
                #valloss_2 += loss_2.item()
                valloss_2 += loss_2.item()
                # record mini batch loss
            
            outstr = '------- 1st valloss={0:.4f}'\
                .format(valloss_2/(v+1)) + '\n'
            
            logger['validation_1'].append(valloss_2/(v+1))
            
            #scheduler.step(valloss_2/(v+1))
            
            torch.cuda.empty_cache()
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            #scheduler.step(valloss_2/(v+1))
            
            save_1('3d_dense_unet_elastic_save', m, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4924 [14:01<1151:19:38, 841.92s/it]

Epoch 76 finished ! Training Loss: 0.1968



  0%|          | 2/4924 [25:44<1093:45:46, 799.99s/it]

Epoch 77 finished ! Training Loss: 0.1931



  0%|          | 3/4924 [37:37<1058:05:39, 774.06s/it]

Epoch 78 finished ! Training Loss: 0.2185



  0%|          | 4/4924 [49:27<1031:39:34, 754.87s/it]

Epoch 79 finished ! Training Loss: 0.1996

Epoch 80 finished ! Training Loss: 0.1969



  0%|          | 5/4924 [1:02:05<1032:32:05, 755.67s/it]

------- 1st valloss=0.2117

Checkpoint 80 saved !


  0%|          | 6/4924 [1:13:53<1013:00:57, 741.53s/it]

Epoch 81 finished ! Training Loss: 0.1955



  0%|          | 7/4924 [1:25:44<1000:22:49, 732.43s/it]

Epoch 82 finished ! Training Loss: 0.1935



  0%|          | 8/4924 [1:37:31<989:30:34, 724.62s/it] 

Epoch 83 finished ! Training Loss: 0.2046



  0%|          | 9/4924 [1:49:25<984:58:49, 721.45s/it]

Epoch 84 finished ! Training Loss: 0.1827

Epoch 85 finished ! Training Loss: 0.2011



  0%|          | 10/4924 [2:02:15<1004:51:42, 736.16s/it]

------- 1st valloss=0.2474

Checkpoint 85 saved !


  0%|          | 11/4924 [2:14:04<993:26:08, 727.94s/it] 

Epoch 86 finished ! Training Loss: 0.1979



  0%|          | 12/4924 [2:26:02<988:57:29, 724.81s/it]

Epoch 87 finished ! Training Loss: 0.1898



  0%|          | 13/4924 [2:37:56<984:20:50, 721.57s/it]

Epoch 88 finished ! Training Loss: 0.1946



  0%|          | 14/4924 [2:49:50<981:20:51, 719.52s/it]

Epoch 89 finished ! Training Loss: 0.1989

Epoch 90 finished ! Training Loss: 0.1978



  0%|          | 15/4924 [3:02:39<1001:23:09, 734.36s/it]

------- 1st valloss=0.1374

Checkpoint 90 saved !


  0%|          | 16/4924 [3:14:27<990:22:20, 726.43s/it] 

Epoch 91 finished ! Training Loss: 0.1832



  0%|          | 17/4924 [3:26:12<981:15:15, 719.89s/it]

Epoch 92 finished ! Training Loss: 0.1965



  0%|          | 18/4924 [3:37:56<974:25:37, 715.03s/it]

Epoch 93 finished ! Training Loss: 0.1749



  0%|          | 19/4924 [3:49:50<973:45:36, 714.69s/it]

Epoch 94 finished ! Training Loss: 0.2083

Epoch 95 finished ! Training Loss: 0.1892



  0%|          | 20/4924 [4:02:34<993:50:42, 729.58s/it]

------- 1st valloss=0.1623

Checkpoint 95 saved !


  0%|          | 21/4924 [4:14:25<986:17:21, 724.18s/it]

Epoch 96 finished ! Training Loss: 0.1765



  0%|          | 22/4924 [4:26:21<982:22:39, 721.45s/it]

Epoch 97 finished ! Training Loss: 0.1940



  0%|          | 23/4924 [4:38:10<977:13:08, 717.81s/it]

Epoch 98 finished ! Training Loss: 0.2191



  0%|          | 24/4924 [4:49:54<971:15:56, 713.58s/it]

Epoch 99 finished ! Training Loss: 0.1879

Epoch 100 finished ! Training Loss: 0.1951



  1%|          | 25/4924 [5:02:32<989:31:04, 727.14s/it]

------- 1st valloss=0.4101

Checkpoint 100 saved !


  1%|          | 26/4924 [5:14:16<979:48:07, 720.15s/it]

Epoch 101 finished ! Training Loss: 0.1972



  1%|          | 27/4924 [5:26:05<975:01:11, 716.78s/it]

Epoch 102 finished ! Training Loss: 0.2005



  1%|          | 28/4924 [5:37:58<973:25:00, 715.75s/it]

Epoch 103 finished ! Training Loss: 0.2110



  1%|          | 29/4924 [5:49:53<972:45:18, 715.41s/it]

Epoch 104 finished ! Training Loss: 0.1804

Epoch 105 finished ! Training Loss: 0.1813



  1%|          | 30/4924 [6:02:36<991:53:27, 729.63s/it]

------- 1st valloss=0.1729

Checkpoint 105 saved !


  1%|          | 31/4924 [6:14:25<983:08:55, 723.35s/it]

Epoch 106 finished ! Training Loss: 0.1871



  1%|          | 32/4924 [6:26:12<976:35:01, 718.66s/it]

Epoch 107 finished ! Training Loss: 0.1793



  1%|          | 33/4924 [6:38:03<973:02:16, 716.20s/it]

Epoch 108 finished ! Training Loss: 0.1968



  1%|          | 34/4924 [6:49:57<972:13:10, 715.74s/it]

Epoch 109 finished ! Training Loss: 0.1833

Epoch 110 finished ! Training Loss: 0.1934



  1%|          | 35/4924 [7:02:39<990:45:20, 729.54s/it]

------- 1st valloss=0.1810

Checkpoint 110 saved !


  1%|          | 36/4924 [7:14:23<979:58:32, 721.75s/it]

Epoch 111 finished ! Training Loss: 0.1815



  1%|          | 37/4924 [7:26:13<975:08:55, 718.34s/it]

Epoch 112 finished ! Training Loss: 0.1917



  1%|          | 38/4924 [7:38:03<971:18:15, 715.66s/it]

Epoch 113 finished ! Training Loss: 0.1891



  1%|          | 39/4924 [7:49:53<969:03:33, 714.15s/it]

Epoch 114 finished ! Training Loss: 0.1696

Epoch 115 finished ! Training Loss: 0.1901



  1%|          | 40/4924 [8:02:33<987:25:07, 727.83s/it]

------- 1st valloss=0.3829

Checkpoint 115 saved !


  1%|          | 41/4924 [8:14:25<980:40:50, 723.01s/it]

Epoch 116 finished ! Training Loss: 0.1687



  1%|          | 42/4924 [8:26:07<972:11:03, 716.89s/it]

Epoch 117 finished ! Training Loss: 0.1684



  1%|          | 43/4924 [8:38:04<971:48:47, 716.76s/it]

Epoch 118 finished ! Training Loss: 0.1977



  1%|          | 44/4924 [8:49:55<969:28:27, 715.19s/it]

Epoch 119 finished ! Training Loss: 0.1866

Epoch 120 finished ! Training Loss: 0.2015



  1%|          | 45/4924 [9:02:45<991:24:29, 731.52s/it]

------- 1st valloss=0.2168

Checkpoint 120 saved !


  1%|          | 46/4924 [9:14:43<985:43:17, 727.47s/it]

Epoch 121 finished ! Training Loss: 0.1917



  1%|          | 47/4924 [9:26:28<976:14:10, 720.62s/it]

Epoch 122 finished ! Training Loss: 0.1923



  1%|          | 48/4924 [9:38:20<972:41:27, 718.15s/it]

Epoch 123 finished ! Training Loss: 0.1977



  1%|          | 49/4924 [9:50:15<971:18:16, 717.27s/it]

Epoch 124 finished ! Training Loss: 0.1996

Epoch 125 finished ! Training Loss: 0.2010



  1%|          | 50/4924 [10:02:52<987:16:50, 729.22s/it]

------- 1st valloss=0.2244

Checkpoint 125 saved !


  1%|          | 51/4924 [10:14:44<979:51:30, 723.88s/it]

Epoch 126 finished ! Training Loss: 0.1702



  1%|          | 52/4924 [10:26:36<975:00:24, 720.45s/it]

Epoch 127 finished ! Training Loss: 0.2073



  1%|          | 53/4924 [10:38:29<971:46:42, 718.21s/it]

Epoch 128 finished ! Training Loss: 0.1867



  1%|          | 54/4924 [10:50:20<968:41:04, 716.07s/it]

Epoch 129 finished ! Training Loss: 0.1906

Epoch 130 finished ! Training Loss: 0.1786



  1%|          | 55/4924 [11:03:01<986:34:26, 729.44s/it]

------- 1st valloss=0.1440

Checkpoint 130 saved !


  1%|          | 56/4924 [11:15:02<982:53:32, 726.87s/it]

Epoch 131 finished ! Training Loss: 0.1872



  1%|          | 57/4924 [11:26:57<978:00:22, 723.41s/it]

Epoch 132 finished ! Training Loss: 0.1849



  1%|          | 58/4924 [11:38:44<971:02:56, 718.41s/it]

Epoch 133 finished ! Training Loss: 0.1942



  1%|          | 59/4924 [11:50:39<969:39:55, 717.53s/it]

Epoch 134 finished ! Training Loss: 0.1768

Epoch 135 finished ! Training Loss: 0.1838



  1%|          | 60/4924 [12:03:18<986:18:50, 730.00s/it]

------- 1st valloss=0.3030

Checkpoint 135 saved !


  1%|          | 61/4924 [12:15:11<979:07:21, 724.83s/it]

Epoch 136 finished ! Training Loss: 0.1838



  1%|▏         | 62/4924 [12:27:07<975:15:00, 722.11s/it]

Epoch 137 finished ! Training Loss: 0.1732



  1%|▏         | 63/4924 [12:38:59<970:53:54, 719.04s/it]

Epoch 138 finished ! Training Loss: 0.1946



In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_3d_dense_unet.txt','w+')

for e in tqdm(range(epochs)):
    torch.cuda.empty_cache()
    
    out_1 = m(image_1)
    #out_1 = m(image_1)
    #loss_1 = dice_loss_3(out_1, label_1)
    loss_1 = dice_loss_3(out_1, label_1)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1
    
    outstr = 'in epoch {}, loss = {}'.format(e, loss.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    torch.cuda.empty_cache()

record.close()