In [1]:
from PIL import Image
import os
import pathlib
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import random
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle as sk_shuffle
from skimage.util import random_noise
import time
import os
from torch.utils import data
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from prior_dataloader import RetraceDataLoader
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler
from custom_unets import NestedUNet, U_Net
from sync_batchnorm import SynchronizedBatchNorm2d, DataParallelWithCallback, convert_model

import glob2
import pdb
import ipdb

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
import torchvision.models as models
import torch.nn as nn

rohan_unet = NestedUNet(1,33)
if torch.cuda.device_count() > 0:
      print("Let's use", torch.cuda.device_count(), "GPUs!")
      # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
      rohan_unet = nn.DataParallel(rohan_unet)
rohan_unet = rohan_unet.to(device)
rohan_unet = convert_model(rohan_unet)
rohan_unet.load_state_dict(torch.load('/home/rohan/prior_seg/models/prior_model1/1st_epoch_42.0000_f1_0.8835.pth'))

Let's use 1 GPUs!


<All keys matched successfully>

In [4]:
# from collections import OrderedDict
# new_state_dict = OrderedDict()
# state_dict = torch.load('/home/rohan/prior_seg/models/prior_model1/1st_epoch_42.0000_f1_0.8835.pth',  map_location=device)
# for k, v in state_dict.items():
#     name = k[7:] # remove `module.`
#     new_state_dict[name] = v
# # load params
# rohan_unet.load_state_dict(new_state_dict)
# unet_model = convert_model(unet_model)
# unet_model = unet_model.to(device)
rohan_unet = rohan_unet.eval()

In [5]:
from torchsummary import summary
# summary(rohan_unet, input_size=(1,128,128))

In [None]:
root_dir = '/home/rohan/Datasets/prior_clean/train/'
syn_root_dir = '/home/rohan/Datasets/synthetic_prior_clean/train/'

# prior_data = RetraceDataLoader(root_dir, syn_root_dir, length = 100)
teeth_dataset = RetraceDataLoader(root_dir=root_dir,
                                  root_dir_synth=None, #syn_root_dir,
                                  image_size=(128,128),
                                  length = 'all',# pass 'all' for all
                                  transform=None)

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler

validation_split = .1
shuffle_dataset = True
random_seed= 42
batch_size = 48

# Creating data indices for training and validation splits:
dataset_size = len(teeth_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

trainloader = torch.utils.data.DataLoader(
    teeth_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    sampler=train_sampler,
    worker_init_fn=worker_init_fn,
    pin_memory = True,
    drop_last =True
)
valloader = torch.utils.data.DataLoader(
    teeth_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    sampler=valid_sampler,
    worker_init_fn=worker_init_fn,
    pin_memory = True,
    drop_last =True
)
print ('Train size: ', len(trainloader))
print ('Validation size: ', len(valloader))

In [None]:
import time
import copy
import pdb
import pandas as pd

dataloaders = {'train': trainloader,'val':valloader}
dataset_sizes = {'train':len(trainloader), 'val':len(valloader)}


SMOOTH = 1e-6


def dice_loss(input, target):
    smooth = SMOOTH
    
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_score(input, target):
    smooth = SMOOTH
#     print(input.shape)
#     ipdb.set_trace()
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_per_channel(inputs, target):
    
    dice_ch = 0.0
    for i in range(1, inputs.shape[1]):
        inp = inputs[:,i,:,:]
        inp = inp.contiguous()
        targs = target[:,i,:,:]
        targs = targs.contiguous()
        dice_chl = dice_score(inp,targs)
        dice_ch +=dice_chl
    
    return dice_ch / (inputs.shape[1]-1)


def infer_model(model, output_df, num_epochs=15):
    start = time.time()
    out_dict={}
    running_f1=0.0
    running_f1_ch=0.0
    running_img_f1 = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            for data in dataloaders[phase]:
#                 ipdb.set_trace()
                inputs = data['image'][:,:,:,:]
                labels = data['masks'][:,:,:,:]
#               labels = labels.unsqueeze(0)
#                 labels = labels.float()
                
                inputs = inputs.to(device)
                labels = labels.to(device)
                labels = labels.type(torch.cuda.FloatTensor)
                with torch.set_grad_enabled(phase == 'train'):
                    
                    outputs = model(inputs)
                    preds = torch.sigmoid(outputs)
                    
                    bin_preds = preds.clone().detach()
                    bin_preds[bin_preds<=0.5]= 0.0
                    bin_preds[bin_preds>0.5]= 1.0
                    f1 = dice_score(bin_preds, labels)
                    f1_ch = dice_per_channel(bin_preds, labels)
                    running_f1 += f1
                    running_f1_ch += f1_ch
                    
                    img_f1 = 0.0
                    for img,lbl,pred,bin_pred in zip(inputs, labels, preds, bin_preds):
                        out_dict = {}
                        hard_dice = dice_score(bin_pred, lbl)
                        img_f1 += hard_dice
#                         ipdb.set_trace()
                        out_dict['hard_dice'] = hard_dice.detach().cpu().numpy()
                        out_dict['input_image'] = img[0,:,:].detach().cpu().numpy()
                        out_dict['mask'] = lbl.detach().cpu().numpy()
                        out_dict['pred'] = bin_pred.detach().cpu().numpy()
                        
#                         out_dict['act_image'] = act_img.numpy()
#                         out_dict['act_mask'] = act_mask.numpy()
#                         out_dict['act_tmask'] = act_tmask.numpy()
                        
                        output_df = output_df.append(out_dict, ignore_index=True)
                    batch_img_f1 = img_f1 / labels.shape[0]
                    running_img_f1 +=batch_img_f1
            
            torch.cuda.empty_cache()
#         ipdb.set_trace()
        epoch_f1 = running_f1 / dataset_sizes[phase]
        epoch_f1_ch = running_f1_ch / dataset_sizes[phase]
        img_f1 = running_img_f1 / dataset_sizes[phase]
        print('F1: {:.4f} '.format(epoch_f1))
        print('F1 per channel: {:.4f} '.format(epoch_f1_ch))
        print('Mean Image F1: {:.4f} '.format(img_f1))
        print('Epoch completed in {:.4f} seconds'.format(time.time()-start))
        torch.cuda.empty_cache()
    return model, output_df

In [None]:
import pandas as pd

output_df = pd.DataFrame(columns=['hard_dice','input_image', 'mask', 'pred'])
model_out, output_df = infer_model(rohan_unet, output_df, num_epochs=1)

In [None]:
sorted_df = output_df.sort_values('hard_dice', ignore_index = True)

In [None]:
new_df = pd.DataFrame(columns = sorted_df.columns)

for i in range(len(sorted_df)):
    if len(np.unique(sorted_df['mask'][i]))>1:
        new_df = new_df.append({'hard_dice':sorted_df.hard_dice[i], 'input_image':sorted_df.input_image[i], 'mask':sorted_df['mask'][i], 'pred':sorted_df.pred[i]}, ignore_index=True)

new_df.head()

In [None]:
low_100 = sorted_df[:50]
top_100 = sorted_df[-50:]
conditional_100 = sorted_df[sorted_df.hard_dice>0.6][:50]
has_caries_top100 = new_df[-50:]

In [None]:
# sorted_df.pred[1].shape


In [None]:
# preds_numpy = preds.detach().numpy()
def plotter(ax, input_img, gt_mask, pred, f1, idx):
    
    ax[idx,0].imshow(input_img, cmap='gray')
    
    ax[idx,1].imshow(np.squeeze(np.argmax(np.flip(gt_mask,axis=0),axis=0)) ,vmin=0,vmax=32 , cmap='flag')
    
    ax[idx,2].imshow(np.squeeze(np.argmax(np.flip(pred,axis=0),axis=0)) ,vmin=0,vmax=32 , cmap='flag')
    
    if idx==0:
        ax[idx,0].set_title('Input Image')
        ax[idx,1].set_title('Ground Truth Mask')
        ax[idx,2].set_title('Prediction Image \n Dice = {:.4f}'.format(f1))
    else:
        ax[idx,2].set_title('Dice = {:.4f}'.format(f1))
    

def plot_df(df, title):
    f, ax = plt.subplots(len(df) , 3, figsize=(24,250))

    for i in range(len(df)):
        input_img = df.input_image[df.index[i]]
        gt_mask = df['mask'][df.index[i]]
        pred = df.pred[df.index[i]]
        f1 = df.hard_dice[df.index[i]]
        
        
        plotter(ax, input_img, gt_mask, pred, f1, i)
    f.tight_layout()
    f.suptitle(title,y=1.001, fontsize=16)
    plt.show()


plot_df(low_100, 'Images with low hard dice value')

In [None]:
plot_df(top_100, 'Images with high hard dice value')

In [None]:
plot