# 6. BBox Regression Network

In [6]:
import numpy as np
import matplotlib.pyplot as plt
import time

from tqdm import tqdm
from dataset import *
from vnet import *
from training import *
from niiutility import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import sys

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

## 6.1 Setup Torch Global Variable, load memory map 

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 2
NUM_WORKERS = 8
NUM_TRAIN = 259
LEARNING_RATE = 1e-3

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    # Some Magical Flags
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    print('using GPU for training')
else:
    device = torch.device('cpu')



using GPU for training


* BvMaskDataset, return image and bbox tuple of 6

In [9]:
#-------------------------LOAD THE DATA SET-------------------------------------------#
regen = False

if regen:
    data_index = np.arange(370)
    data_idnex = np.stack ([np.random.shuffle(data_index[:259]),np.random.shuffle(data_index[259:])])
    print(list(data_index))
else:
    data_index = np.array ([120, 148, 24, 124, 200, 76, 159, 186, 125, 164, 251, 155, 0, 252, 238, 103, 53, 179, 244, 149, 45, 31, 131, 115, 82, 216, 130, 213, 245, 199, 229, 254, 56, 158, 32, 86, 221, 84, 81, 196, 117, 38, 28, 218, 257, 7, 237, 181, 105, 194, 16, 104, 75, 78, 249, 87, 144, 1, 183, 203, 54, 255, 129, 253, 202, 25, 34, 132, 80, 89, 137, 201, 114, 189, 110, 4, 71, 195, 97, 33, 157, 21, 250,
                            192, 258, 49, 47, 119, 191, 217, 143, 68, 190, 11, 176, 206, 108, 226, 50, 69, 118, 61, 35, 57, 243, 154, 15, 102, 146, 174, 163, 156, 233, 37, 180, 100, 184, 55, 239, 135, 151, 101, 205, 220, 169, 134, 228, 234, 51, 145, 29, 207, 141, 142, 44, 175, 12, 198, 52, 8, 30, 17, 10, 2, 126, 256, 40, 85, 46, 139, 178, 235, 23, 70, 188, 209, 93, 5, 153, 172, 127, 64, 241, 182, 18, 236, 187, 79, 210, 96,
                            3, 99, 63, 123, 171, 48, 6, 165, 43, 9, 230, 211, 19, 242, 162, 161, 173, 73, 106, 59, 136, 90, 112, 167, 246, 227, 109, 225, 41, 160, 133, 22, 177, 168, 14, 152, 107, 122, 223, 74, 62, 185, 222, 13, 150, 42, 212, 248, 147, 128, 67, 83, 214, 94, 98, 219, 232, 58, 247, 88, 66, 140, 116, 92, 113, 215, 27, 60, 138, 231, 39, 72, 166, 170, 91, 77, 224, 208, 240, 95, 26, 204, 197, 121, 36, 111, 193,
                            65, 20, 324, 330, 325, 262, 351, 366, 269, 365, 369, 297, 316, 363, 293, 267, 302, 313, 352, 307, 335, 290, 356, 286, 328, 320, 340, 333, 322, 355, 315, 296, 299, 312, 341, 261, 306, 282, 283, 292, 298, 321, 346, 310, 361, 265, 314, 331, 358, 305, 349, 276, 285, 368, 271, 317, 367, 336, 279, 275, 323, 277, 281, 337, 309, 260, 357, 266, 278, 294, 319, 332, 273, 303, 280, 284, 304, 334, 360, 300, 353, 308, 345, 364, 311, 270, 362, 339, 289, 342, 348, 263, 287, 274, 295, 327, 268, 272, 318, 354, 259, 329, 350, 359, 344, 338, 343, 288, 291, 347, 264, 326, 301])

# I double checked SubsetRandomSampler and find that the data_index is sort of useless here

dataset_trans = DatasetBVSegmentation(data_index,
                         transform=transforms.Compose([
                             RandomAffineOld(180, 30),
                             RandomFilp(0.5)
                         ])
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_trans)
train_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_trans, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

In [None]:
for i_batch, sample_batched in enumerate(train_loader):
    print(i_batch, sample_batched['image'].size(), \
          sample_batched['label'].size())
    show_batch_image(sample_batched['image'],BATCH_SIZE,None)

    # observe 4th batch and stop.
    if i_batch == 3:
        break

## Loading the Region Proposal Network
* using checkpoint2019-05-01 09:49:28.653642.pth
* taking tensor of size batch, 1, X, Y, Z
* returning 

In [10]:
def train_seg(model, traindata, valdata, optimizer, scheduler, device, dtype, lossFun, logger, epochs=1, startepoch=0, usescheduler=False):
    """
    Train a model with an optimizer
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    cirrculum = 0
    N = len(traindata)
    for e in range(epochs):
        epoch_loss = 0
        for t, batch in tqdm(enumerate(traindata)):
            model.train()  # put model to training mode
            x = batch['image']
            y = batch['label']
            xhalf = batch['half']
            x = x.numpy()  
            y = y.numpy()
            xhalf = xhalf.numpy()

            # Getting the bbox from region proposal network
            batchSize = xhalf.shape[0]

            xslice = np.zeros(xhalf.shape) - 64 #same mean removal
            yslice = np.zeros(xhalf.shape) # generate ground truth mask

            for b in range(batchSize):

                ysingle = y[b]
                x1,x2,y1,y2,z1,z2 = loadbvmask(ysingle) # get ground truth bv box tuple
                xd = x2-x1 # x distance
                yd = y2-y1
                zd = z2-z1
                
                xd, yd, zd = np.clip ([xd, yd, zd], a_min=0, a_max=127)

                xo = np.random.randint(128 - xd) # x offset
                yo = np.random.randint(128 - yd) # x offset
                zo = np.random.randint(128 - zd) # x offset

                xx1 = x1 - xo
                yy1 = y1 - yo
                zz1 = z1 - zo

                xx1, yy1, zz1 = np.clip ([xx1, yy1, zz1], a_min=0, a_max=128)

                xslice[b] = x[b, :, xx1:xx1+128, yy1:yy1+128, zz1:zz1+128]
                yslice[b] = y[b, :, xx1:xx1+128, yy1:yy1+128, zz1:zz1+128]
            
            xslice = torch.from_numpy(xslice)
            yslice = torch.from_numpy(yslice)
            
            xslice = xslice.to(device=device, dtype=dtype)
            yslice = yslice.to(device=device, dtype=dtype)
            
            scores = model(xslice)
            loss = lossFun(scores, yslice, cirrculum=cirrculum)

            # avoid gradient
            epoch_loss += loss.item()

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

        print('Epoch {0} finished ! Training Loss: {1:.4f}'.format(e + startepoch, epoch_loss/N))
        
        # Get validation loss
        loss_val = check_accuracy_seg(model, valdata, device, dtype, 
            cirrculum=cirrculum, lossFun=lossFun)
        
        logger['train'].append(epoch_loss/N)
        logger['validation'].append(loss_val)

        # Taking a scheduler step on validation loss
        if usescheduler:
            scheduler.step(loss_val)
        # else pass

        if (e+startepoch)%50 == 0:
            model_save_path = 'checkpoint' + str(datetime.datetime.now())+'.pth'
            state = {'epoch': e+startepoch + 1, 'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'logger': logger}
            torch.save(state, model_save_path)
            print('Checkpoint {} saved !'.format(e+startepoch + 1))
            
def check_accuracy_seg(model, dataloader, device, dtype, cirrculum, lossFun):
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        loss = 0
        N = len(dataloader)
        for t, batch in tqdm(enumerate(dataloader)):
            x = batch['image']
            y = batch['label']
            xhalf = batch['half']
            x = x.numpy()  
            y = y.numpy()
            xhalf = xhalf.numpy()

            # Getting the bbox from region proposal network
            batchSize = xhalf.shape[0]

            xslice = np.zeros(xhalf.shape) - 64 #same mean removal
            yslice = np.zeros(xhalf.shape) # generate ground truth mask

            for b in range(batchSize):

                ysingle = y[b]
                x1,x2,y1,y2,z1,z2 = loadbvmask(ysingle) # get ground truth bv box tuple
                
                xd = x2-x1 # x distance
                yd = y2-y1
                zd = z2-z1
                
                xd, yd, zd = np.clip ([xd, yd, zd], a_min=0, a_max=127)

                xo = np.random.randint(128 - xd) # x offset
                yo = np.random.randint(128 - yd) # x offset
                zo = np.random.randint(128 - zd) # x offset

                xx1 = x1 - xo
                yy1 = y1 - yo
                zz1 = z1 - zo

                xx1, yy1, zz1 = np.clip ([xx1, yy1, zz1], a_min=0, a_max=128)

                xslice[b] = x[b, :, xx1:xx1+128, yy1:yy1+128, zz1:zz1+128]
                yslice[b] = y[b, :, xx1:xx1+128, yy1:yy1+128, zz1:zz1+128]
            
            xslice = torch.from_numpy(xslice)
            yslice = torch.from_numpy(yslice)
            
            xslice = xslice.to(device=device, dtype=dtype)
            yslice = yslice.to(device=device, dtype=dtype)
            
            scores = model(xslice)

            loss += lossFun(scores, yslice, cirrculum=cirrculum)

        print('     validation loss = {0:.4f}'.format(loss/N))
        return loss/N


In [11]:
LoadCKP = True

CKPPath = 'checkpoint2019-05-11 06:23:50.403772.pth'

model = VNet(classnum=1, slim=True)
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=100, verbose=True)
logger = {'train':[], 'validation':[]}

if LoadCKP:
    loadckp(model, optimizer, scheduler, logger, CKPPath, device=device)

loading checkpoint 'checkpoint2019-05-11 06:23:50.403772.pth'
loaded checkpoint 'checkpoint2019-05-11 06:23:50.403772.pth' (epoch 701)


In [None]:
#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

from loss import *

train_seg(model, train_loader, validation_loader, optimizer, scheduler,\
      device=device, dtype=dtype, lossFun=dice_loss, logger=logger, epochs=5000, startepoch=700)

130it [08:05,  8.35s/it]

Epoch 350 finished ! Training Loss: 0.1037



56it [01:46,  1.20s/it]


     validation loss = 0.1587
Checkpoint 351 saved !


130it [07:27,  2.91s/it]

Epoch 351 finished ! Training Loss: 0.1040



56it [01:47,  1.14s/it]


     validation loss = 0.1167


130it [07:24,  2.92s/it]

Epoch 352 finished ! Training Loss: 0.1007



56it [01:46,  1.14s/it]


     validation loss = 0.0886


130it [07:25,  2.93s/it]

Epoch 353 finished ! Training Loss: 0.1015



56it [01:46,  1.21s/it]


     validation loss = 0.0950


130it [07:26,  2.92s/it]

Epoch 354 finished ! Training Loss: 0.1016



56it [01:48,  1.24s/it]


     validation loss = 0.0870


130it [07:27,  2.92s/it]

Epoch 355 finished ! Training Loss: 0.1017



56it [01:49,  1.17s/it]


     validation loss = 0.0901


130it [07:25,  2.92s/it]

Epoch 356 finished ! Training Loss: 0.1017



56it [01:43,  1.17s/it]


     validation loss = 0.0867


130it [07:25,  2.92s/it]

Epoch 357 finished ! Training Loss: 0.1013



56it [01:45,  1.15s/it]


     validation loss = 0.0873


130it [07:24,  2.92s/it]


Epoch 358 finished ! Training Loss: 0.1039


56it [01:47,  1.14s/it]


     validation loss = 0.0865


130it [07:26,  2.93s/it]

Epoch 359 finished ! Training Loss: 0.1038



56it [01:42,  1.11s/it]


     validation loss = 0.0851


130it [07:26,  2.92s/it]

Epoch 360 finished ! Training Loss: 0.1006



56it [01:50,  1.14s/it]


     validation loss = 0.0937


130it [07:24,  2.92s/it]

Epoch 361 finished ! Training Loss: 0.0997



56it [01:45,  1.11s/it]


     validation loss = 0.0863


130it [07:21,  2.92s/it]

Epoch 362 finished ! Training Loss: 0.1023



56it [01:50,  1.15s/it]


     validation loss = 0.0842


130it [07:23,  2.91s/it]

Epoch 363 finished ! Training Loss: 0.1004



56it [01:45,  1.21s/it]


     validation loss = 0.0808


130it [07:27,  2.93s/it]

Epoch 364 finished ! Training Loss: 0.1008



56it [01:44,  1.15s/it]


     validation loss = 0.0810


130it [07:23,  2.93s/it]

Epoch 365 finished ! Training Loss: 0.1007



56it [01:53,  1.13s/it]


     validation loss = 0.0899


130it [07:23,  2.92s/it]

Epoch 366 finished ! Training Loss: 0.0999



56it [01:45,  1.20s/it]


     validation loss = 0.0837


130it [07:24,  2.92s/it]

Epoch 367 finished ! Training Loss: 0.1012



56it [01:45,  1.12s/it]


     validation loss = 0.0832


130it [07:27,  2.93s/it]

Epoch 368 finished ! Training Loss: 0.0994



56it [01:47,  1.11s/it]


     validation loss = 0.0868


130it [07:25,  2.92s/it]

Epoch 369 finished ! Training Loss: 0.1014



56it [01:48,  1.18s/it]


     validation loss = 0.0808


130it [07:24,  2.92s/it]

Epoch 370 finished ! Training Loss: 0.1009



56it [01:51,  1.29s/it]


     validation loss = 0.0842


130it [07:23,  2.91s/it]

Epoch 371 finished ! Training Loss: 0.1040



56it [01:49,  1.10s/it]


     validation loss = 0.0912


130it [07:24,  2.93s/it]

Epoch 372 finished ! Training Loss: 0.0999



56it [01:47,  1.16s/it]


     validation loss = 0.0900


130it [07:25,  2.92s/it]

Epoch 373 finished ! Training Loss: 0.1010



56it [01:46,  1.24s/it]


     validation loss = 0.0813


130it [07:26,  2.92s/it]

Epoch 374 finished ! Training Loss: 0.1011



56it [01:48,  1.08s/it]


     validation loss = 0.0888


130it [07:27,  2.92s/it]

Epoch 375 finished ! Training Loss: 0.1005



56it [01:44,  1.20s/it]


     validation loss = 0.0839


130it [07:26,  2.92s/it]

Epoch 376 finished ! Training Loss: 0.1005



56it [01:46,  1.15s/it]


     validation loss = 0.0829


130it [07:26,  2.91s/it]

Epoch 377 finished ! Training Loss: 0.1002



56it [01:48,  1.20s/it]


     validation loss = 0.0900


130it [07:26,  2.93s/it]

Epoch 378 finished ! Training Loss: 0.0999



56it [01:45,  1.14s/it]


     validation loss = 0.0807


130it [07:27,  2.92s/it]

Epoch 379 finished ! Training Loss: 0.1023



56it [01:47,  1.12s/it]


     validation loss = 0.0856


130it [07:26,  2.91s/it]


Epoch 380 finished ! Training Loss: 0.0993


56it [01:46,  1.14s/it]


     validation loss = 0.0849


130it [07:25,  2.92s/it]

Epoch 381 finished ! Training Loss: 0.1023



56it [01:49,  1.22s/it]


     validation loss = 0.0808


130it [07:24,  2.92s/it]

Epoch 382 finished ! Training Loss: 0.0991



56it [01:43,  1.15s/it]


     validation loss = 0.0808


130it [07:26,  2.92s/it]

Epoch 383 finished ! Training Loss: 0.0998



56it [01:45,  1.13s/it]


     validation loss = 0.0904


130it [07:25,  2.92s/it]

Epoch 384 finished ! Training Loss: 0.0998



56it [01:47,  1.14s/it]


     validation loss = 0.0832


130it [07:24,  2.91s/it]

Epoch 385 finished ! Training Loss: 0.0994



56it [01:43,  1.18s/it]


     validation loss = 0.0802


130it [07:26,  2.91s/it]

Epoch 386 finished ! Training Loss: 0.1001



56it [01:46,  1.15s/it]


     validation loss = 0.0814


130it [07:26,  2.91s/it]

Epoch 387 finished ! Training Loss: 0.1007



56it [01:46,  1.19s/it]


     validation loss = 0.0892


130it [07:25,  2.92s/it]

Epoch 388 finished ! Training Loss: 0.1011



56it [01:47,  1.19s/it]


     validation loss = 0.0818


130it [07:25,  2.94s/it]

Epoch 389 finished ! Training Loss: 0.1005



56it [01:51,  1.17s/it]


     validation loss = 0.0813


130it [07:25,  2.92s/it]

Epoch 390 finished ! Training Loss: 0.1007



56it [01:41,  1.10s/it]


     validation loss = 0.0829


130it [07:23,  2.92s/it]

Epoch 391 finished ! Training Loss: 0.1003



56it [01:52,  1.24s/it]


     validation loss = 0.0824


130it [07:24,  2.91s/it]

Epoch 392 finished ! Training Loss: 0.1003



56it [01:45,  1.11s/it]


     validation loss = 0.0830


130it [07:26,  2.92s/it]

Epoch 393 finished ! Training Loss: 0.0997



56it [01:45,  1.25s/it]


     validation loss = 0.0836


130it [07:25,  2.92s/it]

Epoch 394 finished ! Training Loss: 0.0995



56it [01:45,  1.16s/it]


     validation loss = 0.0829


130it [07:24,  2.92s/it]

Epoch 395 finished ! Training Loss: 0.0995



56it [01:44,  1.15s/it]


     validation loss = 0.0848


130it [07:23,  2.91s/it]

Epoch 396 finished ! Training Loss: 0.1005



56it [01:53,  1.21s/it]


     validation loss = 0.0827


122it [07:01,  3.30s/it]

In [None]:
#-------------------------SAVE THE MODEL STATE DICT----------------------------------#
PATH = 'SEGNET-504.pth'
torch.save(model.state_dict(), PATH)

In [None]:
filename ='checkpoint2019-05-01 09:49:28.653642.pth'

RPNcheckpoint = torch.load(filename, map_location=torch.device('cpu'))

RPN = LNet(img_size=(128, 128, 128), out_size=6)
RPN.load_state_dict(RPNcheckpoint['state_dict'])
RPN = RPN.eval()
RPN = RPN.to(device=device)

In [None]:
for e in range(500):
    print('epoch {} begins' .format(e))
    for t, batch in tqdm(enumerate(train_loader)):
        x = batch['image']
        y = batch['label']
        xhalf = batch['half']
        x = x.numpy()  
        y = y.numpy()
        xhalf = xhalf.numpy()

        # Getting the bbox from region proposal network
        batchSize = xhalf.shape[0]

        xslice = np.zeros(xhalf.shape) - 64 #same mean removal
        yslice = np.zeros(xhalf.shape) # generate ground truth mask

        for b in range(batchSize):

            ysingle = y[b]
            x1,x2,y1,y2,z1,z2 = loadbvmask(ysingle) # get ground truth bv box tuple
            xd = x2-x1 # x distance
            yd = y2-y1
            zd = z2-z1

            xd, yd, zd = np.clip ([xd, yd, zd], a_min=0, a_max=127)

            xo = np.random.randint(128 - xd) # x offset
            yo = np.random.randint(128 - yd) # x offset
            zo = np.random.randint(128 - zd) # x offset

            xx1 = x1 - xo
            yy1 = y1 - yo
            zz1 = z1 - zo

            xx1, yy1, zz1 = np.clip ([xx1, yy1, zz1], a_min=0, a_max=128)

            xslice[b] = x[b, :, xx1:xx1+128, yy1:yy1+128, zz1:zz1+128]
            yslice[b] = y[b, :, xx1:xx1+128, yy1:yy1+128, zz1:zz1+128]

In [None]:
'''
for t, batch in enumerate(validation_loader):
    x = batch['image']
    y = batch['label']
    xhalf = batch['half']
    x = x.numpy()  # move to device, e.g. GPU
    y = y.numpy()
    xhalf = xhalf.to(device=device, dtype=dtype)
    
    # Getting the bbox from region proposal network
    with torch.no_grad():
        BBox = RPN(xhalf)
    
    BBox = BBox.cpu().numpy()
    batchSize = BBox.shape[0]

    thold = 18

    for b in range(batchSize):

        box = BBox[b]

        # Avoid inverse accident
        x1 = np.min(box[:2])
        x2 = np.max(box[:2])
        y1 = np.min(box[2:4])
        y2 = np.max(box[2:4])
        z1 = np.min(box[4:])
        z2 = np.max(box[4:])

        # Add threshold
        x1 = np.max((0,   x1-thold))
        x2 = np.min((127, x2+thold))
        y1 = np.max((0,   y1-thold))
        y2 = np.min((127, y2+thold))
        z1 = np.max((0,   z1-thold))
        z2 = np.min((127, z2+thold))
        
        x1,x2,y1,y2,z1,z2 = int(x1*2), int(x2*2), int(y1*2), int(y2*2), int(z1*2), int(z2*2)
        # Round data
        xslice = x[b, 0, x1:x2+1, y1:y2+1, z1:z2+1]
        print(np.max(xslice.shape)/128)
        xrescale = np.zeros([128,128,128])
'''