In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from dense_vnet.DenseVNet import DenseVNet
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device("cuda:0")
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = get_full_resolution_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = get_full_resolution_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
from bv_refinement_network.RefinementModel import RefinementModel, RefinementModel_NoDown
from refinenet import refine_net

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    
refine_model = refine_net(num_classes=1)
refine_model = nn.DataParallel(refine_model)
refine_model = convert_model(refine_model)
refine_model = refine_model.to(device, dtype)

optimizer = optim.Adam(refine_model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=25)

deeplab = DeepLab(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

checkpoint = torch.load('../deeplab_dilated_save/2019-08-10 09:28:43.844872 epoch: 1160.pth') # best one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

epoch = 0

Let's use 2 GPUs!


In [None]:
'''
test_dictionary = train_dataset[33]

image_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256)
label_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256)
bv_label = label_1.narrow(1,2,1).to(device, dtype)
if get_dimensions(bv_label) == 4:
    bv_label.unsqueeze_(0)

image_1 = image_1.to(device=device, dtype=dtype) 
label_1 = label_1.to(device=device, dtype=dtype)
'''

"\ntest_dictionary = train_dataset[33]\n\nimage_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256)\nlabel_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256)\nbv_label = label_1.narrow(1,2,1).to(device, dtype)\nif get_dimensions(bv_label) == 4:\n    bv_label.unsqueeze_(0)\n\nimage_1 = image_1.to(device=device, dtype=dtype) \nlabel_1 = label_1.to(device=device, dtype=dtype)\n"

In [None]:
def get_bboxes(image, label, output, batchsize, box_size):
    image_final = torch.zeros((batchsize, 1, box_size, box_size, box_size)).to(device, dtype)
    label_final = torch.zeros((batchsize, 1, box_size, box_size, box_size)).to(device, dtype)
    output_final = torch.zeros((batchsize, 1, box_size, box_size, box_size)).to(device, dtype)
    half_size = int(box_size/2)
    image_size_x = int(image.shape[-3])
    image_size_y = int(image.shape[-2])
    image_size_z = int(image.shape[-1])
    for b in range(batchsize):
        out = output[b]
        x,y,z = loadbvcenter(binarize_output(out))
        x, y, z = np.clip([x, y, z], a_min=half_size, a_max=175)
        x1 = max(x-half_size, 0)
        x2 = min(x+half_size, 256)
        y1 = max(y-half_size, 0)
        y2 = min(y+half_size, 256)
        z1 = max(z-half_size, 0)
        z2 = min(z+half_size, 256)
        image_final[b] = image[b, :, x1:x2, y1:y2, z1:z2]
        label_final[b] = label[b, :, x1:x2, y1:y2, z1:z2]
        output_final[b] = output[b, :, x1:x2, y1:y2, z1:z2]
    return image_final, label_final, output_final

In [None]:
epochs = 5000

record = open('train_bv_refine_refnet2.txt','a+')

logger = {'train':[], 'validation_1': []}

min_val = 1

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches
    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
        train_losses=[]
        for minibatch in range(BATCH_SIZE):
            refine_model.train()
            deeplab.eval()
            # Set the model flag to train
            # 1. enable dropout
            # 2. batchnorm behave differently in train and test
            #print(batch['image1_data'])
            image_1 = batch['image1_data'][minibatch].to(device=device, dtype=dtype)
            image_1 = image_1.view(1,1,256,256,256)

            label_1 = batch['image1_label'][minibatch].to(device=device, dtype=dtype)
            label_1 = label_1.view(1,3,256,256,256)

            bv_label = label_1[:, 2, :, :, :]
            bv_label = bv_label.view(1,1,256,256,256)

            # Get coarse output from deeplab model from 256 resolution input
            out_coarse = deeplab(image_1)
            out_coarse = out_coarse.view(1,3,256,256,256)

            bv_coarse = out_coarse[:, 2, :, :, :]
            bv_coarse = bv_coarse.view(1,1,256,256,256)
            
            box_size = 192
            half_size = int(box_size / 2)
            
            image_size_x = 256
            image_size_y = 256
            image_size_z = 256
            
            x,y,z = loadbvcenter(binarize_output(bv_coarse).view(1, 256,256,256))
            x, y, z = np.clip([x, y, z], a_min=box_size-half_size, a_max=box_size+half_size)
            x1 = max(x-half_size, 0)
            x2 = min(x+half_size, image_size_x)
            y1 = max(y-half_size, 0)
            y2 = min(y+half_size, image_size_y)
            z1 = max(z-half_size, 0)
            z2 = min(z+half_size, image_size_z)
            
            
            bbox_bv = bv_coarse.view(256,256,256)[x1:x2, y1:y2, z1:z2]
            bbox_bv = reshape_image(bbox_bv.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_bv = bbox_bv.view(1,1,box_size,box_size,box_size)
            
            bbox_bv_label = bv_label.view(256,256,256)[x1:x2, y1:y2, z1:z2]
            bbox_bv_label = reshape_image(bbox_bv_label.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_bv_label = bbox_bv_label.view(1,1,box_size,box_size,box_size)

            #bbox_image = get_bounding_box_image(image_1, (256,256,256)).to(device, dtype)
            bbox_image = image_1[:, :, x1:x2, y1:y2, z1:z2]
            bbox_image = reshape_image(bbox_image.squeeze(), box_size, box_size, box_size).to(device, dtype)
            bbox_image = bbox_image.view(1, 1, box_size, box_size, box_size)
            
            #bbox_iamge, bbox_bv_label, bbox_bv = get_bboxes(image_1_resize, bv_label_resize, bv_coarse_resize, 1, 200)
            
            bbox_concat = torch.cat([bbox_bv, bbox_image], dim=1)
            bbox_concat_2 = F.interpolate(bbox_concat, scale_factor=1/2, mode='trilinear', align_corners=True)
            bbox_concat_4 = F.interpolate(bbox_concat, scale_factor=1/4, mode='trilinear', align_corners=True)

            refine_out = refine_model(bbox_concat, bbox_concat_2, bbox_concat_4)
            #refine_out = refine_model(seg_image_concat)
            # do the inference

            #print(refine_out.shape)
            #print(bbox_bv_label.shape)

            loss = dice_loss(refine_out, bbox_bv_label)
            print(loss)
            train_losses.append(loss)
        
        loss = sum(train_losses) / BATCH_SIZE
        train_losses=[]
        epoch_loss += loss.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%1 == 0:
    # do validation every 5 epoches
        deeplab.eval()
        refine_model.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                val_losses = []
                for minibatch in range(BATCH_SIZE):
                    image_1 = vbatch['image1_data'][minibatch].to(device=device, dtype=dtype)
                    image_1 = image_1.view(1,1,256,256,256)

                    label_1 = vbatch['image1_label'][minibatch].to(device=device, dtype=dtype)
                    label_1 = label_1.view(1,3,256,256,256)

                    bv_label = label_1[:, 2, :, :, :]
                    bv_label = bv_label.view(1,1,256,256,256)
                    
                    # Get coarse output from deeplab model from 256 resolution input
                    out_coarse = deeplab(image_1)
                    out_coarse = out_coarse.view(1,3,256,256,256)

                    bv_coarse = out_coarse[:, 2, :, :, :]
                    bv_coarse = bv_coarse.view(1,1,256,256,256)
                    
                    box_size = 192
                    half_size = int(box_size / 2)

                    image_size_x = 256
                    image_size_y = 256
                    image_size_z = 256

                    x,y,z = loadbvcenter(binarize_output(bv_coarse).view(1,256,256,256))
                    x, y, z = np.clip([x, y, z], a_min=box_size-half_size, a_max=box_size+half_size)
                    x1 = max(x-half_size, 0)
                    x2 = min(x+half_size, image_size_x)
                    y1 = max(y-half_size, 0)
                    y2 = min(y+half_size, image_size_y)
                    z1 = max(z-half_size, 0)
                    z2 = min(z+half_size, image_size_z)


                    bbox_bv = bv_coarse.view(256,256,256)[x1:x2, y1:y2, z1:z2]
                    bbox_bv = reshape_image(bbox_bv.squeeze(), box_size, box_size, box_size).to(device, dtype)
                    bbox_bv = bbox_bv.view(1,1,box_size,box_size,box_size)

                    bbox_bv_label = bv_label.view(256,256,256)[x1:x2, y1:y2, z1:z2]
                    bbox_bv_label = reshape_image(bbox_bv_label.squeeze(), box_size, box_size, box_size).to(device, dtype)
                    bbox_bv_label = bbox_bv_label.view(1,1,box_size,box_size,box_size)

                    #bbox_image = get_bounding_box_image(image_1, (256,256,256)).to(device, dtype)
                    bbox_image = image_1[:, :, x1:x2, y1:y2, z1:z2]
                    bbox_image = reshape_image(bbox_image.squeeze(), box_size, box_size, box_size).to(device, dtype)
                    bbox_image = bbox_image.view(1, 1, box_size, box_size, box_size)

                    #bbox_iamge, bbox_bv_label, bbox_bv = get_bboxes(image_1_resize, bv_label_resize, bv_coarse_resize, 1, 200)

                    bbox_concat = torch.cat([bbox_bv, bbox_image], dim=1)
                    bbox_concat_2 = F.interpolate(bbox_concat, scale_factor=1/2, mode='trilinear', align_corners=True)
                    bbox_concat_4 = F.interpolate(bbox_concat, scale_factor=1/4, mode='trilinear', align_corners=True)

                    refine_out = refine_model(bbox_concat, bbox_concat_2, bbox_concat_4)
                    
                    loss = dice_loss(refine_out, bbox_bv_label)
                    val_losses.append(loss)
                
                avg_loss = sum(val_losses) / BATCH_SIZE
                val_losses = []
                print(avg_loss)
            
                # calculate loss
                valloss_1 += avg_loss.item()
                
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)
            
            if avg_val_loss < min_val:
                min_val = avg_val_loss
                save_1('refine_bv_refnet2', refine_model, optimizer, logger, e, scheduler)
            elif e % 10 == 0:
                save_1('refine_bv_refnet2', refine_model, optimizer, logger, e, scheduler)
            
            print(outstr)
            record.write(outstr)
            record.flush()
    


record.close()

  0%|          | 0/4999 [00:00<?, ?it/s]

tensor(0.9888, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9950, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9869, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9808, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9911, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9802, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9782, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9866, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9930, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9891, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9791, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9820, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9893, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9793, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9896, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9848, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9718, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.9759, device='cuda:0',

tensor(0.1378, device='cuda:0')
tensor(0.2878, device='cuda:0')
tensor(0.1723, device='cuda:0')
tensor(0.1953, device='cuda:0')
tensor(0.3876, device='cuda:0')
tensor(0.2375, device='cuda:0')
tensor(0.1605, device='cuda:0')
tensor(0.2699, device='cuda:0')
tensor(0.1169, device='cuda:0')
tensor(0.1835, device='cuda:0')


  0%|          | 1/4999 [10:15<853:51:49, 615.03s/it]

tensor(0.1403, device='cuda:0')
Checkpoint 1 saved !
------- 1st valloss=0.2090

tensor(0.3142, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.3105, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1186, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2387, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1518, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1499, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.3560, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.3478, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1314, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.4406, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1959, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1437, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1194, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1671, devi

tensor(0.2130, device='cuda:0')
tensor(0.2250, device='cuda:0')
tensor(0.2281, device='cuda:0')
tensor(0.1594, device='cuda:0')
tensor(0.1694, device='cuda:0')
tensor(0.2965, device='cuda:0')
tensor(0.2199, device='cuda:0')
tensor(0.1497, device='cuda:0')
tensor(0.1312, device='cuda:0')
tensor(0.1228, device='cuda:0')
tensor(0.2151, device='cuda:0')
tensor(0.1602, device='cuda:0')


  0%|          | 2/4999 [19:50<837:03:08, 603.04s/it]

Checkpoint 2 saved !
------- 1st valloss=0.1864

tensor(0.2581, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.3775, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1516, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1058, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2661, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.3134, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1890, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1778, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1939, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2176, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1288, device='cuda:0', grad_fn=<RsubBackward1>)
tensor

tensor(0.2187, device='cuda:0')
tensor(0.2916, device='cuda:0')
tensor(0.1485, device='cuda:0')
tensor(0.1676, device='cuda:0')
tensor(0.1739, device='cuda:0')
tensor(0.1548, device='cuda:0')
tensor(0.1565, device='cuda:0')
tensor(0.1628, device='cuda:0')
tensor(0.2204, device='cuda:0')
tensor(0.1717, device='cuda:0')
tensor(0.1419, device='cuda:0')


  0%|          | 3/4999 [29:21<823:46:54, 593.60s/it]

Checkpoint 3 saved !
------- 1st valloss=0.1654

tensor(0.1754, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1214, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1283, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1370, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1087, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1558, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1146, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.4271, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1148, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1130, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1724, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1837, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.3445, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1856, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2706, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1479, device='cuda:0', grad_fn=<Rs

tensor(0.2897, device='cuda:0')
tensor(0.1344, device='cuda:0')
tensor(0.2132, device='cuda:0')
tensor(0.2056, device='cuda:0')
tensor(0.1536, device='cuda:0')
tensor(0.1176, device='cuda:0')
tensor(0.1626, device='cuda:0')
tensor(0.1525, device='cuda:0')
tensor(0.2000, device='cuda:0')
tensor(0.1613, device='cuda:0')
tensor(0.1309, device='cuda:0')


  0%|          | 4/4999 [38:46<811:50:47, 585.11s/it]

Checkpoint 4 saved !
------- 1st valloss=0.1591

tensor(0.1475, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1751, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2252, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1762, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1487, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2001, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2129, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1404, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.4409, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1352, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2190, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2054, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2972, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1101, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1563, device='cuda:0', grad_fn=<RsubBa

tensor(0.2245, device='cuda:0')
tensor(0.1403, device='cuda:0')
tensor(0.1789, device='cuda:0')
tensor(0.1489, device='cuda:0')
tensor(0.1208, device='cuda:0')
tensor(0.1980, device='cuda:0')
tensor(0.1456, device='cuda:0')
tensor(0.1521, device='cuda:0')
tensor(0.1662, device='cuda:0')
tensor(0.1804, device='cuda:0')
tensor(0.1322, device='cuda:0')


  0%|          | 5/4999 [48:12<803:43:00, 579.37s/it]

Checkpoint 5 saved !
------- 1st valloss=0.1569

tensor(0.1816, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0872, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1257, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1978, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1192, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1362, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1641, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2353, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2117, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1619, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1810, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.3449, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1186, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1174, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1354, device='cuda:0', grad_fn=<RsubBa

tensor(0.1268, device='cuda:0')
tensor(0.1407, device='cuda:0')
tensor(0.1582, device='cuda:0')
tensor(0.1690, device='cuda:0')
tensor(0.1509, device='cuda:0')
tensor(0.1641, device='cuda:0')
tensor(0.1451, device='cuda:0')
tensor(0.1522, device='cuda:0')
tensor(0.1122, device='cuda:0')
tensor(0.2496, device='cuda:0')
tensor(0.1480, device='cuda:0')


  0%|          | 6/4999 [57:43<799:50:26, 576.69s/it]

Checkpoint 6 saved !
------- 1st valloss=0.1518

tensor(0.1445, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1215, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1081, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2216, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1262, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0899, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1196, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1055, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0943, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0995, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2155, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1009, device='cuda:0', grad_fn=<RsubBackward1>)
te

tensor(0.1426, device='cuda:0')
tensor(0.1644, device='cuda:0')
tensor(0.1506, device='cuda:0')
tensor(0.1800, device='cuda:0')
tensor(0.1357, device='cuda:0')
tensor(0.0973, device='cuda:0')
tensor(0.1504, device='cuda:0')
tensor(0.1832, device='cuda:0')
tensor(0.1243, device='cuda:0')
tensor(0.1852, device='cuda:0')
tensor(0.2773, device='cuda:0')


  0%|          | 7/4999 [1:07:12<796:28:10, 574.38s/it]

Checkpoint 7 saved !
------- 1st valloss=0.1516

tensor(0.1647, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1698, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2158, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1006, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1813, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1567, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1854, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2075, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1623, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1928, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1863, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2189, device='cuda:0', grad_fn=<RsubBackward1>)
te

tensor(0.1603, device='cuda:0')
tensor(0.1367, device='cuda:0')
tensor(0.1563, device='cuda:0')
tensor(0.2209, device='cuda:0')
tensor(0.1374, device='cuda:0')
tensor(0.1153, device='cuda:0')
tensor(0.1282, device='cuda:0')
tensor(0.1960, device='cuda:0')
tensor(0.0847, device='cuda:0')
tensor(0.1752, device='cuda:0')
tensor(0.1342, device='cuda:0')


  0%|          | 8/4999 [1:16:39<793:17:13, 572.20s/it]

Checkpoint 8 saved !
------- 1st valloss=0.1507

tensor(0.0903, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1939, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1133, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1834, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1816, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1788, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1947, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.4506, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1154, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2362, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1467, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1314, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1151, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2215, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1546, device='cuda:0', grad_fn=<RsubBa

tensor(0.1925, device='cuda:0')
tensor(0.1283, device='cuda:0')
tensor(0.1191, device='cuda:0')
tensor(0.1647, device='cuda:0')
tensor(0.1548, device='cuda:0')
tensor(0.1366, device='cuda:0')
tensor(0.1839, device='cuda:0')
tensor(0.2247, device='cuda:0')
tensor(0.1009, device='cuda:0')
tensor(0.1412, device='cuda:0')
tensor(0.1612, device='cuda:0')


  0%|          | 9/4999 [1:26:07<791:24:13, 570.95s/it]

------- 1st valloss=0.1583

tensor(0.1582, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1022, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(1., device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1151, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1447, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1076, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2114, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1356, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.0901, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1720, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1481, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1520, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1653, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1703, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1436, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.2135, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(0.1082, device='cuda:0', grad_fn=<RsubBackward1>)
tensor(

In [None]:
print(bv_coarse.shape)

In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
        # move data to device, convert dtype to desirable dtype
        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()
        
        if bv.item() >= 0.2 or bd.item() >= 0.1:
            show_image_slice(image_1)
            show_image_slice(label_1)
            show_image_slice(output)

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)