In [None]:
!pip install timm effdet

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
EVAL_CKPTS = True

CKPTS_v = [
    '../input/vinbigdata-effdet-d2-f0f2-ckpts/F1_E79_ModelX_v4_T0.325_V0.410.ckpt', 
    '../input/vinbigdata-effdet-d2-f0f2-ckpts/F2_E82_ModelX_v4_T0.321_V0.409.ckpt',
]

In [None]:
import os, sys
import glob
import pickle
from collections import OrderedDict, namedtuple, deque
from copy import deepcopy
import colorsys



import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import h5py
from tqdm import tqdm

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2


import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler

from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain, DetBenchPredict

In [None]:
# Data directories list
DS_DIR_v = [
    "../input/vinbigdata-chest-xray-abnormalities-detection"]

for DS_DIR in DS_DIR_v:
    if os.path.exists(DS_DIR):
        print(f' DS_DIR Found: "{DS_DIR}"')
        break
else:
    raise Exception(' Dataset not found.')

In [None]:
DS_PATH = '../input/train-test-ds-bbox-cache'

In [None]:
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Validation Functions 
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #


class FastSMA():
    def __init__(
        self,
        iterator=None,
        maxlen=1000,
        label='mean = ',
        print_format='0.02f',
        save_filename='loss_trn.fsma'):
        
        
        assert type(maxlen) == int and maxlen > 0, 'ERROR length must be a positive int.'
        self.maxlen = maxlen
        self.label = label
        self.print_format = print_format
        self.save_filename = save_filename
        self.clear()

        if iterator is not None:
            if type(next(iter(iterator))) in [list, tuple, np.ndarray]:
                for i in iterator:
                    self.append(*i)
            else:
                for i in iterator:
                    self.append(i)
            
        return None

    def append(self, v, i_step=None):        
        self.cumsum += v
        
        if i_step is None:
            i_step = self.step_hist_v[-1] + 1 if len(self.step_hist_v)  else 0
        
        if len(self.history) == self.maxlen:
            self.cumsum -= self.history[0]
            self.last_mean = self.cumsum / self.maxlen
        else:
            self.last_mean = self.cumsum / (len(self.history) + 1) 

        self.history.append( v )

        self.sma_hist_v.append(self.last_mean)
        self.val_hist_v.append(v)
        self.step_hist_v.append( i_step )
            
        return None

    def mean(self):
        return self.last_mean

    def clear(self):
        self.history = deque(maxlen=self.maxlen)
        self.cumsum = 0.0
        self.last_mean = 0.0
        self.sma_hist_v = []
        self.val_hist_v = []
        self.step_hist_v = []
        
        
        self.key2save_v = [
            'history',
            'cumsum',
            'last_mean',
            'sma_hist_v',
            'val_hist_v',
            'step_hist_v',
        ]
        return None
        
    def get_sma_history(self):
        return np.array(self.step_hist_v), np.array(self.sma_hist_v)
    
    def get_val_history(self):
        return np.array(self.step_hist_v), np.array(self.val_hist_v)

    def __str__(self):
        return ('{}{:'+self.print_format+'}').format(self.label, self.mean())

    def __repr__(self):
        return self.__str__()
    
    
    def save(self, filename=None):
        if filename is None:
            filename = self.save_filename
            
        to_save_d = {}
        for key2save in self.key2save_v:
            to_save_d[key2save] = deepcopy( getattr(self, key2save) )
        
        with open(filename, 'wb') as f:
            f.write( pickle.dumps(to_save_d) )
                
        print(f' Saved complete: "{filename}"')
            
        return None

    def load(self, filename=None):
        if filename is None:
            filename = self.save_filename
            
        with open(filename, 'rb') as f:
            data = f.read()
            
        to_save_d = pickle.loads(data)
        
        
        for k, v in to_save_d.items():
            if k in self.key2save_v:
                setattr(self, k, v)


        self.maxlen = self.history.maxlen
        print(f' Restored: "{filename}"')
        
        return self
    
    
    def plot(self, label=None, do_show=False):
        x, y = self.get_sma_history()
        plt.plot(x, y, label=label)
        plt.grid()
        
        if do_show:
            plt.show()
        
        return None


    def plot_sma(self, label='', step=500):
        if label == '':
            label = os.path.split( self.save_filename )[-1].replace('.fsma', '')
            
        step_v, loss_v = self.get_val_history()
        
        m = loss_v.shape[0] % step
        loss_v = loss_v[m:].reshape(-1, step).mean(axis=-1)
        step_v = step_v[m:].reshape(-1, step)[:,-1]

        plt.plot(step_v, loss_v, label=label)

        return None
    

def calc_iou(bb0, bb1):
    if (len(bb0.shape) == 2):
        bb0 = bb0.T
        
    if (len(bb1.shape) == 2):
        bb1 = bb1.T
        

    bb0_x0, bb0_y0, bb0_x1, bb0_y1 = bb0
    bb1_x0, bb1_y0, bb1_x1, bb1_y1 = bb1
    
    assert (bb0_x0 < bb0_x1).all()
    assert (bb0_y0 < bb0_y1).all()
    assert (bb1_x0 < bb1_x1).all()
    assert (bb1_y0 < bb1_y1).all()

    # determine the coordinates of the intersection rectangle
    x_left   = np.maximum(bb0_x0, bb1_x0)
    y_top    = np.maximum(bb0_y0, bb1_y0)
    x_right  = np.minimum(bb0_x1, bb1_x1)
    y_bottom = np.minimum(bb0_y1, bb1_y1)

#     if (x_right < x_left).all(axis=0) or (y_bottom < y_top).all(axis=0):
#         return np.zeros( out_dim )
    
    ret_mask = ~( (x_right < x_left) + (y_bottom < y_top) )

    # The intersection of two axis-aligned bounding boxes is always an
    # axis-aligned bounding box
    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    # compute the area of both AABBs
    bb0_area = (bb0_x1 - bb0_x0) * (bb0_y1 - bb0_y0)
    bb1_area = (bb1_x1 - bb1_x0) * (bb1_y1 - bb1_y0)
    
    iou = intersection_area / (bb0_area + bb1_area - intersection_area)
    
    
    return iou * ret_mask


def join_preds(bbox_v, p_det_v=None, mode='p_det_weight'):
    
    if p_det_v is None:
        p_det_v = np.ones(bbox_v.shape[0])
        
    if mode == 'p_det_weight':
        p_v = ( p_det_v / p_det_v.sum() )[:,None]
        
        bbox = (bbox_v * p_v).sum(axis=0)
        p = p_det_v.mean()
        
    elif mode == 'p_det_max':
        i_max = p_det_v.argmax()
        
        bbox = bbox_v[i_max]
        p    = p_det_v[i_max]
    
    elif mode == 'random':
        i_max = np.random.randint(0, p_det_v.shape[0])
        
        bbox = bbox_v[i_max]
        p    = p_det_v[i_max]
        
    else:
        raise Exception(f'Unknown mode "{mode}"')
        
        
    return bbox, p
    

def clean_predictions(preds_v, iou_th=0.1, mode='p_det_weight', consensus_level=1):
    ret_preds_v = []
    for pred_d in preds_v:
        
        cls_v = pred_d['cls']
        
        if 'bbox' in pred_d.keys():
            bbox_key = 'bbox'
        else:
            bbox_key = 'bboxes'
            
        bbox_v = pred_d[bbox_key]
        
        if 'p_det' in pred_d.keys():
            ret_p_det = True
            p_det_v = pred_d['p_det']
        else:
            ret_p_det = False
            p_det_v = np.ones(pred_d['cls'].shape)
        
        
        if 'rad_id' in pred_d.keys():
            ret_rad_id = True
            rad_id_v = pred_d['rad_id']
        else:
            ret_rad_id = False
            
            
        new_cls_v = []
        new_bbox_v = []
        new_p_det_v = []
        new_rad_id_v = []
        for i_c in np.unique(cls_v):
            f_c = (cls_v == i_c)
            
            n_c = f_c.sum()
            if n_c == 1:
                if consensus_level > 1 and i_c != -1:
                    continue
                    
                    
                if ret_rad_id:
                    if i_c == -1:
                        n_rads = rad_id_v.size
                        
                        if n_rads < consensus_level:
                            continue
                            
                        else:
                            if n_rads > 1:
                                new_rad_id_v.append( np.concatenate(rad_id_v, axis=-1) )
                            else:
                                new_rad_id_v.append( rad_id_v[f_c][0] )
                    else:
                        new_rad_id_v.append( rad_id_v[f_c][0] )
                    

                new_cls_v.append( i_c )
                new_bbox_v.append( bbox_v[f_c][0] )
                new_p_det_v.append( p_det_v[f_c][0] )
                
                
                
            else:
                f_cls_v = cls_v[f_c]
                f_bbox_v = bbox_v[f_c]
                f_p_det_v = p_det_v[f_c]
                if ret_rad_id:
                    f_rad_id_v = rad_id_v[f_c]
                    
                to_join_idxs_v = []
                for i in range(0, n_c):
                    idxs_s = set( np.argwhere( calc_iou(f_bbox_v[i], f_bbox_v) > iou_th ).T[0] )
                    
#                     print(idxs_s)
                    for i in range(len(to_join_idxs_v)):
                        if len( idxs_s.intersection(to_join_idxs_v[i]) ) > 0:
                            to_join_idxs_v[i] = to_join_idxs_v[i].union(idxs_s)
                            break
                            
                    else:
                        to_join_idxs_v.append(idxs_s)
                    
                for to_join_idxs in to_join_idxs_v:
                    to_join_idxs = list(to_join_idxs)
                    
                    if len(to_join_idxs) < consensus_level:
                        continue
                        
                    bbox, p_det = join_preds(
                        f_bbox_v[to_join_idxs],
                        f_p_det_v[to_join_idxs],
                        mode=mode,
                    )
                    
                    new_cls_v.append( i_c )
                    new_bbox_v.append( bbox )
                    new_p_det_v.append( p_det )
                    
                    if ret_rad_id:
                        new_rad_id_v.append( np.concatenate(f_rad_id_v[to_join_idxs], axis=-1))
        
        ret_preds_d = {
            'cls': np.array(new_cls_v),
            bbox_key: np.array(new_bbox_v),
        }
        
        if ret_p_det:
            ret_preds_d['p_det'] = np.array(new_p_det_v)
            
        if ret_rad_id:
            ret_preds_d['rad_id'] = np.array(new_rad_id_v)
            
        for k in pred_d.keys():
            if k not in ['cls', bbox_key, 'p_det', 'rad_id']:
                ret_preds_d[k] = pred_d[k]
        
        ret_preds_v.append(ret_preds_d)
    
    return ret_preds_v


def evalueate_dataset(
    ds,
    model,
    det_th=0.25,
    unscale_bboxes=True,
    batch_size=16,
    num_workers=8,
    pin_memory=True,
    do_clean_predictions=True,
    clean_iou_th=0.10,
    clean_mode='p_det_weight',
):
    
    ds_iter = tqdm(DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=batch_merge))
    
    
    ret_preds_v = []
    for data in ds_iter:
        pred_v = model.predict(
            data,
            det_th=det_th,
            unscale_bboxes=unscale_bboxes)
        
        if do_clean_predictions:
            pred_v =  clean_predictions(
                pred_v,
                iou_th=clean_iou_th,
                mode=clean_mode)
            
        for i_s, pred_d in enumerate(pred_v):
            pred_d['sample_id']      = data['sample_id'][i_s]
            pred_d['original_shape'] = data['original_shape'][i_s]
            
            
        ret_preds_v.extend(pred_v)

    return ret_preds_v
    
    
    
def pred_to_str(pred_d):
    cls_v = pred_d['cls']
    bbox_v = pred_d['bbox']
    p_det_v = pred_d['p_det']
    
    if len(cls_v) == 0:
        ret_s = '14 1 0 0 1 1'
    
    else:
        s_v = []
        for cls, p_det, bbox in zip(cls_v.astype(np.int), p_det_v, np.round(bbox_v).astype(np.int)):
            s = '{} {:0.05} {} {} {} {}'.format(
                int(cls),
                p_det,
                *bbox
            )
            
            s_v.append(s)
            
        ret_s = ' '.join(s_v)
    
    return ret_s


def predictions_to_df(
    preds_v,
    save_path=None,
):
    pred_summary_d = {
        'image_id':[],
        'PredictionString':[]
    }
    
    for pred_d in preds_v:
        pred_str = pred_to_str(pred_d)
        pred_summary_d['image_id'].append( pred_d['sample_id'] )
        pred_summary_d['PredictionString'].append( pred_str )
        
    pred_summary_df = pd.DataFrame(pred_summary_d)
    
    if save_path is not None:
        pred_summary_df.to_csv(
            save_path,
            index=None)
        
        print(f' Saved submission: "{save_path}"')
        
    return pred_summary_df




def voc_ap(recall, precision, use_07_metric=False):
    """
    Reference:
        https://github.com/wang-tf/pascal_voc_tools/blob/master/pascal_voc_tools/Evaluater/tools.py
    
    ap = voc_ap(recall, precision, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses  the
    VOC 07 11 point method (default: False).
    Please make shure that recall and precison are sorted by scores.
    Args:
        recall: the shape of (n,) ndarray;
        precision: the shape of (n,) ndarray;
        use_07_metric: if true, the 11 points method will be used.
    Returns:
        the float number result of average precision.
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(recall >= t) == 0:
                p = 0
            else:
                p = np.max(precision[recall >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], recall, [1.]))
        mpre = np.concatenate(([0.], precision, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def compute_overlaps(boxes, one_box):
    """
    Reference:
        https://github.com/wang-tf/pascal_voc_tools/blob/master/pascal_voc_tools/Evaluater/tools.py
        
    iou = compute_overlaps(boxes, one_box)
    compute intersection over union of ndarray.
    The format of one_box is [xmin, ymin, xmax, ymax].
    Args:
        boxes: the (n, 4) shape ndarray, ground truth boundboxes;
        bb: the (4,) shape ndarray, detected boundboxes;
    Returns:
        a (n, ) shape ndarray.
    """
    # compute overlaps
    # intersection
    ixmin = np.maximum(boxes[:, 0], one_box[0])
    iymin = np.maximum(boxes[:, 1], one_box[1])
    ixmax = np.minimum(boxes[:, 2], one_box[2])
    iymax = np.minimum(boxes[:, 3], one_box[3])
    iw = np.maximum(ixmax - ixmin + 1., 0.)
    ih = np.maximum(iymax - iymin + 1., 0.)
    inters = iw * ih

    # union
    boxes_area = (boxes[:, 2] - boxes[:, 0] + 1.) * (boxes[:, 3] -
                                                     boxes[:, 1] + 1.)
    one_box_area = (one_box[2] - one_box[0] + 1.) * (one_box[3] - one_box[1] +
                                                     1.)
    iou = inters / (one_box_area + boxes_area - inters)

    return iou


def voc_eval(class_recs: dict,
             detect: dict,
             iou_thresh: float = 0.5,
             use_07_metric: bool = False):
    """
    Reference:
        https://github.com/wang-tf/pascal_voc_tools/blob/master/pascal_voc_tools/Evaluater/tools.py
        
    recall, precision, ap = voc_eval(class_recs, detection,
                                [iou_thresh],
                                [use_07_metric])
    Top level function that does the PASCAL VOC evaluation.
    Please make sure that the class_recs only have one class annotations.
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    Args:
        class_recalls: recalls dict of a class
            class_recs[image_name]={'bbox': []}.
        detection: Path to annotations
            detection={'image_ids':[], bbox': [], 'confidence':[]}.
        [iou_thresh]: Overlap threshold (default = 0.5)
        [use_07_metric]: Whether to use VOC07's 11 point AP computation
            (default False)
    Returns:
        a dict of result including true_positive_number, false_positive_number,
        recall, precision and average_precision.
    Raises:
        TypeError: the data format is not np.ndarray.
    """
    # format data
    # class_rec data load
    npos = 0
    for imagename in class_recs.keys():
        if not isinstance(class_recs[imagename]['bbox'], np.ndarray):
            raise TypeError
        detected_num = class_recs[imagename]['bbox'].shape[0]
        npos += detected_num
        class_recs[imagename]['det'] = [False] * detected_num

    # detections data load
    image_ids = detect['image_ids']
    confidence = detect['confidence']
    BB = detect['bbox']
    if not isinstance(confidence, np.ndarray):
        raise TypeError
    if not isinstance(BB, np.ndarray):
        raise TypeError
    
    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    BB = BB[sorted_ind]
    image_ids = [image_ids[x] for x in sorted_ind]

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        iou_max = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            overlaps = compute_overlaps(BBGT, bb)
            iou_max = np.max(overlaps)
            iou_max_index = np.argmax(overlaps)

        if iou_max > iou_thresh:
            if not R['det'][iou_max_index]:
                tp[d] = 1.
                R['det'][iou_max_index] = 1
            else:
                fp[d] = 1.
        else:
            fp[d] = 1.

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    true_positive_number = tp[-1] if len(tp) > 0 else 0
    false_positive_number = fp[-1] if len(fp) > 0 else 0

    recall = tp / np.maximum(float(npos), np.finfo(np.float64).eps)
    # avoid divide by zero in case the first detection matches a difficult ground truth
    precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    average_precision = voc_ap(recall, precision, use_07_metric)

    result = {}
    result['true_positive_number'] = true_positive_number
    result['false_positive_number'] = false_positive_number
    result['positive_number'] = npos
    result['recall'] = recall
    result['precision'] = precision
    result['average_precision'] = average_precision
    return result



def make_class_detection_d(preds_v, n_classes=15):
    detection_v = [ {
        'image_ids': [],
        'bbox': [],
        'confidence': [],
    } for i in range(n_classes)]
    
    
    for pred_d in preds_v:
        sample_id = pred_d['sample_id']
        if len(pred_d['bbox']) > 0:
            for bbox, p_det, cls in zip(pred_d['bbox'], pred_d['p_det'], pred_d['cls']):
                detection_v[cls]['image_ids'].append(sample_id) 
                detection_v[cls]['bbox'].append(bbox) 
                detection_v[cls]['confidence'].append(p_det) 
        else:
            bbox = np.array([0.0, 0.0, 1.0, 1.0])
            p_det = 1.0
            cls = n_classes-1
            
            detection_v[cls]['image_ids'].append(sample_id) 
            detection_v[cls]['bbox'].append(bbox) 
            detection_v[cls]['confidence'].append(p_det) 
            
            
    for cls in range(len(detection_v)):
        for k in detection_v[cls].keys():            
            detection_v[cls][k] = np.array(detection_v[cls][k])
    
    return detection_v


def make_class_gt_d(gt_df, n_classes=15):
    
    class_gt_v = []
    for class_id in range(n_classes):
        class_gt_d = {}
        for k in gt_df.image_id.unique():
            class_gt_d[k] = {'bbox': []}
        
        class_gt_v.append( class_gt_d )
    
    for image_id, class_id, x_min, y_min, x_max, y_max in gt_df[['image_id', 'class_id', 'x_min', 'y_min', 'x_max', 'y_max']].values:
        bbox = (x_min, y_min, x_max, y_max)
        
        class_gt_v[class_id][image_id]['bbox'].append(bbox)
        
    for cls in range(len(class_gt_v)):
        for k in class_gt_v[cls].keys():
            class_gt_v[cls][k]['bbox'] = np.array(class_gt_v[cls][k]['bbox'])
            
    return class_gt_v
        

        
def calc_metrics(preds_v, gt_df, iou_thresh=0.4, show_summay=True, n_classes=15, use_07_metric=False):
    class_recs_v = make_class_gt_d(gt_df, n_classes=n_classes)
    detect_v     = make_class_detection_d(preds_v, n_classes=n_classes)
    
    mAP = 0.0
    metrics_v = []
    for i_class in range(n_classes):
        class_val_d = voc_eval(
            class_recs_v[i_class],
             detect_v[i_class],
             iou_thresh=iou_thresh,
             use_07_metric=use_07_metric)
        
        metrics_v.append(class_val_d)
        
        mAP += class_val_d['average_precision']
    
    mAP /= n_classes
    
    summary_df = pd.DataFrame(
        columns=['Cls', 'TP', 'FP', 'P', 'Prec', 'Rec', 'AP']
    )
    for i_c, m_d in enumerate(metrics_v):
        summary_df = summary_df.append({
            'Cls': i_c,
            'TP': m_d['true_positive_number'],
            'FP': m_d['false_positive_number'],
            'P': m_d['positive_number'],
            'Prec': m_d['precision'][-1] if len(m_d['precision']) > 0 else 0.0,
            'Rec': m_d['recall'][-1] if len(m_d['recall']) > 0 else 0.0,
            'AP': m_d['average_precision'],
        },
        ignore_index=True,)
        
    summary_df = summary_df.astype({
            'Cls': np.int,
            'TP': np.int,
            'FP': np.int,
            'P': np.int,
            'Prec': np.float32,
            'Rec': np.float32,
            'AP': np.float32,
        })

    summary_df = summary_df.set_index( 'Cls' )

    if show_summay:
        print(' Summary:')
        print( summary_df )
        print(f'mAP = {mAP:0.04f}')
    
    return metrics_v, summary_df, mAP




def plot_PvsR_curve(metrics_v, class2color_v):
    
    plt.figure(0, figsize=(20,10))
    for i_c, metrics_d in enumerate(metrics_v):
        plt.plot(
            metrics_d['recall'],
            metrics_d['precision'],
            label=f'cls={i_c}',
            c=np.array(class2color_v[i_c])/255 )


    plt.xlabel('recall')
    plt.ylabel('precision')

    plt.xlim( (0,1) )
    plt.ylim( (0,1) )

    plt.legend()
    plt.show()
    
    return None
    
    
    
def filter_det_th(preds_v, det_th=0.15):
    new_pred_v = []
    for pred_d in preds_v:
        
        if len(pred_d['p_det']) > 0:
            if type(det_th) is float:
                det_th_v = det_th
            else:
                det_th_v = np.zeros(pred_d['p_det'].shape[0])

                for i_c, i_cls in enumerate(pred_d['cls']):
                    det_th_v[i_c] = det_th[i_cls]
                    
            f = pred_d['p_det'] >= det_th_v
        else:
            f = None
            
        new_pred_d = {}
        for k in pred_d.keys():
            if (f is not None) and (k in ['cls', 'p_det', 'bbox']):
                new_pred_d[k] = pred_d[k][f]
                
            else:
                new_pred_d[k] = pred_d[k]
        
        new_pred_v.append(new_pred_d)
        
    return new_pred_v

In [None]:
class2str_v = [
    'Aortic enlargement',
    'Atelectasis',
    'Calcification',
    'Cardiomegaly',
    'Consolidation',
    'ILD',
    'Infiltration',
    'Lung Opacity',
    'Nodule/Mass',
    'Other lesion',
    'Pleural effusion',
    'Pleural thickening',
    'Pneumothorax',
    'Pulmonary fibrosis',
    'Nothing',
]


class2color_v = [
    tuple(round(i * 255) for i in colorsys.hsv_to_rgb(i_c/len(class2str_v), 1, 1))
    for i_c in range(len(class2str_v))
]

In [None]:
def save_obj(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)
    print(f'Saved: {filename}')
    return None

def load_obj(filename):
    with open(filename, 'rb') as f:
        obj = pickle.load(f)
    print(f'Loaded: {filename}')
    return obj

In [None]:
def read_dicom_image(
    path,
    voi_lut=True,
    fix_monochrome=True,
    do_norm=True):

    # Original from: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
    dicom = pydicom.read_file(path)

    # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to 
    # "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array

    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.max(data) - data

    if do_norm:
        max_value = (2 ** dicom.BitsStored) - 1
        data = data / max_value

        assert (data.max() <= 1.0) and (data.min() >= 0.0), f'Normalization ERROR in file: "{path}"'

#     if do_norm:
#         max_val = np.max(data)
#         min_val = np.min(data)
#         data = (data - min_val) / (max_val- min_val)

    return data.astype(np.float32)

In [None]:
def calc_image_features(image):

    img_exp  = image ** 3.0
    img_uint = (255 * image).astype(np.uint8)
    img_equ  = cv2.equalizeHist(img_uint)
    img_edge = cv2.Canny(img_equ, 50, 130)

    img_ret = np.concatenate(
        [
            image[:,:,None],
            img_exp[:,:,None],
            img_equ[:,:,None].astype(np.float32)  / 255,
            img_edge[:,:,None].astype(np.float32) / 255,
            ],
        axis=-1)

    return img_ret

In [None]:
class FoldDataset():
    def __init__(
        self,
        ds_path='./',
        ds_name='test',
        images_dir='./test_images',
        
        model_resolution=(512, 512),
        df_path=None,
        
        mode='none',
        
        i_fold=0,
        n_folds=5,
        test_split=0.1,
        
        do_augmentation=True,
        
        downsample_factor=2,
        remove_classes_v=None,
        select_classes_v=None,
        show_warnings=True,
        
        do_random_shuffle=True,
        random_seed=3128,
        sample_bbox_and_class=False,
        clean_boxes=True,
        clean_mode='random',
        clean_iou_th=0.1,
        
        use_img_cache=False,
    ):
        
        
        self.ds_path = ds_path
        self.ds_name = ds_name
        
        self.model_resolution = model_resolution
        self.bboxes_df_path   = df_path
        self.i_fold           = i_fold
        self.n_folds          = n_folds
        self.do_augmentation  = do_augmentation
        
        self.images_dir        = images_dir
        self.downsample_factor = downsample_factor
        self.remove_classes_v  = remove_classes_v
        self.select_classes_v  = select_classes_v
        self.show_warnings     = show_warnings
        self.test_split        = test_split
        
        self.do_random_shuffle = do_random_shuffle
        self.random_seed       = random_seed
        
        self.sample_bbox_and_class = sample_bbox_and_class
        
        self.clean_boxes = clean_boxes
        self.clean_mode = clean_mode
        self.clean_iou_th = clean_iou_th
        
        
        self.use_img_cache = use_img_cache
        
        
        if (self.select_classes_v is not None) and (14 in self.select_classes_v):
            self.select_classes_v.remove(14)
            self.select_classes_v.append(-1)
                    
        
        self.mode = mode.lower()
        posible_modes_v = ['cv_trn', 'cv_val', 'cv_tst', 'none']
        assert self.mode in posible_modes_v, f'Parameter "mode" must be in: {posible_modes_v}'
        
        if mode == 'cv_trn':
            self.do_CV  = True
            self.iter_trn = True
            self.iter_val = False
            self.iter_tst = False
            
        elif mode == 'cv_val':
            self.do_CV  = True
            self.iter_trn = False
            self.iter_val = True
            self.iter_tst = False
            
        elif mode == 'cv_tst':
            self.do_CV  = True
            self.iter_trn = False
            self.iter_val = False
            self.iter_tst = True
        
        elif mode == 'none':
            self.do_CV  = False
            self.iter_trn = False
            self.iter_val = False
            self.iter_tst = False
        
        
        
        self.h5_path = os.path.join(self.ds_path, f'{self.ds_name}.h5df')
        self.misc_path = os.path.join(self.ds_path, f'{self.ds_name}.pickle')
        self.cache_path = os.path.join(self.ds_path, f'{self.ds_name}.cache')
        
        self.f_h5 = None
        
        if self.use_img_cache:
            if not os.path.exists(self.h5_path) or not os.path.exists(self.misc_path):
                self._gen_images_ds()

            self.open_h5_file()
            
        self.read_misc_d()
        
        self.bbox_df = None
        
        self._set_all_samples_ids()
        
        self._build_albumentations()
        
        if self.bboxes_df_path is not None:
            if not os.path.exists(self.cache_path):
                self.read_bboxes_df()
                self._save_ds_state()
                
            else:
                self._load_ds_state()
                
            
        self.update_fold_filter()
        
        return None
    
    def _save_ds_state(self):    
        to_save_d = {
            'bbox_df': self.bbox_df,
            'class_to_sample_v': self.class_to_sample_v,
            'bbox_d': self.bbox_d,
        }
        
        save_obj(to_save_d, self.cache_path)
        
        return None
    
    
    def _load_ds_state(self):
        state_d = load_obj( self.cache_path )
        
        for k, v in state_d.items():
            self.__setattr__(k, v)
        
        return None
    
    def _set_all_samples_ids(self):
        self.all_sample_ids = np.array( sorted( self.misc_d.keys() ) )
        
        if self.bbox_df is not None:
            assert ( self.all_sample_ids == np.array(sorted(self.bbox_df['image_id'].unique()) ) ).all()
        
        
        if self.do_random_shuffle:
            np.random.seed(self.random_seed)
            
        return None
    
    
    def _build_bbox_d(self):
        self.bbox_d  = {}
        self.class_to_sample_v = [ [] for i in range( self.bbox_df.class_id.max() + 1)]

        sample_it = tqdm(self.all_sample_ids, desc='Building bboxes')
        for s_id in sample_it:
            sample_bboxes_df = self.bbox_df[self.bbox_df.image_id == s_id]

            cls = sample_bboxes_df['class_id'].values
            bboxes = sample_bboxes_df[['x_min', 'y_min', 'x_max', 'y_max']].values
            rad_id = sample_bboxes_df[ ['rad_id'] ].values
            
            for i_c in np.unique(cls):
                self.class_to_sample_v[i_c].append( s_id )
                
            # Class 14 filter
            f_14 = (cls==14)
            if f_14.any():
                i_14 = np.argmax(f_14)
                
                orig_img_h, orig_img_w = self.misc_d[s_id]
                bboxes[i_14] = np.array([0, 0, orig_img_w, orig_img_h])
                cls[i_14]    = -1  # if I add 1 will be the class  0 (backfround for effdet)
                
                f_14 = ~f_14
                f_14[i_14] = True
                bboxes = bboxes[f_14]
                cls    = cls[f_14]
                
            
            self.bbox_d[s_id] = {
                'bboxes': bboxes,
                'cls':    cls,
                'rad_id': rad_id,
            }

        for i_c in range(len(self.class_to_sample_v)):
            self.class_to_sample_v[i_c] = np.array(self.class_to_sample_v[i_c])

        return None



    def read_bboxes_df(self):
        self.bbox_df = pd.read_csv(
            self.bboxes_df_path,
            dtype={
                'x_min':np.float32,
                'y_min':np.float32,
                'x_max':np.float32,
                'y_max':np.float32,
            })
        
#         self.bbox_df.fillna(0s.0)
        
        self._build_bbox_d()
        return None
        
        
    def update_fold_filter(self):
        assert self.i_fold >= 0 and self.i_fold < self.n_folds, 'ERROR, incorrect i_fold.'
        
        # Apply sample class filtering
        if self.select_classes_v is None:
            self.selected_sample_ids = self.all_sample_ids
            
        else:
            self.selected_sample_ids = set()
            for i_c in self.select_classes_v:
                self.selected_sample_ids = self.selected_sample_ids.union( self.class_to_sample_v[i_c] )
            
            self.selected_sample_ids = np.array( sorted(self.selected_sample_ids) )
            
            
        if self.remove_classes_v is not None:
            self.selected_sample_ids = set(self.selected_sample_ids)
            
            for i_c in self.remove_classes_v:
                self.selected_sample_ids = self.selected_sample_ids.difference( self.class_to_sample_v[i_c] )
            
            self.selected_sample_ids = np.array( sorted(self.selected_sample_ids) )
        # # # # # # # # # # # # # # # # #
        
        n_samples = len( self.selected_sample_ids )
        n_samples_cv = int( (1.0-self.test_split) * n_samples)
        
        if self.do_CV:
            if self.iter_tst:
                f_samples = np.zeros(n_samples, dtype=np.bool)
                f_samples[n_samples_cv:] = True
                
            else:
                n_samples_per_fold = n_samples_cv // self.n_folds
                f_samples = np.zeros(n_samples, dtype=np.bool)

                if self.i_fold < self.n_folds - 1:
                    f_samples[ self.i_fold * n_samples_per_fold: (self.i_fold+1) * n_samples_per_fold] = True
                else:
                    f_samples[ self.i_fold * n_samples_per_fold: n_samples_cv] = True

                if self.iter_trn:
                    f_samples[:n_samples_cv]  = ~(f_samples[:n_samples_cv])
                    
        else:
            f_samples = np.ones(n_samples, dtype=np.bool)
            
        self.fold_sample_filter = f_samples 
        self.fold_samples = self.selected_sample_ids[self.fold_sample_filter]

        return None
    
    
    def get_image(self, sample_id):
        
        if self.use_img_cache:
            image = self.f_h5[sample_id][:]
        
        else:
            file_path = os.path.join(self.images_dir, f'{sample_id}.dicom')
            image, image_shape = self.read_and_downsample_dicom_image(
                file_path
            )
        
        return image
    
    def __getitem__(self, index):
        
        if index < 0:
            index = self.__len__() + index
        
        sample_id = self.fold_samples[index]
        
        image = self.get_image(sample_id)
        
        original_shape = self.misc_d[sample_id]
        
        img_h, img_w = image.shape
        orig_img_h, orig_img_w = original_shape
        
        scale_h = img_h / orig_img_h
        scale_w = img_w / orig_img_w
        
        image = calc_image_features(image)
        
        sample = {
            'sample_id': sample_id,
            'image': image,
            'original_shape': self.misc_d[sample_id],
        }
        
        
        if self.bbox_df is not None:
            # The images are already resized.
            # Format at this point is (x0,y0,x1,y1)
            scale_v = np.array([scale_w, scale_h, scale_w, scale_h])
            sample['bboxes'] = self.bbox_d[sample_id]['bboxes'] * scale_v
            sample['cls'] = self.bbox_d[sample_id]['cls']
            
            if self.select_classes_v is not None:
                f = np.zeros(sample['cls'].shape[0], dtype=np.bool)
                    
                for i in self.select_classes_v:
                    f[sample['cls'] == i] = True

                sample['bboxes'] = sample['bboxes'][f]
                sample['cls']    = sample['cls'][f]            
            
            if self.sample_bbox_and_class:
                idx = np.random.randint(0, sample['cls'].shape[0])
                
                (x0, y0, x1, y1) = sample['bboxes'][idx].astype(np.int)
                
                d_x = np.random.randint(0, (x1-x0)//4 )
                d_y = np.random.randint(0, (y1-y0)//4 )
                x0 = max(0, x0-d_x)
                y0 = max(0, y0-d_y)
                
                x1 = min(sample['image'].shape[1], x1+d_x)
                y1 = min(sample['image'].shape[0], y1+d_y)
                
                sample['image'] = sample['image'][y0:y1, x0:x1]
                sample['cls'] = sample['cls'][idx:idx+1]
                sample['bboxes'] = np.array( [ (0, 0, sample['image'].shape[1], sample['image'].shape[0]) ] )
                
            if self.clean_boxes:
                to_clean_d = {'bbox': sample['bboxes'], 'cls':sample['cls']}
                cleaned_d = clean_predictions(
                    [to_clean_d],
                    iou_th=self.clean_iou_th,
                    mode=self.clean_mode)[0]
                
                sample['bboxes'] = cleaned_d['bbox']
                sample['cls'] = cleaned_d['cls']
                
                
                
        else:
            sample['bboxes'] = np.array( [ (0, 0, image.shape[1], image.shape[0]) ] )
            sample['cls']    = np.array( [-1] )
        
        

        if self.do_augmentation:
            for i_try in range(10):
                try:
                    transform_sample = self.TR_trn(**sample)
                    if len(transform_sample['bboxes']) > 0:
                        # Updating sample
                        sample = transform_sample
                        break

                except Exception as e:
                    pass

            else:
                if self.show_warnings:
                    print(f' - WARNING: Imposible to Augmentate image, idx={index}.', file=sys.stderr)
                    
                # doing a basic transform
                sample = self.TR_val(**sample)
                
        else:
            sample = self.TR_val(**sample)
                
        
        # Format (x0,y0,x1,y1) to (y0,x0,y1,x1) for EffDet
        sample['bboxes']     = torch.tensor(sample['bboxes'],     dtype=torch.float32)[:, [1,0,3,2]]  
        sample['cls']        = torch.tensor(sample['cls'],        dtype=torch.int64)
#         sample['extra']      = torch.tensor([sample['s_norm'], sample['a_norm'], sample['d_norm']], dtype=torch.float32).T
        
        
        return sample
    
    
    def open_h5_file(self, mode='r'):
        if self.f_h5 is not None:
            self.close_file()
            
        if mode == 'w':
            if os.path.exists(self.h5_path):
                print(f' The file "{self.h5_path}" already exists, if you continue the file will be deleted. Contunue (y/n) ?')
                r = input()
                if r.lower() != 'y':
                    print('Operation aborted.')
                    sys.exit()
            
        self.f_h5 = h5py.File(self.h5_path, mode)
        return None
    
    
    def close_file(self):
        if self.f_h5 is not None:
            self.f_h5.close()
            
        return None
    
    
    def read_misc_d(self):
        self.misc_d = load_obj( self.misc_path )
        return None
        
        
    def save_misc_d(self, warn=True):
        if warn and self.misc_d:
            if os.path.exists(self.misc_d):
                print(f' The file "{self.misc_d}" already exists, if you continue the file will be deleted. Contunue (y/n) ?')
                r = input()
                if r.lower() != 'y':
                    print('Operation aborted.')
                    sys.exit()
            
        
        save_obj(self.misc_d, self.misc_path)
        return None
    
    
    def downsample_img(
        self,
        image,
        downsample_factor=2):
        
        new_dims_v = (
            image.shape[1] // downsample_factor,
            image.shape[0] // downsample_factor )
        
        image_rs = cv2.resize(
            image,
            new_dims_v
        )
        
        return image_rs
    
    
    def read_and_downsample_dicom_image(self, file_path):
        image = read_dicom_image(file_path)
        image_shape = image.shape

        image_rs = self.downsample_img(
            image,
            downsample_factor=self.downsample_factor)
        
        return image_rs, image_shape
        
        
    def _gen_images_ds(self):
        all_files_v = glob.glob(os.path.join(self.images_dir, '*.dicom') )
        
        self.open_h5_file(mode='w')
        self.misc_d = {}
        
        file_it = tqdm( all_files_v )
        for file_path in file_it:
            file_id = os.path.splitext(os.path.basename(file_path))[0]
            file_it.set_description(file_id)
            
            self.read_dicom_image(file_path)
            
            image_rs, image_shape = self.read_and_downsample_dicom_image(file_path)
                        
            self.f_h5.create_dataset(file_id, data=image_rs)
            self.misc_d[file_id] = image_shape
        
        self.close_file()
        self.save_misc_d(warn=False)
        
        return None
    
    
    def __len__(self):
        return len(self.fold_samples)
    
    
    def _build_albumentations(self):
        self.TR_trn = A.Compose(
            [
#                 A.RandomSizedCrop(
#                     min_max_height=[int(0.7*self.image_resolution[0]), int(1.0*self.image_resolution[0])],
#                     height=int(0.9*self.image_resolution[0]),
#                     width=int(0.9*self.image_resolution[1]), 
#                     p=0.5),
                
                
                A.Crop(
                    x_min=128,
                    y_min=128,
                    x_max=128,
                    y_max=128,
                    p=0.5),

#                     A.OneOf([
#                         A.HueSaturationValue(
#                             hue_shift_limit=0.2,
#                             sat_shift_limit= 0.2,
#                             val_shift_limit=0.2,
#                             p=0.9),

#                         A.RandomBrightnessContrast(
#                             brightness_limit=0.2, 
#                             contrast_limit=0.2,
#                             p=0.9),

#                     ],
#                         p=0.9),

#                     A.ToGray(p=0.01),

                A.HorizontalFlip(p=0.5),

    #             A.VerticalFlip(p=0.5),

    #             A.RandomRotate90(p=0.5),

                A.Rotate(
                    limit=15,
                    p=0.6,
                ),

#                     A.Transpose(p=0.5),

                A.Blur(blur_limit=3, p=0.6),

#                     A.OneOf([
#                         A.Blur(blur_limit=3, p=1.0),
#                         A.MedianBlur(blur_limit=3, p=1.0)
#                     ],
#                         p=0.1),

                A.Cutout(
                    num_holes=20,
                    max_h_size=64,
                    max_w_size=64,
                    fill_value=0,
                    p=0.5),

                A.Resize(
                    height=self.model_resolution[0],
                    width=self.model_resolution[1],
                    p=1.0),

                ToTensorV2(p=1.0),
            ],

            p=1.0, 
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0, 
                min_visibility=0,
                label_fields=['cls']
            )
            )
        
        if self.sample_bbox_and_class:
            self.TR_trn = A.Compose(
            [

                A.HorizontalFlip(p=0.5),

                A.Rotate(
                    limit=15,
                    p=0.6,
                ),

#                     A.Transpose(p=0.5),

                A.Blur(blur_limit=3, p=0.6),

                A.Cutout(
                    num_holes=10,
                    max_h_size=5,
                    max_w_size=5,
                    fill_value=0,
                    p=0.5),

                A.Resize(
                    height=self.model_resolution[0],
                    width=self.model_resolution[1],
                    p=1.0),

                ToTensorV2(p=1.0),
            ],

            p=1.0, 
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0, 
                min_visibility=0,
                label_fields=['cls']
            )
            )
            
        
        
        self.TR_val = A.Compose(
            [
                A.Resize(
                    height=self.model_resolution[0],
                    width=self.model_resolution[1],
                    p=1.0),

                ToTensorV2(p=1.0),
            ], 

            p=1.0, 
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0, 
                min_visibility=0,
                label_fields=['cls']
            )
        )

        
        return None

    
    def get_GT_Dataframe(self, merge_mode='mean'):
        """ Reeturns the fold GT Dataframe"""
        gt_df = self.bbox_df[self.bbox_df.image_id.isin(self.fold_samples)].copy()
        gt_df.fillna(0, inplace=True)
        gt_df.loc[gt_df["class_id"] == 14, ['x_max', 'y_max']] = 1.0
        
        
        
        gb = gt_df.groupby(['image_id', 'class_id'])

        if merge_mode == 'mean':
            gt_df = gb.agg(lambda x: ' '.join(np.unique(x.values)) if type(x.values[0]) is str else x.values.mean() ).reset_index()

        elif merge_mode == 'first':
            gt_df = gb.first().reset_index()

        elif merge_mode is None:
            pass

        else:
            raise Exception(f'Mode "{merge_mode}" unknown.')

        return gt_df


ds = FoldDataset(
    ds_path='../input/train-test-ds-bbox-cache/',
    ds_name='train',
    images_dir=os.path.join(DS_DIR, 'train'),
    model_resolution=(512, 512),
    df_path=os.path.join(DS_DIR, 'train.csv'),
    
    mode='cv_trn',
    i_fold=1,
    n_folds=5,
    test_split=0.1,
    
    do_augmentation=True,
    downsample_factor=2,
    remove_classes_v=[14],
    select_classes_v=None,
    show_warnings=False,
    random_seed=3128,
)

In [None]:
class ModelX(nn.Module):
    def __init__(
        self,
        model_resolution=(768, 512), 
        n_input_channels=4,
        n_classes=14,
        n_extras=0,
        extra_loss_weight=1.0,

        init_lr=1e-4,
        optimizer_name='adam',
        clip_grad_norm=5.0,
        weight_decay=0.0,
        
        use_pretrained_model=True,
        backbone_name='tf_efficientdet_d4',
        
        checkpoint_base_path='./model_checkpoint',
        model_name=' ModelX_v1',
        device=None,
        ):
        
        super().__init__()
        
        self.model_resolution = np.array(model_resolution) 
        self.n_input_channels = n_input_channels
        self.n_classes        = n_classes
        self.n_extras         = n_extras
        self.extra_loss_weight = extra_loss_weight
        
        self.lr               = init_lr
        self.optimizer_name   = optimizer_name
        self.clip_grad_norm   = clip_grad_norm
        self.weight_decay     = weight_decay
        
        self.use_pretrained_model = use_pretrained_model
        self.backbone_name        = backbone_name.lower()
        
        self.checkpoint_base_path = checkpoint_base_path
        self.model_name           = model_name
        
        
#         self.resolution_scale = self.image_resolution / self.model_resolution
#         self.boxes_scale = np.array(
#             [
#                 self.resolution_scale[1],
#                 self.resolution_scale[0],
#                 self.resolution_scale[1],
#                 self.resolution_scale[0],
#             ],
#             np.float32
#         )
        
        print(f'New Model: "{self.model_name}"')
        
        # Seting device
        self.set_device(device)
        
        # Creating output dir
        if not os.path.exists(self.checkpoint_base_path):
            print(f'Creating save dir: "{self.checkpoint_base_path}"')
            os.makedirs(self.checkpoint_base_path)        
        
        
        # Building Backbone
        self._build_backbone()
        
        
        # Building Optimizers
        self.build_optimizer()
        
        
        # Moving model to device
        self.to(self.device)
        
        # Model Summary
        self.calc_total_weights()
        
        return None
    
    
    @torch.jit.ignore
    def build_optimizer(self, params_v=None):
        if params_v is None:
            params_v = self.parameters()
        
        if self.optimizer_name.lower() == 'adam':
            self.optimizer = optim.Adam(
                params_v,
                lr=self.lr,
                weight_decay=self.weight_decay,
            )
            
        else:
            raise Exception(f'Un implemented optimizer: {self.optimizer_name}')
        
        return self.optimizer
    
    

    @torch.jit.ignore
    def _build_backbone(self):
        
        self.backbone_config = get_efficientdet_config(self.backbone_name)
        self.backbone_config.num_classes = self.n_classes
        self.backbone_config.image_size  = tuple([int(i) for i in self.model_resolution])
        self.backbone_config.extra_variables = self.n_extras
        self.backbone_config.extra_loss_weight = self.extra_loss_weight
        
        self.backbone = EfficientDet(
            self.backbone_config,
            pretrained_backbone=self.use_pretrained_model)
        
        first_conv = self.backbone.backbone.conv_stem

        # Updating first ConvHead
        self.backbone.backbone.conv_stem = nn.Conv2d(
            self.n_input_channels,
            first_conv.out_channels,
            kernel_size=first_conv.kernel_size,
            stride=first_conv.stride,
            padding=first_conv.padding,
            bias=False,
        )
        
        # Deleting unused conv
        del(first_conv)
        
        self.net_labeler_train   = DetBenchTrain(self.backbone)
        self.net_labeler_predict = DetBenchPredict(self.backbone)
            
        return None
    
    
    
    def forward(self, x):
        return self.backbone(x)
        
        
    @torch.jit.ignore
    def save_checkpoint(
        self,
        step=0,
        loss=None,
        file_name='weights.ckpt',
        verbose=True,
        ):
        
        
        base_path = self.checkpoint_base_path
        
        PATH = os.path.join(base_path, file_name)
        
        torch.save({
            'step': step,
            'model_state_dict':     self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict() if self.optimizer is not None else None,
            'loss': loss,
            }, PATH)
        
        if verbose:
            print(f' Saved checkpoint: {PATH}.')
        return None
    
    @torch.jit.ignore
    def restore_checkpoint(self, PATH, verbose=True):
        checkpoint = torch.load(
            PATH,
            map_location=self.device,)
        
        load_optimizer = True
        saved_state_dict   = checkpoint['model_state_dict']
        current_state_dict = self.state_dict()
        new_state_dict = OrderedDict()
        for key in current_state_dict.keys():
            if (key in saved_state_dict.keys()) and (saved_state_dict[key].shape == current_state_dict[key].shape):
                new_state_dict[key] = saved_state_dict[key]

            else:
                load_optimizer = False
                if key not in saved_state_dict.keys():
                    print(f' - WARNING: key="{key}" not found in saved checkpoint.\n   Weights will not be loaded.', file=sys.stderr)
                else:
                    s0 = tuple(saved_state_dict[key].shape)
                    s1 = tuple(current_state_dict[key].shape)
                    print(f' - WARNING: shapes mismatch in "{key}": {s0} vs {s1}.\n   Weights will not be loaded.', file=sys.stderr)
                new_state_dict[key] = current_state_dict[key]
        
        self.load_state_dict( new_state_dict )
        
        if self.optimizer is not None:
            if load_optimizer:
                try:
                    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                except Exception as e:
                    print(' - WARNING: ERROR while loading the optimizer. The optimizer will be reseted.', file=sys.stderr)
                    self.build_optimizer()
            else:
                print(' - WARNING: Optimizer will not be loaded.', file=sys.stderr)
        
        if verbose:
            print(f' Restored checkpoint: {PATH}.')
        
        return checkpoint 
    
    @torch.jit.ignore
    def calc_total_weights(self, verbose=True):
        n_w = 0
        for p in self.parameters():
            n_w += np.prod(p.shape, dtype=np.int)
        
        if verbose:
            print(' - Total weights: {:0.02f}M'.format(n_w/1e6))
        
        return n_w
    
    
    @torch.jit.export
    def flip(self, tensor, dim=1):
        """ Just flip a tensro dim."""
        fliped_idx    = torch.arange(tensor.size(dim)-1, -1, -1).long().to(self.device)
        fliped_tensor = tensor.index_select(dim, fliped_idx)
        return fliped_tensor
    
    @torch.jit.ignore
    def set_device(self, device=None):
        if device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        elif type(device) == str:
            self.device = torch.device( device )
        else:
            raise Exception('Not implemented')
            
        print(f' - Selecting device: {self.device}')
        return self.device
    
    
    @torch.jit.ignore
    def get_bboxes(self, outputs, det_th, data=None, unscale_bboxes=False):
        preds_v = []
        
        if data is not None:
            assert len(data['sample_id']) == outputs.shape[0]
            
        for i_sample, img_pred in enumerate(outputs.detach().cpu().numpy()):
            
            bboxes, p_det, idx_class = np.split(img_pred, [4,5], axis=-1)
            f_det = p_det[:,0] > det_th
            
            if unscale_bboxes:
                model_img_h, model_img_w = self.model_resolution
                orig_img_h, orig_img_w = data['original_shape'][i_sample]
                
                scale_h = model_img_h / orig_img_h
                scale_w = model_img_w / orig_img_w
                
                # x0, y0, x1, y1
                scale_v = np.array([scale_w, scale_h, scale_w, scale_h])
                
                bboxes = bboxes / scale_v
            

            preds_v.append(
                {
                    'bbox': bboxes[f_det],
                    'p_det': p_det[f_det, 0],
                    'cls':  idx_class[f_det, 0].astype(int) - 1, # We must substract 1 to the class number
                })

        return preds_v

    

    @torch.jit.ignore
    def predict(self, data, det_th=0.4, output_losses=False, training=False, filter_boxes=True, unscale_bboxes=False):
        
        if training:
            self.train()
            torch.set_grad_enabled(True)
            
        else:
            self.eval()
            torch.set_grad_enabled(False)
            
            
        images = data['image'].to(self.device)
        
        if output_losses:
            target_d = {
                'bbox': [x.to(self.device) for x in data['bboxes']],
                'cls':  [x.to(self.device) + 1 for x in data['cls']],   # We must sum 1 to the class number
#                 'img_scale': None,
#                 'img_size': None,
            }
            
            if self.n_extras > 0:
                target_d['extra'] = [x.to(self.device) for x in data['extra']]
                
            # dict with: 'loss', 'class_loss', 'box_loss'
            outputs = self.net_labeler_train.forward(
                images,
                target_d
            )
            
            if not training and filter_boxes:
                outputs['detections'] = self.get_bboxes(outputs['detections'], det_th, data, unscale_bboxes)
            
        else:
            outputs = self.net_labeler_predict.forward(
                images,
            )
            
            if filter_boxes:
                outputs = self.get_bboxes(outputs, det_th, data, unscale_bboxes)
        
        return outputs

    
    @torch.jit.ignore
    def train_step(self, data):
        
        outputs = self.predict(data=data, output_losses=True, training=True)
        loss = outputs['loss']
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        
        if self.clip_grad_norm > 0.0:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(),
                self.clip_grad_norm)
            
        self.optimizer.step()
        
#         trn_batch_loss = loss.item()

        return outputs


    @torch.jit.ignore
    def put_boxes(self, preds_v, data, do_plot=False):
        BS = len( preds_v )

        assert BS == data['image'].shape[0], 'Wrong Batch Size'


        box_image_v = []
        for i_b in range(BS):

            pred_d = preds_v[i_b]

            image = data['image'][i_b].cpu().numpy().transpose([1,2,0])
            box_image = (image * 255)


            if len(box_image.shape) == 3 and box_image.shape[-1] > 3:
                box_image = box_image[:,:,:3]

            elif len(box_image.shape) == 3:
                box_image = box_image[:,:,0]

            if len(box_image.shape) == 2:
                box_image = box_image[:,:, None] * np.ones(3)

            box_image = box_image.copy().astype(np.uint8)

            
            for bbox, idx_class in zip( data['bboxes'][i_b].numpy(), data['cls'][i_b].numpy()):
                (y0, x0, y1, x1) = bbox.astype(np.int)
                
                box_image = cv2.rectangle(
                    box_image,
                    (x0,y0),
                    (x1,y1),
                    class2color_v[idx_class],
                    thickness=2,
                )

#                 box_image = cv2.putText(
#                     box_image,
#                     class2str_v[idx_class] + '(GT)',
#                     org=(x0, y0-3),
#                     fontFace=cv2.FONT_HERSHEY_SIMPLEX,
#                     fontScale=0.3,
#                     color=class2color_v[idx_class],
#                     thickness=1,
#                     lineType=cv2.LINE_AA,
#                     bottomLeftOrigin=False
#                     )      

            
            for bbox, idx_class, p_det in zip( pred_d['bbox'], pred_d['cls'], pred_d['p_det']):

                (x0, y0, x1, y1) = bbox.astype(np.int)

                box_image = cv2.rectangle(
                    box_image,
                    (x0,y0),
                    (x1,y1),
                    class2color_v[idx_class],
                    thickness=1,
                    )

                box_image = cv2.putText(
                    box_image,
                    class2str_v[idx_class] + f'({p_det:0.1f})',
                    org=(x0, y0-3),
                    fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=0.3,
                    color=class2color_v[idx_class],
                    thickness=1,
                    lineType=cv2.LINE_AA,
                    bottomLeftOrigin=False
                    )
            

                
                

            box_image_v.append(box_image)

            if do_plot:
                plt.figure(0, figsize=(20,20) )
                plt.imshow(box_image)
                plt.show()

        return np.array( box_image_v )

# Building Model

In [None]:
MODEL_D = 2
N_FOLDS = 5
I_FOLD  = 1
MODEL_RESOLUTION = (768, 512)

In [None]:
model_cfg_d = {
    'model_resolution': MODEL_RESOLUTION, 
    'n_input_channels': 4,
    
    'n_classes': 14,
    'n_extras': 0,
    'extra_loss_weight':6.0,
    
    'use_pretrained_model':True,
    'backbone_name':f'tf_efficientdet_d{MODEL_D}',
    
    'init_lr': 1e-4,
    'optimizer_name': 'adam',
    'clip_grad_norm':3.0,
    'weight_decay': 1e-5,
    'model_name':'ModelX_v4',
    'checkpoint_base_path':f'./EffDet_d{MODEL_D}_F{I_FOLD}_v5',
    
    'device': None,
}

In [None]:

model  = ModelX(**model_cfg_d)

N_HIST = 1000

loss_names_v = ['loss', 'box_loss', 'class_loss'] #, 'extra_loss']

trn_fsma_d = {}
for k in loss_names_v:
    trn_fsma_d[k] = FastSMA(
        maxlen=N_HIST,
        label='mean = ',
        print_format='0.02f',
        save_filename=os.path.join(model.checkpoint_base_path, f'{k}_trn.fsma')
    )

    
val_fsma_d = {}
for k in loss_names_v:
    val_fsma_d[k] = FastSMA(
        maxlen=N_HIST,
        label='mean = ',
        print_format='0.02f',
        save_filename=os.path.join(model.checkpoint_base_path, f'{k}_val.fsma')
    )

# Train only Detection Heads

In [None]:
trainable_params_v = sum([
    list( model.backbone.backbone.conv_stem.parameters() ),
    list( model.backbone.class_net.parameters() ),
    list( model.backbone.box_net.parameters() ),
#     list( model.backbone.extra_net.parameters() ),
    list( model.backbone.fpn.parameters() ),
], [])

n_w = 0
for p in trainable_params_v:
    n_w += np.prod(p.shape)
    
print(f' Total trainable parameters: {n_w/1e6:0.02f}M')

_ = model.build_optimizer(trainable_params_v)


# Loading Model

In [None]:
_ = model.restore_checkpoint(
    '../input/vinbigdata-effdet-d2-f0f2-ckpts/F1_E79_ModelX_v4_T0.325_V0.410.ckpt'
)

try:
    for k, fsma in trn_fsma_d.items():
        fsma.load()

    for k, fsma in val_fsma_d.items():
        fsma.load()
except Exception:
    print(' WARNING: FastSMA not loaded.')
    pass

# Training Functions

In [None]:
def batch_merge(samples_v):
    ret_batch_d = {k : [] for k in samples_v[0].keys()}
    
    for sample_d in samples_v:
        for k, v in sample_d.items():
            ret_batch_d[k].append( v )
    
    ret_batch_d['image'] = torch.stack(ret_batch_d['image'])
    return ret_batch_d


def train_one_epoch(
    model,
    i_fold,
    i_epoch,
    ds_trn,
    trn_fsma_d,
    batch_size=1,
    shuffle=True,
    num_workers=16,
    pin_memory=False):
    
    trn_iter = tqdm(DataLoader(
        ds_trn,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=batch_merge))
    
    # Train Step
    epoch_losses_d = {k:[] for k in trn_fsma_d.keys()}
    for i_step, trn_data in enumerate(trn_iter):
        trn_loss_d = model.train_step(trn_data)
        
        for k, fsma in trn_fsma_d.items():
            l = trn_loss_d[k].item()
            fsma.append(l)
            epoch_losses_d[k].append(l)
            
            
        trn_iter.set_description('[TRN_F={:d}E={:d}_L={:0.03f}_B={:0.03f}_C={:0.03f}_X={:0.03f}]'.format(
            i_fold,
            i_epoch,
            trn_fsma_d['loss'].mean(),
            trn_fsma_d['box_loss'].mean(),
            trn_fsma_d['class_loss'].mean(),
            trn_fsma_d['extra_loss'].mean() if 'extra_loss' in trn_fsma_d.keys() else 0.0,
            ) )
        
    
    for k in epoch_losses_d.keys():
        epoch_losses_d[k] = np.mean(epoch_losses_d[k])
        
    return epoch_losses_d


def validate_one_epoch(
    model,
    i_fold,
    i_epoch,
    ds_val,
    val_fsma_d,
    batch_size=1,
    shuffle=False,
    num_workers=16,
    pin_memory=False,
    
    det_th=0.15,
    clear_predictions=True,
    clear_iou_th=0.10,
    clear_mode='p_det_weight',
):
    
    gt_df = ds_val.get_GT_Dataframe()
    
    val_iter = tqdm(DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=batch_merge))

    # Validation step
    epoch_losses_d = {k:[] for k in val_fsma_d.keys()}
    
    all_preds_v = []
    for i_step, val_data in enumerate(val_iter):
        val_loss_d = model.predict(
            val_data,
            output_losses=True,
            training=False,
            det_th=det_th,
            filter_boxes=True,
            unscale_bboxes=True)
        
        
        # all_preds_v calculation
        pred_v =  clean_predictions(
            val_loss_d['detections'],
            iou_th=clear_iou_th,
            mode=clear_mode)
        
        
        if clear_predictions:
            pred_v =  clean_predictions(
                pred_v,
                iou_th=clear_iou_th,
                mode=clear_mode)
            
        for i_s, pred_d in enumerate(pred_v):
            pred_d['sample_id']      = val_data['sample_id'][i_s]
            pred_d['original_shape'] = val_data['original_shape'][i_s]
            
        all_preds_v += pred_v
        # all_preds_v calculation
            
            
        for k, fsma in val_fsma_d.items():
            l = val_loss_d[k].item()
            fsma.append(l)
            epoch_losses_d[k].append(l)
            
                
        val_iter.set_description('[VAL_F={:d}E={:d}_L={:0.03f}_B={:0.03f}_C={:0.03f}_X={:0.03f}]'.format(
            i_fold,
            i_epoch,
            val_fsma_d['loss'].mean(),
            val_fsma_d['box_loss'].mean(),
            val_fsma_d['class_loss'].mean(),
            val_fsma_d['extra_loss'].mean() if 'extra_loss' in val_fsma_d.keys() else 0.0,
            ) )
    
    for k in epoch_losses_d.keys():
        epoch_losses_d[k] = np.mean(epoch_losses_d[k])
    
    try:
        metrics_v, summary_df, mAP = calc_metrics(
            all_preds_v,
            gt_df,
            iou_thresh=0.4,
            show_summay=True)


        plot_PvsR_curve(metrics_v, class2color_v)
        
    except:
        print('Problems with evaluation')
    
    
    return epoch_losses_d



def save_model(
    model,
    trn_fsma_d,
    val_fsma_d,
    i_epoch,
    i_fold):
        
    f_sma_trn = trn_fsma_d['loss']
    f_sma_val = val_fsma_d['loss']
    
    model_filename = f'F{i_fold:}_E{i_epoch:}_{model.model_name}_T{f_sma_trn.mean():0.03f}_V{f_sma_val.mean():0.03f}.ckpt'
    model.save_checkpoint(
        step={'i_epoch':i_epoch, 'i_fold': i_fold},
        loss={'trn':f_sma_trn.mean(), 'val':f_sma_val.mean()},
        file_name=model_filename,
        verbose=True,
    )
    
    for k, fsma in trn_fsma_d.items():
        fsma.save()
    
    for k, fsma in val_fsma_d.items():
        fsma.save()
    
    return None

# Loading datasets

In [None]:
def load_fold_ds(i_fold, N_FOLDS=5, SELECT_CLASSES=None, DS_PATH='./'):
    ds_trn = FoldDataset(
        ds_path=DS_PATH,
        ds_name='train',
        images_dir=os.path.join(DS_DIR, 'train'),
        model_resolution=MODEL_RESOLUTION,
        df_path=os.path.join(DS_DIR, 'train.csv'),
        
        mode='cv_trn',
        i_fold=i_fold,
        n_folds=N_FOLDS,
        test_split=0.1,
        
        do_augmentation=True,
        downsample_factor=2,
        remove_classes_v=[],
        select_classes_v=SELECT_CLASSES,
        show_warnings=False,
        do_random_shuffle=True,
        random_seed=3128,
        
        clean_boxes=True,
        clean_mode='random',
        clean_iou_th=0.1,
    )

    ds_val = FoldDataset(
        ds_path=DS_PATH,
        ds_name='train',
        images_dir=os.path.join(DS_DIR, 'train'),
        model_resolution=MODEL_RESOLUTION,
        df_path=os.path.join(DS_DIR, 'train.csv'),
        
        mode='cv_val',
        i_fold=i_fold,
        n_folds=N_FOLDS,
        test_split=0.1,
        
        do_augmentation=False,
        downsample_factor=2,
        remove_classes_v=[],
        select_classes_v=SELECT_CLASSES,
        show_warnings=False,
        do_random_shuffle=True,
        random_seed=3128,
        
        clean_boxes=True,
        clean_mode='random',
        clean_iou_th=0.1,
    )
    

    ds_tst_oof = FoldDataset(
        ds_path=DS_PATH,
        ds_name='train',
        images_dir=os.path.join(DS_DIR, 'train'),
        model_resolution=MODEL_RESOLUTION,
        df_path=os.path.join(DS_DIR, 'train.csv'),

        mode='cv_tst',
        i_fold=I_FOLD,
        n_folds=N_FOLDS,
        test_split=0.1,

        do_augmentation=False,
        downsample_factor=2,
        remove_classes_v=[],
        select_classes_v=SELECT_CLASSES,
        show_warnings=False,
        do_random_shuffle=False,
        random_seed=3128,
    )
        
    ds_tst = FoldDataset(
        ds_path=DS_PATH,
        ds_name='test',
        images_dir=os.path.join(DS_DIR, 'test'),
        model_resolution=MODEL_RESOLUTION,
        df_path=None,

        mode='none',
        i_fold=I_FOLD,
        n_folds=N_FOLDS,
        test_split=0.1,

        do_augmentation=False,
        downsample_factor=2,
        remove_classes_v=[],
        select_classes_v=None,
        show_warnings=False,
        do_random_shuffle=False,
        random_seed=3128,
    )


    return ds_trn, ds_val, ds_tst, ds_tst_oof

# Training

In [None]:
if not EVAL_CKPTS:
    n_epochs = 100
    N_WORKERS = 4
    BATCH_SIZE = 5
    PIN_MEMORY = True

    SELECT_CLASSES = None #[2, 14]

    ds_trn, ds_val, ds_tst, ds_tst_oof = load_fold_ds(
        i_fold=I_FOLD,
        N_FOLDS=5,
        SELECT_CLASSES=None,
        DS_PATH=DS_PATH
    )

    for i_fold in range(I_FOLD, I_FOLD+1):
        ds_trn, ds_val = load_fold_ds(i_fold, N_FOLD, SELECT_CLASSES)
    #     sys.exit(0)

        for i_epoch in range(3, n_epochs):
            print(f'Starting new epoch: Epoch = {i_epoch}  Fold = {i_fold}')
            trn_loss_epoch_d = train_one_epoch(
                model,
                i_fold,
                i_epoch,
                ds_trn,
                trn_fsma_d,
                batch_size=BATCH_SIZE,
                shuffle=True,
                num_workers=N_WORKERS,
                pin_memory=PIN_MEMORY)

            val_loss_epoch_d = validate_one_epoch(
                model,
                i_fold,
                i_epoch,
                ds_val,
                val_fsma_d,
                batch_size=BATCH_SIZE,
                shuffle=False,
                num_workers=N_WORKERS,
                pin_memory=PIN_MEMORY)

            print(f' Epoch {i_epoch} Summary:')
            print(' - trn_loss_epoch_d:')
            for k, v in trn_loss_epoch_d.items():
                print(f'  |-> {k} = {v:0.04f}')

            print()
            print(' - val_loss_epoch_d:')
            for k, v in val_loss_epoch_d.items():
                print(f'  |-> {k} = {v:0.04f}')

            print()
            save_model(
                model,
                trn_fsma_d,
                val_fsma_d,
                i_epoch,
                i_fold)

            print()

# Inference

In [None]:
if EVAL_CKPTS:
    N_WORKERS = 4
    
    for ckpt_path in CKPTS_v:
        I_FOLD = int(os.path.split(ckpt_path)[1].split('_')[0][1])
        
        _ = model.restore_checkpoint(ckpt_path)
        
        ds_trn, ds_val, ds_tst, ds_tst_oof = load_fold_ds(
            i_fold=I_FOLD,
            N_FOLDS=5,
            SELECT_CLASSES=None,
            DS_PATH=DS_PATH
        )
            
        
        all_preds_v = evalueate_dataset(
            ds_tst,
            model,
            det_th=0.00,
            unscale_bboxes=True,
            batch_size=16,
            num_workers=N_WORKERS,
            pin_memory=True,
            do_clean_predictions=False,
            clean_iou_th=0.1,
            clean_mode='p_det_weight',
        )

        preds_df = predictions_to_df(
            all_preds_v,
            f'ds_tst_F{I_FOLD}_noTH_noClean.csv')
