In [1]:
import os, shutil
import torch
import pickle
from Models import single_model as net
import numpy as np
import Transforms as myTransforms
from Dataset import Dataset
from parallel import DataParallelModel, DataParallelCriterion
import time
from argparse import ArgumentParser
from IoUEval import IoUEval
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.optim.lr_scheduler
from torch.nn.parallel import gather
import torch.nn as nn
import torch.nn.functional as F

In [2]:


def BCEDiceLoss(inputs, targets):
    bce = F.binary_cross_entropy(inputs, targets)
    inter = (inputs * targets).sum()
    eps = 1e-5
    dice = (2 * inter + eps) / (inputs.sum() + targets.sum() + eps)
    return bce + 1 - dice

class CrossEntropyLoss(nn.Module):
    def __init__(self):
        super(CrossEntropyLoss, self).__init__()

    def forward(self, inputs, target):
        if isinstance(target, tuple):
            target = target[0]
        if inputs.shape[1] == 5:
            loss1 = BCEDiceLoss(inputs[:, 0, :, :], target)
            loss2 = BCEDiceLoss(inputs[:, 1, :, :], target)
            loss3 = BCEDiceLoss(inputs[:, 2, :, :], target)
            loss4 = BCEDiceLoss(inputs[:, 3, :, :], target)
            loss5 = BCEDiceLoss(inputs[:, 4, :, :], target)
            return loss1 + loss2 + loss3 + loss4 + loss5
        elif inputs.shape[1] == 1:
            #print(inputs.shape)
            loss = BCEDiceLoss(inputs[:, 0, :, :], target)
            return loss


class FLoss(nn.Module):
    def __init__(self, beta=0.3, log_like=False):
        super(FLoss, self).__init__()
        self.beta = beta
        self.log_like = log_like

    def _compute_loss(self, prediction, target):
        EPS = 1e-10
        N = prediction.size(0)
        TP = (prediction * target).view(N, -1).sum(dim=1)
        H = self.beta * target.view(N, -1).sum(dim=1) + prediction.view(N, -1).sum(dim=1)
        fmeasure = (1 + self.beta) * TP / (H + EPS)
        if self.log_like:
            loss = -torch.log(fmeasure)
        else:
            loss  = 1 - fmeasure
        return loss.mean()

    def forward(self, inputs, target):
        loss1 = self._compute_loss(inputs[:, 0, :, :], target)
        loss2 = self._compute_loss(inputs[:, 1, :, :], target)
        loss3 = self._compute_loss(inputs[:, 2, :, :], target)
        loss4 = self._compute_loss(inputs[:, 3, :, :], target)
        loss5 = self._compute_loss(inputs[:, 4, :, :], target)
        return 1.0*loss1 + 1.0*loss2 + 1.0*loss3 + 1.0*loss4 + 1.0*loss5



In [3]:
@torch.no_grad()
def val(args, val_loader, model, criterion):
    # switch to evaluation mode
    model.eval()
    sal_eval_val = IoUEval()
    epoch_loss = []
    total_batches = len(val_loader)
    for iter, (input, target) in enumerate(val_loader):
        start_time = time.time()

        if args.gpu:
            input = input.cuda()
            target = target.cuda()
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target).float()

        # run the mdoel
        output = model(input_var)
        loss = criterion(output, target_var)
        #torch.cuda.synchronize()
        time_taken = time.time() - start_time

        epoch_loss.append(loss.data.item())

        # compute the confusion matrix
        if args.gpu and torch.cuda.device_count() > 1:
            output = gather(output, 0, dim=0)
        sal_eval_val.add_batch(output[:, 0, :, :],  target_var)
        if iter % 50 == 0 or iter == len(val_loader) - 1:
            print('[%d/%d] loss: %.3f time: %.3f' % (iter, total_batches, loss.data.item(), time_taken))

    average_epoch_loss_val = sum(epoch_loss) / len(epoch_loss)
    IoU, MAE = sal_eval_val.get_metric()
    
    auc_roc_score = sal_eval_val.get_auc_roc()
    
    return average_epoch_loss_val, IoU, MAE, auc_roc_score

In [4]:
NORMALISE_PARAMS = [np.array([0.406, 0.456, 0.485], dtype=np.float32).reshape((1, 1, 3)), # MEAN
                        np.array([0.225, 0.224, 0.229], dtype=np.float32).reshape((1, 1, 3))] # STD

height = 512
width = 512
data_dir = './data/IDRID'
num_workers = 8
batch_size = 1

# compose the data with transforms
trainDataset_main = myTransforms.Compose([
    myTransforms.Normalize(*NORMALISE_PARAMS),
    myTransforms.Scale(width, height),
    myTransforms.RandomCropResize(int(7./224.*width)),
    myTransforms.RandomFlip(),
    #myTransforms.GaussianNoise(),
    myTransforms.ToTensor()
])

trainDataset_scale1 = myTransforms.Compose([
#         myTransforms.Normalize(*NORMALISE_PARAMS),
    #myTransforms.Scale(512, 512),
    myTransforms.Scale(352, 352),
#         myTransforms.RandomCropResize(int(7./224.*args.width)),
#         myTransforms.RandomFlip(),
    myTransforms.ToTensor()
])
trainDataset_scale2 = myTransforms.Compose([
    myTransforms.Normalize(*NORMALISE_PARAMS),
    #myTransforms.Scale(1024, 1024),
    myTransforms.Scale(448, 448),
    myTransforms.RandomCropResize(int(7./224.*width)),
    myTransforms.RandomFlip(),
    myTransforms.ToTensor()
])

valDataset = myTransforms.Compose([
    myTransforms.Normalize(*NORMALISE_PARAMS),
    myTransforms.Scale(width, height),
    myTransforms.ToTensor()
])

# since we training from scratch, we create data loaders at different scales
# so that we can generate more augmented data and prevent the network from overfitting
trainLoader_main = torch.utils.data.DataLoader(
    Dataset(data_dir, 'train', transform=trainDataset_main),
    batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=False, drop_last=True)

valLoader = torch.utils.data.DataLoader(
    Dataset(data_dir, 'test', transform=valDataset),
    batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False, drop_last=True)

max_batches = len(trainLoader_main) #+ len(trainLoader_scale1) + len(trainLoader_scale2)
print('max_batches {}'.format(max_batches))


max_batches 54


In [6]:
device = 'cuda:0'


In [7]:
model = net.JCS()
checkpoint_path = './snapshots/18_08/single_pretrained_full_seg_idrid_full_sg_mean_60epoch/checkpoint.pth.tar'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model = model.cuda()
model.eval()

JCS(
  (vgg16): VGG16BN(
    (conv1_1): ConvBNReLU(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): FrozenBatchNorm2d(64)
      (act): ReLU(inplace=True)
    )
    (conv1_2): ConvBNReLU(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): FrozenBatchNorm2d(64)
      (act): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2_1): ConvBNReLU(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): FrozenBatchNorm2d(128)
      (act): ReLU(inplace=True)
    )
    (conv2_2): ConvBNReLU(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): FrozenBatchNorm2d(128)
      (act): ReLU(inplace=True)
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3_1): ConvBNReLU(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), s

In [8]:
output_vals = []
sal_eval_val = IoUEval()
for iter, (input, target) in enumerate(valLoader):
    torch.cuda.empty_cache()
    start_time = time.time()


    input = input.cuda()
    target = target.cuda()
    input_var = torch.autograd.Variable(input)
    target_var = torch.autograd.Variable(target).float()

    # run the mdoel
    output = model(input_var)
    output_vals.append(output[:,0,:,:].detach().cpu().numpy())
#     sal_eval_val.add_batch(output[:, 0, :, :],  target_var)
    

In [26]:
output_vals[0].shape

(1, 512, 512)

In [11]:
from sklearn.metrics import roc_curve, auc, precision_recall_fscore_support, average_precision_score, roc_auc_score
import numpy as np
data_len = len(valLoader.dataset)
tot_auc_pr = 0.0
for iter, (input, target) in enumerate(valLoader):
    target = target.flatten().numpy()
    out = output_vals[iter].flatten()
    auc_pr = roc_auc_score(target, out)
    tot_auc_pr += auc_pr
#     print("AUC PR: ", auc_pr)
tot_auc_pr/data_len

0.96140996384955