In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:

deeplab = DeepLabModified(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

In [None]:
'''
deeplab = DeepLabModified(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

checkpoint = torch.load('../deeplab_modified_save/2019-08-12 17:24:36.436884 epoch: 290.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-3)
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=8)
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)
print(sum(checkpoint['logger']['validation_1']) / len(checkpoint['logger']['validation_1']))
for param_group in optimizer.param_groups:
    print(param_group['lr'])
'''

"\ndeeplab = DeepLabModified(output_stride=16)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\n\ncheckpoint = torch.load('../deeplab_modified_save/2019-08-12 17:24:36.436884 epoch: 290.pth') # latest one\n\ndeeplab.load_state_dict(checkpoint['state_dict_1'])\ndeeplab = deeplab.to(device, dtype)\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-3)\noptimizer.load_state_dict(checkpoint['optimizer'])\n\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=8)\nscheduler.load_state_dict(checkpoint['scheduler'])\n\nepoch = checkpoint['epoch']\nprint(epoch)\nprint(sum(checkpoint['logger']['validation_1']) / len(checkpoint['logger']['validation_1']))\nfor param_group in optimizer.param_groups:\n    print(param_group['lr'])\n"

In [None]:
epochs = 5000

min_val = 1

record = open('train_deeplab_modified2.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%3 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            if e > 200:
                scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                save_1('deeplab_modified_save2', deeplab, optimizer, logger, e, scheduler)
            elif e % 15 == 0:
                save_1('deeplab_modified_save2', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [23:07<1926:45:19, 1387.82s/it]

Epoch 1 finished ! Training Loss: 0.5194



  0%|          | 2/4999 [42:44<1838:28:05, 1324.49s/it]

Epoch 2 finished ! Training Loss: 0.4770

Epoch 3 finished ! Training Loss: 0.4699

------- 1st valloss=0.6226

0.6226444399875143 less than 1


  0%|          | 3/4999 [1:02:54<1790:21:19, 1290.09s/it]

Checkpoint 3 saved !


  0%|          | 4/4999 [1:22:35<1744:31:43, 1257.32s/it]

Epoch 4 finished ! Training Loss: 0.4599



  0%|          | 5/4999 [1:41:57<1704:31:14, 1228.73s/it]

Epoch 5 finished ! Training Loss: 0.4562

Epoch 6 finished ! Training Loss: 0.4505



  0%|          | 6/4999 [2:02:15<1699:51:04, 1225.61s/it]

------- 1st valloss=0.6399



  0%|          | 7/4999 [2:21:40<1674:25:41, 1207.52s/it]

Epoch 7 finished ! Training Loss: 0.4514



  0%|          | 8/4999 [2:41:17<1661:07:24, 1198.17s/it]

Epoch 8 finished ! Training Loss: 0.4428

Epoch 9 finished ! Training Loss: 0.4385

------- 1st valloss=0.5070

0.5069867165192313 less than 0.6226444399875143


  0%|          | 9/4999 [3:01:38<1670:21:47, 1205.07s/it]

Checkpoint 9 saved !


  0%|          | 10/4999 [3:21:20<1660:28:37, 1198.18s/it]

Epoch 10 finished ! Training Loss: 0.4410



  0%|          | 11/4999 [3:40:58<1651:49:51, 1192.18s/it]

Epoch 11 finished ! Training Loss: 0.4358

Epoch 12 finished ! Training Loss: 0.4293

------- 1st valloss=0.4991

0.49909668901692267 less than 0.5069867165192313


  0%|          | 12/4999 [4:01:13<1661:00:26, 1199.04s/it]

Checkpoint 12 saved !


  0%|          | 13/4999 [4:20:41<1647:50:16, 1189.77s/it]

Epoch 13 finished ! Training Loss: 0.4208



  0%|          | 14/4999 [4:40:10<1638:31:51, 1183.29s/it]

Epoch 14 finished ! Training Loss: 0.4039

Epoch 15 finished ! Training Loss: 0.3768

------- 1st valloss=0.6746



  0%|          | 15/4999 [5:00:27<1652:27:48, 1193.59s/it]

Checkpoint 15 saved !


  0%|          | 16/4999 [5:20:10<1647:36:20, 1190.32s/it]

Epoch 16 finished ! Training Loss: 0.3401



  0%|          | 17/4999 [5:39:42<1639:42:48, 1184.86s/it]

Epoch 17 finished ! Training Loss: 0.3149

Epoch 18 finished ! Training Loss: 0.2991

------- 1st valloss=0.4025

0.4024557924788931 less than 0.49909668901692267


  0%|          | 18/4999 [6:00:03<1654:30:27, 1195.79s/it]

Checkpoint 18 saved !


  0%|          | 19/4999 [6:19:40<1646:24:10, 1190.17s/it]

Epoch 19 finished ! Training Loss: 0.2818



  0%|          | 20/4999 [6:39:23<1642:46:51, 1187.79s/it]

Epoch 20 finished ! Training Loss: 0.2674

Epoch 21 finished ! Training Loss: 0.2651



  0%|          | 21/4999 [6:59:55<1660:46:29, 1201.04s/it]

------- 1st valloss=0.6718



  0%|          | 22/4999 [7:19:44<1655:28:44, 1197.45s/it]

Epoch 22 finished ! Training Loss: 0.2650



  0%|          | 23/4999 [7:39:19<1645:49:29, 1190.71s/it]

Epoch 23 finished ! Training Loss: 0.2467

Epoch 24 finished ! Training Loss: 0.2387



  0%|          | 24/4999 [7:59:53<1663:48:54, 1203.97s/it]

------- 1st valloss=0.6940



  1%|          | 25/4999 [8:19:23<1649:01:40, 1193.51s/it]

Epoch 25 finished ! Training Loss: 0.2403



  1%|          | 26/4999 [8:39:08<1645:19:10, 1191.06s/it]

Epoch 26 finished ! Training Loss: 0.2310

Epoch 27 finished ! Training Loss: 0.2184



  1%|          | 27/4999 [8:59:25<1655:42:36, 1198.82s/it]

------- 1st valloss=0.5413



  1%|          | 28/4999 [9:19:08<1648:54:59, 1194.15s/it]

Epoch 28 finished ! Training Loss: 0.2115



  1%|          | 29/4999 [9:38:32<1636:11:59, 1185.17s/it]

Epoch 29 finished ! Training Loss: 0.2133

Epoch 30 finished ! Training Loss: 0.2023

------- 1st valloss=0.4431



  1%|          | 30/4999 [9:58:43<1646:27:55, 1192.85s/it]

Checkpoint 30 saved !


  1%|          | 31/4999 [10:18:27<1642:22:09, 1190.12s/it]

Epoch 31 finished ! Training Loss: 0.2090



  1%|          | 32/4999 [10:38:11<1639:26:11, 1188.24s/it]

Epoch 32 finished ! Training Loss: 0.1898

Epoch 33 finished ! Training Loss: 0.1836



  1%|          | 33/4999 [10:58:25<1650:00:15, 1196.14s/it]

------- 1st valloss=0.5519



  1%|          | 34/4999 [11:18:06<1643:15:05, 1191.48s/it]

Epoch 34 finished ! Training Loss: 0.1908



  1%|          | 35/4999 [11:37:27<1630:14:08, 1182.28s/it]

Epoch 35 finished ! Training Loss: 0.1695

Epoch 36 finished ! Training Loss: 0.1834

------- 1st valloss=0.2741

0.27405652339043823 less than 0.4024557924788931


  1%|          | 36/4999 [11:57:33<1639:41:28, 1189.38s/it]

Checkpoint 36 saved !


  1%|          | 37/4999 [12:17:01<1630:47:34, 1183.16s/it]

Epoch 37 finished ! Training Loss: 0.1703



  1%|          | 38/4999 [12:36:49<1632:13:18, 1184.44s/it]

Epoch 38 finished ! Training Loss: 0.1687

Epoch 39 finished ! Training Loss: 0.1633

------- 1st valloss=0.1274

0.1274323175134866 less than 0.27405652339043823


  1%|          | 39/4999 [12:57:00<1642:59:47, 1192.50s/it]

Checkpoint 39 saved !


  1%|          | 40/4999 [13:16:31<1633:43:22, 1186.01s/it]

Epoch 40 finished ! Training Loss: 0.1613



  1%|          | 41/4999 [13:36:05<1628:23:43, 1182.38s/it]

Epoch 41 finished ! Training Loss: 0.1656

Epoch 42 finished ! Training Loss: 0.1503



  1%|          | 42/4999 [13:56:09<1637:05:00, 1188.92s/it]

------- 1st valloss=0.6367



  1%|          | 43/4999 [14:15:39<1629:01:51, 1183.32s/it]

Epoch 43 finished ! Training Loss: 0.1359



  1%|          | 44/4999 [14:35:06<1621:51:07, 1178.34s/it]

Epoch 44 finished ! Training Loss: 0.1427

Epoch 45 finished ! Training Loss: 0.1452

------- 1st valloss=0.6739



  1%|          | 45/4999 [14:55:31<1640:41:40, 1192.27s/it]

Checkpoint 45 saved !


  1%|          | 46/4999 [15:15:04<1632:36:48, 1186.64s/it]

Epoch 46 finished ! Training Loss: 0.1385



  1%|          | 47/4999 [15:34:43<1628:59:37, 1184.24s/it]

Epoch 47 finished ! Training Loss: 0.1383

Epoch 48 finished ! Training Loss: 0.1321



  1%|          | 48/4999 [15:55:06<1644:43:03, 1195.92s/it]

------- 1st valloss=0.2048



  1%|          | 49/4999 [16:14:52<1640:16:49, 1192.93s/it]

Epoch 49 finished ! Training Loss: 0.1287



  1%|          | 50/4999 [16:34:27<1632:39:51, 1187.63s/it]

Epoch 50 finished ! Training Loss: 0.1284

Epoch 51 finished ! Training Loss: 0.1388



  1%|          | 51/4999 [16:54:48<1645:52:21, 1197.48s/it]

------- 1st valloss=0.5500



  1%|          | 52/4999 [17:14:07<1629:58:42, 1186.16s/it]

Epoch 52 finished ! Training Loss: 0.1346



  1%|          | 53/4999 [17:33:38<1623:05:23, 1181.38s/it]

Epoch 53 finished ! Training Loss: 0.1278

Epoch 54 finished ! Training Loss: 0.1343



  1%|          | 54/4999 [17:54:09<1643:18:27, 1196.34s/it]

------- 1st valloss=0.5391



  1%|          | 55/4999 [18:13:35<1630:30:01, 1187.26s/it]

Epoch 55 finished ! Training Loss: 0.1266



  1%|          | 56/4999 [18:33:03<1622:04:02, 1181.36s/it]

Epoch 56 finished ! Training Loss: 0.1146

Epoch 57 finished ! Training Loss: 0.1168



  1%|          | 57/4999 [18:53:24<1638:02:25, 1193.23s/it]

------- 1st valloss=0.2326



  1%|          | 58/4999 [19:12:53<1628:02:32, 1186.19s/it]

Epoch 58 finished ! Training Loss: 0.1138



  1%|          | 59/4999 [19:32:20<1619:43:01, 1180.36s/it]

Epoch 59 finished ! Training Loss: 0.1090

Epoch 60 finished ! Training Loss: 0.1161

------- 1st valloss=0.5480



  1%|          | 60/4999 [19:52:45<1637:46:09, 1193.76s/it]

Checkpoint 60 saved !


  1%|          | 61/4999 [20:12:26<1631:59:38, 1189.79s/it]

Epoch 61 finished ! Training Loss: 0.1178



  1%|          | 62/4999 [20:32:01<1625:43:53, 1185.46s/it]

Epoch 62 finished ! Training Loss: 0.1165

Epoch 63 finished ! Training Loss: 0.1198



  1%|▏         | 63/4999 [20:52:26<1641:31:18, 1197.22s/it]

------- 1st valloss=0.5165



  1%|▏         | 64/4999 [21:12:02<1632:44:20, 1191.06s/it]

Epoch 64 finished ! Training Loss: 0.1192



  1%|▏         | 65/4999 [21:31:56<1633:19:44, 1191.73s/it]

Epoch 65 finished ! Training Loss: 0.1122

Epoch 66 finished ! Training Loss: 0.1065



  1%|▏         | 66/4999 [21:52:02<1639:09:50, 1196.23s/it]

------- 1st valloss=0.2130



  1%|▏         | 67/4999 [22:11:43<1632:14:30, 1191.42s/it]

Epoch 67 finished ! Training Loss: 0.1082



  1%|▏         | 68/4999 [22:31:02<1618:56:58, 1181.95s/it]

Epoch 68 finished ! Training Loss: 0.1098

Epoch 69 finished ! Training Loss: 0.1067



  1%|▏         | 69/4999 [22:51:30<1637:24:01, 1195.67s/it]

------- 1st valloss=0.5936



  1%|▏         | 70/4999 [23:11:04<1628:19:35, 1189.28s/it]

Epoch 70 finished ! Training Loss: 0.1098



  1%|▏         | 71/4999 [23:30:32<1618:57:47, 1182.68s/it]

Epoch 71 finished ! Training Loss: 0.1084

Epoch 72 finished ! Training Loss: 0.1057



  1%|▏         | 72/4999 [23:50:58<1636:24:45, 1195.67s/it]

------- 1st valloss=0.6118



  1%|▏         | 73/4999 [24:10:36<1629:07:09, 1190.59s/it]

Epoch 73 finished ! Training Loss: 0.1093



  1%|▏         | 74/4999 [24:30:26<1628:21:55, 1190.28s/it]

Epoch 74 finished ! Training Loss: 0.1056

Epoch 75 finished ! Training Loss: 0.1040

------- 1st valloss=0.1820



  2%|▏         | 75/4999 [24:50:40<1637:37:12, 1197.29s/it]

Checkpoint 75 saved !


  2%|▏         | 76/4999 [25:10:11<1626:37:31, 1189.49s/it]

Epoch 76 finished ! Training Loss: 0.1023



  2%|▏         | 77/4999 [25:29:44<1619:28:44, 1184.50s/it]

Epoch 77 finished ! Training Loss: 0.0961



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()
        
        if bv.item() >= 0.2 or bd.item() >= 0.1:
            show_image_slice(image_1)
            show_image_slice(label_1)
            show_image_slice(output)

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)