In [None]:
!pip install --no-deps '../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl' -q
!pip install '../input/mean-average-precision-for-boxes/map_boxes-1.0.5-py3-none-any.whl' -q
!pip install '../input/pytorchlightning/pytorch_lightning-1.2.4-py3-none-any.whl' -q

In [None]:
import sys

sys.path.insert(0, "../input/omegaconf/omegaconf-master")
sys.path.insert(0, "../input/efficientdetpytorch/efficientdet-pytorch-master")
sys.path.insert(0, "../input/vbd-chest-xray-script")
sys.path.insert(0, "../input/weightedboxesfusion")
sys.path.insert(0, "../input/timm-pytorch-image-models/pytorch-image-models-master")

In [None]:
import os

import numpy as np
import pandas as pd

import torch
import pytorch_lightning as pl

from tqdm import tqdm

from models import XrayClassifier, XrayDetector
from datamodule import XrayTestDataModule, XrayTestEnsembleDataModule
from ensemble_boxes import *

In [None]:
def make_clf_preds(model, image_size, test_loader, device, debug=False):
    image_key = "image_" + str(image_size)

    image_ids = []
    preds = []

    for index, (sample, image_id, _, _) in enumerate(tqdm(test_loader)):
        if debug and index > 2:
            break
        image = sample[image_key]
        if debug: print(image_key, image.shape)
            
        pred = model(image.to(device))
        image_ids.extend(image_id)
        preds.extend(pred.detach().cpu().numpy().squeeze())
        
    return image_ids, preds

In [None]:
def make_clf_preds_ensemble(model, image_size_list, test_loader, device, debug=False):
    image_key_list = ["image_" + str(x) for x in image_size_list]
    
    image_ids = []
    preds = []
    
    for index, (sample, image_id, _, _) in enumerate(tqdm(test_loader)):
        if debug and index > 2:
            break
            
        pred_list = []
        for key, m in zip(image_key_list, model):
            image = sample[key]
            if debug: print(image.shape)
            pred = m(image.to(device))
            pred_list.append(pred)
            
        pred_concat = torch.cat(pred_list, dim=1)
        pred_mean = torch.mean(pred_concat, dim=1)
        
        image_ids.extend(image_id)
        preds.extend(pred_mean.detach().cpu().numpy().squeeze())
        
    return image_ids, preds

In [None]:
def make_clf_preds_df(model, image_size, test_loader, device, debug=False):
    if isinstance(model, list):
        if len(model) > 1:
            image_ids, preds = make_clf_preds_ensemble(model, image_size, test_loader, device, debug)
        else:
            image_ids, preds = make_clf_preds(model[0], image_size[0], test_loader, device, debug)
    else:
        image_ids, preds = make_clf_preds(model, image_size, test_loader, device, debug)
        
    df = pd.DataFrame(data=list(zip(image_ids, preds)), columns=["image_id_dicom", "preds"])
    
    return df

In [None]:
def convert_batch_pred(prediction, height, width, resize_height, resize_width):

    boxes = prediction[:, :, :4].detach().cpu().numpy()
    scores = prediction[:, :, 4].detach().cpu().numpy()
    labels = prediction[:, :, 5].detach().cpu().numpy().astype(np.int32)

    # 1-index to 0-index
    labels -= 1        

    height = height.detach().cpu().numpy()
    height = np.expand_dims(height, axis=1)

    width = width.detach().cpu().numpy()
    width = np.expand_dims(width, axis=1)

    boxes[:, :, 0] = boxes[:, :, 0] * width / resize_width
    boxes[:, :, 1] = boxes[:, :, 1] * height / resize_height
    boxes[:, :, 2] = boxes[:, :, 2] * width / resize_width
    boxes[:, :, 3] = boxes[:, :, 3] * height / resize_height

    boxes = boxes.astype(np.int32)

    boxes[:, :, 0] = boxes[:, :, 0].clip(min=0, max=width - 1)
    boxes[:, :, 1] = boxes[:, :, 1].clip(min=0, max=height - 1)
    boxes[:, :, 2] = boxes[:, :, 2].clip(min=0, max=width - 1)
    boxes[:, :, 3] = boxes[:, :, 3].clip(min=0, max=height - 1)
    
    return boxes, scores, labels

In [None]:
def make_det_preds(model, image_size, test_loader, device, debug=False, downscale_factor = 1):
    image_key = "image_" + str(image_size)
    
    image_ids = []
    boxes_preds = []
    scores_preds = []
    labels_preds = []

    for index, (sample, image_id, height_raw, width_raw) in enumerate(tqdm(test_loader)):
        if debug and index > 2:
            break
        image = sample[image_key]
        if debug: print(image_key, image.shape)
        pred = model(image.to(device))
                
        boxes, scores, labels = convert_batch_pred(
            pred,
            height=height_raw,
            width=width_raw,
            resize_height=image_size,
            resize_width=image_size,
        )

        image_ids.extend(image_id)
        boxes_preds.extend(boxes)
        scores_preds.extend(scores)
        labels_preds.extend(labels)

    return image_ids, boxes_preds, scores_preds, labels_preds

In [None]:
def make_det_preds_ensemble(model, image_size_list, test_loader, device, debug=False, downscale_factor=1, max_det_per_image=None, method="nms"):
    image_key_list = ["image_" + str(x) for x in image_size_list]
    
    image_ids = []
    boxes_preds = []
    scores_preds = []
    labels_preds = []
    
    for index, (sample, image_id, height_raw, width_raw) in enumerate(tqdm(test_loader)):
        if debug and index > 2:
            break

        boxes_list = []
        scores_list = []
        labels_list = []

        height_raw_np = height_raw.detach().cpu().numpy()
        height_raw_np = np.expand_dims(height_raw_np, axis=1)

        width_raw_np = width_raw.detach().cpu().numpy()
        width_raw_np = np.expand_dims(width_raw_np, axis=1)

        boxes_model = None
        scores_model = None
        labels_model = None
        
        for key, m in zip(image_key_list, model):
            image = sample[key]
            image_height = image.shape[2]
            image_width = image.shape[3]

            if debug: print(image.shape)
            pred = m(image.to(device))

            boxes, scores, labels = convert_batch_pred(
                pred,
                height=height_raw,
                width=width_raw,
                resize_height=image_height,
                resize_width=image_width,
            )

            # normalize boxes in range 0 to 1
            boxes = boxes.astype(np.float32)

            boxes[:, :, 0] = boxes[:, :, 0] / width_raw_np
            boxes[:, :, 1] = boxes[:, :, 1] / height_raw_np
            boxes[:, :, 2] = boxes[:, :, 2] / width_raw_np
            boxes[:, :, 3] = boxes[:, :, 3] / height_raw_np

            boxes_list.append(boxes)
            scores_list.append(scores)
            labels_list.append(labels)
                
        boxes_model = np.stack(boxes_list, axis=0)
        scores_model = np.stack(scores_list, axis=0)
        labels_model = np.stack(labels_list, axis=0)
        
        batch_size = labels_model.shape[1]

        iou_thr = 0.5
        skip_box_thr = 0.0001

        for index in range(batch_size):
            if method == "nms":
                boxes, scores, labels = nms(
                    boxes_model[:, index, :, :],
                    scores_model[:, index, :],
                    labels_model[:, index, :],
                    weights=None,
                    iou_thr=iou_thr
                )
            elif method == "wbf":
                boxes, scores, labels = weighted_boxes_fusion(
                    boxes_model[:, index, :, :],
                    scores_model[:, index, :],
                    labels_model[:, index, :],
                    weights=None,
                    iou_thr=iou_thr,
                    skip_box_thr=skip_box_thr,
                )
            else:
                raise Exception("method should be 'nms' or 'wbf'")

            if max_det_per_image is not None:
                if boxes.shape[0] > max_det_per_image:
                    ind = np.argsort(scores)[::-1][:max_det_per_image]
                    boxes = boxes[ind, :]
                    scores = scores[ind]
                    labels = labels[ind]
           
            labels = labels.astype(np.int64)

            # transform boxes to target size
            boxes[:, 0] = boxes[:, 0] * width_raw_np[index]
            boxes[:, 1] = boxes[:, 1] * height_raw_np[index]
            boxes[:, 2] = boxes[:, 2] * width_raw_np[index]
            boxes[:, 3] = boxes[:, 3] * height_raw_np[index]

            image_ids.extend([image_id[index]])
            boxes_preds.extend([boxes])
            scores_preds.extend([scores])
            labels_preds.extend([labels])

    return image_ids, boxes_preds, scores_preds, labels_preds

In [None]:
def format_pred(labels: np.ndarray, boxes: np.ndarray, scores: np.ndarray) -> str:
    pred_strings = []
    for label, score, bbox in zip(labels, scores, boxes):
        xmin, ymin, xmax, ymax = bbox.astype(np.int64)
        pred_strings.append(f"{label} {score} {xmin} {ymin} {xmax} {ymax}")
    return " ".join(pred_strings)

In [None]:
def make_det_preds_df(model, image_size, test_loader, device, debug=False, method="nms", max_det_per_image=None):
    if isinstance(model, list):
        if len(model) > 1:
            image_ids, boxes_preds, scores_preds, labels_preds = make_det_preds_ensemble(
                model, image_size, test_loader, device, debug, method, max_det_per_image
            )
        else:
            image_ids, boxes_preds, scores_preds, labels_preds = make_det_preds(
                model[0], image_size[0], test_loader, device, debug
            )
    else:
        image_ids, boxes_preds, scores_preds, labels_preds = make_det_preds(
            model, image_size, test_loader, device, debug
        )

    ids = []
    pred_string_list = []


    # class, confidence, xmin, ymin, xmax, ymax
    for image_id, boxes, scores, labels in zip(
        image_ids, boxes_preds, scores_preds, labels_preds
    ):
        image_id = image_id.split(".")[0]
        ids.append(image_id)
        
        pred_string = format_pred(labels, boxes, scores)
        pred_string_list.append(pred_string)

    df = pd.DataFrame(
        data=(zip(ids, pred_string_list)), columns=["image_id", "PredictionString"]
    )

    return df

In [None]:
def make_combined_df(det_df, finding_df):
    def filter_combined(row):
        prob = 1 - row["preds"]
        row["PredictionString"] += f" 14 {prob} 0 0 1 1"
#         row["PredictionString"] += f" 14 1 0 0 1 1"
        return row

    temp_df = finding_df.copy()
    temp_df["image_id"] = temp_df["image_id_dicom"].str.split(".").str[0]
    det_merged = pd.merge(det_df, temp_df, on="image_id", how="left")
    
    result = det_merged.apply(filter_combined, axis=1)
    return result    

In [None]:
def make_normal_df(image_ids):
    ids = []
    prediction_strings = []

    for image_id in image_ids:
        image_id = image_id.split(".")[0]
        ids.append(image_id)
        pred_string = "14 1 0 0 1 1"
        prediction_strings.append(pred_string)

    df = pd.DataFrame(
        data=(zip(ids, prediction_strings)), columns=["image_id", "PredictionString"]
    )
    return df

# First Models: 2-Stages / CLF + DET

In [None]:
# ----------
# debug mode
# ----------
DEBUG = False

In [None]:
# ----------
# settings
# ----------
pl.seed_everything(0)

batch_size = 16 if not DEBUG else 2
num_workers = 2

dataset_dir = "../input/vinbigdata-chest-xray-abnormalities-detection"

# No finding quantile threshole
# 0.65 means 65% of outputs are "No finding"
PRED_THR = 0.65

In [None]:
# default image size
# b0: 224, b1: 240, b2: 260, b3: 300
# b4: 380, b5: 456, b6: 528, b7: 600, b8: 672

# d0: 512, d1: 640, d2: 768, d3: 896
# d4: 1024, d5: 1280, d6: 1280, d7: 1536

# ----------
# checkpoint
# ----------
clf_checkpoint = []
clf_image_size = [
    456,
    1024,
    600, 600, 600, 600, 600,
    528, 528, 528, 528, 528,
    600, 600, 600,
]

# b5-456
clf_checkpoint.append("../input/vbd-final-checkpoint/b5-456-timm-bn-5folds-0_VIN-384_checkpoints_xray-classifier-epoch034-val_loss0.5986.ckpt")

# b5-1024
clf_checkpoint.append("../input/vbd-final-checkpoint/b5-1024-timm-bn-5folds-0_VIN-410_checkpoints_xray-classifier-epoch041-val_loss0.5980.ckpt")

# b5-600
clf_checkpoint.append("../input/vbd-final-checkpoint/b5-600-timm-bn-5folds-0_VIN-397_checkpoints_xray-classifier-epoch043-val_loss0.5975.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b5-600-timm-bn-5folds-1_VIN-427_checkpoints_xray-classifier-epoch036-val_loss0.5989.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b5-600-timm-bn-5folds-2_VIN-433_checkpoints_xray-classifier-epoch038-val_loss0.5980.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b5-600-timm-bn-5folds-3_VIN-435_checkpoints_xray-classifier-epoch027-val_loss0.5967.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b5-600-timm-bn-5folds-4_VIN-439_checkpoints_xray-classifier-epoch049-val_loss0.5983.ckpt")

# b6-528
clf_checkpoint.append("../input/vbd-final-checkpoint/b6-528-timm-bn-5folds-0_VIN-349_checkpoints_xray-classifier-epoch042-val_loss0.5979.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b6-528-timm-bn-5folds-1_VIN-351_checkpoints_xray-classifier-epoch039-val_loss0.6011.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b6-528-timm-bn-5folds-2_VIN-352_checkpoints_xray-classifier-epoch034-val_loss0.5997.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b6-528-timm-bn-5folds-3_VIN-354_checkpoints_xray-classifier-epoch041-val_loss0.5975.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/b6-528-timm-bn-5folds-4_VIN-355_checkpoints_xray-classifier-epoch042-val_loss0.5996.ckpt")

# resnet200d-600
clf_checkpoint.append("../input/vbd-final-checkpoint/resnet200d-600-timm-bn-5folds-0_VIN-424_checkpoints_xray-classifier-epoch043-val_loss0.5986.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/resnet200d-600-timm-bn-5folds-1_VIN-440_checkpoints_xray-classifier-epoch039-val_loss0.6002.ckpt")
clf_checkpoint.append("../input/vbd-final-checkpoint/resnet200d-600-timm-bn-5folds-3_VIN-442_checkpoints_xray-classifier-epoch049-val_loss0.5969.ckpt")


det_checkpoint = []
det_image_size = [
    1024,
    1024,
    896,
    896, 896, 896, 896, 896,
    1024,
    768, 768, 768, 768, 768,
    896, 896, 896, 896,
]

# d3-1024
det_checkpoint.append("../input/vbd-final-checkpoint/d3-1024-fin-aug-bn-nms-v2-5folds-0_VIN-403_checkpoints_xray-detector-epoch042-val_loss0.7453.ckpt")

# d4-1024
det_checkpoint.append("../input/vbd-final-checkpoint/d4-1024-aug-bn-nms-v2-5folds-0_VIN-325_checkpoints_xray-detector-epoch040-val_loss0.7330.ckpt")

# d4-896 best LB
det_checkpoint.append("../input/vbd-final-checkpoint/d4-896-aug-nms-v2-5folds-0_VIN-269_checkpoints_xray-detector-epoch047-val_loss0.7300.ckpt")

# d4-896
det_checkpoint.append("../input/vbd-final-checkpoint/d4-896-fin-aug-bn-nms-v2-5folds-0_VIN-377_checkpoints_xray-detector-epoch049-val_loss0.7309.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d4-896-fin-aug-bn-nms-v2-5folds-1_VIN-379_checkpoints_xray-detector-epoch039-val_loss0.7414.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d4-896-fin-aug-bn-nms-v2-5folds-2_VIN-382_checkpoints_xray-detector-epoch039-val_loss0.7460.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d4-896-fin-aug-bn-nms-v2-5folds-3_VIN-383_checkpoints_xray-detector-epoch043-val_loss0.7534.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d4-896-fin-aug-bn-nms-v2-5folds-4_VIN-386_checkpoints_xray-detector-epoch046-val_loss0.7629.ckpt")

# d5-1024
det_checkpoint.append("../input/vbd-final-checkpoint/d5-1024-fin-aug-bn-nms-v2-5folds-0_VIN-414_checkpoints_xray-detector-epoch038-val_loss0.7259.ckpt")

# d5-768
det_checkpoint.append("../input/vbd-final-checkpoint/d5-768-aug-bn-nms-v2-5folds-0_VIN-328_checkpoints_xray-detector-epoch047-val_loss0.7264.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d5-768-aug-bn-nms-v2-5folds-1_VIN-420_checkpoints_xray-detector-epoch040-val_loss0.7390.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d5-768-aug-bn-nms-v2-5folds-2_VIN-421_checkpoints_xray-detector-epoch039-val_loss0.7381.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d5-768-aug-bn-nms-v2-5folds-3_VIN-422_checkpoints_xray-detector-epoch049-val_loss0.7481.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d5-768-aug-bn-nms-v2-5folds-4_VIN-423_checkpoints_xray-detector-epoch040-val_loss0.7526.ckpt")

# d5-896
det_checkpoint.append("../input/vbd-final-checkpoint/d5-896-aug-bn-nms-v2-5folds-0_VIN-368_checkpoints_xray-detector-epoch038-val_loss0.7235.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d5-896-aug-bn-nms-v2-5folds-1_VIN-436_checkpoints_xray-detector-epoch034-val_loss0.7374.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d5-896-aug-bn-nms-v2-5folds-2_VIN-438_checkpoints_xray-detector-epoch049-val_loss0.7312.ckpt")
det_checkpoint.append("../input/vbd-final-checkpoint/d5-896-aug-bn-nms-v2-5folds-3_VIN-441_checkpoints_xray-detector-epoch038-val_loss0.7518.ckpt")

In [None]:
# ----------
# device
# ----------
device = (
    torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
print(f"device {device}")

torch.set_grad_enabled(False)

In [None]:
# -------------------------
# Stage 1: Classification
# -------------------------
print("Stage 1: Classification - finding vs no-finding(normal)")
dm_clf = XrayTestEnsembleDataModule(
    dataset_dir=dataset_dir,
    batch_size=batch_size,
    num_workers=num_workers,
    image_size_list=clf_image_size,
)

dm_clf.prepare_data()
dm_clf.setup()

In [None]:
clf_list = []

for ckpt in clf_checkpoint:
    clf_list.append(XrayClassifier.load_from_checkpoint(ckpt, pretrained=False))

for clf in clf_list:
    clf.to(device)
    clf.eval()

In [None]:
clf_df = make_clf_preds_df(
    clf_list, clf_image_size, dm_clf.test_dataloader(), device, debug=DEBUG
)

In [None]:
pred_thres = clf_df.preds.quantile(PRED_THR)

In [None]:
print(f"Finding    prob thr: {pred_thres}")
print(f"No finding prob thr: {1 - pred_thres}")

In [None]:
clf_df[clf_df.preds<=pred_thres].shape

In [None]:
finding_df = clf_df[clf_df.preds > pred_thres]
no_finding_df = clf_df[clf_df.preds <= pred_thres]

In [None]:
# --------------------
# Stage 2: Detection
# --------------------

image_ids = finding_df["image_id_dicom"].tolist()

print("Stage 2: Detection")
dm_det = XrayTestEnsembleDataModule(
    dataset_dir=dataset_dir,
    image_ids=image_ids,
    batch_size=batch_size,
    num_workers=num_workers,
    image_size_list=det_image_size,
)

dm_det.prepare_data()
dm_det.setup()

In [None]:
det_list = []

for ckpt, image_size in zip(det_checkpoint, det_image_size):
    det_list.append(
        XrayDetector.load_from_checkpoint(
            ckpt,
            pretrained=False,
            pretrained_backbone=False,
            image_size=image_size
        )
    )

for det in det_list:
    det.to(device)
    det.eval()

In [None]:
len(det_list)

In [None]:
det_df = make_det_preds_df(
    det_list, det_image_size, dm_det.test_dataloader(), device, debug=DEBUG, method="nms", max_det_per_image=None,
)

In [None]:
det_combined_df = make_combined_df(det_df, finding_df)
det_combined_df.head()

In [None]:
det_combined_df.iloc[0, 1][-110:]

In [None]:
normal_df = make_normal_df(no_finding_df.image_id_dicom.tolist())

In [None]:
sub1_df = pd.concat([det_combined_df[["image_id", "PredictionString"]], normal_df])
sub1_df.head()

# Second Models: 1-Stage / DET for all classes

In [None]:
def make_preds(model, image_size, test_loader, device, debug=False, downscale_factor = 1):
    image_key = "image_" + str(image_size)
    
    image_ids = []
    boxes_preds = []
    scores_preds = []
    labels_preds = []

    for index, (sample, image_id, height_raw, width_raw) in enumerate(tqdm(test_loader)):
        if debug and index > 2:
            break
        image = sample[image_key]
        if debug: print(image_key, image.shape)
        pred = model(image.to(device))
                
        boxes, scores, labels = convert_batch_pred(
            pred,
            height=height_raw,
            width=width_raw,
            resize_height=image_size,
            resize_width=image_size,
        )

        image_ids.extend(image_id)
        boxes_preds.extend(boxes)
        scores_preds.extend(scores)
        labels_preds.extend(labels)

    return image_ids, boxes_preds, scores_preds, labels_preds

In [None]:
def make_preds_ensemble(model, image_size_list, test_loader, device, debug=False, downscale_factor=1, max_det_per_image=None, method="nms"):
    image_key_list = ["image_" + str(x) for x in image_size_list]
    
    image_ids = []
    boxes_preds = []
    scores_preds = []
    labels_preds = []
    
    for index, (sample, image_id, height_raw, width_raw) in enumerate(tqdm(test_loader)):
        if debug and index > 2:
            break

        boxes_list = []
        scores_list = []
        labels_list = []

        height_raw_np = height_raw.detach().cpu().numpy()
        height_raw_np = np.expand_dims(height_raw_np, axis=1)

        width_raw_np = width_raw.detach().cpu().numpy()
        width_raw_np = np.expand_dims(width_raw_np, axis=1)

        boxes_model = None
        scores_model = None
        labels_model = None
        
        for key, m in zip(image_key_list, model):
            image = sample[key]
            image_height = image.shape[2]
            image_width = image.shape[3]
            
            if debug: print(image.shape)
            pred = m(image.to(device))

            boxes, scores, labels = convert_batch_pred(
                pred,
                height=height_raw,
                width=width_raw,
                resize_height=image_height,
                resize_width=image_width,
            )

            # normalize boxes in range 0 to 1
            boxes = boxes.astype(np.float32)

            boxes[:, :, 0] = boxes[:, :, 0] / width_raw_np
            boxes[:, :, 1] = boxes[:, :, 1] / height_raw_np
            boxes[:, :, 2] = boxes[:, :, 2] / width_raw_np
            boxes[:, :, 3] = boxes[:, :, 3] / height_raw_np

            boxes_list.append(boxes)
            scores_list.append(scores)
            labels_list.append(labels)
                
        boxes_model = np.stack(boxes_list, axis=0)
        scores_model = np.stack(scores_list, axis=0)
        labels_model = np.stack(labels_list, axis=0)
        
        batch_size = labels_model.shape[1]

        iou_thr = 0.5
        skip_box_thr = 0.0001

        for index in range(batch_size):
            if method == "nms":
                boxes, scores, labels = nms(
                    boxes_model[:, index, :, :],
                    scores_model[:, index, :],
                    labels_model[:, index, :],
                    weights=None,
                    iou_thr=iou_thr
                )
            elif method == "wbf":
                boxes, scores, labels = weighted_boxes_fusion(
                    boxes_model[:, index, :, :],
                    scores_model[:, index, :],
                    labels_model[:, index, :],
                    weights=None,
                    iou_thr=iou_thr,
                    skip_box_thr=skip_box_thr,
                )
            else:
                raise Exception("method should be 'nms' or 'wbf'")

            if max_det_per_image is not None:
                if boxes.shape[0] > max_det_per_image:
                    ind = np.argsort(scores)[::-1][:max_det_per_image]
                    boxes = boxes[ind, :]
                    scores = scores[ind]
                    labels = labels[ind]
           
            labels = labels.astype(np.int64)

            # transform boxes to target size
            boxes[:, 0] = boxes[:, 0] * width_raw_np[index]
            boxes[:, 1] = boxes[:, 1] * height_raw_np[index]
            boxes[:, 2] = boxes[:, 2] * width_raw_np[index]
            boxes[:, 3] = boxes[:, 3] * height_raw_np[index]

            image_ids.extend([image_id[index]])
            boxes_preds.extend([boxes])
            scores_preds.extend([scores])
            labels_preds.extend([labels])

    return image_ids, boxes_preds, scores_preds, labels_preds

In [None]:
def remove_duplicate_nofinding(image_ids, boxes_preds, scores_preds, labels_preds):
    filtered_boxes_preds = []
    filtered_scores_preds = []
    filtered_labels_preds = []
    
    for boxes, scores, labels in zip(boxes_preds, scores_preds, labels_preds):
        # Find indices of "No finding".
        ind = np.argwhere(labels==14)

        if ind.size != 0:
            # To leave unique one "No finding" of highest confidence score,
            # drop index with maximum confidence score.
            scores_nofinding = scores[ind]
            ind = np.delete(ind, np.argmax(scores_nofinding))

            # Remove duplicate of "No finding"
            boxes = np.delete(boxes, ind, axis=0)
            scores = np.delete(scores, ind, axis=0)
            labels = np.delete(labels, ind, axis=0)
        
        filtered_boxes_preds.append(boxes)
        filtered_scores_preds.append(scores)
        filtered_labels_preds.append(labels)
        
    return image_ids, filtered_boxes_preds, filtered_scores_preds, filtered_labels_preds

In [None]:
def get_nofinding_probs(labels_preds, scores_preds):
    probs = []
    for labels, scores in zip(labels_preds, scores_preds):
        ind = np.argwhere(labels==14)
        prob = scores[ind] if ind.size != 0 else np.array([0.0])
        probs.append(prob)
        
    probs = np.asarray(probs)
    return probs

In [None]:
def print_quantiles(data):
    range_list = [x * .1 for x in range(10)]
    quantiles = np.quantile(data, range_list)

    print("--- quantiles ---")
    for r, q in zip(range_list, quantiles):
        print(f"{r:.2f}: {q:.6f}")
    print("-----------------")

In [None]:
def clear_nofinding_det(image_ids, boxes_preds, scores_preds, labels_preds, prob_thr=1.0):
    filtered_boxes_preds = []
    filtered_scores_preds = []
    filtered_labels_preds = []
    
    for boxes, scores, labels in zip(boxes_preds, scores_preds, labels_preds):
        # Find indices of "No finding".
        ind = np.argwhere(labels==14)

        if ind.size != 0 and scores[ind] > prob_thr:
            try:
                ind = ind.squeeze(axis=0)
            except:
                raise ValueError("Size of ind should be 0 or 1")

            # Delete all "finding" detection of "No finding" with prob > prob_thr
            boxes = boxes[ind]
#             scores = scores[ind]
            scores = np.array([1])
            labels = labels[ind]
        
        filtered_boxes_preds.append(boxes)
        filtered_scores_preds.append(scores)
        filtered_labels_preds.append(labels)
        
    return image_ids, filtered_boxes_preds, filtered_scores_preds, filtered_labels_preds

In [None]:
def convert_nofinding_box(labels: np.ndarray, boxes: np.ndarray) -> np.ndarray:
    bbox_list = []
    for label, bbox in zip(labels, boxes):
        if label == 14:
            bbox = np.array([0, 0, 1, 1])
        bbox_list.append(bbox)
    
    boxes_np = np.asarray(bbox_list)
    return boxes_np

In [None]:
def make_preds_df(model, image_size, test_loader, device, debug=False, method="nms", max_det_per_image=None, quantile_thr=0.4):
    if isinstance(model, list):
        if len(model) > 1:
            image_ids, boxes_preds, scores_preds, labels_preds = make_preds_ensemble(
                model, image_size, test_loader, device, debug, method, max_det_per_image
            )
        else:
            image_ids, boxes_preds, scores_preds, labels_preds = make_preds(
                model[0], image_size[0], test_loader, device, debug
            )
    else:
        image_ids, boxes_preds, scores_preds, labels_preds = make_preds(
            model, image_size, test_loader, device, debug
        )

    image_ids, boxes_preds, scores_preds, labels_preds = remove_duplicate_nofinding(
        image_ids, boxes_preds, scores_preds, labels_preds
    )
        

    ids = []
    pred_string_list = []

    
    nofinding_probs = get_nofinding_probs(labels_preds, scores_preds)
    print_quantiles(nofinding_probs)

    prob_thr = np.quantile(nofinding_probs, quantile_thr)
    print(f"prob_thr: {prob_thr}")

    image_ids, boxes_preds, scores_preds, labels_preds = clear_nofinding_det(
        image_ids, boxes_preds, scores_preds, labels_preds, prob_thr
    )

    # class, confidence, xmin, ymin, xmax, ymax
    for image_id, boxes, scores, labels in zip(
        image_ids, boxes_preds, scores_preds, labels_preds
    ):
        image_id = image_id.split(".")[0]
        ids.append(image_id)

        boxes = convert_nofinding_box(labels, boxes)
        
        pred_string = format_pred(labels, boxes, scores)
        pred_string_list.append(pred_string)

    df = pd.DataFrame(
        data=(zip(ids, pred_string_list)), columns=["image_id", "PredictionString"]
    )

    return df

In [None]:
# ----------
# settings
# ----------
# Thresholds for filtering "No finding"'s detection
QUANTILE_THR = 0.70

# default image size
# d0: 512, d1: 640, d2: 768, d3: 896
# d4: 1024, d5: 1280, d6: 1280, d7: 1536

# ----------
# checkpoint
# ----------
checkpoint = []
image_size_list = [896, 896]
checkpoint.append("../input/vbd-final-checkpoint/d4-896-all-aug-bn-nms-v2-5folds-0_VIN-412_checkpoints_xray-detector-epoch046-val_loss0.3329.ckpt")
checkpoint.append("../input/vbd-final-checkpoint/d4-896-all-aug-bn-nms-v2-5folds-1_VIN-431_checkpoints_xray-detector-epoch046-val_loss0.3230.ckpt")

In [None]:
# --------------------
# Prediction
# --------------------
print("Prediction")
dm = XrayTestEnsembleDataModule(
    dataset_dir=dataset_dir,
    batch_size=batch_size,
    num_workers=num_workers,
    image_size_list=image_size_list,
)

dm.prepare_data()
dm.setup()

In [None]:
models = []

for ckpt, image_size in zip(checkpoint, image_size_list):
    models.append(
        XrayDetector.load_from_checkpoint(
            ckpt,
            pretrained=False,
            pretrained_backbone=False,
            image_size=image_size
        )
    )

for m in models:
    m.to(device)
    m.eval()

In [None]:
len(models)

In [None]:
sub2_df = make_preds_df(
    models,
    image_size_list,
    dm.test_dataloader(),
    device,
    debug=DEBUG,
    method="nms",
    max_det_per_image=None,
    quantile_thr=QUANTILE_THR,
)

# Ensemble

In [None]:
def get_pred_dict(df):
    df_dict = {}

    for img_id, pred_str in df.itertuples(index=False):
        preds = pred_str.split()
        max_len = len(preds)

        labels = []
        for k in range(0, max_len, 6):
            labels.append(preds[k])

        scores = []
        for k in range(1, max_len, 6):
            scores.append(preds[k])

        boxes = []
        for k in range(2, max_len, 6):
            boxes.append(preds[k:k+4])

        df_dict[img_id] = dict(
            labels=np.asarray(labels, dtype=np.float),
            scores=np.asarray(scores, dtype=np.float),
            boxes=np.asarray(boxes, dtype=np.float),
        )
    return df_dict

In [None]:
def make_ensemble_pred_dict(*args, iou_thr=0.5, weights=None, norm_factor=10_000):
    pred_dict = {}

    for img_id in args[0].keys():
        labels_list = []
        scores_list = []
        boxes_list = []

        for df_dict in args:
            labels_list.append(df_dict[img_id]["labels"])
            scores_list.append(df_dict[img_id]["scores"])
            boxes_list.append(df_dict[img_id]["boxes"] / norm_factor)

        boxes, scores, labels = nms(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr)
        boxes *= norm_factor
        boxes = boxes.astype(np.int)
        labels = labels.astype(np.int)

        pred_dict[img_id] = dict(
            labels=labels,
            scores=scores,
            boxes=boxes,
        )
    return pred_dict

In [None]:
def make_submission_df(pred_dict):
    sub_df = pd.DataFrame(pred_dict.keys(), columns=["image_id"])
    sub_df["PredictionString"] = ""

    for img_id in pred_dict.keys():
        pred_str = ''

        labels = pred_dict[img_id]["labels"]
        scores = pred_dict[img_id]["scores"]
        boxes = pred_dict[img_id]["boxes"]

        for label, score, box in zip(labels, scores, boxes):
            pred_str += str(label) + " " + str(score) + " " + ' '.join(map(str, box)) + " "

        pred_str = pred_str.strip()
        sub_df["PredictionString"][sub_df.image_id == img_id] = pred_str
    return sub_df

In [None]:
sub1_dict = get_pred_dict(sub1_df)
sub2_dict = get_pred_dict(sub2_df)

In [None]:
ensemble_pred_dict = make_ensemble_pred_dict(sub1_dict, sub2_dict)

In [None]:
submission_df = make_submission_df(ensemble_pred_dict)

In [None]:
submission_df.head()

In [None]:
submission_df.to_csv("submission.csv", index=False)