In [None]:
from PIL import Image
import os
import pathlib
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import random
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle as sk_shuffle
from skimage.util import random_noise
import time
import os
from torch.utils import data
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

# from prior_dataloader import RetraceDataLoader, retrace_parser, retrace_parser_synth
from calculus_dataloader import RetraceDataLoader, CalculusBase

from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler
from custom_unets import NestedUNet, U_Net, DeepNestedUNet
from sync_batchnorm import SynchronizedBatchNorm2d, DataParallelWithCallback, convert_model
# from kornia.losses import FocalLoss
from pywick.losses import BCEDiceFocalLoss

import glob2
import pdb
import ipdb

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
import torchvision.models as models
import torch.nn as nn

rohan_unet = DeepNestedUNet(1,1)
if torch.cuda.device_count() > 0:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    rohan_unet = nn.DataParallel(rohan_unet)
rohan_unet = rohan_unet.to(device)
rohan_unet = convert_model(rohan_unet)
rohan_unet = rohan_unet.to(device)

In [None]:
pretrained_dict = torch.load('/home/rohan/prior_seg/models/prior_256/256_model_epoch_65.0_f1_0.8814.pth')
print('Weight dict before = {} weights'.format(len(pretrained_dict)))
model_dict = rohan_unet.state_dict()

# filter unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                       (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}
print('Weight dict after = {} weights'.format(len(pretrained_dict)))
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
rohan_unet.load_state_dict(model_dict)

In [None]:
for module in list(rohan_unet.children()):
    for layer in list(module.children())[:-1]:
        for param in layer.parameters():
            param.requires_grad = False

# for module in list(rohan_unet.children()):
#     for layer in list(module.children()):
#         for param in layer.parameters():
#             print(layer, param.requires_grad)

In [None]:
from torchsummary import summary
# summary(rohan_unet, input_size=(1,256,256))

In [None]:
random_seed= 42
mark4 = time.time()
root_dir='/home/rohan/Datasets/prior_clean/train'
calculus_clean_train = CalculusBase(root_dir, oversample=2, issynthetic=False, test_size=0.1,im_range='all', isTrain=True)
calculus_clean_val = CalculusBase(root_dir, oversample=2, issynthetic=False, test_size=0.1,im_range='all', isTrain=False)
print('Real dataload time {:.6f}'.format(time.time() - mark4))

mark3 = time.time()
synth_root_dir='/home/rohan/Datasets/synthetic_prior_clean/train'
calculus_synthetic_train = CalculusBase(synth_root_dir, oversample=1, issynthetic=True, test_size=0.05,im_range=[0,88000], isTrain=True)
calculus_synthetic_val = CalculusBase(synth_root_dir, oversample=1, issynthetic=True, test_size=0.05,im_range=[0,88000], isTrain=False)
print('Synthetic dataload time {:.6f}'.format(time.time() - mark3))


In [None]:
train_dataset = RetraceDataLoader(calculus_clean_train,
                                           calculus_synthetic_train,
                                  image_size=(384,384),
                                  length = 'all',# pass 'all' for all
                                  crop = True,
                                  transform=None,
                                 mode = 'train')


val_dataset = RetraceDataLoader(calculus_clean_val,
                                           None,
                                  image_size=(384,384),
                                  length = 'all',# pass 'all' for all
                                  crop = True,
                                  transform=None,
                               mode = 'val')

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler

random_seed= 42
batch_size = 256
# def get_sampler(dataset, shuffle=False):
    
#     # Creating data indices for training and validation splits:
#     dataset_size = len(dataset)
#     indices = list(range(dataset_size))
#     if shuffle :
#         np.random.seed(random_seed)
#         np.random.shuffle(indices)
#     sampler = SequentialSampler(indices)
    
#     return sampler

# # Creating PT data samplers and loaders:
# train_sampler = get_sampler(train_dataset, shuffle=True)
# valid_sampler = get_sampler(val_dataset, shuffle=False)

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
trainloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    sampler=None,
    worker_init_fn=worker_init_fn,
    pin_memory = True,
    drop_last =True
)

valloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    sampler=None,
    worker_init_fn=worker_init_fn,
    pin_memory = True,
    drop_last =True
)
print ('Train size: ', len(trainloader))
print ('Validation size: ', len(valloader))

In [None]:
import time
import copy
import pdb
import pandas as pd

dataloaders = {'train': trainloader,'val':valloader}
dataset_sizes = {'train':len(trainloader), 'val':len(valloader)}


SMOOTH = 1e-6


def dice_loss(input, target):
    smooth = SMOOTH
    
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_score(input, target):
    smooth = SMOOTH
#     print(input.shape)
#     ipdb.set_trace()
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_per_channel(inputs, target):
    
    dice_ch = 0.0
    for i in range(0, inputs.shape[1]):
        inp = inputs[:,i,:,:]
        inp = inp.contiguous()
        targs = target[:,i,:,:]
        targs = targs.contiguous()
        dice_chl = dice_score(inp,targs)
        dice_ch +=dice_chl
    
    return dice_ch / (inputs.shape[1]-1)

def dice_per_image(inputs, target):
    
    dice_img = 0.0
    for i in range(0, inputs.shape[0]):
        inp = inputs[i,:,:,:]
        inp = inp.contiguous()
        targs = target[i,:,:,:]
        targs = targs.contiguous()
        dice_im = dice_score(inp,targs)
        dice_img +=dice_im
    
    return dice_img / (inputs.shape[0]-1)


def train_model(model, criterion, optimizer, scheduler, writer, num_epochs=15):
    start = time.time()
    save_dict={}
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 10.0
    best_iou = 0.0
    best_f1 = 0.0
    best_f1_ch = 0.0
    best_f1_img = 0.0

    for epoch in range(num_epochs):
        ep_start = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        lrate = scheduler.get_lr()[0]
        writer.add_scalar('Learning Rate', lrate, epoch)
        print('LR {:.5f}'.format(lrate))
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_ious = 0.0
            running_f1 = 0.0
            running_f1_ch = 0.0
            running_f1_img = 0.0

            # Iterate over data.
            for data in dataloaders[phase]:
#                 ipdb.set_trace()
                inputs = data['image'][:,:,:,:]
                labels = data['masks'][:,:,:,:]
#               labels = labels.unsqueeze(0)
#                 labels = labels.float()
                
                inputs = inputs.to(device)
                labels = labels.to(device)
                labels = labels.type(torch.cuda.FloatTensor)
                 # zero the parameter gradients
                optimizer.zero_grad()
#               torch.autograd.set_detect_anomaly(True)
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    
                    outputs = model(inputs)
                    preds = torch.sigmoid(outputs)
                    
                    bce = criterion(preds, labels)
                    diceloss = dice_loss(preds,labels)
                    loss = bce * 0.5 + diceloss * (1 - 0.5)
                    
                    bin_preds = preds.clone().detach()
                    bin_preds[bin_preds<=0.5]= 0.0
                    bin_preds[bin_preds>0.5]= 1.0
                    
#                     ipdb.set_trace()
                    f1 = dice_score(bin_preds, labels)
                    f1_img = dice_per_image(bin_preds,labels)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    # statistics
                    running_loss += loss.data.cpu().numpy() # * inputs.size(0)
                    running_f1 += f1
                    running_f1_img += f1_img
                    
            torch.cuda.empty_cache()
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_f1 = running_f1 / dataset_sizes[phase]
            epoch_f1_img = running_f1_img / dataset_sizes[phase]

            if phase == 'train':
                writer.add_scalar('Loss/train', epoch_loss, epoch)
                writer.add_scalar('Hard_Dice/train', epoch_f1, epoch)
                writer.add_scalar('Hard_Dice_per_image/train', epoch_f1_img, epoch)
            else:
                writer.add_scalar('Loss/val', epoch_loss, epoch)
                writer.add_scalar('Hard_Dice/val', epoch_f1, epoch)
                writer.add_scalar('Hard_Dice_per_image/val', epoch_f1_img, epoch)

            print('{} Loss: {:.4f} F1: {:.4f} F1/img: {:.4f}'.format(phase, epoch_loss, epoch_f1, epoch_f1_img))#F1/ch: {:.4f} epoch_f1_ch,

            # deep copy the model
            if phase == 'val' and epoch_f1 > best_f1:
                best_loss = epoch_loss
                best_f1 = epoch_f1
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, '/home/rohan/prior_seg/models/calculus_frozen/calculus_256_model_epoch_{:.1f}_f1_{:.4f}.pth'.format(epoch, best_f1))
            writer.add_scalar('Hard_Dice/best_val', best_f1, epoch)
            

        print('Epoch completed in {:.4f} seconds'.format(time.time()-ep_start))
        torch.cuda.empty_cache()
        

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val F1: {:4f}'.format(best_f1))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd import Variable

criterion = torch.nn.BCELoss()
optimizer = optim.Adam(rohan_unet.parameters(), lr=0.0001)
writer = SummaryWriter(log_dir='/home/rohan/prior_seg/logs/calculus_frozen', filename_suffix = '_calculus_256')
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.85)
# rohan_unet.load_state_dict(torch.load('/home/rohan/caries_seg/models/test_model/128_nestedunet_model_epoch_0.0000_f1_0.1858.pth'))
model_trained = train_model(rohan_unet, criterion, optimizer, exp_lr_scheduler, writer = writer, num_epochs=20)