In [1]:
from PIL import Image
import os
import pathlib
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import random
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle as sk_shuffle
from skimage.util import random_noise
import time
import os
from torch.utils import data
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from majority_dl import RetraceDataLoader, retrace_parser, retrace_parser_synth
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler
from custom_unets import NestedUNet, U_Net, DeepNestedUNet
from sync_batchnorm import SynchronizedBatchNorm2d, DataParallelWithCallback, convert_model
# from kornia.losses import FocalLoss
from pywick.losses import BCEDiceFocalLoss, BinaryFocalLoss
import segmentation_models_pytorch as smp

import glob2
import pdb
import ipdb

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
import torchvision.models as models
import torch.nn as nn

# rohan_unet = DeepNestedUNet(1,33)
rohan_unet = smp.FPN(encoder_name= "densenet121",
        encoder_depth= 5,
        encoder_weights= None,
        decoder_pyramid_channels= 256,
        decoder_segmentation_channels= 128,
        decoder_merge_policy= "add",
        decoder_dropout= 0.2,
        in_channels= 1,
        classes= 3,
        activation= None,
        upsampling= 4
    )


if torch.cuda.device_count() > 0:
      print("Let's use", torch.cuda.device_count(), "GPUs!")
      # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
      rohan_unet = nn.DataParallel(rohan_unet)
# rohan_unet.load_state_dict(torch.load('/home/rohan/prior_seg/models/prior_fpn_1/fpn_model_epoch_17.0_f1_0.8538.pth'))
rohan_unet = rohan_unet.to(device)
rohan_unet = convert_model(rohan_unet)
rohan_unet = rohan_unet.to(device)

Let's use 2 GPUs!


In [4]:
from torchsummaryX import summary
# summary(rohan_unet, input_size=(1,128,128))

In [5]:
root_dir = '/home/rohan/Datasets/prior_clean/train/'
syn_root_dir = '/home/rohan/Datasets/synthetic_prior_clean/train/'

# prior_data = RetraceDataLoader(root_dir, syn_root_dir, length = 100)
teeth_dataset = RetraceDataLoader(root_dir=root_dir,
                                  root_dir_synth=syn_root_dir,
                                  image_size=(256,256),
                                  length = 'all',# pass 'all' for all
                                  transform=None)

Dataset length:  60178
Loaded dataset length: 60178
Dataset length synthetic:  300885


In [6]:
# teeth_dataset[7]

In [7]:
from torch.utils.data.sampler import SubsetRandomSampler

validation_split = .1
shuffle_dataset = True
random_seed= 42
batch_size = 80

# Creating data indices for training and validation splits:
dataset_size = len(teeth_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

trainloader = torch.utils.data.DataLoader(
    teeth_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    sampler=train_sampler,
    worker_init_fn=worker_init_fn,
    pin_memory = True,
    drop_last =True
)
valloader = torch.utils.data.DataLoader(
    teeth_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    sampler=valid_sampler,
    worker_init_fn=worker_init_fn,
    pin_memory = True,
    drop_last =True
)
print ('Train size: ', len(trainloader))
print ('Validation size: ', len(valloader))

Train size:  677
Validation size:  75


In [8]:
import time
import copy
import pdb
import pandas as pd

dataloaders = {'train': trainloader,'val':valloader}
dataset_sizes = {'train':len(trainloader), 'val':len(valloader)}


SMOOTH = 1e-6


def dice_loss(input, target):
    smooth = SMOOTH
    
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_score(input, target):
    smooth = SMOOTH
#     print(input.shape)
#     ipdb.set_trace()
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_per_channel(inputs, target):
    
    dice_ch = 0.0
    for i in range(0, inputs.shape[1]):
        inp = inputs[:,i,:,:]
        inp = inp.contiguous()
        targs = target[:,i,:,:]
        targs = targs.contiguous()
        dice_chl = dice_score(inp,targs)
        dice_ch +=dice_chl
    
    return dice_ch / (inputs.shape[1])

def dice_per_image(inputs, target):
    
    dice_img = 0.0
    for i in range(0, inputs.shape[0]):
        inp = inputs[i,:,:,:]
        inp = inp.contiguous()
        targs = target[i,:,:,:]
        targs = targs.contiguous()
        dice_im = dice_score(inp,targs)
        dice_img +=dice_im
    
    return dice_img / (inputs.shape[0])


def train_model(model, criterion, optimizer, scheduler, writer, num_epochs=15):
    start = time.time()
    save_dict={}
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 10.0
    best_iou = 0.0
    best_f1 = 0.0
    best_f1_ch = 0.0
    best_f1_img = 0.0

    for epoch in range(num_epochs):
        ep_start = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        lrate = scheduler.get_lr()[0]
        writer.add_scalar('Learning Rate', lrate, epoch)
        print('LR {:.5f}'.format(lrate))
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_ious = 0.0
            running_f1 = 0.0
            running_f1_ch = 0.0
            running_f1_img = 0.0

            # Iterate over data.
            for data in dataloaders[phase]:
#                 ipdb.set_trace()
                inputs = data['image'][:,:,:,:]
                labels = data['masks'][:,:,:,:]
#               labels = labels.unsqueeze(0)
#                 labels = labels.float()
                
                inputs = inputs.to(device)
                labels = labels.to(device)
                labels = labels.type(torch.cuda.FloatTensor)
                 # zero the parameter gradients
                optimizer.zero_grad()
#               torch.autograd.set_detect_anomaly(True)
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    
                    outputs = model(inputs)
                    fl = criterion(outputs, labels)
                    preds = torch.sigmoid(outputs)
#                     ipdb.set_trace()
                    
                    diceloss = dice_loss(preds,labels)
                    loss = fl * 0.8 + diceloss * (1 - 0.8)
                    
                    bin_preds = preds.clone().detach()
                    bin_preds[bin_preds<=0.5]= 0.0
                    bin_preds[bin_preds>0.5]= 1.0
                    
                    f1 = dice_score(bin_preds, labels)
                    f1_ch = dice_per_channel(bin_preds,labels)
                    f1_img = dice_per_image(bin_preds,labels)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    # statistics
                    running_loss += loss.data.cpu().numpy() # * inputs.size(0)
#                     running_ious += iou.data.cpu().numpy() # * inputs.size(0)
                    running_f1 += f1
                    running_f1_ch += f1_ch
                    running_f1_img += f1_img
                    
            torch.cuda.empty_cache()
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_f1 = running_f1 / dataset_sizes[phase]
            epoch_f1_ch = running_f1_ch / dataset_sizes[phase]
            epoch_f1_img = running_f1_img / dataset_sizes[phase]

            if phase == 'train':
                writer.add_scalar('Loss/train', epoch_loss, epoch)
                writer.add_scalar('Hard_Dice/train', epoch_f1, epoch)
                writer.add_scalar('Hard_Dice_per_channel/train', epoch_f1_ch, epoch)
                writer.add_scalar('Hard_Dice_per_image/train', epoch_f1_img, epoch)
            else:
                writer.add_scalar('Loss/val', epoch_loss, epoch)
                writer.add_scalar('Hard_Dice/val', epoch_f1, epoch)
                writer.add_scalar('Hard_Dice_per_channel/val', epoch_f1_ch, epoch)
                writer.add_scalar('Hard_Dice_per_image/val', epoch_f1_img, epoch)

            print('{} Loss: {:.4f} F1: {:.4f} F1/ch: {:.4f} F1/img: {:.4f}'.format(phase, epoch_loss, epoch_f1, epoch_f1_ch, epoch_f1_img))

            # deep copy the model
            if phase == 'val' and epoch_f1 > best_f1:
                best_loss = epoch_loss
                best_f1 = epoch_f1
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, '/home/rohan/prior_seg/models/major_model2_contd/restorative_model_epoch_{:.1f}_f1_{:.4f}.pth'.format(epoch, best_f1))
            writer.add_scalar('Hard_Dice/best_val', best_f1, epoch)
            

        print('Epoch completed in {:.4f} seconds'.format(time.time()-ep_start))
        torch.cuda.empty_cache()
        

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val F1: {:4f}'.format(best_f1))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd import Variable

criterion = BinaryFocalLoss(gamma=2.0, alpha=0.25) #torch.nn.BCELoss()
print("Focal Loss alpha = {:.2f} gamma = {:.1f}".format(criterion.alpha, criterion.gamma))
optimizer = optim.Adam(rohan_unet.parameters(), lr=0.00038)
writer = SummaryWriter(log_dir='/home/rohan/prior_seg/logs/major_model2_contd', filename_suffix = '_restorative')
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.95)
rohan_unet.load_state_dict(torch.load('/home/rohan/prior_seg/models/major_model2/restorative_model_epoch_55.0_f1_0.9380.pth'))
model_trained = train_model(rohan_unet, criterion, optimizer, exp_lr_scheduler, writer = writer, num_epochs=30)

Focal Loss alpha = 0.25 gamma = 2.0
Epoch 0/29
----------
LR 0.00038
train Loss: 0.0128 F1: 0.9493 F1/ch: 0.9588 F1/img: 0.8993
val Loss: 0.0176 F1: 0.9299 F1/ch: 0.9433 F1/img: 0.8820
Epoch completed in 451.6647 seconds
Epoch 1/29
----------
LR 0.00038
train Loss: 0.0130 F1: 0.9484 F1/ch: 0.9582 F1/img: 0.8969
val Loss: 0.0206 F1: 0.9198 F1/ch: 0.9359 F1/img: 0.8650
Epoch completed in 451.1109 seconds
Epoch 2/29
----------
LR 0.00038
train Loss: 0.0125 F1: 0.9503 F1/ch: 0.9594 F1/img: 0.9000
val Loss: 0.0171 F1: 0.9332 F1/ch: 0.9456 F1/img: 0.8789
Epoch completed in 450.4279 seconds
Epoch 3/29
----------
LR 0.00034
train Loss: 0.0122 F1: 0.9516 F1/ch: 0.9607 F1/img: 0.9027
val Loss: 0.0169 F1: 0.9343 F1/ch: 0.9463 F1/img: 0.8756
Epoch completed in 454.6396 seconds
Epoch 4/29
----------
LR 0.00036
train Loss: 0.0122 F1: 0.9514 F1/ch: 0.9604 F1/img: 0.9029
val Loss: 0.0175 F1: 0.9311 F1/ch: 0.9441 F1/img: 0.8723
Epoch completed in 451.9986 seconds
Epoch 5/29
----------
LR 0.00036
train 