In [None]:
!pip install timm effdet

In [None]:
import os, sys
import glob
import pickle
from collections import OrderedDict, namedtuple, deque
from copy import deepcopy
import copy
import colorsys

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import h5py
from tqdm import tqdm

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2


import torch
from torch import nn, optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader

from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain, DetBenchPredict



import torch
import torchvision
import torch.nn.functional as F
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import BackboneWithFPN


from torch import nn, optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader

# sys.path.append('./yolov5')
# from utils.loss import ComputeLoss
# from utils.general import non_max_suppression

from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain, DetBenchPredict


In [None]:
# # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # #  Global Configs   # # # # # # # # 
# # # # # # # # # # # # # # # # # # # # # # # #

GLOBAL_N_WARMUP_EPOCHS = 5
GLOBAL_N_DECAY_EPOCHS = 60
GLOBAL_N_EPOCHS = GLOBAL_N_WARMUP_EPOCHS + GLOBAL_N_DECAY_EPOCHS + 15 + 50

GLOBAL_CONTINUE_TRAINING = False

GLOBAL_LR = 1e-4

GLOBAL_GRAD_STEPS = 5
GLOBAL_BATCH_SIZE = 3 #4
GLOBAL_N_WORKERS = 5
GLOBAL_DEVICE = None # uses cuda:0 if it is available

GLOBAL_I_FOLD = 0
GLOBAL_CLS = None

GLOBAL_MODEL_RESOLUTION = (1024, 768) 


EVAL_CKPTS = True # True=Inference,  False=Training

USE_NIH = False
DOWNSAMPLE_FACTOR = 2

CKPTS_v = [
    '../input/effdet-d2-v2-ckpts/F0_E62_ModelX_V19_T0.344_V0.381.ckpt', 
    '../input/effdet-d2-v2-ckpts/F0_E74_ModelX_V19_T0.348_V0.381.ckpt',
    '../input/effdet-d2-v2-ckpts/F0_E82_ModelX_V19_T0.361_V0.383.ckpt',
]

# # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # #

In [None]:
TRAIN_DS_NAME = 'train_nih' if USE_NIH else 'train'


# Data directories list
DS_DIR_v = [
    "../input/vinbigdata-chest-xray-abnormalities-detection"]

for DS_DIR in DS_DIR_v:
    if os.path.exists(DS_DIR):
        print(f' DS_DIR Found: "{DS_DIR}"')
        break
else:
    raise Exception(' Dataset not found.')
    
    


DS_PATH = '../input/train-test-ds-bbox-cache'

In [None]:
class2str_v = [
    'Aortic enlargement',
    'Atelectasis',
    'Calcification',
    'Cardiomegaly',
    'Consolidation',
    'ILD',
    'Infiltration',
    'Lung Opacity',
    'Nodule/Mass',
    'Other lesion',
    'Pleural effusion',
    'Pleural thickening',
    'Pneumothorax',
    'Pulmonary fibrosis',
    'No Finding',
]


class2color_v = [
    tuple(round(i * 255) for i in colorsys.hsv_to_rgb(i_c/len(class2str_v), 1, 1))
    for i_c in range(len(class2str_v))
]

In [None]:
class FastSMA():
    def __init__(
        self,
        iterator=None,
        maxlen=1000,
        label='mean = ',
        print_format='0.02f',
        save_filename='loss_trn.fsma'):
        
        
        assert type(maxlen) == int and maxlen > 0, 'ERROR length must be a positive int.'
        self.maxlen = maxlen
        self.label = label
        self.print_format = print_format
        self.save_filename = save_filename
        self.clear()

        if iterator is not None:
            if type(next(iter(iterator))) in [list, tuple, np.ndarray]:
                for i in iterator:
                    self.append(*i)
            else:
                for i in iterator:
                    self.append(i)
            
        return None

    def append(self, v, i_step=None):        
        self.cumsum += v
        
        if i_step is None:
            i_step = self.step_hist_v[-1] + 1 if len(self.step_hist_v)  else 0
        
        if len(self.history) == self.maxlen:
            self.cumsum -= self.history[0]
            self.last_mean = self.cumsum / self.maxlen
        else:
            self.last_mean = self.cumsum / (len(self.history) + 1) 

        self.history.append( v )

        self.sma_hist_v.append(self.last_mean)
        self.val_hist_v.append(v)
        self.step_hist_v.append( i_step )
            
        return None

    def mean(self):
        return self.last_mean

    def clear(self):
        self.history = deque(maxlen=self.maxlen)
        self.cumsum = 0.0
        self.last_mean = 0.0
        self.sma_hist_v = []
        self.val_hist_v = []
        self.step_hist_v = []
        
        
        self.key2save_v = [
            'history',
            'cumsum',
            'last_mean',
            'sma_hist_v',
            'val_hist_v',
            'step_hist_v',
        ]
        return None
        
    def get_sma_history(self):
        return np.array(self.step_hist_v), np.array(self.sma_hist_v)
    
    def get_val_history(self):
        return np.array(self.step_hist_v), np.array(self.val_hist_v)

    def __str__(self):
        return ('{}{:'+self.print_format+'}').format(self.label, self.mean())

    def __repr__(self):
        return self.__str__()
    
    
    def save(self, filename=None):
        if filename is None:
            filename = self.save_filename
            
        to_save_d = {}
        for key2save in self.key2save_v:
            to_save_d[key2save] = deepcopy( getattr(self, key2save) )
        
        with open(filename, 'wb') as f:
            f.write( pickle.dumps(to_save_d) )
                
        print(f' Saved complete: "{filename}"')
            
        return None

    def load(self, filename=None):
        if filename is None:
            filename = self.save_filename
            
        with open(filename, 'rb') as f:
            data = f.read()
            
        to_save_d = pickle.loads(data)
        
        
        for k, v in to_save_d.items():
            if k in self.key2save_v:
                setattr(self, k, v)


        self.maxlen = self.history.maxlen
        print(f' Restored: "{filename}"')
        
        return self
    
    
    def plot(self, label=None, do_show=False):
        x, y = self.get_sma_history()
        plt.plot(x, y, label=label)
        plt.grid()
        
        if do_show:
            plt.show()
        
        return None


    def plot_sma(self, label='', step=500):
        if label == '':
            label = os.path.split( self.save_filename )[-1].replace('.fsma', '')
            
        step_v, loss_v = self.get_val_history()
        
        m = loss_v.shape[0] % step
        loss_v = loss_v[m:].reshape(-1, step).mean(axis=-1)
        step_v = step_v[m:].reshape(-1, step)[:,-1]

        plt.plot(step_v, loss_v, label=label)

        return None

In [None]:
def batch_merge(samples_v):
    ret_batch_d = {k : [] for k in samples_v[0].keys()}
    
    for sample_d in samples_v:
        for k, v in sample_d.items():
            ret_batch_d[k].append( v )
        
    ret_batch_d['image'] = np.stack(ret_batch_d['image'])
    
#     for sample_d in samples_v:
#         for k, v in sample_d.items():
#             if type(v) == np.ndarray:
#                 ret_batch_d[k].append( torch.from_numpy(v) ) 
#             else:
#                 ret_batch_d[k].append( v )
                
#     ret_batch_d['image'] = torch.stack(ret_batch_d['image'])
    
    return ret_batch_d
    



def data2tensor(data, device=None, pin_memory=False):
    data['image'] = torch.tensor(data['image'], device=device, pin_memory=pin_memory)
    
    for k in data.keys():
        if type(data[k][0]) is np.ndarray:
            for i_s in range(len(data['image'])):
                data[k][i_s] = torch.tensor( data[k][i_s], device=device, pin_memory=pin_memory)
                
    return data
    
    

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
    

def calc_iou(bb0, bb1):
    if (len(bb0.shape) == 2):
        bb0 = bb0.T
        
    if (len(bb1.shape) == 2):
        bb1 = bb1.T
        

    bb0_x0, bb0_y0, bb0_x1, bb0_y1 = bb0
    bb1_x0, bb1_y0, bb1_x1, bb1_y1 = bb1
    
    assert (bb0_x0 < bb0_x1).all()
    assert (bb0_y0 < bb0_y1).all()
    assert (bb1_x0 < bb1_x1).all()
    assert (bb1_y0 < bb1_y1).all()

    # determine the coordinates of the intersection rectangle
    x_left   = np.maximum(bb0_x0, bb1_x0)
    y_top    = np.maximum(bb0_y0, bb1_y0)
    x_right  = np.minimum(bb0_x1, bb1_x1)
    y_bottom = np.minimum(bb0_y1, bb1_y1)

#     if (x_right < x_left).all(axis=0) or (y_bottom < y_top).all(axis=0):
#         return np.zeros( out_dim )
    
    ret_mask = ~( (x_right < x_left) + (y_bottom < y_top) )

    # The intersection of two axis-aligned bounding boxes is always an
    # axis-aligned bounding box
    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    # compute the area of both AABBs
    bb0_area = (bb0_x1 - bb0_x0) * (bb0_y1 - bb0_y0)
    bb1_area = (bb1_x1 - bb1_x0) * (bb1_y1 - bb1_y0)
    
    iou = intersection_area / (bb0_area + bb1_area - intersection_area)
    
    
    return iou * ret_mask


def join_preds(bbox_v, p_det_v=None, mode='p_det_weight'):
    
    if p_det_v is None:
        p_det_v = np.ones(bbox_v.shape[0])
        
    if mode == 'p_det_weight' or mode == 'p_det_weight_pmean':
        typed_p_det_v = p_det_v.astype(bbox_v.dtype)
        p_v = ( typed_p_det_v / typed_p_det_v.sum() )[:,None]
        
        bbox = (bbox_v * p_v).sum(axis=0)
        p = p_det_v.mean()
    
    elif mode == 'p_det_weight_psum':
        typed_p_det_v = p_det_v.astype(bbox_v.dtype)
        p_v = ( typed_p_det_v / typed_p_det_v.sum() )[:,None]
        
        bbox = (bbox_v * p_v).sum(axis=0)
        p = p_det_v.sum()
    
    elif mode == 'median' or mode == 'median_pmean':
        bbox = np.median(bbox_v, axis=0)
        p = p_det_v.mean()

    elif mode == 'p_det_max':
        i_max = p_det_v.argmax()
        
        bbox = bbox_v[i_max]
        p    = p_det_v[i_max]
    
    elif mode == 'random':
        i_max = np.random.randint(0, p_det_v.shape[0])
        
        bbox = bbox_v[i_max]
        p    = p_det_v[i_max]
        
    else:
        raise Exception(f'Unknown mode "{mode}"')
        
        
    return bbox, p
    

def clean_predictions(preds_v, iou_th=0.1, mode='p_det_weight', consensus_level=1, n_models2ensemble=1):
    ret_preds_v = []
    for pred_d in preds_v:
        
        cls_v = pred_d['cls']
        
        if 'bbox' in pred_d.keys():
            bbox_key = 'bbox'
        else:
            bbox_key = 'bboxes'
            
        bbox_v = pred_d[bbox_key]
        
        if 'p_det' in pred_d.keys():
            ret_p_det = True
            p_det_v = pred_d['p_det']
        else:
            ret_p_det = False
            p_det_v = np.ones(pred_d['cls'].shape)
        
        
        if 'rad_id' in pred_d.keys():
            ret_rad_id = True
            rad_id_v = pred_d['rad_id']
        else:
            ret_rad_id = False


        if 'model_id' in pred_d.keys():
            model_id_v = pred_d['model_id']
        else:
            model_id_v = np.zeros(pred_d['cls'].shape, dtype=np.int)
            
            
        new_cls_v = []
        new_bbox_v = []
        new_p_det_v = []
        new_rad_id_v = []
        for i_c in np.unique(cls_v):
            f_c = (cls_v == i_c)
            
            n_c = f_c.sum()
            if n_c == 1:
                if consensus_level > 1 and i_c != -1:
                    continue
                    
                    
                if ret_rad_id:
                    if i_c == -1:
                        n_rads = rad_id_v.size
                        
                        if n_rads < consensus_level:
                            continue
                            
                        else:
                            if n_rads > 1:
                                new_rad_id_v.append( np.concatenate(rad_id_v, axis=-1) )
                            else:
                                new_rad_id_v.append( rad_id_v[f_c][0] )
                    else:
                        new_rad_id_v.append( rad_id_v[f_c][0] )
                    

                new_cls_v.append( i_c )
                new_bbox_v.append( bbox_v[f_c][0] )
                new_p_det_v.append( p_det_v[f_c][0] )
                
                
                
            else:
                f_cls_v = cls_v[f_c]
                f_bbox_v = bbox_v[f_c]
                f_p_det_v = p_det_v[f_c]
                f_model_id_v = model_id_v[f_c]

                if ret_rad_id:
                    f_rad_id_v = rad_id_v[f_c]
                    
                to_join_idxs_v = []
                for i in range(0, n_c):
                    idxs_s = set( np.argwhere( calc_iou(f_bbox_v[i], f_bbox_v) > iou_th ).T[0] )
                    
#                     print(idxs_s)
                    for i in range(len(to_join_idxs_v)):
                        if len( idxs_s.intersection(to_join_idxs_v[i]) ) > 0:
                            to_join_idxs_v[i] = to_join_idxs_v[i].union(idxs_s)
                            break
                            
                    else:
                        to_join_idxs_v.append(idxs_s)
                    
                for to_join_idxs in to_join_idxs_v:
                    to_join_idxs = list(to_join_idxs)
                    
                    if len(to_join_idxs) < consensus_level:
                        continue
                        
                    bbox, p_det = join_preds(
                        f_bbox_v[to_join_idxs],
                        f_p_det_v[to_join_idxs],
                        mode=mode,
                    )

                    if n_models2ensemble > 1:
                        ens_prop = len( np.unique(f_model_id_v[to_join_idxs]) ) / n_models2ensemble
                        p_det = p_det * ens_prop
                        
                    new_cls_v.append( i_c )
                    new_bbox_v.append( bbox )
                    new_p_det_v.append( p_det )
                    
                    if ret_rad_id:
                        new_rad_id_v.append( np.concatenate(f_rad_id_v[to_join_idxs], axis=-1))
        
        ret_preds_d = {
            'cls': np.array(new_cls_v),
            bbox_key: np.array(new_bbox_v),
        }
        
        if ret_p_det:
            ret_preds_d['p_det'] = np.array(new_p_det_v)
            
        if ret_rad_id:
            ret_preds_d['rad_id'] = np.array(new_rad_id_v, dtype=object)
            
        for k in pred_d.keys():
            if k not in ['cls', bbox_key, 'p_det', 'rad_id', 'model_id']:
                ret_preds_d[k] = pred_d[k]
        
        ret_preds_v.append(ret_preds_d)
    
    return ret_preds_v



# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

def evalueate_dataset(
    ds,
    model,
    det_th=0.25,
    unscale_bboxes=True,
    batch_size=16,
    num_workers=8,
    pin_memory=True,
    do_clean_predictions=True,
    clean_iou_th=0.10,
    clean_mode='p_det_weight',
    do_TTA=False,
    TTA_clean_iou_th=0.2,
    TTA_clean_mode='median_pmean',
):
    
    ds_iter = tqdm(DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=batch_merge))
    
    
    ret_preds_v = []
    for data in ds_iter:
        if do_TTA:
            pred_v = model.predict_TTA(
                data,
                det_th=det_th,
                unscale_bboxes=unscale_bboxes,
                TTA_clean_iou_th=TTA_clean_iou_th,
                TTA_clean_mode=TTA_clean_mode,
                )

        else:
            pred_v = model.predict(
                data,
                det_th=det_th,
                unscale_bboxes=unscale_bboxes)
        
        if do_clean_predictions:
            pred_v = clean_predictions(
                pred_v,
                iou_th=clean_iou_th,
                mode=clean_mode)
            
        for i_s, pred_d in enumerate(pred_v):
            pred_d['sample_id']      = data['sample_id'][i_s]
            pred_d['original_shape'] = data['original_shape'][i_s]
            
            
        ret_preds_v.extend(pred_v)

    return ret_preds_v
    
    
    
def pred_to_str(pred_d):
    cls_v = pred_d['cls']
    bbox_v = pred_d['bbox']
    p_det_v = pred_d['p_det']
    
    if len(cls_v) == 0:
        ret_s = '14 1 0 0 1 1'
    
    else:
        s_v = []
        for cls, p_det, bbox in zip(cls_v.astype(np.int), p_det_v, np.round(bbox_v).astype(np.int)):
            s = '{} {:0.05} {} {} {} {}'.format(
                int(cls),
                p_det,
                *bbox
            )
            
            s_v.append(s)
            
        ret_s = ' '.join(s_v)
    
    return ret_s


def predictions_to_df(
    preds_v,
    save_path=None,
):
    pred_summary_d = {
        'image_id':[],
        'PredictionString':[]
    }
    
    for pred_d in preds_v:
        pred_str = pred_to_str(pred_d)
        pred_summary_d['image_id'].append( pred_d['sample_id'] )
        pred_summary_d['PredictionString'].append( pred_str )
        
    pred_summary_df = pd.DataFrame(pred_summary_d)
    
    if save_path is not None:
        pred_summary_df.to_csv(
            save_path,
            index=None)
        
        print(f' Saved submission: "{save_path}"')
        
    return pred_summary_df


def read_prediction_csv(filename='./ds_tst_F0_V6_JustCLS0_1.25x.csv'):
    sub_df = pd.read_csv(filename)
    
    preds_v = []
    for sample_id, preds in sub_df.values:
        preds_split = preds.split()

        pred_d = {
            'sample_id': sample_id,
            'cls':    [],
            'bbox': [],
            'p_det':  [],
            }


        for i in range(0, len(preds_split), 6):
            cls, p_det, x_min, y_min, x_max, y_max = [float(x) for x in preds_split[i:i+6]]
            cls = int(cls)
            
            if cls != 14:
                bboxes = np.array([x_min, y_min, x_max, y_max])
                pred_d['cls'].append(cls)
                pred_d['bbox'].append(bboxes)
                pred_d['p_det'].append(p_det)

        pred_d['cls']    = np.array( pred_d['cls'] )
        pred_d['bbox']   = np.array( pred_d['bbox'] )
        pred_d['p_det']  = np.array(pred_d['p_det'] )
        
        preds_v.append(pred_d)
        
    return preds_v
    

# Validation functions

def voc_ap(recall, precision, use_07_metric=False):
    """
    Reference:
        https://github.com/wang-tf/pascal_voc_tools/blob/master/pascal_voc_tools/Evaluater/tools.py
        
    ap = voc_ap(recall, precision, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses  the
    VOC 07 11 point method (default: False).
    Please make shure that recall and precison are sorted by scores.
    Args:
        recall: the shape of (n,) ndarray;
        precision: the shape of (n,) ndarray;
        use_07_metric: if true, the 11 points method will be used.
    Returns:
        the float number result of average precision.
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(recall >= t) == 0:
                p = 0
            else:
                p = np.max(precision[recall >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], recall, [1.]))
        mpre = np.concatenate(([0.], precision, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def compute_overlaps(boxes, one_box):
    """
    Reference:
        https://github.com/wang-tf/pascal_voc_tools/blob/master/pascal_voc_tools/Evaluater/tools.py
        
    iou = compute_overlaps(boxes, one_box)
    compute intersection over union of ndarray.
    The format of one_box is [xmin, ymin, xmax, ymax].
    Args:
        boxes: the (n, 4) shape ndarray, ground truth boundboxes;
        bb: the (4,) shape ndarray, detected boundboxes;
    Returns:
        a (n, ) shape ndarray.
    """
    # compute overlaps
    # intersection
    ixmin = np.maximum(boxes[:, 0], one_box[0])
    iymin = np.maximum(boxes[:, 1], one_box[1])
    ixmax = np.minimum(boxes[:, 2], one_box[2])
    iymax = np.minimum(boxes[:, 3], one_box[3])
    iw = np.maximum(ixmax - ixmin + 1., 0.)
    ih = np.maximum(iymax - iymin + 1., 0.)
    inters = iw * ih

    # union
    boxes_area = (boxes[:, 2] - boxes[:, 0] + 1.) * (boxes[:, 3] -
                                                     boxes[:, 1] + 1.)
    one_box_area = (one_box[2] - one_box[0] + 1.) * (one_box[3] - one_box[1] +
                                                     1.)
    iou = inters / (one_box_area + boxes_area - inters)

    return iou


def voc_eval(class_recs: dict,
             detect: dict,
             iou_thresh: float = 0.5,
             use_07_metric: bool = False):
    """
    Reference:
        https://github.com/wang-tf/pascal_voc_tools/blob/master/pascal_voc_tools/Evaluater/tools.py
        
    recall, precision, ap = voc_eval(class_recs, detection,
                                [iou_thresh],
                                [use_07_metric])
    Top level function that does the PASCAL VOC evaluation.
    Please make sure that the class_recs only have one class annotations.
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    Args:
        class_recalls: recalls dict of a class
            class_recs[image_name]={'bbox': []}.
        detection: Path to annotations
            detection={'image_ids':[], bbox': [], 'confidence':[]}.
        [iou_thresh]: Overlap threshold (default = 0.5)
        [use_07_metric]: Whether to use VOC07's 11 point AP computation
            (default False)
    Returns:
        a dict of result including true_positive_number, false_positive_number,
        recall, precision and average_precision.
    Raises:
        TypeError: the data format is not np.ndarray.
    """
    # format data
    # class_rec data load
    npos = 0
    for imagename in class_recs.keys():
        if not isinstance(class_recs[imagename]['bbox'], np.ndarray):
            raise TypeError
        detected_num = class_recs[imagename]['bbox'].shape[0]
        npos += detected_num
        class_recs[imagename]['det'] = [False] * detected_num

    # detections data load
    image_ids = detect['image_ids']
    confidence = detect['confidence']
    BB = detect['bbox']
    if not isinstance(confidence, np.ndarray):
        raise TypeError
    if not isinstance(BB, np.ndarray):
        raise TypeError
    
    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    BB = BB[sorted_ind]
    image_ids = [image_ids[x] for x in sorted_ind]

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        iou_max = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            overlaps = compute_overlaps(BBGT, bb)
            iou_max = np.max(overlaps)
            iou_max_index = np.argmax(overlaps)

        if iou_max > iou_thresh:
            if not R['det'][iou_max_index]:
                tp[d] = 1.
                R['det'][iou_max_index] = 1
            else:
                fp[d] = 1.
        else:
            fp[d] = 1.

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    true_positive_number = tp[-1] if len(tp) > 0 else 0
    false_positive_number = fp[-1] if len(fp) > 0 else 0

    recall = tp / np.maximum(float(npos), np.finfo(np.float64).eps)
    # avoid divide by zero in case the first detection matches a difficult ground truth
    precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    average_precision = voc_ap(recall, precision, use_07_metric)

    result = {}
    result['true_positive_number'] = true_positive_number
    result['false_positive_number'] = false_positive_number
    result['positive_number'] = npos
    result['recall'] = recall
    result['precision'] = precision
    result['average_precision'] = average_precision

    return result



def make_class_detection_d(preds_v, n_classes=15):
    detection_v = [ {
        'image_ids': [],
        'bbox': [],
        'confidence': [],
    } for i in range(n_classes)]
    
    
    for pred_d in preds_v:
        sample_id = pred_d['sample_id']
        if len(pred_d['bbox']) > 0:
            for bbox, p_det, cls in zip(pred_d['bbox'], pred_d['p_det'], pred_d['cls']):
                detection_v[cls]['image_ids'].append(sample_id) 
                detection_v[cls]['bbox'].append(bbox) 
                detection_v[cls]['confidence'].append(p_det) 
        else:
            bbox = np.array([0.0, 0.0, 1.0, 1.0])
            p_det = 1.0
            cls = n_classes-1
            
            detection_v[cls]['image_ids'].append(sample_id) 
            detection_v[cls]['bbox'].append(bbox) 
            detection_v[cls]['confidence'].append(p_det) 
            
            
    for cls in range(len(detection_v)):
        for k in detection_v[cls].keys():            
            detection_v[cls][k] = np.array(detection_v[cls][k])
    
    return detection_v


def make_class_gt_d(gt_df, n_classes=15):
    
    class_gt_v = []
    for class_id in range(n_classes):
        class_gt_d = {}
        for k in gt_df.image_id.unique():
            class_gt_d[k] = {'bbox': []}
        
        class_gt_v.append( class_gt_d )
    
    for image_id, class_id, x_min, y_min, x_max, y_max in gt_df[['image_id', 'class_id', 'x_min', 'y_min', 'x_max', 'y_max']].values:
        bbox = (x_min, y_min, x_max, y_max)
        
        class_gt_v[class_id][image_id]['bbox'].append(bbox)
        
    for cls in range(len(class_gt_v)):
        for k in class_gt_v[cls].keys():
            class_gt_v[cls][k]['bbox'] = np.array(class_gt_v[cls][k]['bbox'])
            
    return class_gt_v
        






        
def calc_metrics(preds_v, gt_df, iou_thresh=0.4, show_summay=True, n_classes=15, use_07_metric=False):
    class_recs_v = make_class_gt_d(gt_df, n_classes=n_classes)
    detect_v     = make_class_detection_d(preds_v, n_classes=n_classes)
    
    mAP = 0.0
    metrics_v = []
    for i_class in range(n_classes):
        class_val_d = voc_eval(
            class_recs_v[i_class],
             detect_v[i_class],
             iou_thresh=iou_thresh,
             use_07_metric=use_07_metric)
        
        metrics_v.append(class_val_d)
        
        mAP += class_val_d['average_precision']
    
    mAP /= n_classes
    
    summary_df = pd.DataFrame(
        columns=['Cls', 'TP', 'FP', 'P', 'Prec', 'Rec', 'AP']
    )
    for i_c, m_d in enumerate(metrics_v):
        summary_df = summary_df.append({
            'Cls': i_c,
            'TP': m_d['true_positive_number'],
            'FP': m_d['false_positive_number'],
            'P': m_d['positive_number'],
            'Prec': m_d['precision'][-1] if len(m_d['precision']) > 0 else 0.0,
            'Rec': m_d['recall'][-1] if len(m_d['recall']) > 0 else 0.0,
            'AP': m_d['average_precision'],
        },
        ignore_index=True,)
        
    summary_df = summary_df.astype({
            'Cls': np.int,
            'TP': np.int,
            'FP': np.int,
            'P': np.int,
            'Prec': np.float32,
            'Rec': np.float32,
            'AP': np.float32,
        })

    summary_df = summary_df.set_index( 'Cls' )

    if show_summay:
        print(' Summary:')
        print( summary_df )
        print(f'mAP = {mAP:0.04f}')
    
    return metrics_v, summary_df, mAP



def plot_PvsR_curve(metrics_v, class2color_v, do_show=True):
    
    fig = plt.figure(0, figsize=(20,10))
    for i_c, metrics_d in enumerate(metrics_v):
        plt.plot(
            metrics_d['recall'],
            metrics_d['precision'],
            label=f'cls={i_c}',
            c=np.array(class2color_v[i_c])/255 )


    plt.xlabel('recall')
    plt.ylabel('precision')

    plt.xlim( (0,1) )
    plt.ylim( (0,1) )

    plt.legend()
    
    if do_show:
        plt.show()
    
    return fig
    
    
    
def filter_det_th(preds_v, det_th=0.15):
    if det_th == 0.0:
        return copy.deepcopy(preds_v)

    new_pred_v = []
    for pred_d in preds_v:
        
        if len(pred_d['p_det']) > 0:
            if type(det_th) is float:
                det_th_v = det_th
            else:
                det_th_v = np.zeros(pred_d['p_det'].shape[0])

                for i_c, i_cls in enumerate(pred_d['cls']):
                    det_th_v[i_c] = det_th[i_cls]
                    
            f = pred_d['p_det'] >= det_th_v
        else:
            f = None
            
        new_pred_d = {}
        for k in pred_d.keys():
            if (f is not None) and (k in ['cls', 'p_det', 'bbox']):
                new_pred_d[k] = pred_d[k][f]
                
            else:
                new_pred_d[k] = pred_d[k]
        
        new_pred_v.append(new_pred_d)
        
    return new_pred_v




def join_predictions(preds_v_list=[], add_model_id=False):
    """ Join Predictions from differetn models. """
    
    preds_v_list = copy.deepcopy(preds_v_list)
    
    for i in range(len(preds_v_list)):
        assert len(preds_v_list[0]) == len(preds_v_list[i]), 'ERROR, pred_v have differetn sizes'
        preds_v_list[i] = sorted(preds_v_list[i], key=lambda l: l['sample_id'])
        
    
    ret_pred_v = []
    for p_v in zip(*preds_v_list):
        for p in p_v[1:]:
            assert p_v[0]['sample_id'] == p['sample_id']
        
        sample_id = p_v[0]['sample_id']
        p_v = [p  for p in p_v if len(p['cls']) > 0]
        
        
        keys_v = ['bbox', 'cls', 'p_det']
        if len(p_v) and 'model_id' in p_v[0].keys():
            keys_v.append('model_id')

        elif add_model_id:
            keys_v.append('model_id')
            
            for i_m in range(len(p_v)):
                p_v[i_m]['model_id'] = i_m * np.ones_like(p_v[i_m]['cls'])
                

        if len(p_v) > 0:
            pred_d = {
                'sample_id': sample_id
            }
            
            
            for k in keys_v:
                pred_d[k] = np.concatenate([p[k]  for p in p_v], axis=0)
                
        else:
            pred_d = {
                'sample_id': sample_id
            }

            for k in keys_v:
                pred_d[k] = np.array([])
    
        ret_pred_v.append(pred_d)
    
    return ret_pred_v


def norm_p_det(pred_v):
    p_det_v = []
    for pred_d in pred_v:
        if len(pred_d['p_det']) > 0:
            p_det_v.append( pred_d['p_det'] )
    
    p_det_v = np.concatenate(p_det_v)
    p_det_max = p_det_v.max()
    
    
    print('p_det_max =', p_det_max)
    if p_det_max > 1.0:
        ret_pred_v = copy.deepcopy(pred_v)
        for pred_d in ret_pred_v:
            if len(pred_d['p_det']) > 0:
                pred_d['p_det'] = pred_d['p_det'] / p_det_max
    
    else:
        print('skipping norm_p_det')
        
        return pred_v
        
    return ret_pred_v




def fix_boxes(preds_v, img_shapes_d=None):
    """
    Fixes:
    - p_det > 1.0 or p_det < 0.0
    - xmax - xmin > 0
    - ymax - ymin > 0
    
    if img_shapes_d is provided: also clips the bboxes using the sample shape.
    """
    
    for preds_d in preds_v:
        if len(preds_d['cls']) > 0:
            if img_shapes_d is not None:
                x_max, y_max = img_shapes_d[ preds_d['sample_id'] ]
                
                EPS = 0.0
                preds_d['bbox'] = np.clip(
                    preds_d['bbox'],
                    np.array([EPS, EPS, EPS, EPS], dtype=preds_d['bbox'].dtype),
                    np.array([x_max-EPS, y_max-EPS, x_max-EPS, y_max-EPS], dtype=preds_d['bbox'].dtype) )
            
            
            dx_dy = preds_d['bbox'][:,2:] - preds_d['bbox'][:,:2]

            f0 = (dx_dy <= 1).any(axis=-1)
            f1 = (preds_d['p_det']<=0) + (preds_d['p_det']>1.0)
            
            if f0.any() or f1.any():
                print('.', end='')
                f = ~(f0 + f1)
                for k in ['p_det', 'bbox', 'cls']:
                    preds_d[k] = preds_d[k][f]

    return None




def add_class_14(
    preds_v,
    pred_clf_c14_filename='2-cls test pred.csv',
    low_threshold=0.00,
    high_threshold=0.99,
    rm_preds_high_th=True
    ):

    cls_df = pd.read_csv(pred_clf_c14_filename)

    class_14_d = {}
    for sample_id, p_cls in cls_df.values:
        p_14 = 1.0 - p_cls

        if p_14 < low_threshold:
            # Keep, do nothing.
            class_14_d[sample_id] = 0.0

        elif p_14 >= high_threshold:
            # Replace, remove all "det" preds.
            class_14_d[sample_id] = 1.0

        else:
            # Add, keep "det" preds and add normal pred.
            class_14_d[sample_id] = p_14
            
    
    
    ret_preds_v = copy.deepcopy(preds_v)
                                
    for pred_d in tqdm(ret_preds_v):
        default_case = False
        p_14 = class_14_d[ pred_d['sample_id'] ]

        if p_14 == 1:
            if rm_preds_high_th:
                pred_d['bbox']  = np.array([[0.0, 0.0, 1.0, 1.0]])
                pred_d['cls']   = np.array([14])
                pred_d['p_det'] = np.array([1.0])
                
            else:
                default_case = True

        elif p_14 == 0.0:
            continue

        else:
            default_case = True
            
        if default_case:
            if len(pred_d['bbox']) > 0 and 14 not in pred_d['cls']:
                pred_d['bbox'] = np.append(pred_d['bbox'], np.array([[0.0, 0.0, 1.0, 1.0]]), axis=0)
                pred_d['cls']  = np.append(pred_d['cls'], 14)
                pred_d['p_det'] = np.append(pred_d['p_det'], p_14)

            else:
                pred_d['bbox'] = np.array([[0, 0, 1, 1]])
                pred_d['cls']  = np.array([14], dtype=np.int)
                pred_d['p_det'] = np.array([p_14])
    
    
    return ret_preds_v




def plot_sample(img, bbox_v=None, cls_v=None, original_shape=None, ax=None, class2color_v=None):

    if class2color_v is None:
        class2color_v = [
                tuple(round(i * 255) for i in colorsys.hsv_to_rgb(i_c/15, 1, 1))
                for i_c in range(15)
            ]
            
            
    if type(img) is torch.Tensor:
        img = img.clone().detach().cpu().numpy()
        
        
    if type(bbox_v) is torch.Tensor:
        bbox_v = bbox_v.clone().detach().cpu().numpy()
    
    if type(cls_v) is torch.Tensor:
        cls_v = cls_v.clone().detach().cpu().numpy()
    
    
    if img.shape[0] < 8:
        img = img.transpose( (1,2,0) )
    
    if len(img.shape) == 3 and img.shape[2] > 3:
        img = img[...,:3]
        
    if img.max() <= 1.0:
        img = (255*img).astype(np.uint8)
    
    
    if ax is None:
        do_show = True
        f, ax = plt.subplots(1, figsize=(15,15))
    else:
        do_show = False
        
    if original_shape is not None:
        org_h, org_w = original_shape
        img_h, img_w = img.shape[:2]

        scale = np.array([
            img_w/org_w,
            img_h/org_h,
            img_w/org_w,
            img_h/org_h,
        ])
        
    else:
        scale = np.ones(4)
    
    if (bbox_v is not None) and (cls_v is not None):
        for cls, (x0, y0, x1, y1) in zip(cls_v, (bbox_v * scale).astype(np.int)):
            img = cv2.rectangle(img, (x0, y0), (x1, y1), class2color_v[cls])
        
    ax.imshow(img)
    
    if do_show:
        plt.show()
        
    return None




def scale_bboxes(bbox, scale_wh=[1.0, 1.0], original_shape=None):

    if type(scale_wh) is float:
        scale_wh = scale_wh * np.ones(2)
    else:
        scale_wh = np.array(scale_wh)
        
    p_ll = bbox[:,:2]
    p_ur = bbox[:,2:]

    p_c  = 0.5 * (p_ur + p_ll)
    p_wh = 0.5 * scale_wh * (p_ur - p_ll)

    p_ur = p_c + p_wh
    p_ll = p_c - p_wh

    new_bbox = np.concatenate([p_ll, p_ur], axis=-1)
    
    if original_shape is not None:
        new_bbox = np.clip(
            new_bbox,
            np.zeros(4),
            np.array([original_shape[1], original_shape[0], original_shape[1], original_shape[0]]))
    
    
    f = ( (new_bbox[:,2:] - new_bbox[:,:2]) > 0 ).all(axis=-1)
    new_bbox = new_bbox[f]
    return new_bbox





def wbf_ensemble(
    to_ens_preds_v,
    samples_shapes_d,
    iou_th=0.6,
    skip_box_th=0.01,
    conf_type='avg'
):
    
    n_models = len(to_ens_preds_v)

    ens_preds_v = join_predictions(
        to_ens_preds_v,
        add_model_id=True)



    boxes_v  = [[] for i_m in range(n_models)]
    scores_v = [[] for i_m in range(n_models)]
    labels_v = [[] for i_m in range(n_models)]

    ret_preds_v = []
    for pred_d in ens_preds_v:
        if len( pred_d['cls'] ) > 0:
            h, w = samples_shapes_d[pred_d['sample_id']]
            scale = np.array([w, h, w, h])

            boxes_v  = []
            scores_v = []
            labels_v = []

            for i_m in range(n_models):
                f_m = (pred_d['model_id'] == i_m)

                boxes_v.append( pred_d['bbox'][f_m] / scale )
                scores_v.append( pred_d['p_det'][f_m] )
                labels_v.append( pred_d['cls'][f_m] )


            boxes, scores, labels = weighted_boxes_fusion(
                boxes_list=boxes_v,
                scores_list=scores_v,
                labels_list=labels_v,
                weights=None,
                iou_thr=iou_th,
                skip_box_thr=skip_box_th,
                conf_type=conf_type,
                allows_overflow=False,
            )

            ret_pred_d = {
                'sample_id': pred_d['sample_id'],
                'bbox':      boxes * scale,
                'cls':       labels.astype(np.int),
                'p_det':     scores,
            }

        else:
            ret_pred_d = copy.deepcopy( pred_d )
            del(ret_pred_d['model_id'])

        ret_preds_v.append( ret_pred_d )
        
        
    return ret_preds_v


In [None]:
def save_obj(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)
    print(f'Saved: {filename}')
    return None

def load_obj(filename):
    with open(filename, 'rb') as f:
        obj = pickle.load(f)
    print(f'Loaded: {filename}')
    return obj

In [None]:
def read_dicom_image(
    path,
    voi_lut=True,
    fix_monochrome=True,
    do_norm=True):

    # Original from: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
    dicom = pydicom.read_file(path)

    # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to 
    # "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array

    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.max(data) - data

    if do_norm:
        max_value = (2 ** dicom.BitsStored) - 1
        data = data / max_value

        assert (data.max() <= 1.0) and (data.min() >= 0.0), f'Normalization ERROR in file: "{path}"'

#     if do_norm:
#         max_val = np.max(data)
#         min_val = np.min(data)
#         data = (data - min_val) / (max_val- min_val)

    return data.astype(np.float32)

def read_image_file(
    path,
    do_norm=True,
    monocrome=True,
):
    
    data = cv2.imread(path)
    
    if do_norm:
        max_value = 255
        data = data / max_value
        
        assert (data.max() <= 1.0) and (data.min() >= 0.0), f'Normalization ERROR in file: "{path}"'
    
    if monocrome:
        if len(data.shape) == 3:
            data = data[:,:,0]
        
    return data.astype(np.float32)

In [None]:
def calc_image_features(image):
    img_exp  = image ** 3.0
    img_uint = (255 * image).astype(np.uint8)
    img_equ  = cv2.equalizeHist(img_uint)
    img_edge = cv2.Canny(img_equ, 50, 130)

    img_ret = np.concatenate(
        [
            image[:,:,None],
            img_exp[:,:,None],
            img_equ[:,:,None].astype(np.float32)  / 255,
            img_edge[:,:,None].astype(np.float32) / 255,
            ],
        axis=-1)

    return img_ret

# Dataset handler

In [None]:
class FoldDataset():
    def __init__(
        self,
        ds_path='./',
        ds_name='test',
        images_dir='./test_images',
        
        model_resolution=(512, 512),
        df_path=None,
        
        mode='none',
        
        i_fold=0,
        n_folds=5,
        test_split=0.1,
        
        do_augmentation=True,
        
        downsample_factor=2,
        remove_classes_v=None,
        select_classes_v=None,
        reclass_samples=True,
        
        
        show_warnings=True,
        
        do_random_shuffle=True,
        random_seed=3128,
        sample_bbox_and_class=False,
        clean_boxes=True,
        clean_mode='random',
        clean_iou_th=0.1,
        consensus_level=1,
        
        effdet_xy_invert=False,
        
        no_finding_cls=class2str_v.index('No Finding'),
        
        use_img_cache=False,
    ):
        
        
        self.ds_path = ds_path
        self.ds_name = ds_name
        
        self.model_resolution = model_resolution
        self.bboxes_df_path   = df_path
        self.i_fold           = i_fold
        self.n_folds          = n_folds
        self.do_augmentation  = do_augmentation
        
        self.images_dir        = images_dir
        self.downsample_factor = downsample_factor
        self.remove_classes_v  = remove_classes_v
        self.select_classes_v  = select_classes_v
        self.reclass_samples   = reclass_samples
        
        self.show_warnings     = show_warnings
        self.test_split        = test_split
        
        self.do_random_shuffle = do_random_shuffle
        self.random_seed       = random_seed
        
        self.sample_bbox_and_class = sample_bbox_and_class
        
        self.clean_boxes  = clean_boxes
        self.clean_mode   = clean_mode
        self.clean_iou_th = clean_iou_th
        self.consensus_level = 1 if consensus_level is None else consensus_level
        
        self.effdet_xy_invert = effdet_xy_invert
        self.no_finding_cls = no_finding_cls
        
        
        self.use_img_cache = use_img_cache
        
        if (self.select_classes_v is not None) and (14 in self.select_classes_v):
            self.select_classes_v.remove(14)
            self.select_classes_v.append(-1)
            
        if (self.select_classes_v is not None) and self.reclass_samples:
            self.cls_orig2new_d = {}
            self.cls_new2orig_d = {}
            
            for i_cls_new, i_cls in enumerate(sorted(self.select_classes_v)):
                if -1 in self.select_classes_v:
                    i_cls_new -= 1
                
                self.cls_orig2new_d[14] = len(self.select_classes_v) - 1
                self.cls_new2orig_d[len(self.select_classes_v) - 1] = 14
                    
                self.cls_orig2new_d[i_cls]     = i_cls_new
                self.cls_new2orig_d[i_cls_new] = i_cls
                    
        
        self.mode = mode.lower()
        posible_modes_v = ['cv_trn', 'cv_val', 'cv_tst', 'none']
        assert self.mode in posible_modes_v, f'Parameter "mode" must be in: {posible_modes_v}'
        
        if mode == 'cv_trn':
            self.do_CV  = True
            self.iter_trn = True
            self.iter_val = False
            self.iter_tst = False
            
        elif mode == 'cv_val':
            self.do_CV  = True
            self.iter_trn = False
            self.iter_val = True
            self.iter_tst = False
            
        elif mode == 'cv_tst':
            self.do_CV  = True
            self.iter_trn = False
            self.iter_val = False
            self.iter_tst = True
        
        elif mode == 'none':
            self.do_CV  = False
            self.iter_trn = False
            self.iter_val = False
            self.iter_tst = False
        
        
        
        self.h5_path = os.path.join(self.ds_path, f'{self.ds_name}.h5df')
        self.misc_path = os.path.join(self.ds_path, f'{self.ds_name}.pickle')
        self.cache_path = os.path.join(self.ds_path, f'{self.ds_name}.cache')
        
        self.f_h5 = None
        
        if self.use_img_cache:
            if not os.path.exists(self.h5_path) or not os.path.exists(self.misc_path):
                self._gen_images_ds()

            self.open_h5_file()
            
        self.read_misc_d()
        
        self.bbox_df = None
        
        self._set_all_samples_ids()
        
        self._build_albumentations()

        self._read_bboxes_df_or_cache()        
            
        self.update_fold_filter()
        
        return None
    
    
    def _read_bboxes_df_or_cache(self):
        if self.bboxes_df_path is not None:
            if not os.path.exists(self.cache_path):
                self.read_bboxes_df()
                self._save_ds_state()
                
            else:
                self._load_ds_state()
                
        return None
    
    def _save_ds_state(self):    
        to_save_d = {
            'bbox_df': self.bbox_df,
            'class_to_sample_v': self.class_to_sample_v,
            'bbox_d': self.bbox_d,
        }
        
        save_obj(to_save_d, self.cache_path)
        
        return None
    
    
    def _load_ds_state(self):
        state_d = load_obj( self.cache_path )
        
        for k, v in state_d.items():
            self.__setattr__(k, v)
        
        return None
    
    def _set_all_samples_ids(self):
        self.all_sample_ids = np.array( sorted( self.misc_d.keys() ) )
        
        if self.bbox_df is not None:
            assert ( self.all_sample_ids == np.array(sorted(self.bbox_df['image_id'].unique()) ) ).all()
        
        
        if self.do_random_shuffle:
            np.random.seed(self.random_seed)
            
        return None
    
    
    def _build_bbox_d(self):
        self.bbox_d  = {}
        self.class_to_sample_v = [ [] for i in range( self.bbox_df.class_id.max() + 1)]
        
#         i = 0
        sample_it = tqdm(self.all_sample_ids, desc='Building bboxes')
        for s_id in sample_it:
            sample_bboxes_df = self.bbox_df[self.bbox_df.image_id == s_id]

            cls = sample_bboxes_df['class_id'].values
            bboxes = sample_bboxes_df[['x_min', 'y_min', 'x_max', 'y_max']].values
            rad_id = sample_bboxes_df[ ['rad_id'] ].values
            
            for i_c in np.unique(cls):
                self.class_to_sample_v[i_c].append( s_id )
                
            # Class 14 filter
            f_14 = (cls==self.no_finding_cls)
            if f_14.any():
                i_14 = f_14.argmax()
                
                orig_img_h, orig_img_w = self.misc_d[s_id]
                bboxes[i_14] = np.array([0, 0, orig_img_w, orig_img_h])
                cls[i_14]    = -1  # if I add 1 will be the class  0 (backfround for effdet)
                
                f_14 = ~f_14
                f_14[i_14] = True
                bboxes = bboxes[f_14]
                cls    = cls[f_14]
            
            f_nan = np.isnan(bboxes).any(axis=-1)
            if f_nan.any():
                orig_img_h, orig_img_w = self.misc_d[s_id]
                bboxes[f_nan] = np.array([0, 0, orig_img_w, orig_img_h])
                
            self.bbox_d[s_id] = {
                'bboxes': bboxes,
                'cls':    cls,
                'rad_id': rad_id,
            }
            
#             i += 1
            
#             if i ==500:
#                 break

        for i_c in range(len(self.class_to_sample_v)):
            self.class_to_sample_v[i_c] = np.array(self.class_to_sample_v[i_c])

        return None



    def read_bboxes_df(self):
        self.bbox_df = pd.read_csv(
            self.bboxes_df_path,
            dtype={
                'x_min':np.float32,
                'y_min':np.float32,
                'x_max':np.float32,
                'y_max':np.float32,
            })
        
#         self.bbox_df.fillna(0.0)
        
        self._build_bbox_d()
        return None
        
        
    def update_fold_filter(self):
        assert self.i_fold >= 0 and self.i_fold < self.n_folds, 'ERROR, incorrect i_fold.'
        
        # Apply sample class filtering
        if self.select_classes_v is None:
            self.selected_sample_ids = self.all_sample_ids
            
        else:
            self.selected_sample_ids = set()
            for i_c in self.select_classes_v:
                self.selected_sample_ids = self.selected_sample_ids.union( self.class_to_sample_v[i_c] )
            
            self.selected_sample_ids = np.array( sorted(self.selected_sample_ids) )
            
            
        if self.remove_classes_v is not None:
            self.selected_sample_ids = set(self.selected_sample_ids)
            
            for i_c in self.remove_classes_v:
                self.selected_sample_ids = self.selected_sample_ids.difference( self.class_to_sample_v[i_c] )
            
            self.selected_sample_ids = np.array( sorted(self.selected_sample_ids) )
        # # # # # # # # # # # # # # # # #
        
        n_samples = len( self.selected_sample_ids )
        n_samples_cv = int( (1.0-self.test_split) * n_samples)
        
        if self.do_CV:
            if self.iter_tst:
                f_samples = np.zeros(n_samples, dtype=np.bool)
                f_samples[n_samples_cv:] = True
                
            else:
                n_samples_per_fold = n_samples_cv // self.n_folds
                f_samples = np.zeros(n_samples, dtype=np.bool)

                if self.i_fold < self.n_folds - 1:
                    f_samples[ self.i_fold * n_samples_per_fold: (self.i_fold+1) * n_samples_per_fold] = True
                else:
                    f_samples[ self.i_fold * n_samples_per_fold: n_samples_cv] = True

                if self.iter_trn:
                    f_samples[:n_samples_cv]  = ~(f_samples[:n_samples_cv])
                    
        else:
            f_samples = np.ones(n_samples, dtype=np.bool)
        
        
        if self.bbox_df is not None and self.consensus_level > 1:
            for i_s in np.argwhere(f_samples).T[0]:
                
                sample_id = self.selected_sample_ids[i_s]
                
                x = clean_predictions(
                    [ self.bbox_d[ sample_id ] ],
                    iou_th=self.clean_iou_th,
                    mode=self.clean_mode,
                    consensus_level=self.consensus_level,
                )

                if x[0]['bboxes'].shape[0] == 0:
                    f_samples[i_s] = False
        
        
        self.fold_sample_filter = f_samples 
        self.fold_samples = self.selected_sample_ids[self.fold_sample_filter]

        return None
    
    
    def get_image(self, sample_id):
        
        if self.use_img_cache:
            image = self.f_h5[sample_id][:]
        
        else:
            file_path = os.path.join(self.images_dir, f'{sample_id}.dicom')
            image, image_shape, file_id = self.read_and_downsample_dicom_image(
                file_path
            )
        
        return image
    
    
    def __getitem__(self, index):
        
        if index < 0:
            index = self.__len__() + index
        
        sample_id = self.fold_samples[index]
        
        image = self.get_image(sample_id)
        
        original_shape = self.misc_d[sample_id]
        
        img_h, img_w = image.shape
        orig_img_h, orig_img_w = original_shape
        
        scale_h = img_h / orig_img_h
        scale_w = img_w / orig_img_w
        
        image = calc_image_features(image)
        
        sample = {
            'sample_id': sample_id,
            'image': image,
            'original_shape': self.misc_d[sample_id],
        }
        
        
        if self.bbox_df is not None:
            # The images are already resized.
            # Format at this point is (x0,y0,x1,y1)
            scale_v = np.array([scale_w, scale_h, scale_w, scale_h])
            sample['bboxes'] = self.bbox_d[sample_id]['bboxes'] * scale_v
            sample['cls'] = self.bbox_d[sample_id]['cls']
            
            if self.select_classes_v is not None:
                f = np.zeros(sample['cls'].shape[0], dtype=np.bool)
                    
                for i in self.select_classes_v:
                    f[sample['cls'] == i] = True

                sample['bboxes'] = sample['bboxes'][f]
                sample['cls']    = sample['cls'][f]
                
                if self.reclass_samples:
                    for i_c, i_cls in enumerate(sample['cls']):
                        sample['cls'][i_c] = self.cls_orig2new_d[i_cls]
                    
            
            if self.sample_bbox_and_class:
                idx = np.random.randint(0, sample['cls'].shape[0])
                
                (x0, y0, x1, y1) = sample['bboxes'][idx].astype(np.int)
                
                d_x = np.random.randint(0, (x1-x0)//4 )
                d_y = np.random.randint(0, (y1-y0)//4 )
                x0 = max(0, x0-d_x)
                y0 = max(0, y0-d_y)
                
                x1 = min(sample['image'].shape[1], x1+d_x)
                y1 = min(sample['image'].shape[0], y1+d_y)
                
                sample['image'] = sample['image'][y0:y1, x0:x1]
                sample['cls'] = sample['cls'][idx:idx+1]
                sample['bboxes'] = np.array( [ (0, 0, sample['image'].shape[1], sample['image'].shape[0]) ] )
                
            if self.clean_boxes:
                to_clean_d = {'bbox': sample['bboxes'], 'cls':sample['cls']}
                cleaned_d = clean_predictions(
                    [to_clean_d],
                    iou_th=self.clean_iou_th,
                    mode=self.clean_mode,
                    consensus_level=self.consensus_level
                )[0]
                
                sample['bboxes'] = cleaned_d['bbox']
                sample['cls'] = cleaned_d['cls']
                
        else:
            sample['bboxes'] = np.array( [ (0, 0, image.shape[1], image.shape[0]) ] )
            sample['cls']    = np.array( [-1] )
        
        

        if self.do_augmentation:
            for i_try in range(10):
                try:
                    transform_sample = self.TR_trn(**sample)
                    if len(transform_sample['bboxes']) > 0:
                        # Updating sample
                        sample = transform_sample
                        break

                except Exception as e:
                    pass

            else:
                if self.show_warnings:
                    print(f' - WARNING: Imposible to Augmentate image, idx={index}.', file=sys.stderr)
                    
                # doing a basic transform
                sample = self.TR_val(**sample)
                
        else:
            sample = self.TR_val(**sample)
        
        sample['image'] = sample['image'].transpose([2,0,1])
        sample['bboxes'] = np.array(sample['bboxes'], dtype=np.float32)
        if self.effdet_xy_invert:
            # Formating: (x0,y0,x1,y1) to (y0,x0,y1,x1) for EffDet
            sample['bboxes'] = sample['bboxes'][:, [1,0,3,2]]
        
        sample['cls'] = np.array(sample['cls'], dtype=np.int)
#         sample['extra'] = np.array([sample['s_norm'], sample['a_norm'], sample['d_norm']], dtype=np.float32).T
        
        
        return sample
    
    
    def open_h5_file(self, mode='r'):
        if self.f_h5 is not None:
            self.close_file()
            
        if mode == 'w':
            if os.path.exists(self.h5_path):
                print(f' The file "{self.h5_path}" already exists, if you continue the file will be deleted. Contunue (y/n) ?')
                r = input()
                if r.lower() != 'y':
                    print('Operation aborted.')
                    sys.exit()
            
        self.f_h5 = h5py.File( self.h5_path, mode)
        return None
    
    
    def close_file(self):
        if self.f_h5 is not None:
            self.f_h5.close()
            
        return None
    
    
    def read_misc_d(self):
        self.misc_d = load_obj( self.misc_path )
        return None
        
        
    def save_misc_d(self, warn=True):
        if warn and self.misc_d:
            if os.path.exists(self.misc_d):
                print(f' The file "{self.misc_d}" already exists, if you continue the file will be deleted. Contunue (y/n) ?')
                r = input()
                if r.lower() != 'y':
                    print('Operation aborted.')
                    sys.exit()
            
        
        save_obj(self.misc_d, self.misc_path)
        return None
    
    
    def downsample_img(
        self,
        image,
        downsample_factor=2):
        
        new_dims_v = (
            image.shape[1] // downsample_factor,
            image.shape[0] // downsample_factor )
        
        image_rs = cv2.resize(
            image,
            new_dims_v
        )
        
        return image_rs
    
    
    def read_and_downsample_dicom_image(self, file_path):
        file_id, file_ext = os.path.splitext(os.path.basename(file_path))
        
        if file_ext == '.dicom':
            image = read_dicom_image(file_path)

        else:
            image = read_image_file(file_path)
                
        image_shape = image.shape

        image_rs = self.downsample_img(
            image,
            downsample_factor=self.downsample_factor)
        
        return image_rs, image_shape, file_id
    
    
    def _gen_images_ds(self):
        all_files_v = glob.glob(os.path.join(self.images_dir, '*.dicom') ) + glob.glob(os.path.join(self.images_dir, '*.png') )
        
        self.open_h5_file(mode='w')
        self.misc_d = {}
        
        file_it = tqdm( all_files_v )
        for file_path in file_it:
            image_rs, image_shape, file_id = self.read_and_downsample_dicom_image(file_path)
            
            file_it.set_description(file_id)
            self.f_h5.create_dataset(file_id, data=image_rs)
            self.misc_d[file_id] = image_shape
        
        self.close_file()
        self.save_misc_d(warn=False)
        
        return None
    

    def __len__(self):
        return len(self.fold_samples)
    
    
    def _build_albumentations(self):
        self.TR_trn = A.Compose(
            [
#                 A.RandomSizedCrop(
#                     min_max_height=[int(0.7*self.image_resolution[0]), int(1.0*self.image_resolution[0])],
#                     height=int(0.9*self.image_resolution[0]),
#                     width=int(0.9*self.image_resolution[1]), 
#                     p=0.5),
                
                
                A.Crop(
                    x_min=128,
                    y_min=128,
                    x_max=128,
                    y_max=128,
                    p=0.5),

#                     A.OneOf([
#                         A.HueSaturationValue(
#                             hue_shift_limit=0.2,
#                             sat_shift_limit= 0.2,
#                             val_shift_limit=0.2,
#                             p=0.9),

                A.RandomBrightnessContrast(
                    brightness_limit=0.2, 
                    contrast_limit=0.2,
                    p=0.9
                ),

#                     ],
#                         p=0.9),

#                     A.ToGray(p=0.01),

                A.HorizontalFlip(p=0.5),

    #             A.VerticalFlip(p=0.5),

    #             A.RandomRotate90(p=0.5),

                A.Rotate(
                    limit=5,
                    p=0.6,
                ),

#                     A.Transpose(p=0.5),

#                 A.Blur(blur_limit=3, p=0.6),

                A.OneOf([
                    A.Blur(blur_limit=3, p=1.0),
                    A.MedianBlur(blur_limit=3, p=1.0)
                    ],
                    p=0.5),

                
                A.Resize(
                    height=self.model_resolution[0],
                    width=self.model_resolution[1],
                    p=1.0),
                    
                A.Cutout(
                    num_holes=int(0.05 * np.prod(self.model_resolution) / (min(self.model_resolution)//20)**2),
                    max_h_size=min(self.model_resolution)//20,
                    max_w_size=min(self.model_resolution)//20,
                    fill_value=0,
                    p=0.5),

#                 ToTensorV2(p=1.0),
                    
            ],

            p=1.0, 
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0, 
                min_visibility=0,
                label_fields=['cls']
            )
            )
        
        if self.sample_bbox_and_class:
            self.TR_trn = A.Compose(
            [

                A.HorizontalFlip(p=0.5),

                A.Rotate(
                    limit=15,
                    p=0.6,
                ),

#                     A.Transpose(p=0.5),

                A.Blur(blur_limit=3, p=0.6),

                A.Cutout(
                    num_holes=10,
                    max_h_size=5,
                    max_w_size=5,
                    fill_value=0,
                    p=0.5),

                A.Resize(
                    height=self.model_resolution[0],
                    width=self.model_resolution[1],
                    p=1.0),

#                 ToTensorV2(p=1.0),
            ],

            p=1.0, 
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0, 
                min_visibility=0,
                label_fields=['cls']
            )
            )
            
        
        
        self.TR_val = A.Compose(
            [
                A.Resize(
                    height=self.model_resolution[0],
                    width=self.model_resolution[1],
                    p=1.0),

#                 ToTensorV2(p=1.0),
            ], 

            p=1.0, 
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0, 
                min_visibility=0,
                label_fields=['cls']
            )
        )

        
        return None

    
#     def get_GT_Dataframe(self, merge_mode='first'):
#         """ Reeturns the fold GT Dataframe"""
#         gt_df = self.bbox_df[self.bbox_df.image_id.isin(self.fold_samples)].copy()
#         gt_df.fillna(0, inplace=True)
#         gt_df.loc[gt_df["class_id"] == 14, ['x_max', 'y_max']] = 1.0
        
        
        
#         gb = gt_df.groupby(['image_id', 'class_id'])

#         if merge_mode == 'mean':
#             gt_df = gb.agg(lambda x: ' '.join(np.unique(x.values)) if type(x.values[0]) is str else x.values.mean() ).reset_index()

#         elif merge_mode == 'first':
#             gt_df = gb.first().reset_index()

#         elif merge_mode is None:
#             gt_df = gt_df.copy()

#         else:
#             raise Exception(f'Mode "{merge_mode}" unknown.')
        
        
#         if (self.select_classes_v is not None):
#             selected_cls_v = self.select_classes_v
#             if 14 in self.select_classes_v or -1 in self.select_classes_v:
#                 selected_cls_v = self.select_classes_v + [14]
            
#             gt_df = gt_df[ gt_df.class_id.isin(selected_cls_v) ]
#             gt_df.class_id = gt_df.class_id.map(lambda x: self.cls_orig2new_d[x])
                    
#         return gt_df
    
    
    def get_GT_Dataframe(
        self, 
        clean_mode='random',
        clean_iou_th=0.1,
        consensus_level=None):
        """ Reeturns the fold GT Dataframe"""

        if consensus_level is None:
            consensus_level = self.consensus_level
            
        ret_d = {
            'image_id': [],
            'class_id': [],
            'class_name': [],
            'rad_id': [],

            'x_min': [],
            'y_min': [],
            'x_max': [],
            'y_max': [],
        }

        for sample_id in self.fold_samples:
            sample_d = clean_predictions(
                [self.bbox_d[sample_id]],
                iou_th=clean_iou_th,
                mode=clean_mode,
                consensus_level=consensus_level,
            )[0]


            for i_c, rad_id, bboxes in zip(sample_d['cls'], sample_d['rad_id'], sample_d['bboxes']):
                if i_c == -1:
                    i_c = 14

                ret_d['image_id'].append(sample_id)
                ret_d['class_id'].append(i_c)
                ret_d['class_name'].append( class2str_v[i_c] )
                ret_d['rad_id'].append(' '.join(rad_id) )

                if i_c == 14:
                    ret_d['x_min'].append( 0.0 )
                    ret_d['y_min'].append( 0.0 )
                    ret_d['x_max'].append( 1.0 )
                    ret_d['y_max'].append( 1.0 )

                else:
                    x_min, y_min, x_max, y_max = bboxes
                    ret_d['x_min'].append( x_min )
                    ret_d['y_min'].append( y_min )
                    ret_d['x_max'].append( x_max )
                    ret_d['y_max'].append( y_max )

        gt_df = pd.DataFrame(ret_d)


        if (self.select_classes_v is not None):
            selected_cls_v = self.select_classes_v
            if 14 in self.select_classes_v or -1 in self.select_classes_v:
                selected_cls_v = self.select_classes_v + [14]

            gt_df = gt_df[ gt_df.class_id.isin(selected_cls_v) ]
            gt_df.class_id = gt_df.class_id.map(lambda x: self.cls_orig2new_d[x])


        return gt_df
    
    def get_samples_wh_v(self, cls, all_samples=True):
        ds = self

        wh_v = []
        img_h, img_w = self.model_resolution
        
        if all_samples:
            samples_d_v = self.all_sample_ids
        else:
            samples_d_v = self.fold_samples
            
        for sample_id in samples_d_v:
            sample_d = self.bbox_d[sample_id]
            bboxes = sample_d['bboxes'][ sample_d['cls'] == cls ]

            if bboxes.shape[0] > 0:
                org_h, org_w = self.misc_d[sample_id]
                scale = np.array( [img_w/org_w, img_h/org_h] )

                wh = (bboxes[:,2:] - bboxes[:,:2]) * scale
                wh_v.append(wh)

        if len(wh_v) > 0:
            wh_v = np.vstack(wh_v)
        else:
            wh_v = np.array([])

        return wh_v
    
    def get_sample_idx_from_id(self, s_id):
        return np.argwhere(self.fold_samples == s_id).T[0]


    def plot_sample(self, idx=0, verbose=False, do_show=True):
        sample = self[idx]

        plt.figure(0, figsize=(20,20))

        img = (255*sample['image'].transpose((1,2,0))).astype(np.uint8)

        for cls, bbox in zip(sample['cls'], sample['bboxes']):
            (x0, y0, x1, y1) = bbox.astype(np.int)
            img = cv2.rectangle(img, (x0,y0), (x1, y1), class2color_v[cls] )
            if verbose:
                print(' - cls={cls}   bbox={bbox}')
                
        plt.imshow( img )
        
        if do_show:
            plt.show()

        return None



    
    def filter_radiologists(self, rad_id_v=['R13']):
        self.fold_samples = self.selected_sample_ids[self.fold_sample_filter]
        self._read_bboxes_df_or_cache()
        self.bbox_df = self.bbox_df[self.bbox_df.rad_id.isin(rad_id_v)]

        f_samples = np.ones(
            self.fold_samples.shape[0],
            dtype=np.bool)

        for sample_id, bbox_d in self.bbox_d.items():
            n_boxes = bbox_d['rad_id'].shape[0]

            f = np.zeros(n_boxes, dtype=np.bool)
            for rad_id in rad_id_v:
                f += (bbox_d['rad_id'].T[0] == rad_id)

            n_rads = f.sum()
            if n_rads == 0:
                i = np.argwhere(self.fold_samples == sample_id).T[0]

                if i.shape[0] > 0:
                    f_samples[i[0]] = False

            if bbox_d['cls'][0] == -1:
                if n_rads > 0:
                    bbox_d['rad_id'] = bbox_d['rad_id'][f]
                continue


    #         print(bbox_d['rad_id'].T[0], f)
            if n_rads == 0:
                original_shape = self.misc_d[sample_id]
                bbox_d['bboxes'] = np.array([ [0.0, 0.0, original_shape[1], original_shape[0] ]])
                bbox_d['cls'] = np.array([-1])
                bbox_d['rad_id'] = np.array([['NoRad']])

            else:
                for k in bbox_d.keys():
                    bbox_d[k] = bbox_d[k][f]


        self.fold_samples = self.fold_samples[f_samples]

    #         print(sample_id)
    #         print(bbox_d)
    #         print(self.bbox_d[sample_id])
    #         print()
    #         break
        return None


    def filter_class(self, cls_v=[0]):
        self.fold_samples = self.selected_sample_ids[self.fold_sample_filter]
        self._read_bboxes_df_or_cache()
        
        self.bbox_df = self.bbox_df[
            self.bbox_df.image_id.isin(
                self.bbox_df[ self.bbox_df.class_id.isin(cls_v) ].image_id.unique() 
            )
        ]

        f_samples = np.zeros(
            self.fold_samples.shape[0],
            dtype=np.bool)
        
        
        for i_sample, sample_id in enumerate(self.fold_samples):
            bbox_d = self.bbox_d[sample_id]

            for cls in cls_v:
                if (cls in bbox_d['cls']):
                    f_samples[i_sample] = True
                    break

        self.fold_samples = self.fold_samples[f_samples]

        return None

    
ds = FoldDataset(
    ds_path='../input/train-test-ds-bbox-cache/',
    ds_name='train',
    images_dir=os.path.join(DS_DIR, 'train'),
    model_resolution=(512, 512),
    df_path=os.path.join(DS_DIR, 'train.csv'),
    
    mode='cv_trn',
    i_fold=1,
    n_folds=5,
    test_split=0.1,
    
    do_augmentation=False,
    downsample_factor=2,
    remove_classes_v=[14],
    select_classes_v=None,
    show_warnings=False,
    random_seed=3128,
)


# ds = FoldDataset(
#     ds_path=DS_PATH,
#     ds_name='train',
#     images_dir=os.path.join(DS_DIR, 'train'),
#     model_resolution=(512, 512),
#     df_path=os.path.join(DS_DIR, 'train.csv'),
    
#     mode='cv_val',
#     i_fold=1,
#     n_folds=5,
#     test_split=0.1,
    
#     do_augmentation=False,
#     downsample_factor=DOWNSAMPLE_FACTOR,
#     remove_classes_v=[],
#     select_classes_v=None,
#     reclass_samples = True,
#     show_warnings=False,
#     random_seed=3128,
    
#     clean_boxes=True,
#     clean_mode='random',
#     clean_iou_th=0.1,
    
#     consensus_level=2,
# )

# len(ds)

# Model

In [None]:
def cos_decay(start_val=1.0, end_val=1e-4, steps=100):
    return lambda x: ((1 - np.cos(x * np.pi / steps)) / 2) * (end_val - start_val) + start_val

def linear_warmup(start_val=1e-4, end_val=1.0, steps=5):
    return lambda x: x / steps * (end_val - start_val) + start_val  # linear

def scheduler_lambda(lr_frac=1e-4, warmup_epochs=5, cos_decay_epochs=60):
    if warmup_epochs > 0:
        lin = linear_warmup(start_val=lr_frac, end_val=1.0, steps=warmup_epochs)
        
    cos = cos_decay(start_val=1.0, end_val=lr_frac, steps=cos_decay_epochs)
    
    def f(x):
        if x < warmup_epochs:
            return lin(x)
        
        elif x <= (warmup_epochs + cos_decay_epochs):
            return cos(x - warmup_epochs)
        
        else:
            return lr_frac
        
    return f



class ModelX(nn.Module):
    def __init__(
        self,
        model_resolution=(768, 512), 
        n_input_channels=3,
        n_classes=14, # Not counting background
        n_extras=0,
        extra_loss_weight=1.0,

        init_lr=1e-3,
        use_scheduler=True,
        n_warmup_epochs=5,
        n_decay_epochs=60,
        lr_prop=1e-4,
        
        n_steps_grad_update=1,
        optimizer_name='adam',
        clip_grad_norm=5.0,
        weight_decay=0.0,
        
        use_pretrained_model=True,
        
        backbone_name='yolo',#'tf_efficientdet_d4',
        trainable_backbone_layers=3,
        
        anchors_d=None,
        
        start_ckpt=None,
        
        checkpoint_base_path='./model_checkpoint',
        model_name=' ModelX_v1',
        device=None,
        parallelize_backbone=False,
        ):
        
        super().__init__()
        
        self.model_resolution = np.array(model_resolution) 
        self.n_input_channels = n_input_channels
        self.n_classes        = n_classes
        self.n_extras         = n_extras
        self.extra_loss_weight = extra_loss_weight
        
        self.lr               = init_lr
        self.use_scheduler    = use_scheduler
        self.n_warmup_epochs  = n_warmup_epochs
        self.n_decay_epochs   = n_decay_epochs
        self.lr_prop          = lr_prop
    
        
        self.optimizer_name   = optimizer_name
        self.clip_grad_norm   = clip_grad_norm
        self.weight_decay     = weight_decay
        self.n_steps_grad_update = n_steps_grad_update
        
        self.use_pretrained_model = use_pretrained_model
        self.backbone_name        = backbone_name.lower()
        
        self.trainable_backbone_layers = trainable_backbone_layers
        
        self.checkpoint_base_path = checkpoint_base_path
        self.model_name           = model_name
        
        self.parallelize_backbone = parallelize_backbone
        
        self.use_effdet = False
        self.use_yolo = False
        
        self.start_ckpt = start_ckpt
        
        self.anchors_d = anchors_d
        
        
        self.do_bbox_cleaning = True
        
        print(f'New Model: "{self.model_name}"')
        
        if self.parallelize_backbone:
            print(' - Setting main device = cuda:0')
            device = 'cuda:0'
            
        self.optimizer = None
        self.scheduler = None
        # Seting device
        self.set_device(device)
        
        # Creating output dir
        if not os.path.exists(self.checkpoint_base_path):
            print(f'Creating save dir: "{self.checkpoint_base_path}"')
            os.makedirs(self.checkpoint_base_path)        
        
        
        # Building Backbone
        self._build_backbone()
        
        
        # Moving model to device
        self.to(self.device)
        
        if self.start_ckpt is not None:
            self.restore_checkpoint(self.start_ckpt)
        
        if self.parallelize_backbone:
            self._parallelize_backbone()
         
        
        # Model Summary
        self.calc_total_weights()
        
        
        # Building Optimizers
        self.build_optimizer()   
        
        return None
    
    
    @torch.jit.ignore
    def get_trainable_weights(self, verbose=True):
        
        if self.use_effdet:
            use_extra_net = ('extra_net' in dir(self.backbone)) and (self.backbone.extra_net is not None)
            
            if self.parallelize_backbone:
                trainable_params_v = sum([

                    list( self.backbone.backbone.module.conv_stem.parameters() ),
                    list( self.backbone.backbone.module.bn1.parameters() ),

                    list( self.backbone.fpn.parameters() ),
                    list( self.backbone.class_net.parameters() ),
                    list( self.backbone.box_net.parameters() ),
                    list( self.backbone.extra_net.parameters() ) if use_extra_net else [],
                ], [])

            else:
                trainable_params_v = sum([

                    list( self.backbone.backbone.conv_stem.parameters() ),
                    list( self.backbone.backbone.bn1.parameters() ),

                    list( self.backbone.fpn.parameters() ),
                    list( self.backbone.class_net.parameters() ),
                    list( self.backbone.box_net.parameters() ),
                    list( self.backbone.extra_net.parameters() ) if use_extra_net else [],
                ], [])
                
                trainable_params_v

        else:
            trainable_params_v = [p for p in self.parameters() if p.requires_grad ]
        
        if verbose:
            n_w = 0
            for p in trainable_params_v:
                n_w += np.prod(p.shape)

            print(f' - Total trainable weights: {n_w/1e6:0.03} M')

            
        return trainable_params_v
    
    
    
    @torch.jit.ignore
    def build_optimizer(self, params_v=None):
        if params_v is None:
            params_v = self.get_trainable_weights(verbose=False)
#             params_v = self.parameters()
        
        param_id_v = [id(p) for p in params_v]
        
        if self.optimizer_name.lower() == 'adam':
            if self.weight_decay > 0.0:
                pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
                for k, v in self.named_modules():
                    if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
                        if id(v.bias) in param_id_v:
                            pg2.append(v.bias)  # biases, no decay
                        else:
                            v.bias.requires_grad_(False)
                            
                    if isinstance(v, nn.BatchNorm2d):
                        if id(v.weight) in param_id_v:
                            pg0.append(v.weight)  # weights, no decay
                        else:
                            v.weight.requires_grad_(False)
                        
                    elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
                        if id(v.weight) in param_id_v:
                            pg1.append(v.weight)  # weights, decay
                        else:
                            v.weight.requires_grad_(False)

                # Weights Nodecay
                self.optimizer = optim.Adam(pg0, lr=self.lr)
                
                # Weights Decay
                self.optimizer.add_param_group({'params': pg1, 'weight_decay': self.weight_decay})
                
                # Biasese NoDecay
                self.optimizer.add_param_group({'params': pg2})

                del pg0, pg1, pg2

            else:   
                self.optimizer = optim.Adam(
                    params_v,
                    lr=self.lr,
                    weight_decay=self.weight_decay,
                )
            
        else:
            raise Exception(f'Un implemented optimizer: {self.optimizer_name}')
        
    
        self.optimizer.zero_grad()
        self.opt_step = 0
        
        if self.use_scheduler:
            self._build_scheduler()
            
        
        if self.use_yolo:
            self.yolo_compute_loss = ComputeLoss(
                self.backbone,
                autobalance=False)
        
        self.get_trainable_weights(verbose=True)
        return self.optimizer
    
    
    @torch.jit.ignore
    def _build_scheduler(self, LAST_EPOCH=None):
        
        self.lr_lf = scheduler_lambda(
            lr_frac=self.lr_prop,
            warmup_epochs=self.n_warmup_epochs,
            cos_decay_epochs=self.n_decay_epochs)

        self.scheduler = lr_scheduler.LambdaLR(
            self.optimizer,
            lr_lambda=self.lr_lf)
    
        if LAST_EPOCH is not None:
            print(' - Scheduler, Setting last_epoch =', LAST_EPOCH)
            self.scheduler.last_epoch = LAST_EPOCH-1
            self.scheduler.step()
            
        return None
        
        
        
    @torch.jit.ignore
    def scheduler_step(self, verbose=True):
        if self.use_scheduler:
            self.scheduler.step()
            
        else:
            print(' - WARNING: scheduler not configured.', file=sys.stderr)
            
        lr = self.get_lr()
        if self.use_scheduler and verbose:
            print(f' - Scheduler: LR={lr:0.02e}')
        return lr
    
    
    @torch.jit.ignore
    def yolo_check_anchors(self, wh_v, anchors=None, thr=4.0):
        imgsz = max(self.model_resolution)

        if anchors is None:
            anchors = self.yolo_get_anchors()


        def metric(k, wh):  # compute metric
            r = wh[:, None] / k[None]
            x = np.minimum(r, 1. / r).min(axis=2)  # ratio metric
            best = x.max(1)  # best_x
            aat = (x > 1. / thr).sum(axis=1).mean()  # anchors above threshold
            bpr = (best > 1. / thr).mean()  # best possible recall
            return bpr, aat

        bpr, aat = metric( anchors, wh_v )

        ret_d = {
            'best_possible_recall':bpr,
            'anchors_above_threshold':aat
        }
        return ret_d


    @torch.jit.ignore
    def yolo_get_anchors(self):
        assert self.use_yolo

        # Not implemented if Data Parallel
        m = self.backbone.model[-1]

        anchor_grid = m.anchor_grid.clone().cpu().view(-1, 2).numpy()
    #     anchors = m.anchors.clone().cpu().numpy()
    #     stride = m.stride.clone().cpu().numpy()

    #     ret_d = {
    #         'anchor_grid': anchor_grid,
    #         'anchors': anchors,
    #         'stride': stride,
    #     }

        return anchor_grid

    @torch.jit.ignore
    def yolo_set_anchors(self, new_anchors):
        assert self.use_yolo

        # Not implemented if Data Parallel
        m = self.backbone.model[-1]

        def check_anchor_order(m):
            # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
            a = m.anchor_grid.prod(-1).view(-1)  # anchor area
            da = a[-1] - a[0]  # delta a
            ds = m.stride[-1] - m.stride[0]  # delta s
            if da.sign() != ds.sign():  # same order
                print('Reversing anchor order')
                m.anchors[:] = m.anchors.flip(0)
                m.anchor_grid[:] = m.anchor_grid.flip(0)


        new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
        m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid)  # for inference
        m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1)  # loss

        check_anchor_order(m)
        return None


    @torch.jit.ignore
    def yolo_kmean_anchors(self, wh_v,thr=4.0, gen=1000, verbose=False, do_plot=False):
        from scipy.cluster.vq import kmeans

        """ Creates kmeans-evolved anchors from training dataset

            Arguments:
                wh_v: labels wh
                thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
                gen: generations to evolve anchors using genetic algorithm
                verbose: print all results

            Return:
                k: kmeans evolved anchors

            Usage:
                from utils.autoanchor import *; _ = kmean_anchors()
        """
        
        assert self.use_yolo

        # Not implemented if Data Parallel
        m = self.backbone.model[-1]
        n = m.anchor_grid.numel() // 2  # number of anchors

        
        img_size = max(self.model_resolution)

        thr = 1. / thr
        prefix = 'autoanchor: '

        def metric(k, wh):  # compute metrics
            r = wh[:, None] / k[None]
            x = np.minimum(r, 1./r).min(axis=2)  # ratio metric
            # x = wh_iou(wh, np.array(k))  # iou metric
            return x, x.max(axis=1)  # x, best_x

        def anchor_fitness(k):  # mutation fitness
            _, best = metric(np.array(k), wh)
            return (best * (best > thr)).mean()  # fitness

        def print_results(k):

            x, best = metric(k, wh0)
            bpr, aat = (best > thr).mean(), (x > thr).mean() * n  # best possible recall, anch > thr
            print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
            print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
                  f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')

            for i, x in enumerate(k):
                print('%i,%i' % (round(x[0]), round(x[1])), end=',  ' if i < len(k) - 1 else '\n')  # use in *.cfg
            return k


        # Get label wh
        wh0 = wh_v.copy()  # wh

        # Filter
        i = (wh0 < 3.0).any(1).sum()
        if i and verbose:
            print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')

        wh = wh0[(wh0 >= 2.0).any(1)]  # filter > 2 pixels
        # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1)  # multiply by random scale 0-1

        # Kmeans calculation
        if verbose:
            print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')

        s = wh.std(0)  # sigmas for whitening
        k, dist = kmeans(wh / s, n, iter=30)  # points, mean distance
        k *= s
        wh = np.array(wh)  # filtered
        wh0 = np.array(wh0)  # unfiltered

        k = k[np.argsort(k.prod(1))]  # sort small to large
        if verbose:
            k = print_results(k)

        # Evolve
        npr = np.random
        f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1  # fitness, generations, mutation prob, sigma
        if verbose:
            pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:')  # progress bar
        else:
            pbar = range(gen)
        for _ in pbar:
            v = np.ones(sh)
            while (v == 1).all():  # mutate until a change occurs (prevent duplicates)
                v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
            kg = (k.copy() * v).clip(min=2.0)
            fg = anchor_fitness(kg)
            if fg > f:
                f, k = fg, kg.copy()
                if verbose:
                    pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
                    print_results(k)

        ret_k = k[np.argsort(k.prod(1))]  # sort small to large

        if verbose:
            print_results(ret_k)

        # Plot
        if do_plot:
            k_v, d_v = [None] * 20, [None] * 20
            for i in range(1, 21):
                k, d = kmeans(wh / s, i)  # points, mean distance

                k_v[i-1] = s * k
                d_v[i-1] = s * d

            fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
            for n_k, k in enumerate(k_v):
                ax[0].plot(k[:,0], k[:,1], 'o', label=f'k={n_k+1}')

            ax[0].set_xlabel('W sentroid')
            ax[0].set_ylabel('H sentroid')
            ax[0].set_xlim( (0, img_size) )
            ax[0].set_ylim( (0, img_size) )
            ax[0].set_title('Points: k_Mean')

            ax[1].plot(np.arange(1,21), d_v, marker='.')
            ax[1].set_xlabel('Number of sentroids')
            ax[1].set_ylabel('ErrMeanDistance')
            ax[1].set_title('Points: ErrMeanDistance')

            plt.show()


            # plot wh
            fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
            ax[0].hist(wh_v[:, 0], bins=100, range=(0,img_size))
            ax[0].set_title('W dist')
            ax[1].hist(wh_v[:, 1], bins=100, range=(0,img_size))
            ax[1].set_title('H dist')

            plt.show()


        return ret_k.astype(np.float32)

    @torch.jit.ignore
    def _build_backbone(self):
        
        if 'efficientdet' in self.backbone_name:
            self.use_effdet = True
            
            self.backbone_config = get_efficientdet_config(self.backbone_name)
            self.backbone_config.num_classes = self.n_classes
            self.backbone_config.image_size  = tuple([int(i) for i in self.model_resolution])
            self.backbone_config.extra_variables = self.n_extras
            self.backbone_config.extra_loss_weight = self.extra_loss_weight

            self.backbone = EfficientDet(
                self.backbone_config,
                pretrained_backbone=self.use_pretrained_model)

            if self.n_input_channels != 3:
                first_conv = self.backbone.backbone.conv_stem

                # Updating first ConvHead
                self.backbone.backbone.conv_stem = nn.Conv2d(
                    self.n_input_channels,
                    first_conv.out_channels,
                    kernel_size=first_conv.kernel_size,
                    stride=first_conv.stride,
                    padding=first_conv.padding,
                    bias=False,
                )

                # Deleting unused conv
                del(first_conv)

            self.net_labeler_train   = DetBenchTrain(self.backbone)
            self.net_labeler_predict = DetBenchPredict(self.backbone)
            
        elif self.backbone_name == 'fasterrcnn_resnet50_fpn':
                self.backbone = torchvision.models.detection.fasterrcnn_resnet50_fpn(
                    pretrained=False,
                    progress=True,
                    num_classes=self.n_classes+1,
                    pretrained_backbone=self.use_pretrained_model,
                    trainable_backbone_layers=self.trainable_backbone_layers)
            
        elif 'fasterrcnn_' in self.backbone_name:

            self.layers_d = {}

            timm_backbone = timm.create_model(
                model_name=self.backbone_name.replace('fasterrcnn_', ''),
                pretrained=self.use_pretrained_model,
                num_classes=0,
                in_chans=self.n_input_channels,
            )


            assert timm_backbone.global_pool is not None, 'ERROR, there is not a global_pool layer'
            timm_backbone.global_pool = nn.Identity()        

            self.return_layers = OrderedDict()
            self.in_channels_list = []
            featmap_names = []
            i_l = 0
            for layer_info_d in timm_backbone.feature_info:
                layer_name = layer_info_d['module']
                layer_chs = layer_info_d['num_chs']

                if 'layer' not in layer_name:
                    continue

                self.in_channels_list.append(layer_chs)
                self.return_layers[layer_name] = f'{i_l}'
                featmap_names.append(f'{i_l}')

                i_l += 1

            featmap_names.append('pool')
            assert len(featmap_names)  > 0

            self.n_featmap_layers = len(featmap_names)

            N_OUT_CHANNELS_FPN = 256

            fpn_backbone = BackboneWithFPN(
                timm_backbone, 
                return_layers=self.return_layers,
                in_channels_list=self.in_channels_list, 
                out_channels=N_OUT_CHANNELS_FPN,
                extra_blocks=None)

            
            if self.anchors_d is None:
                ANCHOR_SIZES  = ( (16,), (32,), (64,), (128,), (256,) )
#                 ANCHOR_SIZES  = ( (8,16), (32,64), (64,128), (128,256), (256,448) )
                ASPECT_RATIOS = ( (0.33, 0.5, 1.0, 2.0, 3.0), ) * len(ANCHOR_SIZES)
            else:
                ANCHOR_SIZES = self.anchors_d['ANCHOR_SIZES']
                ASPECT_RATIOS = self.anchors_d['ASPECT_RATIOS']
                
                print(' - Using custom anchors: ')
                print(' |-> ANCHOR_SIZES:',  ANCHOR_SIZES)
                print(' |-> ASPECT_RATIOS:', ASPECT_RATIOS)
                
            
            anchor_generator = AnchorGenerator(
                    sizes=ANCHOR_SIZES,
                    aspect_ratios=ASPECT_RATIOS,
                )

            roi_pooler = torchvision.ops.MultiScaleRoIAlign(
                featmap_names=featmap_names,
                output_size=7,
                sampling_ratio=2)

            self.backbone = FasterRCNN(
                fpn_backbone,
                num_classes=self.n_classes+1,
                rpn_anchor_generator=anchor_generator,
                box_roi_pool=roi_pooler)


            self.layers_d['timm_backbone'] = timm_backbone
            self.layers_d['fpn_backbone'] = fpn_backbone
            self.layers_d['anchor_generator'] = anchor_generator
            self.layers_d['roi_pooler'] = roi_pooler


        elif 'yolo' in self.backbone_name:
            self.use_yolo = True
            self.backbone = torch.hub.load(
                'ultralytics/yolov5',
                self.backbone_name,
                pretrained=self.use_pretrained_model,
                channels=self.n_input_channels,
                classes=self.n_classes, #+1, # First Class will be background
                autoshape=False,
            )

            # Model parameters
            nc = self.n_classes
            imgsz = max(self.model_resolution)

            nl = self.backbone.model[-1].nl

            hyp = {}
            hyp['box'] = 0.03 * 100
            hyp['cls'] = 0.25 * 100
            hyp['obj'] = 1.00 * 100

            hyp['cls_pw'] = 0.60
            hyp['obj_pw'] = 0.90

            hyp['fl_gamma'] = 3.0

            hyp['box'] *= 3. / nl  # scale to layers
            hyp['cls'] *= nc / 80. * 3. / nl  # scale to classes and layers
            hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl  # scale to image size and layers
            hyp['anchor_t'] = 2.0

            self.backbone.hyp = hyp
            self.backbone.gr = 1.0
            
            
            for i, (n,p) in enumerate(self.named_parameters()):
                p.requires_grad_(True)
                
                
                
        else:
            raise Exception('Unknown backbone name: "{}"'.format(self.backbone_name))

        return None
    
    
    
    def forward(self, *args):
        return self.backbone(*args)
    
    @torch.jit.ignore
    def get_epoch_from_ckpt_path(self, ckpt_path):
        epoch = None
        for s in os.path.split( ckpt_path )[-1].split('_'):
            if len(s) > 0 and s[0].upper() == 'E':
                try:
                    epoch = int(s[1:])
                    
                    break
                except:
                    continue

        return epoch
    
    @torch.jit.ignore
    def find_last_saved_ckpt(self):
        all_path_v = []
        for ckpt_path in glob.glob(os.path.join(self.checkpoint_base_path, '*.ckpt')):
            
            epoch = self.get_epoch_from_ckpt_path(ckpt_path)
            all_path_v.append( (epoch, ckpt_path) )

        if len(all_path_v) > 0:
            all_path_v.sort(key=lambda l:l[0])
            return all_path_v[-1]
        else:
            return None, None
        
        
        
    @torch.jit.ignore
    def save_checkpoint(
        self,
        step=0,
        loss=None,
        file_name='weights.ckpt',
        verbose=True,
        ):
        
        
        base_path = self.checkpoint_base_path
        
        PATH = os.path.join(base_path, file_name)
        
        torch.save({
            'step': step,
            'model_state_dict':     self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict() if self.optimizer is not None else None,
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler is not None else None,
            'loss': loss,
            }, PATH)
        
        if verbose:
            print(f' Saved checkpoint: {PATH}.')
            
        return PATH
    
    @torch.jit.ignore
    def restore_checkpoint(
        self, 
        PATH=None, 
        load_optimizer=True,
        load_scheduler=True,
        verbose=True):
        
        
        if PATH is None:
            print('Restoring last epoch.')
            LAST_EPOCH, PATH = model.find_last_saved_ckpt()
            
        else:
            LAST_EPOCH = None
        
        
        checkpoint = torch.load(
            PATH,
            map_location=self.device,)
        

        if 'model_state_dict' in checkpoint.keys():
            saved_state_dict = checkpoint['model_state_dict']
            
        else:
            saved_state_dict = checkpoint
            
        current_state_dict = self.state_dict()
        new_state_dict = OrderedDict()
        for key in current_state_dict.keys():
            if (key in saved_state_dict.keys()) and (saved_state_dict[key].shape == current_state_dict[key].shape):
                new_state_dict[key] = saved_state_dict[key]

            else:
                load_optimizer = False
                if key not in saved_state_dict.keys():
                    print(f' - WARNING: key="{key}" not found in saved checkpoint.\n   Weights will not be loaded.', file=sys.stderr)
                else:
                    s0 = tuple(saved_state_dict[key].shape)
                    s1 = tuple(current_state_dict[key].shape)
                    print(f' - WARNING: shapes mismatch in "{key}": {s0} vs {s1}.\n   Weights will not be loaded.', file=sys.stderr)
                new_state_dict[key] = current_state_dict[key]
        
        self.load_state_dict( new_state_dict )
        
        if self.optimizer is not None:
            if load_optimizer:
                try:
                    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                except Exception as e:
                    print(' - WARNING: ERROR while loading the optimizer. The optimizer will be reseted.', file=sys.stderr)
                    self.build_optimizer()
            else:
                print(' - WARNING: Optimizer will not be loaded.', file=sys.stderr)
                
                
                
        if self.scheduler is not None:
            if load_scheduler:
                try:
                    self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                except Exception as e:
                    print(' - WARNING: ERROR while loading the scheduler. The Scheduler will be reseted.', file=sys.stderr)
                    
                    if LAST_EPOCH is None:
                        LAST_EPOCH = self.get_epoch_from_ckpt_path(PATH)
                        
                    self._build_scheduler(LAST_EPOCH=LAST_EPOCH)
                    
            else:
                print(' - WARNING: Scheduler will not be loaded.', file=sys.stderr)
                
        if verbose:
            print(f' - Restored checkpoint: {PATH}.')
        
        return checkpoint 
    
    @torch.jit.ignore
    def calc_total_weights(self, verbose=True):
        n_w = 0
        for p in self.parameters():
            n_w += np.prod(p.shape, dtype=np.int)
        
        if verbose:
            print(' - Total weights: {:0.02f}M'.format(n_w/1e6))
        
        return n_w
    
    
    @torch.jit.export
    def flip(self, tensor, dim=1):
        """ Just flip a tensro dim."""
        fliped_idx    = torch.arange(tensor.size(dim)-1, -1, -1).long().to(self.device)
        fliped_tensor = tensor.index_select(dim, fliped_idx)
        return fliped_tensor
    
    @torch.jit.ignore
    def set_device(self, device=None):
        if device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        elif type(device) == str:
            self.device = torch.device( device )
        else:
            raise Exception('Not implemented')
            
        print(f' - Selecting device: {self.device}')
        return self.device
    
    
    @torch.jit.ignore
    def unscale_bboxes(self, bboxes, original_shape):
        
        model_img_h, model_img_w = self.model_resolution
        orig_img_h, orig_img_w   = original_shape

        scale_h = model_img_h / orig_img_h
        scale_w = model_img_w / orig_img_w

        # x0, y0, x1, y1
        scale_v = np.array([scale_w, scale_h, scale_w, scale_h])

        bboxes = bboxes / scale_v
        
        return bboxes
        
        
    @torch.jit.ignore
    def clean_bboxes(self, pred_d, original_shape=None):
        
        EPS = 1e-8
        if len(pred_d['bbox']) == 0:
            return None
            
        if original_shape is not None:
            y_max, x_max = original_shape
            
        else:
            y_max, x_max = self.model_resolution 
        

        pred_d['bbox'] = np.clip(
            pred_d['bbox'],
            np.array([EPS, EPS, EPS, EPS], dtype=pred_d['bbox'].dtype),
            np.array([x_max-EPS, y_max-EPS, x_max-EPS, y_max-EPS], dtype=pred_d['bbox'].dtype) )

        f = ((pred_d['bbox'][:, 2:] - pred_d['bbox'][:, :2]) <= 0.0).any(axis=-1)
        if f.any():
            nf = ~f
            pred_d['bbox']  = pred_d['bbox'][nf]
            pred_d['cls']   = pred_d['cls'][nf]
            pred_d['p_det'] = pred_d['p_det'][nf]

        return None
    
    
    @torch.jit.ignore
    def get_bboxes(self, outputs, det_th, data=None, unscale_bboxes=False):
        preds_v = []
        
        if data is not None:
            assert len(data['sample_id']) == outputs.shape[0]
            
        for i_sample, img_pred in enumerate(outputs.detach().cpu().numpy()):
            
            bboxes, p_det, idx_class = np.split(img_pred, [4,5], axis=-1)
            f_det = p_det[:,0] > det_th
            
            if unscale_bboxes:
                original_shape = data['original_shape'][i_sample]
                bboxes = self.unscale_bboxes(bboxes, original_shape)
            else:
                original_shape = None
            
            
            pred_d = {
                'bbox': bboxes[f_det],
                'p_det': p_det[f_det, 0],
                'cls':  idx_class[f_det, 0].astype(int) - 1, # We must substract 1 to the class number
            }
            
            if self.do_bbox_cleaning:
                self.clean_bboxes(
                    pred_d,
                    original_shape
                )
            
            preds_v.append(pred_d)
            
        return preds_v
    
    
    
    def xyxy2nxywh(self, x, ret_array=None):
        if ret_array is None:
            x_ret = torch.zeros(x.shape, dtype=torch.float32, device=self.device)
        else:
            assert (ret_array.shape == x.shape)
            x_ret = ret_array


        x_ret[:,0] = (x[:,0] + x[:,2]) / (2*self.model_resolution[1])
        x_ret[:,1] = (x[:,1] + x[:,3]) / (2*self.model_resolution[0])
        x_ret[:,2] = (x[:,2] - x[:,0]) / self.model_resolution[1]
        x_ret[:,3] = (x[:,3] - x[:,1]) / self.model_resolution[0]

        return x_ret 

    def nxywh2xyxy(self, x, ret_array=None):
        if ret_array is None:
            x_ret = torch.zeros(x.shape, dtype=torch.float32, device=self.device)
        else:
            assert (ret_array.shape == x.shape)
            x_ret = ret_array

        x_ret[:,0] = (x[:,0] - 0.5 * x[:,2]) * self.model_resolution[1]
        x_ret[:,1] = (x[:,1] - 0.5 * x[:,3]) * self.model_resolution[0]
        x_ret[:,2] = (x[:,0] + 0.5 * x[:,2]) * self.model_resolution[1]
        x_ret[:,3] = (x[:,1] + 0.5 * x[:,3]) * self.model_resolution[0]

        return x_ret
    
    
    def yolo_build_target(self, data):
        sample_icxywh_v = []
        for i_s, (bboxes, cls) in enumerate( zip(data['bboxes'], data['cls']) ):
            f = (cls >= 0)  # filtering background
            bboxes = bboxes[f]
            cls    = cls[f]
            
            n_boxes = len(bboxes)
            
            sample_icxywh = torch.empty(
                (n_boxes, 6),
                dtype=torch.float32,
                device=self.device,
            )
            sample_icxywh[:, 0] = i_s
            sample_icxywh[:, 1] = cls # + 1 # Class 0 will be background

            _ = self.xyxy2nxywh(
                bboxes,
                ret_array=sample_icxywh[:,2:]
            )

            sample_icxywh_v.append(sample_icxywh)
        
        target = torch.vstack(sample_icxywh_v)
        return target



    @torch.jit.ignore
    def predict(
        self,
        data,
        det_th=0.4, 
        output_losses=False,
        training=False,
        filter_boxes=True,
        unscale_bboxes=False,
        yolo_iou_thres=1.0,
    ):
        
        if training:
            self.train()
            torch.set_grad_enabled(True)
            
        else:
            self.eval()
            torch.set_grad_enabled(False)
        
        if type(data['image']) is not torch.Tensor:
            data = data2tensor(data, device=self.device)
            
        else:
            for k in data:
                if data[k] is torch.Tensor:
                    data[k].to(self.device)
                    
            
        images = data['image']
        if self.use_effdet:
            if output_losses:
                target_d = {
                    'bbox': [x[:, [1,0,3,2]] for x in data['bboxes']],
                    'cls':  [x + 1 for x in data['cls']],   # We must sum 1 to the class number
                    'img_scale': None,
                    'img_size': None,
                }

                if self.n_extras > 0:
                    target_d['extra'] = [x for x in data['extra']]

                # dict with: 'loss', 'class_loss', 'box_loss'
                outputs = self.net_labeler_train.forward(
                    images,
                    target_d
                )

                if not training and filter_boxes:
                    outputs['detections'] = self.get_bboxes(outputs['detections'], det_th, data, unscale_bboxes)

            else:
                outputs = self.net_labeler_predict.forward(
                    images,
                )

                if filter_boxes:
                    outputs = self.get_bboxes(outputs, det_th, data, unscale_bboxes)
        
        elif self.use_yolo:
            outputs = None
            if output_losses:
                if not self.training:
                    self.train()
                    
                yolo_target = self.yolo_build_target(data)
                yolo_out = self(images)
                
                loss, (lbox, lobj, lcls, loss2) = self.yolo_compute_loss(
                    yolo_out,
                    yolo_target)
                
                outputs = {
                    'loss': loss,
                    'class_loss': lcls,
                    'box_loss': lbox + lobj,
                }
            
            if filter_boxes:
                if self.training:
                    self.eval()
                
                yolo_out = self(images)
                
                pred_v = non_max_suppression(
                    yolo_out[0],
                    conf_thres=det_th,
                    iou_thres=yolo_iou_thres,
                )
                
                outputs_v = []
                for i_sample, p_matrix in enumerate(pred_v):
                    p_matrix = p_matrix.detach().cpu().numpy()
                    
                    
                    bbox   = p_matrix[:,:4]
                    scores = p_matrix[:,4]
                    cls    = p_matrix[:,5].astype(np.int) # - 1 # Class 0 will be background
                    
                    
                    if unscale_bboxes:
                        original_shape = data['original_shape'][i_sample]
                        bbox = self.unscale_bboxes(
                            bbox,
                            original_shape)
                        
                    else:
                        original_shape = None
                
                    f = (scores > det_th) # * (cls != -1)
                    
                    pred_d = {
                            'bbox' : bbox[f],
                            'cls': cls[f], 
                            'p_det': scores[f],
                        }
                    
                    if self.do_bbox_cleaning:
                        self.clean_bboxes(
                            pred_d,
                            original_shape,
                        )
                    
                    outputs_v.append(
                        pred_d
                    )
                    
                    
                if outputs is None:
                    outputs = outputs_v
                    
                else:
                    outputs['detections'] = outputs_v
                    
            
        else:
            outputs = None
            
            if output_losses:
                if not self.training:
                    self.train()
                    
                target = []
                for cls, bboxes in zip( data['cls'], data['bboxes'] ):
                    target.append(
                        {
                            'boxes': bboxes,
                            'labels': cls + 1,# We must sum 1 to the class number

                        }
                    )
    
                loss_d = self(images, target)

                outputs = {
                    'loss': loss_d['loss_classifier'] + loss_d['loss_objectness'] + loss_d['loss_box_reg'] + loss_d['loss_rpn_box_reg'],
                    'class_loss': loss_d['loss_classifier'],
                    'box_loss': loss_d['loss_objectness'] + loss_d['loss_box_reg'] + loss_d['loss_rpn_box_reg'],
                }
            
                
            if filter_boxes:
                if self.training:
                    self.eval()
                    
                pred_v = self(images)
                
                outputs_v = []
                for i_sample, pred_d in enumerate(pred_v):
                    bbox   = pred_d['boxes'].detach().cpu().numpy()
                    scores = pred_d['scores'].detach().cpu().numpy()
                    cls    = pred_d['labels'].detach().cpu().numpy()- 1 # We must substract 1 to the class number
                    
                    
                    if unscale_bboxes:
                        original_shape = data['original_shape'][i_sample]
                        bbox = self.unscale_bboxes(
                            bbox,
                            original_shape)
                        
                    else:
                        original_shape = None
                
                    f = scores > det_th
                    
                    pred_d = {
                            'bbox' : bbox[f],
                            'cls': cls[f],
                            'p_det': scores[f],
                        }
                    
                    if self.do_bbox_cleaning:
                        self.clean_bboxes(
                            pred_d,
                            original_shape,
                        )
                    
                    outputs_v.append(
                        pred_d
                    )
                    
                    
                        
                if outputs is None:
                    outputs = outputs_v
                    
                else:
                    outputs['detections'] = outputs_v
                
        return outputs
    
    
    @torch.jit.ignore
    def predict_TTA(
        self,
        data,
        det_th=0.00,
        unscale_bboxes=True,
        TTA_clean_iou_th=0.2,
        TTA_clean_mode='median_pmean',
        TTA_max_angle=5,
        TTA_delta_angle=1,
    ):


        output_v = []
        for i_b in range(data['image'].shape[0]):
            tta_pred_v = []
            model_id = 0
            for angle in np.arange(0, TTA_max_angle+1, TTA_delta_angle):
                for do_flop in [True, False]:
                    aug_v = []
                    un_aug_v = []
                    if do_flop:
                        aug_v.append( A.HorizontalFlip(p=1.0, always_apply=True) )

                    aug_v.append( A.Rotate(p=1.0, limit=(angle,  angle), always_apply=True) )
                    aug = A.Compose(
                        aug_v,
                        p=1.0)

                    un_aug_v.append( A.Rotate(p=1.0, limit=(-angle,  -angle), always_apply=True) )
                    if do_flop:
                        un_aug_v.append( A.HorizontalFlip(p=1.0, always_apply=True) )

                    un_aug = A.Compose(
                        un_aug_v,
                        bbox_params=A.BboxParams(
                            format='pascal_voc',
                            min_area=0, 
                            min_visibility=0,
                            label_fields=['cls', 'p_det']),
                        p=1.0)


                    tta_image = aug(
                        image=data['image'][i_b].transpose((1,2,0))
                    )['image'].transpose((2,0,1))[None,...]

                    data_tta = {
                        'sample_id': data['sample_id'][i_b:i_b+1],
                        'original_shape': data['original_shape'][i_b:i_b+1],
                        'image': tta_image,
                    }

                    aug_pred_v = self.predict(
                        data_tta,
                        det_th=det_th,
                        output_losses=False,
                        training=False,
                        filter_boxes=True,
                        unscale_bboxes=unscale_bboxes,
                    )


                    orig_h, orig_w = data['original_shape'][i_b]

                    data_untta = un_aug(
                        image=np.ones( (orig_h, orig_w, 3) ),
                        bboxes=aug_pred_v[0]['bbox'],
                        cls=aug_pred_v[0]['cls'],
                        p_det=aug_pred_v[0]['p_det'],
                    )


                    tta_pred_v.append(
                        {
                            'bbox':np.array(data_untta['bboxes']),
                            'cls':np.array(data_untta['cls']),
                            'p_det':np.array(data_untta['p_det']),
                            'model_id': model_id * np.ones(len(data_untta['cls']), dtype=np.int)
                        }
                    )

                    model_id += 1

            tta_pred_d = {}
            for k in tta_pred_v[0].keys():
                tta_pred_d[k] = np.concatenate( [pred_d[k] for pred_d in tta_pred_v if len(pred_d[k]) > 0], axis=0)

            output_v.append(tta_pred_d)

        output_v = clean_predictions(
            output_v,
            iou_th=TTA_clean_iou_th,
            mode=TTA_clean_mode,
            consensus_level=1,
            n_models2ensemble=model_id,
        )

        output_v = filter_det_th(
            output_v,
            det_th)

        return output_v
    
    @torch.jit.ignore
    def train_step(self, data):
        
        
        outputs = self.predict(
            data=data,
            output_losses=True,
            training=True,
            filter_boxes=False
        )
        
        loss = outputs['loss']
        
        if self.n_steps_grad_update != 1:
            loss = loss / self.n_steps_grad_update
        
        # Backward pass
        loss.backward()
        
        if self.clip_grad_norm > 0.0:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(),
                self.clip_grad_norm)
        
        if self.opt_step != 0 and (self.opt_step % self.n_steps_grad_update) == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
        
#         trn_batch_loss = loss.item()
        
        self.opt_step += 1
        return outputs


    @torch.jit.ignore
    def put_boxes(self, preds_v, data, do_plot=False):
        BS = len( preds_v )

        assert BS == data['image'].shape[0], 'Wrong Batch Size'


        box_image_v = []
        for i_b in range(BS):

            pred_d = preds_v[i_b]

            image = data['image'][i_b].cpu().numpy().transpose([1,2,0])
            box_image = (image * 255)


            if len(box_image.shape) == 3 and box_image.shape[-1] > 3:
                box_image = box_image[:,:,:3]

            elif len(box_image.shape) == 3:
                box_image = box_image[:,:,0]

            if len(box_image.shape) == 2:
                box_image = box_image[:,:, None] * np.ones(3)

            box_image = box_image.copy().astype(np.uint8)

            bboxes_v = data['bboxes'][i_b]
            cls_v = data['cls'][i_b]
            if type(bboxes_v) == torch.Tensor:
                bboxes_v = bboxes_v.detach().cpu().numpy()
            
            if type(cls_v) == torch.Tensor:
                cls_v = cls_v.detach().cpu().numpy()
                
            for bbox, idx_class in zip( bboxes_v, cls_v):
                (x0, y0, x1, y1) = bbox.astype(np.int)
                
                box_image = cv2.rectangle(
                    box_image,
                    (x0,y0),
                    (x1,y1),
                    class2color_v[idx_class],
                    thickness=2,
                )

#                 box_image = cv2.putText(
#                     box_image,
#                     class2str_v[idx_class] + '(GT)',
#                     org=(x0, y0-3),
#                     fontFace=cv2.FONT_HERSHEY_SIMPLEX,
#                     fontScale=0.3,
#                     color=class2color_v[idx_class],
#                     thickness=1,
#                     lineType=cv2.LINE_AA,
#                     bottomLeftOrigin=False
#                     )      

            
            for bbox, idx_class, p_det in zip( pred_d['bbox'], pred_d['cls'], pred_d['p_det']):

                (x0, y0, x1, y1) = bbox.astype(np.int)

                box_image = cv2.rectangle(
                    box_image,
                    (x0,y0),
                    (x1,y1),
                    class2color_v[idx_class],
                    thickness=1,
                    )

                box_image = cv2.putText(
                    box_image,
                    class2str_v[idx_class] + f'({p_det:0.1f})',
                    org=(x0, y0-3),
                    fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=0.3,
                    color=class2color_v[idx_class],
                    thickness=1,
                    lineType=cv2.LINE_AA,
                    bottomLeftOrigin=False
                    )
            

            box_image_v.append(box_image)

            if do_plot:
                plt.figure(0, figsize=(20,20) )
                plt.imshow(box_image)
                plt.show()

        return np.array( box_image_v )
    
    
    @torch.jit.ignore
    def set_lr(self, new_lr=1e-3):
        for param_group in self.optimizer.param_groups:
            if 'lr' in param_group.keys():
                param_group['lr'] = new_lr
        
        self.lr = new_lr
        
        return None
    
    
    @torch.jit.ignore
    def get_lr(self, new_lr=1e-3):
        if self.optimizer is not None:
            to_ret = None
            
            # Checking lr of optimizer.
            for param_group in self.optimizer.param_groups:
                if to_ret is None:
                    to_ret = param_group['lr']
            
#             if to_ret != self.lr:
#                 print("WARNING, optimizer's learning_rate != self.lr. (self.lr will be setted)", file=sys.stderr)
#                 self.lr = to_ret
        else:
            to_ret = self.lr
        
        return to_ret
    
    
    @torch.jit.ignore
    def _parallelize_backbone(self, devices_v=None):
        if devices_v is None:
            self.devices_v = [torch.device(i) for i in range(torch.cuda.device_count())]
        else:
            self.devices_v = [torch.device(i) for i in devices_v]
        
        
        if self.use_effdet:
            self.backbone.backbone = nn.DataParallel(self.backbone.backbone, self.devices_v)
            self.backbone.fpn = nn.DataParallel(self.backbone.fpn, self.devices_v)
            self.backbone.class_net = nn.DataParallel(self.backbone.class_net, self.devices_v)
            self.backbone.box_net = nn.DataParallel(self.backbone.box_net, self.devices_v)

            self.net_labeler_train   = nn.DataParallel( DetBenchTrain(self.backbone) , self.devices_v)
            self.net_labeler_predict = nn.DataParallel( DetBenchPredict(self.backbone), self.devices_v)
            
        else:
            self.backbone = nn.DataParallel(self.backbone, self.devices_v)
        
        print(' - DataParallel, using devices:', [f'{d.type}:{d.index}'for d in self.devices_v])
        return None
    

# Building Model

In [None]:
# backbone_name = 'yolov5l'
# backbone_name = f'fasterrcnn_resnet101d'
backbone_name = f'tf_efficientdet_d2'

N_FOLDS = 5
I_FOLD  = GLOBAL_I_FOLD
N_EPOCHS = GLOBAL_N_EPOCHS
N_WARMUP_EPOCHS = GLOBAL_N_WARMUP_EPOCHS
N_DECAY_EPOCHS  = GLOBAL_N_DECAY_EPOCHS

CLS = GLOBAL_CLS

RAD_ID = None

VERSION = 19
CONSENSUS_LEVEL = None
TEST_SPLIT = 0.05
CLEAN_BOXES = False
CLEAN_IOU_TH = 0.8

cls_nofinding = len(class2str_v) - 1

if CLS is None:
    SELECT_CLASSES = None
    CLS = 'All'
else:
    SELECT_CLASSES = [CLS, 14] #[2, 5, 6, 9, 11, 13, 14]
    
REMOVE_CLASSES = None # [cls_nofinding]

# R8 R9 R10 = 3093 samples
# R11 to R17 = 203 samples
if RAD_ID is None:
    RAD_ID  = 'RA'
    RAD_ID_FILTER_V = None
else:
    RAD_ID_FILTER_V = [RAD_ID] + [f'R{i}' for i in range(11,18)]


if SELECT_CLASSES:
    N_CLASSES = len(SELECT_CLASSES)
    
elif REMOVE_CLASSES:
    N_CLASSES = len(class2str_v) - len(REMOVE_CLASSES) + (1 if cls_nofinding in REMOVE_CLASSES else 0)
    
else:
    N_CLASSES = len(class2str_v)  # counting class 14
    
MODEL_RESOLUTION = GLOBAL_MODEL_RESOLUTION


if 'fasterrcnn' in backbone_name:
    anchors_d = get_anchos_from_cls(
        CLS,
        model_resolution=MODEL_RESOLUTION
    )
    
else:
    anchors_d = None
    

In [None]:
model_cfg_d = {
    'model_resolution': MODEL_RESOLUTION, 
    'n_input_channels': 4,
    
    'anchors_d' : anchors_d,
    
    'n_classes': N_CLASSES-1,
    'n_extras': 0,
    'extra_loss_weight':6.0,
    
    'use_pretrained_model': True,
#     'backbone_name':f'fasterrcnn_resnet{RN}_fpn',
#     'backbone_name':f'tf_efficientdet_d{D}',
#     'backbone_name':f'fasterrcnn_resnet{RN}',
    
    'backbone_name':backbone_name,
    
    'init_lr': GLOBAL_LR,
    'use_scheduler':True,
    'n_warmup_epochs':N_WARMUP_EPOCHS,
    'n_decay_epochs':N_DECAY_EPOCHS,
    'lr_prop':1e-3,
    
    'optimizer_name': 'adam',
    'n_steps_grad_update': GLOBAL_GRAD_STEPS,
    
    'start_ckpt' : None, #'resnet200d_fold0.0_best_loss_init_weights.ckpt',
    'parallelize_backbone':False,
    
    'clip_grad_norm':3.0,
    'weight_decay': 1e-5,
    'model_name':f'ModelX_V{VERSION}',
    'checkpoint_base_path':f'./{backbone_name}_F{I_FOLD}_{RAD_ID}_C{CLS}_V{VERSION}',
    
    'device': GLOBAL_DEVICE,
}

In [None]:
if True:
    N_HIST = 1000
    model  = ModelX(**model_cfg_d)

    loss_names_v = ['loss', 'box_loss', 'class_loss'] #, 'extra_loss']


    trn_fsma_d = {}
    for k in loss_names_v:
        trn_fsma_d[k] = FastSMA(
            maxlen=N_HIST,
            label='mean = ',
            print_format='0.02f',
            save_filename=os.path.join(model.checkpoint_base_path, f'{k}_trn.fsma')
        )

    val_fsma_d = {}
    for k in loss_names_v:
        val_fsma_d[k] = FastSMA(
            maxlen=N_HIST,
            label='mean = ',
            print_format='0.02f',
            save_filename=os.path.join(model.checkpoint_base_path, f'{k}_val.fsma')
        )
        
        
    %matplotlib inline

In [None]:
n_w = 0
print('Optimizable parameters:')
for i_p, (n, p) in enumerate( list( model.named_parameters() )):
    if p.requires_grad:
        print('{:4d}  {:s}  {:50s}  {:}'.format(i_p, 'OPT' if p.requires_grad else '---',  n, p.shape) )
        n_w += np.prod(p.shape)
    else:
        print('{:4d}  {:s}  {:50s}  {:}'.format(i_p, 'OPT' if p.requires_grad else '---',  n, p.shape) )
        
print(f'Total optimizable weights: {n_w/1e6:0.02f} Mw')

# Training Datasets

In [None]:
def batch_merge(samples_v):
    ret_batch_d = {k : [] for k in samples_v[0].keys()}
    
    for sample_d in samples_v:
        for k, v in sample_d.items():
            ret_batch_d[k].append( v )
        
    ret_batch_d['image'] = np.stack(ret_batch_d['image'])
    
    return ret_batch_d


def data2tensor(data, device=None, pin_memory=False):
    data['image'] = torch.tensor(data['image'], device=device, pin_memory=pin_memory)

    for k in data.keys():
        if type(data[k][0]) is np.ndarray:
            for i_s in range(len(data['image'])):
                data[k][i_s] = torch.tensor( data[k][i_s], device=device, pin_memory=pin_memory)
                
    return data


def load_fold_ds(
    i_fold,
    N_FOLDS=5,
    SELECT_CLASSES=None,
    REMOVE_CLASSES=None,
    CONSENSUS_LEVEL=2,
    TEST_SPLIT=0.1,
    CLEAN_BOXES=True,
    CLEAN_IOU_TH=0.5,
    RAD_ID_FILTER_V=None,
    TRAIN_DS_NAME=TRAIN_DS_NAME,
    DS_PATH=DS_PATH,
):
    ds_trn = FoldDataset(
        ds_path=DS_PATH,
        ds_name=TRAIN_DS_NAME,
        images_dir=os.path.join(DS_DIR, 'train'),
        model_resolution=MODEL_RESOLUTION,
        df_path=os.path.join(DS_DIR, 'train.csv'),
        
        mode='cv_trn',
        i_fold=i_fold,
        n_folds=N_FOLDS,
        test_split=TEST_SPLIT,
        
        do_augmentation=True,
        downsample_factor=2,
        remove_classes_v=REMOVE_CLASSES,
        select_classes_v=SELECT_CLASSES,
        show_warnings=False,
        do_random_shuffle=True,
        random_seed=3128,
        
        clean_boxes=CLEAN_BOXES,
        clean_mode='random',
        clean_iou_th=CLEAN_IOU_TH,
        
        consensus_level=CONSENSUS_LEVEL,
    )

    ds_val = FoldDataset(
        ds_path=DS_PATH,
        ds_name=TRAIN_DS_NAME,
        images_dir=os.path.join(DS_DIR, 'train'),
        model_resolution=MODEL_RESOLUTION,
        df_path=os.path.join(DS_DIR, 'train.csv'),
        
        mode='cv_val',
        i_fold=i_fold,
        n_folds=N_FOLDS,
        test_split=TEST_SPLIT,
        
        do_augmentation=False,
        downsample_factor=2,
        remove_classes_v=REMOVE_CLASSES,
        select_classes_v=SELECT_CLASSES,
        show_warnings=False,
        do_random_shuffle=True,
        random_seed=3128,
        
        clean_boxes=CLEAN_BOXES,
        clean_mode='random',
        clean_iou_th=CLEAN_IOU_TH,
        
        consensus_level=CONSENSUS_LEVEL,
    )
    
    
    if RAD_ID_FILTER_V is not None and len(RAD_ID_FILTER_V) > 0:
        ds_trn.filter_radiologists(RAD_ID_FILTER_V)
        ds_val.filter_radiologists(RAD_ID_FILTER_V)
    
    return ds_trn, ds_val


def load_tst_ds(
    i_fold,
    N_FOLDS=5,
    SELECT_CLASSES=None,
    REMOVE_CLASSES=None,
    CONSENSUS_LEVEL=2,
    TEST_SPLIT=0.1,
    CLEAN_BOXES=True,
    CLEAN_IOU_TH=0.5,
    RAD_ID_FILTER_V=None,
    TRAIN_DS_NAME=TRAIN_DS_NAME,
    DS_PATH=DS_PATH,
):
    
    ds_tst = FoldDataset(
        ds_path=DS_PATH,
        ds_name='test',
        images_dir=os.path.join(DS_DIR, 'test'),
        model_resolution=MODEL_RESOLUTION,
        df_path=None,

        mode='none',
        i_fold=i_fold,
        n_folds=N_FOLDS,
        test_split=TEST_SPLIT,

        do_augmentation=False,
        downsample_factor=2,
        remove_classes_v=[],
        select_classes_v=None,

        show_warnings=False,
        do_random_shuffle=False,
        random_seed=3128,

        clean_boxes=CLEAN_BOXES,
        clean_mode='random',
        clean_iou_th=CLEAN_IOU_TH,

        consensus_level=None,
    )


    ds_tst_oof = FoldDataset(
        ds_path=DS_PATH,
        ds_name=TRAIN_DS_NAME,
        images_dir=os.path.join(DS_DIR, 'train'),
        model_resolution=MODEL_RESOLUTION,
        df_path='train.csv',

        mode='cv_tst',
        i_fold=i_fold,
        n_folds=N_FOLDS,
        test_split=TEST_SPLIT,

        do_augmentation=False,
        downsample_factor=2,
        remove_classes_v=REMOVE_CLASSES,
        select_classes_v=SELECT_CLASSES,
        show_warnings=False,
        do_random_shuffle=False,
        random_seed=3128,


        clean_boxes=CLEAN_BOXES,
        clean_mode='random',
        clean_iou_th=CLEAN_IOU_TH,

        consensus_level=CONSENSUS_LEVEL,
    )
    
    if RAD_ID_FILTER_V is not None and len(RAD_ID_FILTER_V) > 0:
        ds_tst_oof.filter_radiologists(RAD_ID_FILTER_V)
        
    return ds_tst, ds_tst_oof

In [None]:
ds_trn, ds_val = load_fold_ds(
    I_FOLD,
    N_FOLDS,
    SELECT_CLASSES,
    REMOVE_CLASSES,
    CONSENSUS_LEVEL,
    TEST_SPLIT,
    CLEAN_BOXES,
    CLEAN_IOU_TH,
    RAD_ID_FILTER_V,
    TRAIN_DS_NAME,
    DS_PATH,
)




ds_tst, ds_tst_oof = load_tst_ds(
    I_FOLD,
    N_FOLDS,
    SELECT_CLASSES,
    REMOVE_CLASSES,
    CONSENSUS_LEVEL,
    TEST_SPLIT,
    CLEAN_BOXES,
    CLEAN_IOU_TH,
    RAD_ID_FILTER_V,
    TRAIN_DS_NAME,
    DS_PATH,
)


# Using FATIH's VALIDATION SAMPLES

val_fatih_samples_id_v = load_obj('../input/validation-samples/fatihs_validation_samples.pickle')
trn_fatih_samples_id_v = np.array( [i_s for i_s in ds_trn.all_sample_ids if i_s not in val_fatih_samples_id_v] )

ds_trn.fold_samples = trn_fatih_samples_id_v
ds_val.fold_samples = val_fatih_samples_id_v

del(ds_tst_oof)

print(f'- TRN samples: {len(ds_trn)}' )
print(f'- VAL samples: {len(ds_val)}' )
print(f'- TST samples: {len(ds_tst)}' )

# Detector Training

In [None]:
def train_one_epoch(
    model,
    i_fold,
    i_epoch,
    ds_trn,
    trn_fsma_d,
    batch_size=1,
    shuffle=True,
    num_workers=16,
    pin_memory=False):
    
    trn_iter = tqdm(DataLoader(
        ds_trn,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=batch_merge))
    
    # Train Step
    epoch_losses_d = {k:[] for k in trn_fsma_d.keys()}
    for i_step, trn_data in enumerate(trn_iter):
        trn_loss_d = model.train_step(
            trn_data
        )
        
        for k, fsma in trn_fsma_d.items():
            l = trn_loss_d[k].item()
            fsma.append(l)
            epoch_losses_d[k].append(l)
            
            
        trn_iter.set_description('[TRN_F={:d}E={:d}_L={:0.03f}_B={:0.03f}_C={:0.03f}_X={:0.03f}]'.format(
            i_fold,
            i_epoch,
            trn_fsma_d['loss'].mean(),
            trn_fsma_d['box_loss'].mean(),
            trn_fsma_d['class_loss'].mean(),
            trn_fsma_d['extra_loss'].mean() if 'extra_loss' in trn_fsma_d.keys() else 0.0,
            ) )
        
#         if i_step == 2:
#             break
    
    for k in epoch_losses_d.keys():
        epoch_losses_d[k] = np.mean(epoch_losses_d[k])
        
    return epoch_losses_d


def validate_one_epoch(
    model,
    i_fold,
    i_epoch,
    ds_val,
    val_fsma_d,
    batch_size=1,
    shuffle=False,
    num_workers=16,
    pin_memory=False,
    
    det_th=0.00,
    clear_predictions=True,
    clear_iou_th=0.50,
    clear_mode='p_det_weight',
):
    
    gt_df = ds_val.get_GT_Dataframe()
    
    val_iter = tqdm(DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=batch_merge))

    # Validation step
    epoch_losses_d = {k:[] for k in val_fsma_d.keys()}
    
    all_preds_v = []
    for i_step, val_data in enumerate(val_iter):
        val_loss_d = model.predict(
            val_data,
            output_losses=True,
            training=False,
            det_th=det_th,
            filter_boxes=True,
            unscale_bboxes=True)
        
        pred_v = val_loss_d['detections']
        
        if clear_predictions:
            pred_v =  clean_predictions(
                pred_v,
                iou_th=clear_iou_th,
                mode=clear_mode)
            
        for i_s, pred_d in enumerate(pred_v):
            pred_d['sample_id']      = val_data['sample_id'][i_s]
            pred_d['original_shape'] = val_data['original_shape'][i_s]
            
        all_preds_v += pred_v
        # all_preds_v calculation
            
            
        for k, fsma in val_fsma_d.items():
            l = val_loss_d[k].item()
            fsma.append(l)
            epoch_losses_d[k].append(l)
            
                
                
        val_iter.set_description('[VAL_F={:d}E={:d}_L={:0.03f}_B={:0.03f}_C={:0.03f}_X={:0.03f}]'.format(
            i_fold,
            i_epoch,
            val_fsma_d['loss'].mean(),
            val_fsma_d['box_loss'].mean(),
            val_fsma_d['class_loss'].mean(),
            val_fsma_d['extra_loss'].mean() if 'extra_loss' in val_fsma_d.keys() else 0.0,
            ) )
        
        
#         if i_step == 2:
#             break
            
            
    for k in epoch_losses_d.keys():
        epoch_losses_d[k] = np.mean(epoch_losses_d[k])
    
    try:
        metrics_v, summary_df, mAP = calc_metrics(
            all_preds_v,
            gt_df,
            iou_thresh=0.4,
            show_summay=False,
            n_classes=N_CLASSES
        )

        fig = plot_PvsR_curve(metrics_v, class2color_v, False)
        
        
        
        epoch_losses_d['metrics_v'] = metrics_v
        epoch_losses_d['summary_df'] = summary_df
        epoch_losses_d['mAP'] = mAP
        epoch_losses_d['fig'] = fig
        
            
    except:
        print('Problems with calc_metrics or plot_PvsR_curve')
    
    
    return epoch_losses_d



def save_model(
    model,
    trn_fsma_d,
    val_fsma_d,
    i_epoch,
    i_fold):
        
    f_sma_trn = trn_fsma_d['loss']
    f_sma_val = val_fsma_d['loss']
    
    model_filename = f'F{i_fold:}_E{i_epoch:}_{model.model_name}_T{f_sma_trn.mean():0.03f}_V{f_sma_val.mean():0.03f}.ckpt'
    
    save_path = model.save_checkpoint(
        step={'i_epoch':i_epoch, 'i_fold': i_fold},
        loss={'trn':f_sma_trn.mean(), 'val':f_sma_val.mean()},
        file_name=model_filename,
        verbose=True,
    )
    
    for k, fsma in trn_fsma_d.items():
        fsma.save()
    
    for k, fsma in val_fsma_d.items():
        fsma.save()
    
    return save_path


class Logger:
    def __init__(
        self,
        log_folder='./folder',
        log_name='fold_cls',
        
    ):
        self.log_folder = log_folder
        self.log_name = log_name
        
        if not os.path.exists(self.log_folder):
            os.makedirs( self.log_folder )
            
        return None


    def log(
        self,
        data,
        step='',
        do_show=True):
        
        if type(data) in [list, tuple]:
            for d in data:
                self.log(
                    d,
                    step=step,
                    do_show=do_show,
                )

        elif type(data) == str:
            with open(os.path.join(self.log_folder, self.log_name + '.txt'), 'a') as f:
                f.write(data + '\n')
                
            if do_show:
                print(data)

        elif type(data) == plt.Figure:
            data.savefig(os.path.join(self.log_folder, f'S={step}'+self.log_name+'.png'))
            
            if do_show:
                data.show()
                plt.show()

        else:
            with open(os.path.join(self.log_folder, self.log_name + '.txt'), 'a') as f:
                f.write(repr(data) + '\n')
                
            if do_show:
                print(data)
        
    
        return None

# Loading Last Ckpt

In [None]:
if GLOBAL_CONTINUE_TRAINING:
    last_epoch, model_path = model.find_last_saved_ckpt()
    _ = model.restore_checkpoint(
        model_path
    )

    for k, fsma in trn_fsma_d.items():
        fsma.load()

    for k, fsma in val_fsma_d.items():
        fsma.load()

In [None]:
if not EVAL_CKPTS:
    if GLOBAL_CONTINUE_TRAINING and last_epoch is not None:
        START_EPOCH = last_epoch + 1
        model.scheduler_step()

        assert model.scheduler.last_epoch == START_EPOCH, 'Scheduler ERROR'

    else:
        START_EPOCH = 0

    N_EPOCHS = GLOBAL_N_EPOCHS
    N_WORKERS = GLOBAL_N_WORKERS
    BATCH_SIZE = GLOBAL_BATCH_SIZE
    PIN_MEMORY = True

    VAL_AFTER_EPOCH = 0
    VAL_EVERY_N_EPOCHS = 1

    # To ensure that the last epoch will be saved
    N_EPOCHS = N_EPOCHS - (N_EPOCHS%VAL_EVERY_N_EPOCHS) 

    L = Logger(
        log_folder=os.path.join(model.checkpoint_base_path, 'log'),
        log_name=f'log_CLS{CLS}_F{I_FOLD}'
    )

    torch.cuda.empty_cache()



    if model.anchors_d is not None:
        L.log(' - Using custom anchors: ', do_show=False)
        L.log(f" |-> INPUT_SHAPE: {model.model_resolution}", do_show=False)
        L.log(f" |-> ANCHOR_SIZES: {model.anchors_d['ANCHOR_SIZES']}", do_show=False)
        L.log(f" |-> ASPECT_RATIOS: {model.anchors_d['ASPECT_RATIOS']}", do_show=False)


    for i_epoch in range(START_EPOCH, N_EPOCHS + 1):
        L.log('\n - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n ')
        L.log(f'Starting: Epoch = {i_epoch}  Fold = {I_FOLD}   Class = {CLS}   LR={model.get_lr():0.02e}')


        trn_loss_epoch_d = train_one_epoch(
            model,
            I_FOLD,
            i_epoch,
            ds_trn,
            trn_fsma_d,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=N_WORKERS,
            pin_memory=PIN_MEMORY)


        show_val_sum = False
        if (i_epoch % VAL_EVERY_N_EPOCHS) == 0 and i_epoch >= VAL_AFTER_EPOCH:
            show_val_sum = True

            val_loss_epoch_d = validate_one_epoch(
                model,
                I_FOLD,
                i_epoch,
                ds_val,
                val_fsma_d,
                batch_size=BATCH_SIZE,
                shuffle=False,
                num_workers=N_WORKERS,
                pin_memory=PIN_MEMORY)        


        L.log(f' Epoch {i_epoch} Summary:')
        L.log(' - trn_loss_epoch_d:')
        for k, v in trn_loss_epoch_d.items():
            if type(v) in [float, np.float, np.float32, np.float64]:
                L.log(f'  |-> {k} = {v:0.04f}')

        if show_val_sum:
            L.log(' - val_loss_epoch_d:')
            for k, v in val_loss_epoch_d.items():
                if type(v) in [float, np.float, np.float32, np.float64]:
                    L.log(f'  |-> {k} = {v:0.04f}')

            if 'fig' in val_loss_epoch_d.keys():
                L.log(val_loss_epoch_d['fig'], step=i_epoch)

            if 'summary_df' in val_loss_epoch_d.keys():
                L.log(val_loss_epoch_d['summary_df'])



            save_path = save_model(
                model,
                trn_fsma_d,
                val_fsma_d,
                i_epoch,
                I_FOLD)

            L.log(f' - Ckpt saved: "{save_path}" \n')

        lr = model.scheduler_step(False)
        L.log(f' - Scheduler: New lr = {lr:0.02e}')


# Detector Inference

# filtering width, height and area of predictions

In [None]:
def filter_wh(
    pred_v,
    img_shape_d,
    min_wh_v,
    max_wh_v,
    min_a_v,
    max_a_v,
):
    
    pred_v = copy.deepcopy(pred_v)
    
    for pred_d in pred_v:
        if pred_d['cls'].shape[0] == 0:
            continue

        sample_id = pred_d['sample_id']
        wh_bb = (pred_d['bbox'][:,2:]-pred_d['bbox'][:,:2]) / img_shape_d[sample_id][::-1]
        a_bb = wh_bb[:,0] * wh_bb[:,1]

        f = np.ones(pred_d['cls'].shape[0], dtype=np.bool)
        for i, i_c in enumerate(pred_d['cls']):
            if (wh_bb[i] > max_wh_v[i_c]).any() \
            or (wh_bb[i] < min_wh_v[i_c]).any() \
            or(a_bb[i] > max_a_v[i_c]).any() \
            or (a_bb[i] < min_a_v[i_c]).any():
                
                f[i] = False
                

        for k in ['bbox', 'cls', 'p_det']:
            pred_d[k] = pred_d[k][f]
    
    return pred_v


def get_max_wh_v(ds, sigmas=3.0, N_CLASSES=15, do_plotting=False):
    max_wh_v = []
    min_wh_v = []
    
    max_a_v = []
    min_a_v = []
    for i_c in range(N_CLASSES-1):
        wh = ds.get_samples_wh_v(i_c) / ds.model_resolution[::-1]
        wh_mean = wh.mean(axis=0)
        wh_std = wh.std(axis=0)
        max_wh_v.append( wh_mean + sigmas * wh_std )
        min_wh_v.append( wh_mean - sigmas * wh_std )
        
        a = wh[:,0] * wh[:,1]
        a_mean = a.mean(axis=0)
        a_std = a.std(axis=0)
        max_a_v.append( a_mean + sigmas * a_std )
        min_a_v.append( a_mean - sigmas * a_std )
        
        if do_plotting:
            plt.hist(a, bins=300, label=f'c{i_c}')
            plt.legend()
            plt.grid()
            plt.show()
        
        
    max_wh_v = np.clip( np.array(max_wh_v), 0, 1)
    min_wh_v = np.clip( np.array(min_wh_v), 0, 1)
    
    max_a_v = np.clip( np.array(max_a_v), 0, 1)
    min_a_v = np.clip( np.array(min_a_v), 0, 1)
    
    
    return min_wh_v, max_wh_v, min_a_v, max_a_v




In [None]:
min_wh_v, max_wh_v, min_a_v, max_a_v = get_max_wh_v(ds=ds_trn, sigmas=3, N_CLASSES=N_CLASSES)

# Dataset Inference

In [None]:
if EVAL_CKPTS:
    for ckpt_path in CKPTS_v:
        epoch = model.get_epoch_from_ckpt_path(ckpt_path)
        
        _ = model.restore_checkpoint(ckpt_path)

        pred_v = evalueate_dataset(
                ds_val,
                model,
                det_th=0.00,
                unscale_bboxes=True,
                batch_size=5,
                num_workers=8,
                pin_memory=False,
                do_clean_predictions=False,
                clean_iou_th=0.4,
                clean_mode='p_det_weight',
                do_TTA=False,
                TTA_clean_iou_th=0.2,
                TTA_clean_mode='median_pmean',
            )

        exec(f'preds_v_val_{epoch} = pred_v')
        
        
        pred_v = evalueate_dataset(
            ds_tst,
            model,
            det_th=0.00,
            unscale_bboxes=True,
            batch_size=5,
            num_workers=8,
            pin_memory=False,
            do_clean_predictions=False,
            clean_iou_th=0.4,
            clean_mode='p_det_weight',
            do_TTA=False,
            TTA_clean_iou_th=0.2,
            TTA_clean_mode='median_pmean',
        )
    
        exec(f'preds_v_tst_{epoch} = pred_v')

In [None]:
FILTER_TH = 0.05

> # DS_VAL Ensemble

In [None]:
if EVAL_CKPTS:
    to_ensemble_v = [
        preds_v_val_82,
        preds_v_val_74,
        preds_v_val_62,
    ]

    to_ensemble_v = [filter_det_th(pred_v, FILTER_TH) for pred_v in to_ensemble_v]


    ens_val_pred_v = join_predictions(
        to_ensemble_v, 
        add_model_id=True)

    print('Ensembing:',len(to_ensemble_v), 'models')

    pred_v = ens_val_pred_v

    pred_v =  clean_predictions(
                    pred_v,
                    iou_th=0.6,
                    mode='p_det_weight_pmean',
                    consensus_level=1,
                    n_models2ensemble=len(to_ensemble_v),
    )

    pred_v = filter_wh(
        pred_v=pred_v,
        img_shape_d=ds_val.misc_d,
        min_wh_v=min_wh_v,
        max_wh_v=max_wh_v,
        min_a_v=min_a_v,
        max_a_v=max_a_v,
    )

    pred_v = norm_p_det(pred_v)

    pred_v = filter_det_th(pred_v, FILTER_TH)

    gt_df = ds_val.get_GT_Dataframe(clean_mode='p_det_max')

    metrics_v, summary_df, mAP = calc_metrics(
                pred_v,
                gt_df,
                iou_thresh=0.4,
                show_summay=True,
                n_classes=N_CLASSES
            )

    print('mAP-AP14: {:0.04f}'.format(summary_df.AP.values[:-1].sum()/15))

    _ = plot_PvsR_curve(metrics_v, class2color_v, do_show=True)
    
    
    
    preds_df = predictions_to_df(
        pred_v,
        f'ds_val_F{I_FOLD}_{RAD_ID}_V{VERSION}_{backbone_name}_EnsE82E74E62_DetTH{FILTER_TH:0.02f}_woCLS14Filter.csv')

# DS_TST Ensemble

In [None]:
if EVAL_CKPTS:
    to_ens_tst_v = [
        preds_v_tst_82,
        preds_v_tst_74,
        preds_v_tst_62,
    ]

    to_ens_tst_v = [filter_det_th(pred_v, FILTER_TH) for pred_v in to_ens_tst_v]

    ens_tst_pred_v = join_predictions(
        to_ens_tst_v, 
        add_model_id=True)

    print('TST Ensemble:',len(to_ens_tst_v), 'models')

    pred_v = ens_tst_pred_v

    pred_v =  clean_predictions(
                    pred_v,
                    iou_th=0.6,
                    mode='p_det_weight_pmean',
                    consensus_level=1,
                    n_models2ensemble=len(to_ens_tst_v),
    )

    pred_v = filter_wh(
        pred_v=pred_v,
        img_shape_d=ds_tst.misc_d,
        min_wh_v=min_wh_v,
        max_wh_v=max_wh_v,
        min_a_v=min_a_v,
        max_a_v=max_a_v,
    )

    pred_v = norm_p_det(pred_v)

    pred_v = filter_det_th(pred_v, FILTER_TH)
    
    preds_df = predictions_to_df(
        pred_v,
        f'ds_tst_F{I_FOLD}_{RAD_ID}_V{VERSION}_{backbone_name}_EnsE82E74E62_DetTH{FILTER_TH:0.02f}_woCLS14Filter.csv')
    
    
    # Applying class 14 filtering 
    pred_v = add_class_14(
        pred_v,
        pred_clf_c14_filename='../input/vinbigdata-2class-prediction/2-cls test pred.csv',
        rm_preds_high_th=False,
    )

    preds_df = predictions_to_df(
        pred_v,
        f'ds_tst_F{I_FOLD}_{RAD_ID}_V{VERSION}_{backbone_name}_EnsE82E74E62_DetTH{FILTER_TH:0.02f}_wCLS14Filter.csv')