In [1]:
import os
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np

import os

from skmultilearn.model_selection import IterativeStratification
from sklearn.metrics import f1_score
import tifffile as tiff
from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

from cvcore.modeling.loss import binary_dice_metric, binary_iou_metric
from cvcore.data.endocv_dataset import EDDDataset
from cvcore.configs import get_cfg_defaults

%load_ext autoreload
%autoreload 2

## Grab external data

In [None]:
img_ids = []
labels = []

for img in os.listdir('data/Abnormal/'):
    img_ids.append(os.path.join('Abnormal', img))
    labels.append([0, 1, 0, 0, 0])

for img in os.listdir('data/ETIS-LaribPolypDB/'):
    img_ids.append(os.path.join('ETIS-LaribPolypDB/', img))
    labels.append([0, 0, 0, 0, 1])

for img in os.listdir('data/Kvasir-SEG/images/'):
    img_ids.append(os.path.join('Kvasir-SEG/images/', img))
    labels.append([0, 0, 0, 0, 1])
    
ext_df = pd.concat([pd.Series(img_ids), pd.DataFrame(labels)], 1)
ext_df.columns = ['img', 'BE', 'suspicious', 'HGD', 'cancer', 'polyp']
ext_df.to_csv('./data/external_data.csv')

## Split original data

In [None]:
imgs_dir = "data/originalImages/"
masks_dir = "data/masks/"

In [None]:
classes = [
    "BE",
    "suspicious",
    "HGD",
    "cancer",
    "polyp",
]

In [None]:
img_labels = []
img_ids = []

for img in os.listdir(imgs_dir):
    img_ids.append(img)
    img_path = os.path.join(imgs_dir, img)
    img_label = []
    for cls in classes:
        mask_path = os.path.join(masks_dir, img.replace(".jpg", f"_{cls}.tif"))
        if os.path.exists(mask_path):
            img_label.append(1)
        else:
            img_label.append(0)
    img_labels.append(img_label)

In [None]:
df = pd.concat([pd.Series(img_ids), 
                pd.DataFrame(img_labels)], axis=1)
df.columns = ["img"] + classes

In [None]:
for cls in classes:
    print(f"Class {cls} - num. samples {df[cls].value_counts()[0]}")
    
NUM_FOLDS = 5
SEED = 2709

iterkfold = IterativeStratification(n_splits=5, random_state=SEED)

x, y = df.iloc[:, 0].values, df.iloc[:, 1:].values
for i, (train, test) in enumerate(iterkfold.split(x, y)):
    print(x[train].shape, x[test].shape)
    df.loc[train].to_csv(f"data/train_fold{i}.csv", index=False)
    df.loc[test].to_csv(f"data/valid_fold{i}.csv", index=False)

## Search thresholds

In [2]:
seg_threshold = [0.5] * 5

In [3]:
_grid_thresholds = np.linspace(0.1, 0.6, 100)
classes = [
    "BE",
    "suspicious",
    "HGD",
    "cancer",
    "polyp"
]

In [4]:
models = ['b4_unet', 'b3_unet', 'resnet50_fpn']

In [5]:
def search_threshold(inputs, targets,
    grid_thresholds=np.linspace(0.1, 0.6, 100), 
    metric_func=f1_score):
    num_classes = inputs.shape[1]
    best_cls_thresholds = []
    for i in range(num_classes):
        class_inp = inputs[:, i]
        class_tar = targets[:, i]
        grid_scores = []
        for thresh in _grid_thresholds:
            grid_scores.append(metric_func(class_tar, class_inp > thresh))
        best_t = grid_thresholds[np.argmax(grid_scores)]
        best_score = np.max(grid_scores)
        best_cls_thresholds.append(best_t)
    return best_cls_thresholds

In [6]:
ens_mask_output_list = []
valid_mask_list = []

for seg_weights in [
    [1., 0., 0.],
    [0., 1., 0.],
    [0., 0., 1.],
#     [0.5, 0.4, 0.1],
#     [0.6, 0.3, 0.1],
#     [0.7, 0.2, 0.1],
    [1./3, 1./3, 1./3]
    
]:
    for i in range(5):
        valid_mask = torch.load(f'thresholds_tuning/mask_{i}.pth')
        ens_mask_output = 0
        for model, w in zip(models, seg_weights):
            if model == 'b4_unet':
                model_output = torch.load(f'thresholds_tuning/{model}_fold{i}.pth')
                model_output = F.interpolate(model_output, 384, 
                    mode='bilinear', align_corners=False)
            else:
                model_output = torch.load(f'thresholds_tuning/{model}_fold{i}.pth')
            ens_mask_output += model_output * w
        ens_mask_output_list.append(ens_mask_output)
        valid_mask_list.append(valid_mask)
    ens_mask_output = torch.cat(ens_mask_output_list, 0)
    valid_mask = torch.cat(valid_mask_list, 0)
    dice_score = binary_dice_metric(ens_mask_output, valid_mask, seg_threshold).mean().item()
    iou = binary_iou_metric(ens_mask_output, valid_mask, seg_threshold).mean().item()
    print(f'\nWeights: {seg_weights} Dice score: {dice_score} - IoU: {iou}')


Weights: [1.0, 0.0, 0.0] Dice score: 0.8522069454193115 - IoU: 0.8279612064361572

Weights: [0.0, 1.0, 0.0] Dice score: 0.8528971672058105 - IoU: 0.8286389708518982

Weights: [0.0, 0.0, 1.0] Dice score: 0.8534929156303406 - IoU: 0.8297587633132935

Weights: [0.3333333333333333, 0.3333333333333333, 0.3333333333333333] Dice score: 0.8470149040222168 - IoU: 0.8231855630874634


In [None]:
seg_weights = [0.5, 0.4, 0.1]

In [None]:
num_folds = 5

ens_mask_pred = 0

for m, seg_w in zip(models, seg_weights):
    single_mask_pred = 0
    for f in range(num_folds): # folds
        if m == "b4_unet":
            single_mask_pred += F.interpolate(
                torch.load(f'./thresholds_tuning/{m}_fold{f}_test.pth'),
                384, align_corners=False, mode='bilinear'
                ) / num_folds
        else:
            single_mask_pred += torch.load(f'./thresholds_tuning/{m}_fold{f}_test.pth') / num_folds
    ens_mask_pred += single_mask_pred * seg_w

In [None]:
ens_mask_pred = torch.where(ens_mask_pred!=0, 
                            torch.sigmoid(ens_mask_pred), ens_mask_pred)
ens_mask_pred = torch.stack([
    ens_mask_pred[:, i, ...] > th
    for i, th in enumerate(best_seg_thresholds)], 1)
ens_mask_pred = ens_mask_pred.float()

for out, i, o_sz in zip(ens_mask_pred, img_id, orig_size):
    out = F.interpolate(out.unsqueeze(0), o_sz,
        mode="bilinear", align_corners=False)
    out = out.squeeze(0)
    out = out.cpu().numpy().astype(np.uint8) * 255
    save_path = os.path.join(mask_pred_dir, i.replace(".jpg", ".tif"))
    tiff.imwrite(save_path, out)

## TRANQUANGDAT ATOMIC BOMB

In [None]:
best_seg_thresholds = [0.5] * 5

# models = ['rx101-x448', 'rx50-x384-iter-focal']
# seg_weights = [.7, .3]

# models = ['rx101-x448', 'rx50-x384-iter-focal', 'rx101-fpn']
# seg_weights = [.5, .3, .2]

models = ['rx101-x448', 'rx50-x384-iter-focal', 'rx101-fpn', 'b4-fpn']
seg_weights = [.45, .3, .2, .05]

out_dir = 'dattran2346_kfold/'
img_id = torch.load(f'{out_dir}test_img_ids.pth')
orig_size = torch.load(f'{out_dir}test_sizes.pth')
mask_pred_dir = out_dir

In [None]:
num_folds = 3
ens_mask_pred = 0

for m, seg_w in zip(models, seg_weights):
    single_mask_pred = 0
    for f in range(num_folds): # folds
        if m == "rx101-x448" or m == "rx101-fpn" or m == "b4-fpn":
            single_mask_pred += F.interpolate(
                torch.load(f'{out_dir}{m}_test_{f}.pth'),
                384, align_corners=False, mode='bilinear'
                ) / num_folds
        else:
            single_mask_pred += torch.load(f'{out_dir}{m}_test_{f}.pth') / num_folds
    ens_mask_pred += single_mask_pred * seg_w

In [None]:
ens_mask_pred = torch.where(ens_mask_pred!=0, 
                            torch.sigmoid(ens_mask_pred), ens_mask_pred)
ens_mask_pred = torch.stack([
    ens_mask_pred[:, i, ...] > th
    for i, th in enumerate(best_seg_thresholds)], 1)
ens_mask_pred = ens_mask_pred.float()

In [None]:
min_ins_ratio = 0.000927
min_art_ratio =  0.000293
min_sat_ratio = 0.000380

for out, i, o_sz in zip(ens_mask_pred, img_id, orig_size):
    out = F.interpolate(out.unsqueeze(0), o_sz,
        mode="bilinear", align_corners=False)
    out = out.squeeze(0)

    area = np.prod(out.shape[1:])
    instrument_area = out[0].sum()
    artefact_area = out[2].sum()
    saturation_area = out[-1].sum()
    
    if instrument_area > 0:
        ins_ratio = instrument_area / area
        if ins_ratio < min_ins_ratio: # less than min area in training set
            print('Instrument ', i)
            out[0] = 0
    
    if artefact_area > 0:
        art_ratio = artefact_area / area
        if art_ratio < min_art_ratio:
            print('Artefact ', i)
            out[2] = 0

    if saturation_area > 0:
        sat_ratio = saturation_area / area
        if sat_ratio < min_sat_ratio:
            print('Saturation ', i)
            out[-1] = 0
            
    out = out.cpu().numpy().astype(np.uint8) * 255
    save_path = os.path.join(mask_pred_dir, i+'.tif')
    tiff.imwrite(save_path, out)

In [None]:
min_ins_ratio = 0.000927
min_art_ratio =  0.000293
min_sat_ratio = 0.000380

for out, i, o_sz in zip(ens_mask_pred, img_id, orig_size):
    out = F.interpolate(out.unsqueeze(0), o_sz,
        mode="bilinear", align_corners=False)
    out = out.squeeze(0)

    area = np.prod(out.shape[1:])
    instrument_area = out[0].sum()
    artefact_area = out[2].sum()
    saturation_area = out[-1].sum()
    
    if instrument_area > 0:
        ins_ratio = instrument_area / area
        if ins_ratio < min_ins_ratio: # less than min area in training set
            print('Instrument ', i)
            out[0] = 0
    
    if artefact_area > 0:
        art_ratio = artefact_area / area
        if art_ratio < min_art_ratio:
            print('Artefact ', i)
            out[2] = 0

    if saturation_area > 0:
        sat_ratio = saturation_area / area
        if sat_ratio < min_sat_ratio:
            print('Saturation ', i)
            out[-1] = 0
            
    out = out.cpu().numpy().astype(np.uint8) * 255
    save_path = os.path.join(mask_pred_dir, i+'.tif')
    tiff.imwrite(save_path, out)

## Search Segmentation threshold

In [None]:
# best_seg_thresholds = []

# for i in range(5): # 5 classes
#     cls_out = ens_mask_output[:, i, ...].unsqueeze(1)
#     cls_mask = valid_mask[:, i, ...].unsqueeze(1)
#     _grid_dice_scores = []
#     _grid_ious = []
#     for thresh in _grid_thresholds:
#         _grid_dice_scores.append(binary_dice_metric(cls_out, cls_mask, thresh).mean().item())
#         _grid_ious.append(binary_iou_metric(cls_out, cls_mask, thresh).mean().item())
#     best_t = _grid_thresholds[np.argmax(_grid_dice_scores)]
# #     best_t = _grid_thresholds[np.argmax(_grid_ious)]
#     best_dice = np.max(_grid_dice_scores)
#     best_iou = np.max(_grid_ious)
#     best_seg_thresholds.append(best_t)

# # for i in range(5):
# for i in range(1):
#     valid_dice = binary_dice_metric(
#         ens_mask_output, valid_mask, best_seg_thresholds)
#     valid_iou = binary_iou_metric(
#         ens_mask_output, valid_mask, best_seg_thresholds)
    
#     print(f'Dice Score - Fold {i}: ', valid_dice.mean(0).mean(0).item())
#     print(f'IoU - Fold {i}: ', valid_iou.mean(0).mean(0).item())