In [133]:
import pandas as pd
import ttach as tta
import os
import shutil
import pathlib
import gc
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
# from tqdm.auto import tqdm
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
import torchvision
from torchvision import datasets
import cv2
from torch.cuda import amp

import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.utils import make_grid
import optuna
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

In [134]:
def seed_everything(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [135]:
seed_everything(42)

In [147]:
CFGS = [           
#     {
#         'model_name': 'Unet',
#         'backbone': 'efficientnet-b7',
#         'img_size': [256, 256],
#         'num_classes': 1,
#         'model_pth': '../../output/exp_pl2_new01/Unet/efficientnet-b7-256/checkpoint_dice.pth',
#         'threshold': 0.84,
#         'tta': True, 
#     },
    {
        'model_name': 'Unet',
        'backbone': 'efficientnet-b7',
        'img_size': [256, 256],
        'num_classes': 1,
        'model_pth': '/home/rohits/pv1/contrail/output/exp_s1_pl_03/pl_round2/efficientnet-b7-256-s2-2/checkpoint_dice.pth',
        'threshold': 0.9,
        'tta': True, 
    },
    
]

In [148]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [149]:
class ContrailDataset:
    def __init__(self, df, transform=None):
        self.df = df  
        self.images = df['image']
        self.labels = df['label']
        self.transform =transform        
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image = np.load("../../input/" + self.images[idx]).astype(float)   
        label = np.load("../../input/" + self.labels[idx]).astype(float)
        
        
        if self.transform :
            data = self.transform(image=image, mask=label)
            image  = data['image']
            label  = data['mask']
            image = np.transpose(image, (2, 0, 1))
            label = np.transpose(label, (2, 0, 1))  
            
        class_label = 1 if label.sum() > 0 else 0
            
        return torch.tensor(image), torch.tensor(label), torch.tensor(class_label)

In [150]:
class Net(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        self.model = smp.Unet(
            encoder_name=cfg["backbone"],      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
            in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=cfg["num_classes"],        # model output channels (number of classes in your dataset)
            activation=None
        )
        
    
    def forward(self, inputs):
        mask = self.model(inputs)
        return mask

In [151]:
def dice_coef(y_true, y_pred, thr=0.5, epsilon=1e-6):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum()
    den = y_true.sum() + y_pred.sum()
    dice = ((2*inter+epsilon)/(den+epsilon)).mean()
    
    return dice


def iou_coef(y_true, y_pred, thr=0.5, epsilon=1e-6):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum()
    union = (y_true + y_pred - y_true*y_pred).sum()
    iou = ((inter+epsilon)/(union+epsilon)).mean(0)
    return iou


def get_transform(img_size):
    transform = A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
    ], p=1.0)
    return transform




In [152]:
def get_preds(model, data_loader, img_size):
    
    model.to(device)
    model.eval()
    
    preds = []
    masks = []
    class_labels = []

    for index, (image, mask, class_label) in enumerate(tqdm(data_loader)):  
        image  = image.to(device, dtype=torch.float)
        mask  = mask.to(device, dtype=torch.float)
        class_label  = class_label.to(device, dtype=torch.int64)
        class_labels.append(class_label)
        
        if img_size != 256:
            mask = torch.nn.functional.interpolate(mask, size=256, mode='nearest')             
        masks.append(torch.squeeze(mask, dim=1))

        with torch.inference_mode():            
            pred = model(image)                
            pred = pred.sigmoid()

            if img_size != 256:
                pred = torch.nn.functional.interpolate(pred, size=256, mode='nearest') 
            preds.append(torch.squeeze(pred, dim=1))

    
    model_masks = torch.cat(masks, dim=0)
    model_preds = torch.cat(preds, dim=0)
    class_labels = torch.cat(class_labels, dim=0) 
    
    
    model_masks = torch.flatten(model_masks, start_dim=0, end_dim=1)
    model_preds = torch.flatten(model_preds, start_dim=0, end_dim=1)  
    class_labels = torch.flatten(class_labels)  
    
    best_threshold = 0.0
    best_dice_score = 0.0
    for threshold in [i / 100 for i in range(101)] :
        score = dice_coef(model_masks, model_preds, thr=threshold).cpu().detach().numpy() 
        if score > best_dice_score:
            best_dice_score = score
            best_threshold = threshold
            
            
    return model_masks, model_preds, best_dice_score, best_threshold, class_labels

In [153]:
def generate_stats(mask, pred, th, stats, append = ""):
    aucpr = iou_coef(mask, pred, thr=th, epsilon=1e-6)
    tp, fp, fn, tn = smp.metrics.get_stats(pred, mask.round().long(), mode='binary', threshold=th)
    iou_image = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
    iou_overall = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")

    stats += f"Global Dice Score {append}: {score}, TH: {th} | IOU_IMAGE_WISE: {iou_image},  IOU_OVERALL: {iou_overall}\n"
    return stats

In [154]:
# val_df.shape, val_df_0.shape, val_df_1.shape

In [155]:
val_df = pd.read_csv("../../input/val_df_filled.csv")
val_df_0 = val_df.loc[val_df['class'] == 0].reset_index(drop=True)
val_df_1 = val_df.loc[val_df['class'] == 1].reset_index(drop=True)


val_transform = get_transform([256, 256])
valid_dataset = ContrailDataset(val_df, transform=val_transform)  
valid_dataset0 = ContrailDataset(val_df_0, transform=val_transform)  
valid_dataset1 = ContrailDataset(val_df_1, transform=val_transform)  


valid_loader = DataLoader(
    valid_dataset, 
    batch_size = 32, #32, 
    shuffle = False, 
    num_workers = 2, 
    pin_memory = True, 
    drop_last = False
)

valid_loader0 = DataLoader(
    valid_dataset0, 
    batch_size = 32, #32, 
    shuffle = False, 
    num_workers = 2, 
    pin_memory = True, 
    drop_last = False
)

valid_loader1 = DataLoader(
    valid_dataset1, 
    batch_size = 32, #32, 
    shuffle = False, 
    num_workers = 2, 
    pin_memory = True, 
    drop_last = False
)


cfg = CFGS[0]

model = Net(cfg)
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(torch.load(cfg['model_pth'], map_location=torch.device('cpu'))['model'])
model_tta = tta.SegmentationTTAWrapper(model, tta.aliases.flip_transform(), merge_mode='mean')     
print()





In [156]:
stats = ""
mask, pred, score, th, cls_labels = get_preds(model, valid_loader, img_size=256)
stats = generate_stats(mask, pred, th, stats, append = "TTA:0")

mask, pred, score, th, cls_labels = get_preds(model, valid_loader1, img_size=256)
stats = generate_stats(mask, pred, th, stats, append = "POS TTA:0")

mask, pred, score, th, cls_labels = get_preds(model_tta, valid_loader, img_size=256)
stats = generate_stats(mask, pred, th, stats, append = "TTA:3")

mask, pred, score, th, cls_labels = get_preds(model_tta, valid_loader1, img_size=256)
stats = generate_stats(mask, pred, th, stats, append = "POS TTA:3")


  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

In [157]:
print("*"*100)
print(stats)
print("*"*100)

****************************************************************************************************
Global Dice Score TTA:0: 0.6755469441413879, TH: 0.75 | IOU_IMAGE_WISE: 0.9531981348991394,  IOU_OVERALL: 0.5080936551094055
Global Dice Score POS TTA:0: 0.6846645474433899, TH: 0.74 | IOU_IMAGE_WISE: 0.8537121415138245,  IOU_OVERALL: 0.5184818506240845
Global Dice Score TTA:3: 0.6779290437698364, TH: 0.66 | IOU_IMAGE_WISE: 0.9536939263343811,  IOU_OVERALL: 0.5107443928718567
Global Dice Score POS TTA:3: 0.6869837045669556, TH: 0.47 | IOU_IMAGE_WISE: 0.8542597889900208,  IOU_OVERALL: 0.5211914777755737

****************************************************************************************************


In [None]:
# Global Dice Score TTA:0: 0.6759867668151855, TH: 0.9 | IOU_IMAGE_WISE: 0.9537177681922913,  IOU_OVERALL: 0.5087465047836304
# Global Dice Score TTA:0: 0.6844078302383423, TH: 0.58 | IOU_IMAGE_WISE: 0.8528633117675781,  IOU_OVERALL: 0.5184430480003357
# Global Dice Score TTA:0: 0.6778432130813599, TH: 0.57 | IOU_IMAGE_WISE: 0.9537582993507385,  IOU_OVERALL: 0.5108285546302795
# Global Dice Score TTA:0: 0.6860638856887817, TH: 0.28 | IOU_IMAGE_WISE: 0.8528570532798767,  IOU_OVERALL: 0.5203468799591064

In [87]:
pred.shape

torch.Size([1856, 256, 256])

In [88]:
mask.shape

torch.Size([1856, 256, 256])

In [109]:
tp, fp, fn, tn = smp.metrics.get_stats(pred, mask.round().long(), mode='binary', threshold=th)

In [111]:
recall_image = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
precision_image = smp.metrics.precision(tp, fp, fn, tn, reduction="micro-imagewise")
iou_score_image = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")


recall_overall = smp.metrics.recall(tp, fp, fn, tn, reduction="macro")
precision_overall = smp.metrics.precision(tp, fp, fn, tn, reduction="macro")
iou_score_overall = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")


In [112]:
precision, precision_overall, recall, recall_overall , iou_score, iou_score_overall

(tensor(0.8831, device='cuda:0'),
 tensor(0.6825, device='cuda:0'),
 tensor(0.8766, device='cuda:0'),
 tensor(0.6682, device='cuda:0'),
 tensor(0.7898, device='cuda:0'),
 tensor(0.5092, device='cuda:0'))

In [None]:
(tensor(0.8831, device='cuda:0'),
 tensor(0.6670, device='cuda:0'),
 tensor(0.8766, device='cuda:0'),
 tensor(0.6833, device='cuda:0'),
 tensor(0.7898, device='cuda:0'),
 tensor(0.5090, device='cuda:0'))

In [115]:
smp.metrics.false_positive_rate(tp, fp, fn, tn, reduction="micro-imagewise")
smp.metrics.false_negative_rate(tp, fp, fn, tn, reduction="micro-imagewise")

tensor(0.8307, device='cuda:0')

In [69]:
pred.shape

torch.Size([1856, 256, 256])

In [71]:
cls_labels.shape

torch.Size([58, 32])

In [72]:
torch.flatten(cls_labels).shape

torch.Size([1856])

torch.Size([1856, 256])

In [141]:
val_df_0

Unnamed: 0,image,label,class
0,valid_data/3687499407028137410/image.npy,valid_data/3687499407028137410/label.npy,0
1,valid_data/7355354609194882312/image.npy,valid_data/7355354609194882312/label.npy,0
2,valid_data/7547747455642200110/image.npy,valid_data/7547747455642200110/label.npy,0
3,valid_data/8604370548989406919/image.npy,valid_data/8604370548989406919/label.npy,0
4,valid_data/4746167155668084215/image.npy,valid_data/4746167155668084215/label.npy,0
...,...,...,...
1299,valid_data/7206666542994541713/image.npy,valid_data/7206666542994541713/label.npy,0
1300,valid_data/2819090710836460710/image.npy,valid_data/2819090710836460710/label.npy,0
1301,valid_data/922629314296188212/image.npy,valid_data/922629314296188212/label.npy,0
1302,valid_data/3319793057592206418/image.npy,valid_data/3319793057592206418/label.npy,0
