In [1]:
import os, sys, pdb, shutil, random, math, datetime
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils

from torchvision import transforms
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from collections import OrderedDict

from dataloaders import Segmentation_transforms as tr
from dataloaders.Segmentation_Image import OCT_image_segmentation
from networks.segmentation.deeplab_xception import DeepLabv3_plus_xception
from networks.segmentation.deeplab_resnet import DeepLabv3_plus_resnet
from networks.segmentation.deeplab_resnet_random import DeepLabv3_plus_resnet_random
from networks.segmentation.coordconv_unet import UNet_coordconv
from networks.segmentation.coordconv_unet_sn_elu import UNet_coordconv_sn_elu

from dataloaders.Image_utils import decode_segmap, decode_segmap_sequence
from utils import get_logger, get_dice_score, lr_poly, aic_fundus_lesion_segmentation
from losses import CrossEntropy2D, DiceLoss2D
from tqdm import tqdm

In [2]:
class Config(object):
    def __init__(self):
        
        self.train_batch = 8 # ResNet101: 20, ResNet50: 28, ResNet34: 44, UNet: 20
        self.val_batch = 7
        
        self.nepoch = 50
        
        self.h = 1024 # original is 1024
        self.w = 512
        self.lr = 1.0*1e-2 # 1.0*1e-7
        self.num_classes = 4
        self.class_weight = [1, 1.5, 1.5, 10] # [1, 1.5, 1.5, 6]
        self.wd = 5e-4
        self.momentum = 0.9
        self.nAveGrad = 1
         
        self.dataset = "Edema" # Edema | defined1 | defined2
        self.network = "DeepLabv3_plus_resnet"
        self.net_config = {"keep_probs": [0.8]*4}
        self.os = 16
        self.backbone = "ResNet50" # ResNet34 | ResNet50 | ResNet101
        self.pretrain_checkpoint = "./pretrained/resnet50.pth" #/root/.torch/models/resnet50-19c8e357.pth | checkpoint/edema_PED/ResNet101_pretrain/aug/epoch7.pth
        self.ignore_prefixs = ["conv1", "fc", "gap"] # ["conv1", "fc", "gap"]
        self.criterion = "cross_entropy" # cross_entropy | dice
        self.scale_min = 0.75
        self.scale_max = 1.5
        self.rotation = 15
        self.denoising = False
        
        self.task = "segmentation"
        self.suffix = "Edema_sizeAvg_aug_1024x512_cross_entropy_scale_0.75_1.5_lr_0.01"
        self.checkpoint = None
        self.included_pixels = [0, 255, 191, 128]
        self.label_dict = OrderedDict([(0, 0), (255, 1), (191, 2), (128, 3)]) #OrderedDict([(0, 0), (255, 1), (191, 2), (128, 3)])
        self.aug_dict = None #OrderedDict([(128, 7)])
        
        self.gpus = "0, 1"
        self.num_workers =2
        
        self.manualSeed = None
        

config = Config()

In [3]:
log_path = os.path.join('logs', config.task, config.network, '{}.log'.format(config.suffix))
if os.path.exists(log_path):
    delete_log = input("The log file %s exist, delete it or not (y/n) \n"%(log_path))
    if delete_log in ['y', 'Y']:
        os.remove(log_path)
    else:
        log_path = os.path.join('logs', config.task, config.network, '{}_{}.log'.format(config.suffix, datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))

checkpoint_path = os.path.join('checkpoint', config.task, config.network, config.suffix)
if os.path.exists(checkpoint_path):
    delete_checkpoint_path = input("The checkpoint folder %s exist, delete it or not (y/n) \n"%(checkpoint_path))
    if delete_checkpoint_path in ['y', 'Y']:
        shutil.rmtree(checkpoint_path)
    else:
        checkpoint_path = os.path.join("checkpoint", config.task, config.network, config.suffix+"_"+datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
else:
    os.makedirs(checkpoint_path)

summary_path = os.path.join("summaries", config.task, config.network, config.suffix)
if os.path.exists(summary_path):
    delete_summary = input("The tf_summary folder %s exist, delete it or not (y/n) \n"%(summary_path))
    if delete_summary in ['y', 'Y']:
        shutil.rmtree(summary_path)
    else:
        summary_path = os.path.join("summaries", config.task, config.network, config.suffix+"_"+datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
else:
    os.makedirs(summary_path)
    
logger = get_logger(log_path)
writer = SummaryWriter(summary_path)
logger.info(config.__dict__)

if config.manualSeed is None:
    config.manualSeed = random.randint(1, 10000)
logger.info("Random Seed: {}".format(config.manualSeed))
np.random.seed(config.manualSeed)
random.seed(config.manualSeed)
torch.manual_seed(config.manualSeed)

The checkpoint folder checkpoint/segmentation/DeepLabv3_plus_resnet/Edema_sizeAvg_aug_1024x512_cross_entropy_scale_0.75_1.5_lr_0.01 exist, delete it or not (y/n) 
y
The tf_summary folder summaries/segmentation/DeepLabv3_plus_resnet/Edema_sizeAvg_aug_1024x512_cross_entropy_scale_0.75_1.5_lr_0.01 exist, delete it or not (y/n) 
y


{'train_batch': 8, 'val_batch': 7, 'nepoch': 50, 'h': 1024, 'w': 512, 'lr': 0.01, 'num_classes': 4, 'class_weight': [1, 1.5, 1.5, 10], 'wd': 0.0005, 'momentum': 0.9, 'nAveGrad': 1, 'dataset': 'Edema', 'network': 'DeepLabv3_plus_resnet', 'net_config': {'keep_probs': [0.8, 0.8, 0.8, 0.8]}, 'os': 16, 'backbone': 'ResNet50', 'pretrain_checkpoint': './pretrained/resnet50.pth', 'ignore_prefixs': ['conv1', 'fc', 'gap'], 'criterion': 'cross_entropy', 'scale_min': 0.75, 'scale_max': 1.5, 'rotation': 15, 'denoising': False, 'task': 'segmentation', 'suffix': 'Edema_sizeAvg_aug_1024x512_cross_entropy_scale_0.75_1.5_lr_0.01', 'checkpoint': None, 'included_pixels': [0, 255, 191, 128], 'label_dict': OrderedDict([(0, 0), (255, 1), (191, 2), (128, 3)]), 'aug_dict': None, 'gpus': '0, 1', 'num_workers': 2, 'manualSeed': None}
Random Seed: 4312


<torch._C.Generator at 0x7fb56a5592d0>

In [4]:
def log_best_metric(metric_list, cur_epoch_idx, logger, state, save_path, save_model=True, metric = "AUC"):
    if len(metric_list) == 0:
        return
    else:
        best_idx = np.argmax(metric_list)
        best_metric = metric_list[best_idx]
        if best_idx == cur_epoch_idx:
            logger.info("Epoch: %d, Validation %s improved to %.4f"%(cur_epoch_idx, metric, best_metric))
            if save_model:
                dir_path = os.path.dirname(save_path)  # get parent path
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                torch.save(state, save_path)
                logger.info("Model saved in file: %s"%(save_path))
        else:
            logger.info("Epoch: %d, Validation %s didn't improve. Best is %.4f in epoch %d"%(cur_epoch_idx, metric, best_metric, best_idx))

def train(model, device, data_loader, criterion, nAveGrad, optimizer, epoch, writer):
    model.train()
    losses = []
    aveGrad = 0
    with tqdm(len(data_loader)) as pbar:
        for batch_idx, sample_batched in enumerate(data_loader):
            inputs, labels = sample_batched['image'], sample_batched['label']
            inputs = inputs.float()
            inputs = inputs.to(device)
            labels = labels.to(device).long()
            
            outputs = model(inputs)
            n_classes = outputs.size(1)
            loss = criterion(outputs, labels)
            
            loss /= nAveGrad # TODO why loss needed divided by the nAveGrad ?
            loss.backward()
            aveGrad += 1
            losses.append(loss.item())
            
            if aveGrad % nAveGrad == 0:
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0
                
            pbar.update(1)
            pbar.set_description("Epoch %d, Batch %d/%d, Train loss: %.4f"%(epoch, batch_idx+1, len(data_loader), np.mean(losses)))
    
    ave_loss = np.mean(losses)
    input_images = vutils.make_grid(inputs, padding = 5, normalize=False)
    gt_idxs = np.expand_dims(labels.detach().cpu().numpy(), 1)
    gt_images = vutils.make_grid(decode_segmap_sequence(gt_idxs), padding = 5, normalize=False, range=(0, 255))
    predicted_idxs = np.expand_dims(torch.max(outputs, 1)[1].detach().cpu().numpy(), 1)
    predicted_images = vutils.make_grid(decode_segmap_sequence(predicted_idxs), padding = 5, normalize=False, range=(0, 255))
    
    writer.add_image('train/input_images', input_images, epoch)
    writer.add_image('train/ground_truth', gt_images, epoch)
    writer.add_image('train/predictions', predicted_images, epoch)
    writer.add_scalar('train/epoch_loss', ave_loss, epoch)
    return ave_loss

def validate(model, device, data_loader, criterion, epoch, writer):
    losses, sample_predictions, sample_ground_truth = [], [], []
    model.eval()
    with torch.no_grad():
        for batch_idx, sample_batched in enumerate(tqdm(data_loader)):
            inputs, labels = sample_batched['image'], sample_batched['label']
            inputs = inputs.float()
            inputs = inputs.to(device)
            labels = labels.to(device).long()

            outputs = model(inputs)
            predictions = torch.max(outputs, 1)[1]
            sample_predictions.append(predictions.cpu().numpy().astype(np.int16))

            n_classes = outputs.size(1)
            loss = criterion(outputs, labels)
            losses.append(loss.item())
            sample_ground_truth.append(labels.cpu().numpy().astype(np.int16))
        
        # only write the last batch in testset 
        input_images = vutils.make_grid(inputs, padding = 5, normalize=False)
        gt_idxs = np.expand_dims(labels.detach().cpu().numpy(), 1)
        gt_images = vutils.make_grid(decode_segmap_sequence(gt_idxs), padding = 5, normalize=False, range=(0, 255))
        predicted_idxs = np.expand_dims(torch.max(outputs, 1)[1].detach().cpu().numpy(), 1)
        predicted_images = vutils.make_grid(decode_segmap_sequence(predicted_idxs), padding = 5, normalize=False, range=(0, 255))

        writer.add_image('test/input_images', input_images, epoch)
        writer.add_image('test/ground_truth', gt_images, epoch)
        writer.add_image('test/predictions', predicted_images, epoch)
    
    ave_loss = np.mean(losses)
    writer.add_scalar('test/epoch_loss', ave_loss, epoch)
    return ave_loss, np.concatenate(sample_predictions, 0), np.squeeze(np.concatenate(sample_ground_truth, 0))

def mean_dice_persample(all_outputs, all_labels, num_image = 128):
    sample_dices = []
    outputs, labels = [], []
    for i in range(len(all_outputs)):
        outputs.append(all_outputs[i])
        labels.append(all_labels[i])
        if (i+1) % 128 == 0:
            sample_dices.append(aic_fundus_lesion_segmentation(np.array(labels), np.array(outputs)))
            outputs, labels = [], []
    valid_dices = [[], [], [], []]
    for sample_dice in sample_dices:
        for i, dice_value in enumerate(sample_dice):
            if not math.isnan(dice_value):
                valid_dices[i].append(dice_value)
    return [round(np.mean(dice_values), 5) for dice_values in valid_dices]

In [6]:
train_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.RandomSized([config.h, config.w], config.scale_min, config.scale_max), # h, w
        tr.RandomRotate(config.rotation),
        tr.Normalize_divide(255.0),
        tr.ToTensor()]) 

val_tr = transforms.Compose([
#         tr.FixedResize([config.h, config.w]), # h, w
        tr.Normalize_divide(255.0),
        tr.ToTensor()])

trainset = OCT_image_segmentation("./data/{}_trainingset/original_images".format(config.dataset), 
                     "./data/{}_trainingset/label_images".format(config.dataset),
                     included_pixels = config.included_pixels,
                     label_dict = config.label_dict,
                     aug_dict = config.aug_dict,
                     denoising = config.denoising,
                     transform = train_tr)
valset = OCT_image_segmentation("./data/{}_validationset/original_images".format(config.dataset), 
                     "./data/{}_validationset/label_images".format(config.dataset),
                     included_pixels = config.included_pixels,
                     label_dict = config.label_dict,
                     aug_dict = None,
                     denoising = config.denoising,
                     transform = val_tr)

trainset_loader = DataLoader(trainset, batch_size=config.train_batch, drop_last=True,shuffle=True, num_workers=config.num_workers)
valset_loader = DataLoader(valset, batch_size=config.val_batch, shuffle=False, num_workers=config.num_workers)

100%|██████████| 70/70 [00:17<00:00,  3.93it/s]
100%|██████████| 15/15 [00:32<00:00,  2.17s/it]


In [7]:
if config.network.startswith("DeepLab"):
    model = globals()[config.network](net_config = config.net_config, nInputChannels=1, n_classes=config.num_classes, os=config.os, 
                                      backbone=config.backbone, checkpoint=config.pretrain_checkpoint, ignore_prefixs =config.ignore_prefixs)
elif config.network.startswith("UNet"):
    model = globals()[config.network](n_channels=1, n_classes=config.num_classes)
else:
    raise("Unknown network: {}".format(config.network))
    
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.wd)
lr_lambda = lambda epoch: (1 - float(epoch) / config.nepoch)** 0.9

start_epoch = -1
if config.checkpoint is not None:
    checkpoint = torch.load(config.checkpoint)
    state_dict = checkpoint["state_dict"]
    start_epoch = checkpoint["epoch"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("module.", "") # remove `module.`
        new_state_dict[name] = v
    logger.info("Resuming from checkpoint: {} at epoch{}".format(config.checkpoint, start_epoch))
    model.load_state_dict(new_state_dict)
    optimizer.load_state_dict(checkpoint["optimizer"])
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()
    optimizer.param_groups[0]["initial_lr"] = config.lr # pytorch 0.4.0 need explicitly set initial_lr when resuming optimizer

gpus = list(map(int, config.gpus.split(",")))

if len(gpus) > 1:
    model = nn.DataParallel(model, gpus)
device = torch.device("cuda:{}".format(gpus[0]))
model.to(device)

if config.criterion == "cross_entropy":
    criterion = CrossEntropy2D(weight=config.class_weight, size_average=True, batch_average=True)
elif config.criterion == "dice":
    criterion = DiceLoss2D(n_classes = config.num_classes, smooth = 1)
else:
    raise("Unknown criterion: {}".format(config.criterion))

lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch = -1 if start_epoch == -1 else start_epoch)

Constructing DeepLabv3+ model...
Number of classes: 4
Output stride: 16
Number of Input Channels: 1
load bn1.running_mean
load bn1.running_var
load bn1.weight
load bn1.bias
load layer1.0.conv1.weight
load layer1.0.bn1.running_mean
load layer1.0.bn1.running_var
load layer1.0.bn1.weight
load layer1.0.bn1.bias
load layer1.0.conv2.weight
load layer1.0.bn2.running_mean
load layer1.0.bn2.running_var
load layer1.0.bn2.weight
load layer1.0.bn2.bias
load layer1.0.conv3.weight
load layer1.0.bn3.running_mean
load layer1.0.bn3.running_var
load layer1.0.bn3.weight
load layer1.0.bn3.bias
load layer1.0.downsample.0.weight
load layer1.0.downsample.1.running_mean
load layer1.0.downsample.1.running_var
load layer1.0.downsample.1.weight
load layer1.0.downsample.1.bias
load layer1.1.conv1.weight
load layer1.1.bn1.running_mean
load layer1.1.bn1.running_var
load layer1.1.bn1.weight
load layer1.1.bn1.bias
load layer1.1.conv2.weight
load layer1.1.bn2.running_mean
load layer1.1.bn2.running_var
load layer1.1.bn



In [10]:
torch.cuda.empty_cache()

In [8]:
metric_list = []
for epoch in range(start_epoch+1, config.nepoch):
    lr_scheduler.step()
    logger.info("Epoch: %d, Learning rate: %.10f"%(epoch, lr_scheduler.get_lr()[0]))
    train_loss = train(model, device, trainset_loader, criterion, config.nAveGrad, optimizer, epoch, writer)
    logger.info("Epoch: %d, Train Loss: %.4f"%(epoch, train_loss))
    
    val_loss, val_predictions, val_ground_truths = validate(model, device, valset_loader, criterion, epoch, writer)
    class_mean_dices = mean_dice_persample(val_predictions, val_ground_truths)
    avg_score = np.mean(class_mean_dices)
    metric_list.append(avg_score)
    logger.info("Epoch: %d, Validation Loss: %.4f, Validation Dice Score: %.4f, %s"%(epoch, val_loss, avg_score, class_mean_dices))
    
    log_best_metric(metric_list, epoch, logger, 
                    {'epoch': epoch,
                     'state_dict': model.state_dict(),
                     'optimizer': optimizer.state_dict()},
                     '{}/epoch{}.pth'.format(checkpoint_path, epoch),
                    save_model=True,
                    metric = "Dice score")

Epoch: 0, Learning rate: 0.0098198187
Epoch 0, Batch 1119/1119, Train loss: 0.0749: : 1119it [09:56,  1.88it/s]
Epoch: 0, Train Loss: 0.0749
100%|██████████| 275/275 [00:54<00:00,  5.03it/s]
Epoch: 0, Validation Loss: 0.1010, Validation Dice Score: 0.6840, [0.97679, 0.67323, 0.63378, 0.45208]
Epoch: 0, Validation Dice score improved to 0.6840
Model saved in file: checkpoint/segmentation/DeepLabv3_plus_resnet/Edema_sizeAvg_aug_1024x512_cross_entropy_scale_0.75_1.5_lr_0.01/epoch0.pth
Epoch: 1, Learning rate: 0.0096392692
Epoch 1, Batch 1119/1119, Train loss: 0.0431: : 1119it [10:10,  1.83it/s]
Epoch: 1, Train Loss: 0.0431
100%|██████████| 275/275 [00:54<00:00,  5.01it/s]
Epoch: 1, Validation Loss: 0.0688, Validation Dice Score: 0.7433, [0.98963, 0.71918, 0.74735, 0.51696]
Epoch: 1, Validation Dice score improved to 0.7433
Model saved in file: checkpoint/segmentation/DeepLabv3_plus_resnet/Edema_sizeAvg_aug_1024x512_cross_entropy_scale_0.75_1.5_lr_0.01/epoch1.pth
Epoch: 2, Learning rate: 0

KeyboardInterrupt: 

In [None]:
for ii, sample in enumerate(trainset_loader):
    img = sample['image'].numpy()
    gt = sample['label'].numpy()
    for jj in range(sample["image"].size(0)):
        tmp = np.array(gt[jj]).astype(np.uint8)
        segmap = decode_segmap(tmp, label_colours = OrderedDict([(0, 0), (1, 255), (2, 191), (3, 128)]))
        img_tmp = np.squeeze(img[jj], 0)
        plt.figure()
        plt.title('display')
        plt.subplot(121)
        plt.imshow(img_tmp, cmap = "gray")
        plt.subplot(122)
        plt.imshow(segmap, cmap = "gray")
        plt.show()


NameError: name 'val_predictions' is not defined