In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from dense_vnet.DenseVNet import DenseVNet
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device("cuda:0")
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = get_full_resolution_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = get_full_resolution_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
from vnet import VNet
from bv_refinement_network.RefinementModel import RefinementModel_ELU
from refinenet import refine_net

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    
checkpoint_refine = torch.load('../refine_bv_resize_save/2019-08-21 16:39:17.658870 epoch: 34.pth')
    
refine_model = refine_net(num_classes=1)
#refine_model = nn.DataParallel(refine_model)
#refine_model = convert_model(refine_model)

refine_model.load_state_dict(checkpoint_refine['state_dict_1'])

refine_model = refine_model.to(device, dtype)

optimizer = optim.Adam(refine_model.parameters(), lr=1e-3)
optimizer.load_state_dict(checkpoint_refine['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
scheduler.load_state_dict(checkpoint_refine['scheduler'])

deeplab = DeepLab(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

checkpoint = torch.load('../deeplab_dilated_save/2019-08-10 09:28:43.844872 epoch: 1160.pth') # best one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

epoch = checkpoint_refine['epoch']
#epoch = 0
print(epoch)

Let's use 2 GPUs!
18


In [None]:
def get_bboxes(image, label, output, batchsize, box_size):
    image_final = torch.zeros((batchsize, 1, box_size, box_size, box_size)).to(device, dtype)
    label_final = torch.zeros((batchsize, 1, box_size, box_size, box_size)).to(device, dtype)
    output_final = torch.zeros((batchsize, 1, box_size, box_size, box_size)).to(device, dtype)
    half_size = int(box_size/2)
    image_size_x = int(image.shape[-3])
    image_size_y = int(image.shape[-2])
    image_size_z = int(image.shape[-1])
    for b in range(batchsize):
        out = output[b]
        x,y,z = loadbvcenter(binarize_output(out))
        x, y, z = np.clip([x, y, z], a_min=half_size, a_max=181)
        x1 = max(x-half_size, 0)
        x2 = min(x+half_size, image_size_x)
        y1 = max(y-half_size, 0)
        y2 = min(y+half_size, image_size_y)
        z1 = max(z-half_size, 0)
        z2 = min(z+half_size, image_size_z)
        image_final[b] = image[b, :, x1:x2, y1:y2, z1:z2]
        label_final[b] = label[b, :, x1:x2, y1:y2, z1:z2]
        output_final[b] = output[b, :, x1:x2, y1:y2, z1:z2]
    return image_final, label_final, output_final

In [None]:
epochs = 5000

record = open('train_bv_refine_resize2.txt','a+')

logger = {'train':[], 'validation_1': []}

min_val = 1

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches
    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
        train_losses=[]
        for minibatch in range(BATCH_SIZE):
            refine_model.train()
            deeplab.eval()
            # Set the model flag to train
            # 1. enable dropout
            # 2. batchnorm behave differently in train and test
            #print(batch['image1_data'])
            image_1 = batch['image1_data'][minibatch].to(device=device, dtype=dtype)
            image_1 = image_1.view(1,1,256,256,256)

            label_1 = batch['image1_label'][minibatch].to(device=device, dtype=dtype)
            label_1 = label_1.view(1,3,256,256,256)

            bv_label = label_1[:, 2, :, :, :]
            bv_label = bv_label.view(1,1,256,256,256)

            original_res = [a[minibatch].item() for a in batch['original_resolution']]

            image_1_resize = F.interpolate(image_1, size=original_res, mode='trilinear', align_corners=True)
            image_1_resize = image_1_resize.view(1,1,original_res[0], original_res[1], original_res[2])

            bv_label_resize = F.interpolate(bv_label, size=original_res, mode='trilinear', align_corners=True)

            # Get coarse output from deeplab model from 256 resolution input
            out_coarse = deeplab(image_1)
            out_coarse = out_coarse.view(1,3,256,256,256)

            bv_coarse = out_coarse[:, 2, :, :, :]
            bv_coarse = bv_coarse.view(1,1,256,256,256)

            bv_coarse_resize = F.interpolate(bv_coarse, size=original_res, mode='trilinear', align_corners=True)
            
            box_size = 192
            half_size = int(box_size / 2)
            
            image_size_x = int(image_1_resize.shape[-3])
            image_size_y = int(image_1_resize.shape[-2])
            image_size_z = int(image_1_resize.shape[-1])
            
            x,y,z = loadbvcenter(binarize_output(bv_coarse_resize).view([1] + original_res))
            x, y, z = np.clip([x, y, z], a_min=box_size-half_size, a_max=box_size+half_size)
            x1 = max(x-half_size, 0)
            x2 = min(x+half_size, image_size_x)
            y1 = max(y-half_size, 0)
            y2 = min(y+half_size, image_size_y)
            z1 = max(z-half_size, 0)
            z2 = min(z+half_size, image_size_z)
            
            
            bbox_bv = bv_coarse_resize.view(original_res)[x1:x2, y1:y2, z1:z2]
            bbox_bv = reshape_image(bbox_bv.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_bv = bbox_bv.view(1,1,box_size,box_size,box_size)
            
            bbox_bv_label = bv_label_resize.view(original_res)[x1:x2, y1:y2, z1:z2]
            bbox_bv_label = reshape_image(bbox_bv_label.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_bv_label = bbox_bv_label.view(1,1,box_size,box_size,box_size)

            #bbox_image = get_bounding_box_image(image_1, (256,256,256)).to(device, dtype)
            bbox_image = image_1_resize[:, :, x1:x2, y1:y2, z1:z2]
            bbox_image = reshape_image(bbox_image.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_image = bbox_image.view(1, 1, box_size, box_size, box_size)
            
            #bbox_iamge, bbox_bv_label, bbox_bv = get_bboxes(image_1_resize, bv_label_resize, bv_coarse_resize, 1, 200)
            
            bbox_concat = torch.cat([bbox_bv, bbox_image], dim=1)
            bbox_concat_2 = F.interpolate(bbox_concat, scale_factor=1/2, mode='trilinear', align_corners=True)
            bbox_concat_4 = F.interpolate(bbox_concat, scale_factor=1/4, mode='trilinear', align_corners=True)

            refine_out = refine_model(bbox_concat, bbox_concat_2, bbox_concat_4)
            #refine_out = refine_model(seg_image_concat)
            # do the inference

            #print(refine_out.shape)
            #print(bbox_bv_label.shape)

            loss = dice_loss(refine_out, bbox_bv_label)
            print(loss)
            train_losses.append(loss)
        
        loss = sum(train_losses) / BATCH_SIZE
        train_losses=[]
        epoch_loss += loss.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%1 == 0:
    # do validation every 5 epoches
        deeplab.eval()
        refine_model.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                val_losses = []
                for minibatch in range(BATCH_SIZE):
                    image_1 = vbatch['image1_data'][minibatch].to(device=device, dtype=dtype)
                    image_1 = image_1.view(1,1,256,256,256)

                    label_1 = vbatch['image1_label'][minibatch].to(device=device, dtype=dtype)
                    label_1 = label_1.view(1,3,256,256,256)

                    bv_label = label_1[:, 2, :, :, :]
                    bv_label = bv_label.view(1,1,256,256,256)

                    original_res = [a[minibatch].item() for a in vbatch['original_resolution']]

                    image_1_resize = F.interpolate(image_1, size=original_res, mode='trilinear', align_corners=True)
                    image_1_resize = image_1_resize.view(1,1,original_res[0], original_res[1], original_res[2])

                    bv_label_resize = F.interpolate(bv_label, size=original_res, mode='trilinear', align_corners=True)

                    # Get coarse output from deeplab model from 256 resolution input
                    out_coarse = deeplab(image_1)
                    out_coarse = out_coarse.view(1,3,256,256,256)

                    bv_coarse = out_coarse[:, 2, :, :, :]
                    bv_coarse = bv_coarse.view(1,1,256,256,256)

                    bv_coarse_resize = F.interpolate(bv_coarse, size=original_res, mode='trilinear', align_corners=True)

                    box_size = 192
                    half_size = int(box_size / 2)

                    image_size_x = int(image_1_resize.shape[-3])
                    image_size_y = int(image_1_resize.shape[-2])
                    image_size_z = int(image_1_resize.shape[-1])

                    x,y,z = loadbvcenter(binarize_output(bv_coarse_resize).view([1] + original_res))
                    x, y, z = np.clip([x, y, z], a_min=box_size-half_size, a_max=box_size+half_size)
                    x1 = max(x-half_size, 0)
                    x2 = min(x+half_size, image_size_x)
                    y1 = max(y-half_size, 0)
                    y2 = min(y+half_size, image_size_y)
                    z1 = max(z-half_size, 0)
                    z2 = min(z+half_size, image_size_z)


                    bbox_bv = bv_coarse_resize.view(original_res)[x1:x2, y1:y2, z1:z2]
                    bbox_bv = reshape_image(bbox_bv.squeeze(), box_size, box_size, box_size).to(device, dtype)
                    bbox_bv = bbox_bv.view(1,1,box_size,box_size,box_size)

                    bbox_bv_label = bv_label_resize.view(original_res)[x1:x2, y1:y2, z1:z2]
                    bbox_bv_label = reshape_image(bbox_bv_label.squeeze(), box_size, box_size, box_size).to(device, dtype)
                    bbox_bv_label = bbox_bv_label.view(1,1,box_size,box_size,box_size)

                    #bbox_image = get_bounding_box_image(image_1, (256,256,256)).to(device, dtype)
                    bbox_image = image_1_resize[:, :, x1:x2, y1:y2, z1:z2]
                    bbox_image = reshape_image(bbox_image.squeeze(), box_size, box_size, box_size).to(device, dtype)
                    bbox_image = bbox_image.view(1, 1, box_size, box_size, box_size)

                    #bbox_iamge, bbox_bv_label, bbox_bv = get_bboxes(image_1_resize, bv_label_resize, bv_coarse_resize, 1, 200)

                    bbox_concat = torch.cat([bbox_bv, bbox_image], dim=1)
                    bbox_concat_2 = F.interpolate(bbox_concat, scale_factor=1/2, mode='trilinear', align_corners=True)
                    bbox_concat_4 = F.interpolate(bbox_concat, scale_factor=1/4, mode='trilinear', align_corners=True)

                    refine_out = refine_model(bbox_concat, bbox_concat_2, bbox_concat_4)
                    
                    loss = dice_loss(refine_out, bbox_bv_label)
                    val_losses.append(loss)
                
                avg_loss = sum(val_losses) / BATCH_SIZE
                val_losses = []
                print(avg_loss)
            
                # calculate loss
                valloss_1 += avg_loss.item()
                
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)
            
            if avg_val_loss < min_val:
                min_val = avg_val_loss
                save_1('refine_bv_resize_save', refine_model, optimizer, logger, e, scheduler)
            elif e % 10 == 0:
                save_1('refine_bv_resize_save', refine_model, optimizer, logger, e, scheduler)
            
            print(outstr)
            record.write(outstr)
            record.flush()
    


record.close()

  0%|          | 0/4981 [00:00<?, ?it/s]

tensor(0.0785, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1797, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0953, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0441, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0845, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1401, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9982, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1195, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0597, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0951, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0743, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0946, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1408, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackwa

tensor(0.1089, device='cuda:0')
tensor(0.0628, device='cuda:0')
tensor(0.0919, device='cuda:0')
tensor(0.0798, device='cuda:0')
tensor(0.0651, device='cuda:0')
tensor(0.1229, device='cuda:0')
tensor(0.0967, device='cuda:0')
tensor(0.1014, device='cuda:0')
tensor(0.1024, device='cuda:0')


  0%|          | 1/4981 [10:02<833:36:40, 602.61s/it]

tensor(0.0514, device='cuda:0')
Checkpoint 19 saved !
------- 1st valloss=0.0898

tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0780, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0944, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0815, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0796, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0687, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1286, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1030, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0635, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0610, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0443, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1587, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0565, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0826, device=

tensor(0.0657, device='cuda:0')
tensor(0.0739, device='cuda:0')
tensor(0.1090, device='cuda:0')
tensor(0.0846, device='cuda:0')
tensor(0.1774, device='cuda:0')
tensor(0.0972, device='cuda:0')
tensor(0.0822, device='cuda:0')
tensor(0.0884, device='cuda:0')
tensor(0.0926, device='cuda:0')
tensor(0.0948, device='cuda:0')
tensor(0.0654, device='cuda:0')
tensor(0.0749, device='cuda:0')
tensor(0.0968, device='cuda:0')


  0%|          | 2/4981 [19:24<816:25:44, 590.31s/it]

Checkpoint 20 saved !
------- 1st valloss=0.0860

tensor(0.0651, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1186, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1149, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0430, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0655, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0383, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1200, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9998, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0521, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0606, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1594, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0905, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0934, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1159, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0969, device='cuda:0', grad_fn=<R

tensor(0.1055, device='cuda:0')
tensor(0.0964, device='cuda:0')
tensor(0.0801, device='cuda:0')
tensor(0.0644, device='cuda:0')
tensor(0.0634, device='cuda:0')
tensor(0.0477, device='cuda:0')
tensor(0.1862, device='cuda:0')
tensor(0.0966, device='cuda:0')
tensor(0.0846, device='cuda:0')
tensor(0.1030, device='cuda:0')
tensor(0.0659, device='cuda:0')
tensor(0.1501, device='cuda:0')


  0%|          | 3/4981 [28:43<803:23:19, 581.00s/it]

------- 1st valloss=0.0870

tensor(0.0803, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1637, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0797, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1031, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0978, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0627, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0780, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0495, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1170, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1174, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1248, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1590, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0795, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0654, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1053, 

tensor(0.1068, device='cuda:0')
tensor(0.1002, device='cuda:0')
tensor(0.0515, device='cuda:0')
tensor(0.0900, device='cuda:0')
tensor(0.0688, device='cuda:0')
tensor(0.1032, device='cuda:0')
tensor(0.0948, device='cuda:0')
tensor(0.0874, device='cuda:0')
tensor(0.0820, device='cuda:0')
tensor(0.0624, device='cuda:0')
tensor(0.0978, device='cuda:0')


  0%|          | 4/4981 [38:01<793:53:22, 574.24s/it]

------- 1st valloss=0.0955

tensor(0.1140, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1200, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1059, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1135, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0768, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0444, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0873, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0483, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0556, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2005, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1328, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1270, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0694, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='

tensor(0.1034, device='cuda:0')
tensor(0.0881, device='cuda:0')
tensor(0.0902, device='cuda:0')
tensor(0.1011, device='cuda:0')
tensor(0.0801, device='cuda:0')
tensor(0.0806, device='cuda:0')
tensor(0.0925, device='cuda:0')
tensor(0.0660, device='cuda:0')
tensor(0.0857, device='cuda:0')
tensor(0.0736, device='cuda:0')
tensor(0.0588, device='cuda:0')


  0%|          | 5/4981 [47:24<788:43:45, 570.62s/it]

Checkpoint 23 saved !
------- 1st valloss=0.0830

tensor(0.1086, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1662, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0884, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0509, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0877, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1188, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1340, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9995, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1181, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1150, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0393, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1186, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0742, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0716, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1208, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubB

tensor(0.0806, device='cuda:0')
tensor(0.0522, device='cuda:0')
tensor(0.0899, device='cuda:0')
tensor(0.0862, device='cuda:0')
tensor(0.0853, device='cuda:0')
tensor(0.1304, device='cuda:0')
tensor(0.1783, device='cuda:0')
tensor(0.0678, device='cuda:0')
tensor(0.0826, device='cuda:0')
tensor(0.0786, device='cuda:0')
tensor(0.0672, device='cuda:0')
tensor(0.0741, device='cuda:0')


  0%|          | 6/4981 [56:46<785:11:13, 568.18s/it]

------- 1st valloss=0.0845

tensor(0.0860, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0578, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0643, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0752, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0851, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1102, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0845, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0867, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0571, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0718, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2306, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1003, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0841, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0534, 

tensor(0.0613, device='cuda:0')
tensor(0.0730, device='cuda:0')
tensor(0.0878, device='cuda:0')
tensor(0.0805, device='cuda:0')
tensor(0.0700, device='cuda:0')
tensor(0.0899, device='cuda:0')
tensor(0.1180, device='cuda:0')
tensor(0.0660, device='cuda:0')
tensor(0.0770, device='cuda:0')
tensor(0.0724, device='cuda:0')
tensor(0.1083, device='cuda:0')


  0%|          | 7/4981 [1:06:05<781:19:48, 565.50s/it]

------- 1st valloss=0.0851

tensor(0.1432, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1056, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0530, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0430, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0507, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1096, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1182, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1519, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0768, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0501, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0410, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0813, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0605, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0988, 

tensor(0.0932, device='cuda:0')
tensor(0.1001, device='cuda:0')
tensor(0.0907, device='cuda:0')
tensor(0.1212, device='cuda:0')
tensor(0.0724, device='cuda:0')
tensor(0.0579, device='cuda:0')
tensor(0.0989, device='cuda:0')
tensor(0.0566, device='cuda:0')
tensor(0.1041, device='cuda:0')
tensor(0.0775, device='cuda:0')
tensor(0.0570, device='cuda:0')
tensor(0.0903, device='cuda:0')


  0%|          | 8/4981 [1:15:22<777:41:13, 562.97s/it]

------- 1st valloss=0.0841

tensor(0.0975, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1154, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0556, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0583, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0474, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0613, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0681, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1351, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0667, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0601, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0654, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0719, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0465, devi

tensor(0.0804, device='cuda:0')
tensor(0.1076, device='cuda:0')
tensor(0.0947, device='cuda:0')
tensor(0.0891, device='cuda:0')
tensor(0.0829, device='cuda:0')
tensor(0.0788, device='cuda:0')
tensor(0.0620, device='cuda:0')
tensor(0.0635, device='cuda:0')
tensor(0.1934, device='cuda:0')
tensor(0.0663, device='cuda:0')
tensor(0.0816, device='cuda:0')


  0%|          | 9/4981 [1:24:45<777:29:53, 562.95s/it]

------- 1st valloss=0.0848

tensor(0.1119, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0921, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0602, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0910, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1181, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0649, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0709, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9999, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0534, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0975, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1409, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0707, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0927, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0421, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(

tensor(0.0509, device='cuda:0')
tensor(0.1034, device='cuda:0')
tensor(0.0808, device='cuda:0')
tensor(0.0963, device='cuda:0')
tensor(0.0890, device='cuda:0')
tensor(0.0694, device='cuda:0')
tensor(0.1031, device='cuda:0')
tensor(0.0736, device='cuda:0')
tensor(0.0825, device='cuda:0')
tensor(0.1068, device='cuda:0')
tensor(0.0936, device='cuda:0')


  0%|          | 10/4981 [1:34:01<774:11:33, 560.67s/it]

------- 1st valloss=0.0868

tensor(0.0547, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0676, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0775, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0423, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0990, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0441, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0378, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0501, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0814, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0900, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0672, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1185, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0716, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0931, 

tensor(0.1849, device='cuda:0')
tensor(0.0699, device='cuda:0')
tensor(0.0957, device='cuda:0')
tensor(0.0831, device='cuda:0')
tensor(0.0667, device='cuda:0')
tensor(0.0518, device='cuda:0')
tensor(0.0639, device='cuda:0')
tensor(0.0533, device='cuda:0')
tensor(0.0443, device='cuda:0')
tensor(0.0648, device='cuda:0')
tensor(0.0893, device='cuda:0')


  0%|          | 11/4981 [1:43:19<772:51:33, 559.82s/it]

Checkpoint 29 saved !
------- 1st valloss=0.0798

tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1465, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0488, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0399, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0803, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0630, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1021, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0879, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0982, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0833, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0502, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1265, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1023, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0599, device='cuda:0', grad_fn=<RsubBackw

tensor(0.0638, device='cuda:0')
tensor(0.0879, device='cuda:0')
tensor(0.0700, device='cuda:0')
tensor(0.0626, device='cuda:0')
tensor(0.0674, device='cuda:0')
tensor(0.0707, device='cuda:0')
tensor(0.1096, device='cuda:0')
tensor(0.0691, device='cuda:0')
tensor(0.0846, device='cuda:0')
tensor(0.0738, device='cuda:0')
tensor(0.0809, device='cuda:0')
tensor(0.0537, device='cuda:0')


  0%|          | 12/4981 [1:52:47<776:11:46, 562.35s/it]

Checkpoint 30 saved !
------- 1st valloss=0.0823

tensor(0.0983, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0452, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0774, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0684, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1158, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0806, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1534, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0711, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0691, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0548, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0415, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0776, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0837, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1695, device='cuda:0', grad_fn=<RsubBackw

tensor(0.0976, device='cuda:0')
tensor(0.0748, device='cuda:0')
tensor(0.1038, device='cuda:0')
tensor(0.0916, device='cuda:0')
tensor(0.0766, device='cuda:0')
tensor(0.0658, device='cuda:0')
tensor(0.0776, device='cuda:0')
tensor(0.0629, device='cuda:0')
tensor(0.1036, device='cuda:0')
tensor(0.1973, device='cuda:0')
tensor(0.0595, device='cuda:0')
tensor(0.0675, device='cuda:0')


  0%|          | 13/4981 [2:02:07<775:16:49, 561.80s/it]

------- 1st valloss=0.0810

tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0971, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0774, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0539, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0927, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0924, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0768, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0947, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1128, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2159, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0912, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0514, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0599, device='cuda

tensor(0.0762, device='cuda:0')
tensor(0.0739, device='cuda:0')
tensor(0.0920, device='cuda:0')
tensor(0.0850, device='cuda:0')
tensor(0.0656, device='cuda:0')
tensor(0.0695, device='cuda:0')
tensor(0.0927, device='cuda:0')
tensor(0.0856, device='cuda:0')
tensor(0.0881, device='cuda:0')
tensor(0.0762, device='cuda:0')
tensor(0.0734, device='cuda:0')


  0%|          | 14/4981 [2:11:24<773:10:44, 560.39s/it]

------- 1st valloss=0.0850

tensor(0.0800, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0584, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1083, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0888, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0868, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0820, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1449, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1686, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0873, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1113, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0687, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0642, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0718, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1.0000, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0638, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.15

In [None]:
deeplab.eval()
refine_model.eval()

with torch.no_grad():

    val_loss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
        # move data to device, convert dtype to desirable dtype
        val_losses = []
        for minibatch in range(BATCH_SIZE):
            image_1 = vbatch['image1_data'][minibatch].to(device=device, dtype=dtype)
            image_1 = image_1.view(1,1,256,256,256)

            label_1 = vbatch['image1_label'][minibatch].to(device=device, dtype=dtype)
            label_1 = label_1.view(1,3,256,256,256)

            bv_label = label_1[:, 2, :, :, :]
            bv_label = bv_label.view(1,1,256,256,256)

            original_res = [a[minibatch].item() for a in vbatch['original_resolution']]

            image_1_resize = F.interpolate(image_1, size=original_res, mode='trilinear', align_corners=True)
            image_1_resize = image_1_resize.view(1,1,original_res[0], original_res[1], original_res[2])

            bv_label_resize = F.interpolate(bv_label, size=original_res, mode='trilinear', align_corners=True)

            # Get coarse output from deeplab model from 256 resolution input
            out_coarse = deeplab(image_1)
            out_coarse = out_coarse.view(1,3,256,256,256)

            bv_coarse = out_coarse[:, 2, :, :, :]
            bv_coarse = bv_coarse.view(1,1,256,256,256)

            bv_coarse_resize = F.interpolate(bv_coarse, size=original_res, mode='trilinear', align_corners=True)
            
            box_size = 192
            half_size = int(box_size / 2)
            
            image_size_x = int(image_1_resize.shape[-3])
            image_size_y = int(image_1_resize.shape[-2])
            image_size_z = int(image_1_resize.shape[-1])
            
            x,y,z = loadbvcenter(binarize_output(bv_coarse_resize).view([1] + original_res))
            x, y, z = np.clip([x, y, z], a_min=box_size-half_size, a_max=box_size+half_size)
            x1 = max(x-half_size, 0)
            x2 = min(x+half_size, image_size_x)
            y1 = max(y-half_size, 0)
            y2 = min(y+half_size, image_size_y)
            z1 = max(z-half_size, 0)
            z2 = min(z+half_size, image_size_z)
            
            
            bbox_bv = bv_coarse_resize.view(original_res)[x1:x2, y1:y2, z1:z2]
            bbox_bv = reshape_image(bbox_bv.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_bv = bbox_bv.view(1,1,box_size,box_size,box_size)
            
            bbox_bv_label = bv_label_resize.view(original_res)[x1:x2, y1:y2, z1:z2]
            bbox_bv_label = reshape_image(bbox_bv_label.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_bv_label = bbox_bv_label.view(1,1,box_size,box_size,box_size)

            #bbox_image = get_bounding_box_image(image_1, (256,256,256)).to(device, dtype)
            bbox_image = image_1_resize[:, :, x1:x2, y1:y2, z1:z2]
            bbox_image = reshape_image(bbox_image.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_image = bbox_image.view(1, 1, box_size, box_size, box_size)
            
            #bbox_iamge, bbox_bv_label, bbox_bv = get_bboxes(image_1_resize, bv_label_resize, bv_coarse_resize, 1, 200)
            
            bbox_concat = torch.cat([bbox_bv, bbox_image], dim=1)
            bbox_concat_2 = F.interpolate(bbox_concat, scale_factor=1/2, mode='trilinear', align_corners=True)
            bbox_concat_4 = F.interpolate(bbox_concat, scale_factor=1/4, mode='trilinear', align_corners=True)

            refine_out = refine_model(bbox_concat, bbox_concat_2, bbox_concat_4)

            loss = dice_loss(refine_out, bbox_bv_label)
            val_losses.append(loss)
            
            if loss.item() > .04:
                show_image_slice(image_1)
                show_image_slice(bv_label_resize)
                show_image_slice(bv_coarse)
                show_image_slice(bbox_image)
                show_image_slice(bbox_bv_label)
                show_image_slice(bbox_bv)
                show_image_slice(refine_out)
        
        loss = sum(val_losses) / BATCH_SIZE
        print(loss.item())
        val_loss += loss.item()
        val_losses = []
        '''
        if loss.item() > .05:
            show_image_slice(image_1)
            show_image_slice(label_1)
            show_image_slice(output)
        '''

    outstr = 'bv loss = {0:.4f}'\
        .format(val_loss/(v+1)) + '\n'
    print(outstr)