In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
"""
deeplab = DeepLab()
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
epoch = 0
"""
# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

"\ndeeplab = DeepLab()\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\ndeeplab = deeplab.to(device=device, dtype=dtype)\n#shape_test(icnet1, True)\n# create the model, by default model type is float, use model.double(), model.float() to convert\n# move the model to desirable device\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nepoch = 0\n"

In [None]:
deeplab = DeepLab(output_stride=2)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
checkpoint = torch.load('../deeplab_save/2019-08-12 17:24:36.436884 epoch: 290.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-5)
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
#scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)
print(sum(checkpoint['logger']['validation_1']) / len(checkpoint['logger']['validation_1']))
for param_group in optimizer.param_groups:
    print(param_group['lr'])

290
0.07887190903294021
1.0000000000000002e-06


In [None]:
epochs = 5000

min_val = .07

record = open('train_deeplab.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%1 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                save_1('deeplab_save', deeplab, optimizer, logger, e, scheduler)
            elif e % 10 == 0:
                save_1('deeplab_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 0/4709 [00:00<?, ?it/s]

Epoch 291 finished ! Training Loss: 0.1361



  0%|          | 1/4709 [14:11<1114:06:56, 851.92s/it]

------- 1st valloss=0.0768

Epoch 292 finished ! Training Loss: 0.1308



  0%|          | 2/4709 [27:10<1085:03:55, 829.88s/it]

------- 1st valloss=0.0750

Epoch 293 finished ! Training Loss: 0.1350



  0%|          | 3/4709 [39:54<1058:57:12, 810.08s/it]

------- 1st valloss=0.1283

Epoch 294 finished ! Training Loss: 0.1314

------- 1st valloss=0.0604

0.06040680019751839 less than 0.07


  0%|          | 4/4709 [52:42<1042:15:08, 797.47s/it]

Checkpoint 294 saved !
Epoch 295 finished ! Training Loss: 0.1297



  0%|          | 5/4709 [1:05:37<1033:25:38, 790.89s/it]

------- 1st valloss=0.0779

Epoch 296 finished ! Training Loss: 0.1244



  0%|          | 6/4709 [1:18:24<1023:35:07, 783.52s/it]

------- 1st valloss=0.0902

Epoch 297 finished ! Training Loss: 0.1366



  0%|          | 7/4709 [1:31:14<1018:10:21, 779.55s/it]

------- 1st valloss=0.0749

Epoch 298 finished ! Training Loss: 0.1303



  0%|          | 8/4709 [1:43:56<1011:01:17, 774.23s/it]

------- 1st valloss=0.1099

Epoch 299 finished ! Training Loss: 0.1257



  0%|          | 9/4709 [1:56:48<1010:07:31, 773.71s/it]

------- 1st valloss=0.1014

Epoch 300 finished ! Training Loss: 0.1305

------- 1st valloss=0.0688



  0%|          | 10/4709 [2:09:39<1008:50:44, 772.90s/it]

Checkpoint 300 saved !
Epoch 301 finished ! Training Loss: 0.1200



  0%|          | 11/4709 [2:22:31<1008:13:10, 772.58s/it]

------- 1st valloss=0.0733

Epoch 302 finished ! Training Loss: 0.1262



  0%|          | 12/4709 [2:35:25<1008:29:38, 772.96s/it]

------- 1st valloss=0.0903

Epoch 303 finished ! Training Loss: 0.1330



  0%|          | 13/4709 [2:48:18<1008:30:18, 773.13s/it]

------- 1st valloss=0.0757

Epoch 304 finished ! Training Loss: 0.1375



  0%|          | 14/4709 [3:01:02<1004:25:53, 770.17s/it]

------- 1st valloss=0.0938

Epoch 305 finished ! Training Loss: 0.1375



  0%|          | 15/4709 [3:13:58<1006:40:36, 772.06s/it]

------- 1st valloss=0.0798

Epoch 306 finished ! Training Loss: 0.1310



  0%|          | 16/4709 [3:26:47<1005:21:31, 771.21s/it]

------- 1st valloss=0.0748

Epoch 307 finished ! Training Loss: 0.1424



  0%|          | 17/4709 [3:39:35<1003:40:40, 770.09s/it]

------- 1st valloss=0.0739

Epoch 308 finished ! Training Loss: 0.1186



  0%|          | 18/4709 [3:52:22<1002:10:58, 769.10s/it]

------- 1st valloss=0.0739

Epoch 309 finished ! Training Loss: 0.1426



  0%|          | 19/4709 [4:05:11<1002:11:09, 769.27s/it]

------- 1st valloss=0.0750

Epoch 310 finished ! Training Loss: 0.1274

------- 1st valloss=0.0603

0.06033968342387158 less than 0.06040680019751839


  0%|          | 20/4709 [4:18:10<1005:44:32, 772.16s/it]

Checkpoint 310 saved !
Epoch 311 finished ! Training Loss: 0.1427



  0%|          | 21/4709 [4:31:06<1007:00:36, 773.30s/it]

------- 1st valloss=0.0731

Epoch 312 finished ! Training Loss: 0.1320



  0%|          | 22/4709 [4:44:07<1009:34:00, 775.43s/it]

------- 1st valloss=0.0926

Epoch 313 finished ! Training Loss: 0.1347



  0%|          | 23/4709 [4:57:02<1009:29:10, 775.53s/it]

------- 1st valloss=0.0830

Epoch 314 finished ! Training Loss: 0.1353



  1%|          | 24/4709 [5:09:55<1008:13:26, 774.73s/it]

------- 1st valloss=0.0951

Epoch 315 finished ! Training Loss: 0.1307



  1%|          | 25/4709 [5:22:50<1007:56:59, 774.68s/it]

------- 1st valloss=0.0804

Epoch 316 finished ! Training Loss: 0.1422



  1%|          | 26/4709 [5:35:35<1003:52:01, 771.71s/it]

------- 1st valloss=0.0841

Epoch 317 finished ! Training Loss: 0.1352



  1%|          | 27/4709 [5:48:24<1002:52:52, 771.12s/it]

------- 1st valloss=0.0698

Epoch 318 finished ! Training Loss: 0.1344



  1%|          | 28/4709 [6:01:14<1002:06:02, 770.68s/it]

------- 1st valloss=0.0613

Epoch 319 finished ! Training Loss: 0.1418

------- 1st valloss=0.0599

0.059900194730447685 less than 0.06033968342387158


  1%|          | 29/4709 [6:14:05<1002:10:33, 770.90s/it]

Checkpoint 319 saved !
Epoch 320 finished ! Training Loss: 0.1349

------- 1st valloss=0.0841



  1%|          | 30/4709 [6:26:54<1001:13:01, 770.33s/it]

Checkpoint 320 saved !
Epoch 321 finished ! Training Loss: 0.1291



  1%|          | 31/4709 [6:39:44<1000:35:36, 770.02s/it]

------- 1st valloss=0.0822

Epoch 322 finished ! Training Loss: 0.1306



  1%|          | 32/4709 [6:52:35<1000:59:33, 770.49s/it]

------- 1st valloss=0.0651

Epoch 323 finished ! Training Loss: 0.1310



  1%|          | 33/4709 [7:05:22<999:20:47, 769.39s/it] 

------- 1st valloss=0.0743

Epoch 324 finished ! Training Loss: 0.1378



  1%|          | 34/4709 [7:18:12<999:08:04, 769.39s/it]

------- 1st valloss=0.1039

Epoch 325 finished ! Training Loss: 0.1423



  1%|          | 35/4709 [7:31:08<1001:52:39, 771.66s/it]

------- 1st valloss=0.0784

Epoch 326 finished ! Training Loss: 0.1428



  1%|          | 36/4709 [7:43:56<1000:10:00, 770.51s/it]

------- 1st valloss=0.0644

Epoch 327 finished ! Training Loss: 0.1359



  1%|          | 37/4709 [7:56:56<1003:26:06, 773.19s/it]

------- 1st valloss=0.0839

Epoch 328 finished ! Training Loss: 0.1368



  1%|          | 38/4709 [8:09:47<1002:30:49, 772.65s/it]

------- 1st valloss=0.0748

Epoch 329 finished ! Training Loss: 0.1243

------- 1st valloss=0.0598

0.05983590950136599 less than 0.059900194730447685


  1%|          | 39/4709 [8:22:37<1001:14:40, 771.84s/it]

Checkpoint 329 saved !
Epoch 330 finished ! Training Loss: 0.1430

------- 1st valloss=0.0766



  1%|          | 40/4709 [8:35:29<1001:08:52, 771.93s/it]

Checkpoint 330 saved !
Epoch 331 finished ! Training Loss: 0.1407



  1%|          | 41/4709 [8:48:19<1000:09:32, 771.33s/it]

------- 1st valloss=0.0611

Epoch 332 finished ! Training Loss: 0.1444



  1%|          | 42/4709 [9:01:10<999:36:19, 771.07s/it] 

------- 1st valloss=0.0612

Epoch 333 finished ! Training Loss: 0.1251



  1%|          | 43/4709 [9:14:09<1002:37:48, 773.57s/it]

------- 1st valloss=0.0630

Epoch 334 finished ! Training Loss: 0.1308



  1%|          | 44/4709 [9:27:02<1002:03:47, 773.30s/it]

------- 1st valloss=0.0613

Epoch 335 finished ! Training Loss: 0.1355



  1%|          | 45/4709 [9:39:59<1003:27:32, 774.54s/it]

------- 1st valloss=0.0646

Epoch 336 finished ! Training Loss: 0.1304



  1%|          | 46/4709 [9:52:54<1003:12:52, 774.52s/it]

------- 1st valloss=0.0608

Epoch 337 finished ! Training Loss: 0.1473

------- 1st valloss=0.0592

0.059154244706682534 less than 0.05983590950136599


  1%|          | 47/4709 [10:05:48<1003:08:03, 774.62s/it]

Checkpoint 337 saved !
Epoch 338 finished ! Training Loss: 0.1319



  1%|          | 48/4709 [10:18:37<1000:44:04, 772.93s/it]

------- 1st valloss=0.0829

Epoch 339 finished ! Training Loss: 0.1319



  1%|          | 49/4709 [10:31:30<1000:17:43, 772.76s/it]

------- 1st valloss=0.0826

Epoch 340 finished ! Training Loss: 0.1317

------- 1st valloss=0.0737



  1%|          | 50/4709 [10:44:23<1000:20:57, 772.97s/it]

Checkpoint 340 saved !
Epoch 341 finished ! Training Loss: 0.1480



  1%|          | 51/4709 [10:57:13<998:59:08, 772.08s/it] 

------- 1st valloss=0.0825

Epoch 342 finished ! Training Loss: 0.1265



  1%|          | 52/4709 [11:10:13<1001:35:06, 774.26s/it]

------- 1st valloss=0.0818

Epoch 343 finished ! Training Loss: 0.1418



  1%|          | 53/4709 [11:23:03<999:53:16, 773.11s/it] 

------- 1st valloss=0.0732

Epoch 344 finished ! Training Loss: 0.1198



  1%|          | 54/4709 [11:35:55<999:20:47, 772.86s/it]

------- 1st valloss=0.0808

Epoch 345 finished ! Training Loss: 0.1365



  1%|          | 55/4709 [11:48:43<997:08:17, 771.31s/it]

------- 1st valloss=0.1103

Epoch 346 finished ! Training Loss: 0.1250



  1%|          | 56/4709 [12:01:27<994:11:18, 769.20s/it]

------- 1st valloss=0.1128

Epoch 347 finished ! Training Loss: 0.1211



  1%|          | 57/4709 [12:14:19<994:46:05, 769.81s/it]

------- 1st valloss=0.0788

Epoch 348 finished ! Training Loss: 0.1189



  1%|          | 58/4709 [12:27:14<996:48:24, 771.56s/it]

------- 1st valloss=0.0756

Epoch 349 finished ! Training Loss: 0.1359



  1%|▏         | 59/4709 [12:40:04<995:51:04, 770.98s/it]

------- 1st valloss=0.0861

Epoch 350 finished ! Training Loss: 0.1256

------- 1st valloss=0.0851



  1%|▏         | 60/4709 [12:52:50<993:55:39, 769.66s/it]

Checkpoint 350 saved !
Epoch 351 finished ! Training Loss: 0.1429



  1%|▏         | 61/4709 [13:05:39<993:10:43, 769.24s/it]

------- 1st valloss=0.0869

Epoch 352 finished ! Training Loss: 0.1316



  1%|▏         | 62/4709 [13:18:31<994:01:41, 770.07s/it]

------- 1st valloss=0.0646

Epoch 353 finished ! Training Loss: 0.1302



  1%|▏         | 63/4709 [13:31:28<996:46:52, 772.37s/it]

------- 1st valloss=0.0778

Epoch 354 finished ! Training Loss: 0.1362



  1%|▏         | 64/4709 [13:44:18<995:36:00, 771.62s/it]

------- 1st valloss=0.0892

Epoch 355 finished ! Training Loss: 0.1405



  1%|▏         | 65/4709 [13:57:15<997:17:49, 773.10s/it]

------- 1st valloss=0.0831



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()
        
        if bv.item() >= 0.2 or bd.item() >= 0.1:
            show_image_slice(image_1)
            show_image_slice(label_1)
            show_image_slice(output)

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)