In [1]:
# Imports
!pip install albumentations



In [20]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [3]:
!pip install rasterio



In [24]:
# Imports
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
import multiprocessing
import glob
import re
import plotly.express as pxy
from matplotlib.patches import Rectangle
from matplotlib.colors import LinearSegmentedColormap
import torch

pd.options.plotting.backend = "plotly"
import random
from glob import glob
import os, shutil
import time
import copy
import joblib
from collections import defaultdict
import gc
from IPython import display as ipd
from torch.optim import lr_scheduler

# visualization
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

# PyTorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import segmentation_models_pytorch as smp
import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

import rasterio
from joblib import Parallel, delayed

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


#Importing Custom u-net model
import u_net


seed          = 42
model_name    = 'Unet'
train_bs      = 32
valid_bs      = train_bs*2
img_size      = (224, 224)
epochs        = 5
lr            = 2e-3
scheduler     = 'CosineAnnealingLR'
min_lr        = 1e-6
T_max         = int(30000/train_bs*epochs)+50
T_0           = 25
warmup_epochs = 0
wd            = 1e-6
n_accumulate  = max(1, 32//train_bs)
n_fold        = 5
fold_selected = 1
num_classes   = 3
device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
multiprocessing.set_start_method("fork")

In [6]:
train_df = pd.read_csv("/Users/khushi/Downloads/uw-madison-gi-tract-image-segmentation/train.csv")

In [7]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 115488 entries, 0 to 115487
Data columns (total 3 columns):
 #   Column        Non-Null Count   Dtype 
---  ------        --------------   ----- 
 0   id            115488 non-null  object
 1   class         115488 non-null  object
 2   segmentation  33913 non-null   object
dtypes: object(3)
memory usage: 2.6+ MB


In [8]:
train_df.describe()

Unnamed: 0,id,class,segmentation
count,115488,115488,33913
unique,38496,3,33899
top,case123_day20_slice_0001,large_bowel,12629 10 12894 12 13158 15 13423 17 13688 19 1...
freq,3,38496,2


In [9]:
# Utility Functions
def get_image_path(base_path, df):
    '''Gets the case, day, slice_no and path of the dataset (either train or test).
    base_path: path to train image folder
    return :: modified dataframe'''
    
    digit_pat = r'[0-9]+'
    # Create case, day and slice columns
    df["case"] = df['id'].apply(lambda x: re.findall(digit_pat, x.split('_')[0])[0]) # df["id"].apply(lambda x: x.split("_")[0])
    df["day"] = df['id'].apply(lambda x: re.findall(digit_pat, x.split('_')[1])[0])  # df["id"].apply(lambda x: x.split("_")[1])
    df["slice_no"] = df["id"].apply(lambda x: x.split("_")[-1])

    df["path"] = 0
    
    n = len(df)

    # Loop through entire dataset
    for k in tqdm(range(n)):
        data = df.iloc[k, :]
        segmentation = data.segmentation

        # In case coordinates for healthy tissue are present
        case = "case"+data.case
        day = 'day'+data.day
        slice_no = data.slice_no
        # Change value to the correct one
        df.loc[k, "path"] = glob(f"{base_path}/{case}/{case}_{day}/scans/slice_{slice_no}*")[0]
    return df


def get_img_size(x, flag):
    
    if x != 0:
        split = x.split("_")
        width = split[3]
        height = split[4]
    
        if flag == "width":
            return int(width)
        elif flag == "height":
            return int(height)
    
    return 0

def get_pixel_size(x, flag):
    
    if x != 0:
        split = x.split("_")
        width = split[-2]
        height = ".".join(split[-1].split(".")[:-1])
    
        if flag == "width":
            return float(width)
        elif flag == "height":
            return float(height)
    
    return 0

def CustomCmap(rgb_color):

    r1,g1,b1 = rgb_color

    cdict = {'red': ((0, r1, r1),
                   (1, r1, r1)),
           'green': ((0, g1, g1),
                    (1, g1, g1)),
           'blue': ((0, b1, b1),
                   (1, b1, b1))}

    cmap = LinearSegmentedColormap('custom_cmap', cdict)
    return cmap


def show_sample_images(sample_paths):
    '''Displays simple images (without mask).'''

    # Get additional info from the path
    case_name = [info.split("_")[0][-7:] for info in sample_paths]
    day_name = [info.split("_")[1].split("/")[0] for info in sample_paths]
    slice_name = [info.split("_")[2] for info in sample_paths]


    # Plot
    fig, axs = plt.subplots(2, 5, figsize=(23, 8))
    axs = axs.flatten()

    for k, path in enumerate(sample_paths):
        
        title = f"{k+1}. {case_name[k]} - {day_name[k]} - {slice_name[k]}"
        axs[k].set_title(title, fontsize = 14, 
                         color = my_colors[-1], weight='bold')
        axs[k].imshow(img)
        axs[k].axis("off")

    plt.tight_layout()
    plt.show()
    
def mask_from_segmentation(segment, shape):
    segm = np.asarray(segment.split(), dtype=int)

   
    # Get start point and length between points
    start_point = segm[0::2] - 1
    length_point = segm[1::2]

    # Compute the location of each endpoint
    end_point = start_point + length_point

    # Create an empty list mask the size of the original image
    # take onl
    case_mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)

    # Change pixels from 0 to 1 that are within the segmentation
    for start, end in zip(start_point, end_point):
        case_mask[start:end] = 1

    case_mask = case_mask.reshape((shape[0], shape[1]))    
    return case_mask

def plot_original_mask(img, mask, alpha=1):

    # Change pixels - when 1 make True, when 0 make NA
    mask = np.ma.masked_where(mask == 0, mask)

    # Split the channels
    mask_largeB = mask[:, :, 0]
    mask_smallB = mask[:, :, 1]
    mask_stomach = mask[:, :, 2]

    # Plot the 2 images (Original and with Mask)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))

    # Original
    ax1.set_title("Original Image")
    ax1.imshow(img)
    ax1.axis("off")

    # With Mask
    ax2.set_title("Image with Mask")
    ax2.imshow(img)
    ax2.imshow(mask_largeB, interpolation='none', cmap=CMAP1, alpha=alpha)
    ax2.imshow(mask_smallB, interpolation='none', cmap=CMAP2, alpha=alpha)
    ax2.imshow(mask_stomach, interpolation='none',cmap=CMAP3, alpha=alpha)
    ax2.legend(legend_colors, ['large_bowel', 'small_bowel', 'stomach'])
    ax2.axis("off")
    
#     fig.savefig('foo.png', dpi=500)
    plt.show()

In [10]:
mask_colors = [(1.0, 0.7, 0.1), (1.0, 0.5, 1.0), (1.0, 0.22, 0.099)]
legend_colors = [Rectangle((0,0),1,1, color=color) for color in mask_colors]
labels = ["Large Bowel", "Small Bowel", "Stomach"]

CMAP1 = CustomCmap(mask_colors[0])
CMAP2 = CustomCmap(mask_colors[1])
CMAP3 = CustomCmap(mask_colors[2])

In [11]:
base_path = "/Users/khushi/Downloads/uw-madison-gi-tract-image-segmentation/train"

# # Prep and save file
train_data = get_image_path(base_path, df=train_df)

train_data["image_width"] = train_data["path"].apply(lambda x: int(x[:-4].rsplit("_",4)[1]))
train_data["image_height"] = train_data["path"].apply(lambda x: int(x[:-4].rsplit("_",4)[2]))

train_data["pixel_width"] = train_data["path"].apply(lambda x: int(x[:-4].rsplit("_",4)[1]))
train_data["pixel_height"] = train_data["path"].apply(lambda x: get_pixel_size(x, "height"))

train_data['slice_no'] = train_data['slice_no'].apply(lambda x: x[0])


train_data['case'] = train_data['case'].astype(int)

df_train = pd.DataFrame({'id':train_data['id'][::3]})

df_train['large_bowel'] = train_data['segmentation'][::3].values
df_train['small_bowel'] = train_data['segmentation'][1::3].values
df_train['stomach'] = train_data['segmentation'][2::3].values

df_train['path'] = train_data['path'][::3].values
df_train['case'] = train_data['case'][::3].values
df_train['day'] = train_data['day'][::3].values
df_train['slice'] = train_data['slice_no'][::3].values
df_train['width'] = train_data['image_width'][::3].values
df_train['height'] = train_data['image_height'][::3].values


df_train.reset_index(inplace=True,drop=True)
df_train.fillna('',inplace=True); 
df_train['count'] = np.sum(df_train.iloc[:,1:4]!='',axis=1).values
df_train.sample(5)

100%|██████████| 115488/115488 [00:35<00:00, 3256.50it/s]


Unnamed: 0,id,large_bowel,small_bowel,stomach,path,case,day,slice,width,height,count
4228,case145_day19_slice_0053,,,44155 1 44513 6 44871 9 45230 11 45589 13 4594...,/Users/khushi/Downloads/uw-madison-gi-tract-im...,145,19,0,360,310,1
17556,case7_day0_slice_0021,,,,/Users/khushi/Downloads/uw-madison-gi-tract-im...,7,0,0,266,266,0
3426,case88_day0_slice_0115,20675 2 21033 7 21389 12 21505 5 21742 19 2186...,17117 6 17468 16 17826 19 18184 23 18543 27 18...,,/Users/khushi/Downloads/uw-madison-gi-tract-im...,88,0,0,360,310,2
25776,case55_day20_slice_0065,,,,/Users/khushi/Downloads/uw-madison-gi-tract-im...,55,20,0,266,266,0
6152,case91_day0_slice_0025,,,,/Users/khushi/Downloads/uw-madison-gi-tract-im...,91,0,0,266,266,0


In [12]:
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction


# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [13]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, df, subset='train', transforms=None):
        self.df = df
        self.subset = subset
        self.transforms = transforms
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        masks = np.zeros((224,224,3), dtype=np.float32)
        img_path = self.df['path'].iloc[index]
        w = self.df['width'].iloc[index]
        h = self.df['height'].iloc[index]
        img = self.__load_img(img_path)
        if self.subset=='train':
            for k,j in zip([0,1,2],["large_bowel","small_bowel","stomach"]):
                rles=self.df[j].iloc[index]
                mask = rle_decode(rles, shape=(h, w, 1))
                mask = cv2.resize(mask, (224,224))
                masks[:,:,k] = mask
        
        masks = masks.transpose(2, 0, 1)
        img = img.transpose(2, 0, 1)
        
        if self.subset=='train': return torch.tensor(img), torch.tensor(masks)
        else: return torch.tensor(img)

    def __load_img(self, img_path):
        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        img = (img - img.min())/(img.max() - img.min())*255.0 
        img = cv2.resize(img, (224,224))
        img = np.tile(img[...,None], [1, 1, 3]) # gray to rgb
        img = img.astype(np.float32) /255.
        return img
    
    def get_id_mask(self, idx, verbose=False):
        '''Returns a mask for each case ID. If no segmentation was found, the mask will be empty
        - meaning formed by only 0
        ID: the case ID from the train.csv file
        verbose: True if we want any prints
        return: segmentation mask'''

        train = self.df
        
        # ~~~ Get the data ~~~
        # Get the portion of dataframe where we have ONLY the speciffied ID
        #index_data = train[train['id']==id].reset_index(drop=True)
        for row in train[train.index==idx].iterrows():
            row = row[1]

            # Split the dataframe into 3 series of observations
            # each for one speciffic class - "large_bowel", "small_bowel", "stomach"
            observations = [index_data.loc[k, :] for k in range(3)]
        
        # ~~~ Create the mask ~~~
        # Get the maximum height out of all observations
        # if max == 0 then no class has a segmentation
        # otherwise we keep the length of the mask
        max_height = np.max([obs.image_height for obs in observations])
        max_width = np.max([obs.image_width for obs in observations])

        # Get shape of the image
        # 3 channels of color/classes
        shape = (max_height, max_width, 3)

        # Create an empty mask with the shape of the image
        mask = np.zeros(shape, dtype=np.uint8)

        # If there is at least 1 segmentation found in the group of 3 classes
        if max_height != 0:
            for k, location in enumerate(["large_bowel", "small_bowel", "stomach"]):
                observation = observations[k]
                segmentation = observation.segmentation

                # If a segmentation is found
                # Append a new channel to the mask
                if pd.isnull(segmentation) == False:
                    mask[..., k] = mask_from_segmentation(segmentation, shape)

        return mask

In [14]:
data_transforms = {
    "train": A.Compose([
        A.Resize(224,224, interpolation=cv2.INTER_NEAREST),
        A.Blur(blur_limit=(5, 5), p=1.0),
        A.Normalize(mean=0.5, std=2, max_pixel_value=255.0)], p=1.0),
    "valid": A.Compose([
        A.Resize(224,224, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
}

In [15]:
skf = StratifiedGroupKFold(n_splits=n_fold, shuffle=True, random_state=42)
for fold, (_, val_idx) in enumerate(skf.split(X=df_train, y=df_train['count'],groups =df_train['case']), 1):
    df_train.loc[val_idx, 'fold'] = fold

df_train['fold'] = df_train['fold'].astype(np.uint8)

train_ids = df_train[df_train["fold"]!=fold_selected].index
valid_ids = df_train[df_train["fold"]==fold_selected].index

# train_ids = train_data.index[:30000]
# valid_ids = train_data.index[30000:]

In [16]:
train_dataset = BuildDataset(df_train[df_train.index.isin(train_ids)], transforms=data_transforms['train'])
valid_dataset = BuildDataset(df_train[df_train.index.isin(valid_ids)], transforms=data_transforms['valid'])

train_loader = DataLoader(train_dataset,batch_size=32, num_workers=4, shuffle=True, pin_memory=True, drop_last=False)

valid_loader = DataLoader(valid_dataset, batch_size=64,num_workers=4, shuffle=False, pin_memory=True)

imgs, msks = next(iter(train_loader))
imgs.size(), msks.size()

(torch.Size([32, 3, 224, 224]), torch.Size([32, 3, 224, 224]))

In [17]:
pre_trained=False

In [25]:
if pre_trained==True:
    ENCODER = 'efficientnet-b4'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = ['large_bowel', 'small_bowel', 'stomach']
    ACTIVATION = 'softmax' # could be None for logits or 'softmax2d' for multiclass segmentation

    # create segmentation model with pretrained encoder
    model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # model = UNet(in_channels=3, out_channels=1)
    model=u_net.UNet()
    model.to(device)

In [26]:
from torchsummary import summary

In [27]:
summary(model,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
            Conv2d-4         [-1, 64, 224, 224]          36,928
       BatchNorm2d-5         [-1, 64, 224, 224]             128
              ReLU-6         [-1, 64, 224, 224]               0
        DoubleConv-7         [-1, 64, 224, 224]               0
         MaxPool2d-8         [-1, 64, 112, 112]               0
            Conv2d-9        [-1, 128, 112, 112]          73,856
      BatchNorm2d-10        [-1, 128, 112, 112]             256
             ReLU-11        [-1, 128, 112, 112]               0
           Conv2d-12        [-1, 128, 112, 112]         147,584
      BatchNorm2d-13        [-1, 128, 112, 112]             256
             ReLU-14        [-1, 128, 1

In [28]:
device = device
DiceLoss    = smp.losses.DiceLoss(mode='multilabel').to(device)
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')

from scipy.spatial.distance import directed_hausdorff

def dice_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter)/(den+epsilon)).mean(dim=(1,0))
    return dice



# The training loss goes to nan while evaluating the model through this hausdorff distance.
def hausdorff_distance(y_true, y_pred):
    difference = y_true - y_pred
    
    # Square distances using PyTorch einsum
    square_distances = torch.einsum("...i,...i->...", difference, difference)
    
    minimum_square_distance_a_to_b = torch.min(square_distances, dim=-1)[0]
    
    # Here we are outputting the mean hausdorff distance.
    return torch.mean(torch.sqrt(torch.max(minimum_square_distance_a_to_b, dim=-1).values))

def iou_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true*y_pred).sum(dim=dim)
    iou = ((inter+epsilon)/(union+epsilon)).mean(dim=(1,0))
    return iou

def criterion(y_pred, y_true):
    return 0.6*BCELoss(y_true, y_pred) + 0.4*DiceLoss(y_pred, y_true)

def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.to(device)
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:         
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
                
        # zero the parameter gradients
        optimizer.zero_grad()

        with amp.autocast(enabled=True):
            y_pred = model(images)
            loss   = criterion(y_pred, masks)
            loss   = loss / max(1, 32//train_bs)
        
        scaler.scale(loss).backward()
    
        if (step+1)%(max(1, 32//train_bs)) == 0:
            scaler.step(optimizer)
            scaler.update()
            
            optimizer.zero_grad()
            
            if scheduler is not None:
                scheduler.step()
#         optimizer.step()

        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(train_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_mem=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss

In [29]:
@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    val_scores = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks) in pbar:        
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        # Why do we want Sigmoid here?
#         y_pred = nn.Sigmoid()(y_pred)  
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
#         val_hausdorff = hausdorff_distance(masks, y_pred).detach().cpu().numpy()
        val_scores.append([val_dice, val_jaccard])
            
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_memory=f'{mem:0.2f} GB')
    val_scores  = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss, val_scores

In [30]:
def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
    cnt = 0
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice      = -np.inf
    best_epoch     = -1
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1):
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(model, optimizer, scheduler,
                                               dataloader=train_loader, 
                                               device=device, epoch=epoch)

        val_loss, val_scores = valid_one_epoch(model, valid_loader, 
                                                 device=device, 
                                                 epoch=epoch)
        val_dice, val_jaccard = val_scores
    
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)
        
        # Log the metrics
        print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')
        
        # deep copy the model
        if val_dice > best_dice:
            cnt = 0
#             print(f"{c_}Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            print(f"Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            #run.summary["Best Dice"]    = best_dice
           # run.summary["Best Jaccard"] = best_jaccard
           # run.summary["Best Epoch"]   = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch-{fold:02d}.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
#             print(f"Model Saved{sr_}")
        else:
            cnt += 1
        
        if cnt>2:
            # Early stopping. 
            # Can also apply callback.
            return model, history
        
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"last_epoch-{fold:02d}.bin"
#         torch.save(model.state_dict(), PATH)
            
        print(); print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_jaccard))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [31]:
def fetch_scheduler(optimizer, scheduler):
    if scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=T_max, 
                                                   eta_min=min_lr)
    elif scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=T_0, 
                                                             eta_min=min_lr)
    elif scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=7,
                                                   threshold=0.0001,
                                                   min_lr=min_lr,)
    elif scheduler == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
    elif scheduler == None:
        return None
        
    return scheduler

In [32]:
for fold in range(1):
    print(f'#'*35)
    print(f'######### Fold: {fold}')
    print(f'#'*35)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=2e-6)
    scheduler = fetch_scheduler(optimizer, scheduler)
    model, history = run_training(model, optimizer, scheduler,
                                  device=device,
                                  num_epochs=10)

###################################
######### Fold: 0
###################################
Epoch 1/10

Train :   0%|          | 0/956 [00:06<?, ?it/s]


KeyboardInterrupt: 

In [None]:
def load_model(path):
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [None]:
test_dataset = BuildDataset(df_train[df_train.index.isin(valid_ids)], 
                            transforms=data_transforms['valid'])
test_loader  = DataLoader(test_dataset, batch_size=5, 
                          num_workers=4, shuffle=False, pin_memory=True)

imgs, msks =  next(iter(test_loader))

imgs = imgs.to(device, dtype=torch.float)

preds = []
for fold in range(1):
    model = load_model(f"best_epoch-{fold:02d}.bin")
    with torch.no_grad():
        pred = model(imgs)
        pred = (nn.Sigmoid()(pred)>0.5).double()
    preds.append(pred)
    
imgs  = imgs.cpu().detach()
preds = torch.mean(torch.stack(preds, dim=0), dim=0).cpu().detach()

In [None]:
def plot_batch(imgs, msks, size=3):
    plt.figure(figsize=(5*5, 5))
    for idx in range(size):
        plt.subplot(1, 5, idx+1)
        img = imgs[idx,].permute((1, 2, 0)).numpy()*255.0
        img = img.astype('uint8')
        msk = msks[idx,].permute((1, 2, 0)).numpy()*255.0
        show_img_train(img, msk)
    plt.tight_layout()
    plt.show()

In [None]:
def show_img_train(img, mask=None):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
#     img = clahe.apply(img)
#     plt.figure(figsize=(10,10))
    plt.imshow(img, cmap='bone')

In [None]:
plt.imshow(preds[0].transpose(2,1))
plt.show()

In [None]:
plot_batch(imgs, preds, size=5)

In [None]:
preds

## Train On-Folds

In [18]:
import plotly.express as px
import numpy as np
tmp = np.load('pred_arr.txt')
# px.imshow(tmp)

In [19]:
px.imshow(tmp)

TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

In [20]:
tmp

array('[[[[1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   ...\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]]\n\n  [[1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   ...\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]]\n\n  [[1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   ...\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]\n   [1. 1. 1. ... 1. 1. 1.]]]]',
      dtype='<U513')