In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from dense_vnet.DenseVNet import *
from data_utils import get_dimensions
from loss import *
import adabound
from train import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
# loaders come with auto batch division and multi-thread acceleration

In [4]:
test_dictionary = train_dataset[33]

#image_2 = test_dictionary['image2_data'].view(1, 1, 128, 128, 128)
#label_2 = test_dictionary['image2_label'].view(1, 3, 128, 128, 128)

image_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256)
label_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256)

#image_2 = image_2.to(device=device, dtype=dtype)  
#label_2 = label_2.to(device=device, dtype=dtype)

image_1 = image_1.to(device=device, dtype=dtype) 
label_1 = label_1.to(device=device, dtype=dtype)

In [None]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Conv3d):
        init.kaiming_normal_(m.weight.data)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=0, std=1)

def downsample_label(label, scale_factor):
    return F.interpolate(label, scale_factor=scale_factor, mode='trilinear', align_corners=True)

In [None]:
def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    model = model.to(device=device, dtype=dtype) if cuda_bool else x
    scores = model(x)
    for i in scores:
        print(i.size())

In [None]:
"""
dense_vnet = DenseVNet()
dense_vnet = nn.DataParallel(dense_vnet)
dense_vnet = convert_model(dense_vnet)
dense_vnet.apply(init_weights)
dense_vnet = dense_vnet.to(device=device, dtype=dtype)

#shape_test(dense_vnet, True)

#optimizer = optim.Adam(dense_vnet.parameters(), lr=1e-3)
optimizer = adabound.AdaBound(dense_vnet.parameters(), lr=1e-3, final_lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
epoch = 0
"""

'\ndense_vnet = DenseVNet()\ndense_vnet = nn.DataParallel(dense_vnet)\ndense_vnet = convert_model(dense_vnet)\ndense_vnet.apply(init_weights)\ndense_vnet = dense_vnet.to(device=device, dtype=dtype)\n\n#shape_test(dense_vnet, True)\n\n#optimizer = optim.Adam(dense_vnet.parameters(), lr=1e-3)\noptimizer = adabound.AdaBound(dense_vnet.parameters(), lr=1e-3, final_lr=0.1)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)\nepoch = 0\n'

In [None]:

dense_vnet = DenseVNet(num_classes=3)
dense_vnet = nn.DataParallel(dense_vnet)
dense_vnet = convert_model(dense_vnet)

#checkpoint = torch.load('../dense_vnet_dropout_save/2019-08-10 18:50:56.964852 epoch: 675.pth') # best one
#checkpoint = torch.load('../dense_vnet_dropout_save/2019-08-11 04:09:10.113433 epoch: 759.pth')
checkpoint = torch.load('../dense_vnet_dropout_save/2019-08-12 11:58:04.178603 epoch: 1039.pth')

dense_vnet.load_state_dict(checkpoint['state_dict_1'])
dense_vnet = dense_vnet.to(device=device, dtype=dtype)

optimizer = optim.Adam(dense_vnet.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

#optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
print(epoch)
for param_group in optimizer.param_groups:
    print(param_group['lr'])


1039
0.001


In [None]:
epochs = 5000

record = open('dense_vnet/train_dense_vnet_dropout.txt','a+')

logger = {'train':[], 'validation_1':[]}

min_val = .0581

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
        
        # iter over the train mini batches
    
        dense_vnet.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = dense_vnet(image_1)
       
        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        
        optimizer.step()
        #scheduler.step(loss)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%1 == 0:
    # do validation every n epochs
    
        dense_vnet.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                
                out_1_val = dense_vnet(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            outstr = '------- 1st valloss={0:.4f}'\
                .format(valloss_1/(v+1)) + '\n'
            
            logger['validation_1'].append(valloss_1/(v+1))
            avg_val_loss = valloss_1 / (v+1)
            #scheduler.step(valloss_1/(v+1))
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                min_val = avg_val_loss
                save_1('dense_vnet_dropout_save', dense_vnet, optimizer, logger, e, scheduler)
            elif e % 10 == 0:
                save_1('dense_vnet_dropout_save', dense_vnet, optimizer, logger, e, scheduler)
            #torch.cuda.empty_cache()
record.close()

  0%|          | 0/3960 [00:00<?, ?it/s]

Epoch 1040 finished ! Training Loss: 0.0788



  0%|          | 1/3960 [07:54<522:20:57, 474.98s/it]

------- 1st valloss=0.0613

Checkpoint 1040 saved !
Epoch 1041 finished ! Training Loss: 0.0823



  0%|          | 2/3960 [14:47<501:34:46, 456.21s/it]

------- 1st valloss=0.0603

Epoch 1042 finished ! Training Loss: 0.0811



  0%|          | 3/3960 [21:51<490:46:54, 446.50s/it]

------- 1st valloss=0.0629

Epoch 1043 finished ! Training Loss: 0.0786



  0%|          | 4/3960 [28:44<479:44:10, 436.56s/it]

------- 1st valloss=0.0612

Epoch 1044 finished ! Training Loss: 0.0795



  0%|          | 5/3960 [35:51<476:33:17, 433.78s/it]

------- 1st valloss=0.0607

Epoch 1045 finished ! Training Loss: 0.0780



  0%|          | 6/3960 [42:54<472:47:06, 430.46s/it]

------- 1st valloss=0.0624

Epoch 1046 finished ! Training Loss: 0.0790



  0%|          | 7/3960 [49:48<467:18:44, 425.58s/it]

------- 1st valloss=0.0604

Epoch 1047 finished ! Training Loss: 0.0766



  0%|          | 8/3960 [56:52<466:29:47, 424.95s/it]

------- 1st valloss=0.0643

Epoch 1048 finished ! Training Loss: 0.0762



  0%|          | 9/3960 [1:04:01<467:43:35, 426.17s/it]

------- 1st valloss=0.0595

Epoch 1049 finished ! Training Loss: 0.0776



  0%|          | 10/3960 [1:10:54<463:20:51, 422.29s/it]

------- 1st valloss=0.0622

Epoch 1050 finished ! Training Loss: 0.0823



  0%|          | 11/3960 [1:17:59<464:14:54, 423.22s/it]

------- 1st valloss=0.0702

Checkpoint 1050 saved !
Epoch 1051 finished ! Training Loss: 0.0780



  0%|          | 12/3960 [1:24:55<461:45:29, 421.06s/it]

------- 1st valloss=0.0655

Epoch 1052 finished ! Training Loss: 0.0782



  0%|          | 13/3960 [1:31:54<460:56:52, 420.42s/it]

------- 1st valloss=0.0622

Epoch 1053 finished ! Training Loss: 0.0763



  0%|          | 14/3960 [1:38:50<459:12:55, 418.95s/it]

------- 1st valloss=0.0627

Epoch 1054 finished ! Training Loss: 0.0787



  0%|          | 15/3960 [1:45:44<457:37:02, 417.60s/it]

------- 1st valloss=0.0627

Epoch 1055 finished ! Training Loss: 0.0750



  0%|          | 16/3960 [1:52:39<456:23:34, 416.59s/it]

------- 1st valloss=0.0643

Epoch 1056 finished ! Training Loss: 0.0775



  0%|          | 17/3960 [1:59:28<453:48:24, 414.33s/it]

------- 1st valloss=0.0708

Epoch 1057 finished ! Training Loss: 0.0769



  0%|          | 18/3960 [2:06:33<457:18:49, 417.64s/it]

------- 1st valloss=0.0656

Epoch 1058 finished ! Training Loss: 0.0784



  0%|          | 19/3960 [2:13:37<459:21:48, 419.62s/it]

------- 1st valloss=0.0692

Epoch 1059 finished ! Training Loss: 0.0798



  1%|          | 20/3960 [2:20:38<459:33:05, 419.89s/it]

------- 1st valloss=0.0665

Epoch 1060 finished ! Training Loss: 0.0772



  1%|          | 21/3960 [2:27:31<457:09:52, 417.82s/it]

------- 1st valloss=0.0658

Checkpoint 1060 saved !
Epoch 1061 finished ! Training Loss: 0.0746



  1%|          | 22/3960 [2:34:35<458:59:58, 419.60s/it]

------- 1st valloss=0.0639

Epoch 1062 finished ! Training Loss: 0.0753



  1%|          | 23/3960 [2:41:28<456:59:12, 417.87s/it]

------- 1st valloss=0.0637

Epoch 1063 finished ! Training Loss: 0.0757



  1%|          | 24/3960 [2:48:26<456:47:38, 417.80s/it]

------- 1st valloss=0.0610

Epoch 1064 finished ! Training Loss: 0.0755



  1%|          | 25/3960 [2:55:22<456:10:30, 417.34s/it]

------- 1st valloss=0.0606

Epoch 1065 finished ! Training Loss: 0.0738



  1%|          | 26/3960 [3:02:17<455:12:18, 416.56s/it]

------- 1st valloss=0.0594

Epoch 1066 finished ! Training Loss: 0.0781



  1%|          | 27/3960 [3:09:09<453:43:31, 415.31s/it]

------- 1st valloss=0.0601

Epoch 1067 finished ! Training Loss: 0.0737



  1%|          | 28/3960 [3:16:01<452:18:24, 414.12s/it]

------- 1st valloss=0.0716

Epoch 1068 finished ! Training Loss: 0.0756



  1%|          | 29/3960 [3:23:01<454:15:43, 416.01s/it]

------- 1st valloss=0.0614

Epoch 1069 finished ! Training Loss: 0.0746



  1%|          | 30/3960 [3:30:05<456:48:15, 418.45s/it]

------- 1st valloss=0.0619

Epoch 1070 finished ! Training Loss: 0.0770



  1%|          | 31/3960 [3:37:13<459:50:49, 421.34s/it]

------- 1st valloss=0.0628

Checkpoint 1070 saved !
Epoch 1071 finished ! Training Loss: 0.0762



  1%|          | 32/3960 [3:44:13<459:06:51, 420.78s/it]

------- 1st valloss=0.0630

Epoch 1072 finished ! Training Loss: 0.0788



  1%|          | 33/3960 [3:51:07<456:59:59, 418.95s/it]

------- 1st valloss=0.0653

Epoch 1073 finished ! Training Loss: 0.0742



  1%|          | 34/3960 [3:58:00<454:42:43, 416.95s/it]

------- 1st valloss=0.0631

Epoch 1074 finished ! Training Loss: 0.0761



In [None]:
dense_vnet.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = dense_vnet(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)

In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_dense_vnet3.txt','w+')

for e in tqdm(range(epochs)):
    
    
    out_2 = dense_vnet(image_1)
    
    #loss_1 = dice_loss_3(out_1, label_1)
    loss_2 = dice_loss_3(out_2, label_1)
    
    #loss = loss_1 + loss_2
    
    outstr = 'in epoch {}, loss_2 = {}'.format(e, loss_2.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer.zero_grad()
    loss_2.backward()
    optimizer.step()

record.close()