This notebook detects 2 class objects.
- class1: helmet without impact
- class2: helmet with impact

Object Detection part is based on [EfficientDet notebook](https://www.kaggle.com/shonenkov/training-efficientdet) for [global wheat detection competition](https://www.kaggle.com/c/global-wheat-detection) by [shonenkov](https://www.kaggle.com/shonenkov), which is using [github repos efficientdet-pytorch](https://github.com/rwightman/efficientdet-pytorch) by [@rwightman](https://www.kaggle.com/rwightman).

Inference part can be foud [here](https://www.kaggle.com/its7171/2class-object-detection-inference/).

In [1]:
# !pip install ./timm-0.1.26-py3-none-any.whl
# !tar xfz ./pkgs.tgz
# for pytorch1.6
# cmd = "sed -i -e 's/ \/ / \/\/ /' timm-efficientdet-pytorch/effdet/bench.py"
# !$cmd

# Import

In [2]:
import sys
# sys.path.insert(0, "timm-efficientdet-pytorch")
sys.path.insert(0, "efficientdet-pytorch-master")
sys.path.insert(0, "omegaconf")

import torch
import os
from datetime import datetime
import time
import random
import cv2
import pandas as pd
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
from albumentations.pytorch.transforms import ToTensorV2
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from glob import glob
import pandas as pd
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain, DetBenchPredict
from effdet.efficientdet import HeadNet
from tqdm import tqdm
from IPython.core.debugger import set_trace
import warnings
from scipy.optimize import linear_sum_assignment

warnings.filterwarnings("ignore")

SEED = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)

IMG_H = 512
IMG_W = 512

## Utils

In [3]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [4]:
def post_processing(data_dict):
    new_data_dict = {}

    for video, (gt_boxes, pred_boxes) in data_dict.items():
        video_name = "_".join(video.split("_")[:2])
        view_name = video.split("_")[2]
        if view_name == "Endzone":
            view = "Sideline"
        else:
            view = 'Endzone'
        
        video_view_name = f'{video_name}_{view}'
        
        remove_boxes = []
        for box_index, box1 in enumerate(pred_boxes):
            count = 0
            for box2 in data_dict[video_view_name][1]:
                if abs(box1[0] - box2[0]) <= 4:
                    count += 1
                    
            if count == 0:
                remove_boxes.append(box_index)
        
        new_pred_boxes = []
        for box_index, box in enumerate(pred_boxes):
            if box_index not in remove_boxes:
                new_pred_boxes.append(box)
                
        new_data_dict[video] = (gt_boxes, new_pred_boxes)
        
    return new_data_dict

In [5]:
def iou(bbox1, bbox2):
    bbox1 = [float(x) for x in bbox1]
    bbox2 = [float(x) for x in bbox2]

    (x0_1, y0_1, x1_1, y1_1) = bbox1
    (x0_2, y0_2, x1_2, y1_2) = bbox2

    # get the overlap rectangle
    overlap_x0 = max(x0_1, x0_2)
    overlap_y0 = max(y0_1, y0_2)
    overlap_x1 = min(x1_1, x1_2)
    overlap_y1 = min(y1_1, y1_2)

    # check if there is an overlap
    if overlap_x1 - overlap_x0 <= 0 or overlap_y1 - overlap_y0 <= 0:
            return 0

    # if yes, calculate the ratio of the overlap to each ROI size and the unified size
    size_1 = (x1_1 - x0_1) * (y1_1 - y0_1)
    size_2 = (x1_2 - x0_2) * (y1_2 - y0_2)
    size_intersection = (overlap_x1 - overlap_x0) * (overlap_y1 - overlap_y0)
    size_union = size_1 + size_2 - size_intersection

    return size_intersection / size_union

def precision_calc(gt_boxes, pred_boxes):
    cost_matix = np.ones((len(gt_boxes), len(pred_boxes)))
    for i, box1 in enumerate(gt_boxes):
        for j, box2 in enumerate(pred_boxes):
            dist = abs(box1[0]-box2[0])
            if dist > 4:
                continue
            iou_score = iou(box1[1:], box2[1:])

            if iou_score < 0.35:
                continue
            else:
                cost_matix[i,j]=0

    row_ind, col_ind = linear_sum_assignment(cost_matix)
    fn = len(gt_boxes) - row_ind.shape[0]
    fp = len(pred_boxes) - col_ind.shape[0]
    tp=0
    for i, j in zip(row_ind, col_ind):
        if cost_matix[i,j]==0:
            tp+=1
        else:
            fp+=1
            fn+=1
    return tp, fp, fn

def f1_calc(val_data, threshold, val_type=1, post_process=False):
    with open(f'val_data_type{val_type}.txt', 'w') as writer:
        writer.write(str(val_data))
    
    data_dict = {}
    for key, value in val_data.items():
        video = "_".join(key.split(".")[0].split("_")[:-1])
        frame = int(key.split(".")[0].split("_")[-1])
        
        if video not in data_dict:
            data_dict[video] = ([], [])
        
        gt_boxes, pred_boxes = [], []
        
        if (val_type == 1 and key in impact_images) or val_type == 2:
            for ori_box in value[0]:
                box = ori_box.copy()
                box[[0,1,2,3]] = box[[1, 0, 3, 2]]

                box[0] = box[0] * 1280 / IMG_W
                box[1] = box[1] * 720 / IMG_H
                box[2] = box[2] * 1280 / IMG_W
                box[3] = box[3] * 720 / IMG_H
                gt_boxes.append([frame] + box.tolist())
            
        for ori_box, score in zip(value[1], value[2]):
            if score > threshold:
                box = ori_box.copy()
                box[0] = box[0] * 1280 / IMG_W
                box[1] = box[1] * 720 / IMG_H
                box[2] = box[2] * 1280 / IMG_W
                box[3] = box[3] * 720 / IMG_H
                box[0] = box[0].clip(min=0, max=1280-1)
                box[2] = box[2].clip(min=0, max=1280-1)
                box[1] = box[1].clip(min=0, max=720-1)
                box[3] = box[3].clip(min=0, max=720-1)
                pred_boxes.append([frame] + box.tolist())
            
        data_dict[video][0].extend(gt_boxes)
        data_dict[video][1].extend(pred_boxes)
    
    if post_process:
        data_dict = post_processing(data_dict)
    ftp, ffp, ffn = [], [], []
    
    for video, data in data_dict.items():
        gt_boxes = data[0]
        pred_boxes = data[1]
        tp, fp, fn = precision_calc(gt_boxes, pred_boxes)
        ftp.append(tp)
        ffp.append(fp)
        ffn.append(fn)

    tp = np.sum(ftp)
    fp = np.sum(ffp)
    fn = np.sum(ffn)
    precision = tp / (tp + fp + 1e-6)
    recall =  tp / (tp + fn +1e-6)
    f1_score = 2*(precision*recall)/(precision+recall+1e-6)
    
    print(f'Threshold: {threshold}, PC: {post_process}, TP: {tp}, FP: {fp}, FN: {fn}, PRECISION: {precision:.4f}, RECALL: {recall:.4f}, F1 SCORE: {f1_score}')
    return f1_score

## Config

In [6]:
class TrainGlobalConfig:
    num_workers = 8
    batch_size = 2
    n_epochs = 30
    lr = 0.0002
    score_threshold = 0.4
    folder = 'effdet5-models'
    verbose = True
    verbose_step = 1
    step_scheduler = False
    validation_scheduler = True
    gpu = 'cuda:1'
    
    SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau
    scheduler_params = dict(
        mode='min',
        factor=0.5,
        patience=1,
        verbose=False, 
        threshold=0.0001,
        threshold_mode='abs',
        cooldown=0, 
        min_lr=1e-8,
        eps=1e-08
    )

## Data Preparation

In [7]:
# video_labels = pd.read_csv('/home/thinh/nfl/train_labels.csv').fillna(0)
# video_labels = video_labels[video_labels['frame'] != 0].reset_index(drop=True)

video_labels = pd.read_csv('aug_train_labels_44.csv')
video_labels['video_name'] = video_labels['video'].apply(lambda x: "_".join(x.split("_")[:2]))
video_labels.head()

Unnamed: 0,gameKey,playID,view,video,frame,label,left,width,top,height,impact,impactType,confidence,visibility,image_name,x,y,w,h,video_name
0,57583,82,Endzone,57583_000082_Endzone.mp4,34,V73,655,21,331,15,1,0,0.0,0.0,57583_000082_Endzone_34.png,655,331,21,15,57583_000082
1,57583,82,Endzone,57583_000082_Endzone.mp4,34,H99,583,21,312,30,2,0,0.0,0.0,57583_000082_Endzone_34.png,583,312,21,30,57583_000082
2,57583,82,Endzone,57583_000082_Endzone.mp4,34,V15,1069,22,301,20,1,0,0.0,0.0,57583_000082_Endzone_34.png,1069,301,22,20,57583_000082
3,57583,82,Endzone,57583_000082_Endzone.mp4,34,H97,402,21,313,29,1,0,0.0,0.0,57583_000082_Endzone_34.png,402,313,21,29,57583_000082
4,57583,82,Endzone,57583_000082_Endzone.mp4,34,V72,445,21,328,16,1,0,0.0,0.0,57583_000082_Endzone_34.png,445,328,21,16,57583_000082


In [8]:
# valid_video_labels = pd.read_csv('/home/thinh/nfl/train_labels.csv').fillna(0)
# valid_video_labels = valid_video_labels[valid_video_labels['frame'] != 0].reset_index(drop=True)
# valid_video_labels['video_name'] = valid_video_labels['video'].apply(lambda x: "_".join(x.split("_")[:2]))
# valid_video_labels['image_name'] = valid_video_labels['video'].str.replace('.mp4', '') + '_' + valid_video_labels['frame'].astype(str) + '.png'
# valid_video_labels['impact'] = valid_video_labels['impact'].astype(int)+1
# valid_video_labels['x'] = valid_video_labels['left']
# valid_video_labels['y'] = valid_video_labels['top']
# valid_video_labels['w'] = valid_video_labels['width']
# valid_video_labels['h'] = valid_video_labels['height']
# impact_images = valid_video_labels[valid_video_labels['impact'] == 2].image_name.unique()

In [9]:
# video_labels_with_impact = video_labels[video_labels['impact'] > 0]
# for row in tqdm(video_labels_with_impact[['video','frame','label']].values):
#     frames = np.array([-4,-3,-2,-1,1,2,3,4])+row[1]
#     video_labels.loc[(video_labels['video'] == row[0]) 
#                                  & (video_labels['frame'].isin(frames))
#                                  & (video_labels['label'] == row[2]), 'impact'] = 1
    
# video_labels['image_name'] = video_labels['video'].str.replace('.mp4', '') + '_' + video_labels['frame'].astype(str) + '.png'
# video_labels = video_labels[video_labels.groupby('image_name')['impact'].transform("sum") > 0].reset_index(drop=True)
# video_labels['impact'] = video_labels['impact'].astype(int)+1
# video_labels['x'] = video_labels['left']
# video_labels['y'] = video_labels['top']
# video_labels['w'] = video_labels['width']
# video_labels['h'] = video_labels['height']
# video_labels.head()

In [10]:
# video_labels.to_csv('aug_train_labels_44.csv', index=False)

In [11]:
# video_names = np.random.permutation(video_labels.video_name.unique())
# valid_video_len = int(len(video_names)*0.2)
# video_valid = video_names[:valid_video_len]
# video_train = video_names[valid_video_len:]
video_valid = ['57583_000082', '57586_004152', '57911_000147', '57997_003691', '57680_002206', '58095_004022', '57906_000718', '58005_001254', '57679_003316', '58103_003494', '57998_002181', '58048_000086']

images_valid = video_labels[ video_labels.video_name.isin(video_valid)].image_name.unique()

# valid_video_labels = valid_video_labels[valid_video_labels.video_name.isin(video_valid)]
# f1_images_valid = valid_video_labels.image_name.unique()

images_train = video_labels[~video_labels.video_name.isin(video_valid)].image_name.unique()

In [12]:
# skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# df_folds = video_labels[['image_name']].copy()
# df_folds.loc[:, 'bbox_count'] = 1
# df_folds = df_folds.groupby('image_name').count()
# df_folds.loc[:, 'video'] = video_labels[['image_name', 'video']].groupby('image_name').min()['video']
# df_folds.loc[:, 'stratify_group'] = np.char.add(
#     df_folds['video'].values.astype(str),
#     df_folds['bbox_count'].apply(lambda x: f'_{x // 20}').values.astype(str),
# )

# df_folds.loc[:, 'fold'] = 0
# for fold_number, (train_index, val_index) in enumerate(skf.split(X=df_folds.index, y=df_folds['stratify_group'])):
#     df_folds.loc[df_folds.iloc[val_index].index, 'fold'] = fold_number
    
# df_folds.head()

In [13]:
# export frame in video, only frame with impact
# def mk_images(video_name, video_labels, video_dir, out_dir, only_with_impact=True):
#     video_path=f"{video_dir}/{video_name}"
#     video_name = os.path.basename(video_path)
#     vidcap = cv2.VideoCapture(video_path)
    
#     if only_with_impact:
#         boxes_all = video_labels.query("video == @video_name")
#         print(video_path, boxes_all[boxes_all.impact > 1.0].shape[0])
#     else:
#         print(video_path)
        
#     frame = 0
#     while True:
#         it_worked, img = vidcap.read()
#         if not it_worked:
#             break
#         frame += 1
#         if only_with_impact:
#             boxes = video_labels.query("video == @video_name and frame == @frame")
#             boxes_with_impact = boxes[boxes.impact > 1.0]
#             if boxes_with_impact.shape[0] == 0:
#                 continue
                
#         image_path = f'{out_dir}/{video_name}'.replace('.mp4',f'_{frame}.png')
#         _ = cv2.imwrite(image_path, img)

In [14]:
# uniq_video = video_labels.video.unique()
# video_dir = '/home/thinh/nfl/train'
# out_dir = '/home/thinh/nfl/train_images'
# !mkdir -p $out_dir
# for video_name in uniq_video:
#     mk_images(video_name, video_labels, video_dir, out_dir)

## Albumentations

In [15]:
def get_train_transforms():
    return A.Compose(
        [
#             A.RandomSizedCrop(min_max_height=(720, 720), height=720, width=720, p=1.0),
            A.OneOf([
                A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2, 
                                     val_shift_limit=0.2, p=0.9),
                A.RandomBrightnessContrast(brightness_limit=0.2, 
                                           contrast_limit=0.2, p=0.9),
            ], p=0.5),
            A.ToGray(p=0.01),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.Transpose(p=0.5),
            A.Resize(height=IMG_H, width=IMG_W, p=1.0),
#             A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], 
        p=1.0, 
        bbox_params=A.BboxParams(
            format='pascal_voc',
            min_area=0, 
            min_visibility=0,
            label_fields=['labels']
        )
    )

def get_valid_transforms():
    return A.Compose(
        [
#             A.RandomSizedCrop(min_max_height=(720, 720), height=720, width=720, p=1.0),
            A.Resize(height=IMG_H, width=IMG_W, p=1.0),
#             A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], 
        p=1.0, 
        bbox_params=A.BboxParams(
            format='pascal_voc',
            min_area=0, 
            min_visibility=0,
            label_fields=['labels']
        )
    )


# def get_train_transforms():
#     return A.Compose(
#         [
#             A.RandomSizedCrop(min_max_height=(600, 720), height=720, width=1280, p=0.1),
#             A.OneOf([
#                 A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2, 
#                                      val_shift_limit=0.2, p=0.9),
#                 A.RandomBrightnessContrast(brightness_limit=0.2, 
#                                            contrast_limit=0.2, p=0.9),
#             ], p=0.9),
#             A.ToGray(p=0.01),
#             A.HorizontalFlip(p=0.5),
#             A.VerticalFlip(p=0.5),
#             A.RandomRotate90(p=0.5),
#             A.Transpose(p=0.5),
#             A.JpegCompression(quality_lower=85, quality_upper=95, p=0.2),
#             A.OneOf([
#                 A.Blur(blur_limit=3, p=1.0),
#                 A.MedianBlur(blur_limit=3, p=1.0)
#             ],p=0.1),
#             A.Resize(height=IMG_H, width=IMG_W, p=1),
# #             A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=0, p=0.5),
#             ToTensorV2(p=1.0),
#         ], 
#         p=1.0, 
#         bbox_params=A.BboxParams(
#             format='pascal_voc',
#             min_area=0, 
#             min_visibility=0,
#             label_fields=['labels']
#         )
#     )

# def get_valid_transforms():
#     return A.Compose(
#         [
#             A.Resize(height=IMG_H, width=IMG_W, p=1.0),
#             ToTensorV2(p=1.0),
#         ], 
#         p=1.0, 
#         bbox_params=A.BboxParams(
#             format='pascal_voc',
#             min_area=0, 
#             min_visibility=0,
#             label_fields=['labels']
#         )
#     )

## Dataset

In [16]:
TRAIN_ROOT_PATH = 'train_images'

class DatasetRetriever(Dataset):

    def __init__(self, marking, image_ids, transforms=None, test=False):
        super().__init__()

        self.image_ids = image_ids
        self.marking = marking
        self.transforms = transforms
        self.test = test

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        
#         if self.test or random.random() > 0.2:
        image, boxes, labels = self.load_image_and_boxes(index)
#         elif random.random() > 0.5:
#             image, boxes, labels = self.load_cutmix_image_and_boxes(index)
#         else:
#             image, boxes, labels = self.load_mixup_image_and_boxes(index)
                
        target = {}
        target['boxes'] = boxes
        target['labels'] = torch.tensor(labels)
        target['image_id'] = torch.tensor([index])

        if self.transforms:
            for i in range(10):
                sample = self.transforms(**{
                    'image': image,
                    'bboxes': target['boxes'],
                    'labels': labels
                })
                if len(sample['bboxes']) > 0:
                    image = sample['image']
                    target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
                    target['boxes'][:,[0,1,2,3]] = target['boxes'][:,[1,0,3,2]]  #yxyx: be warning
                    target['labels'] = torch.tensor(sample['labels'])
                    break
                    
        return image, target, image_id

    def __len__(self) -> int:
        return self.image_ids.shape[0]

    def load_image_and_boxes(self, index):
        image_id = self.image_ids[index]
        #print(f'{TRAIN_ROOT_PATH}/{image_id}')
        image = cv2.imread(f'/home/thinh/nfl/{TRAIN_ROOT_PATH}/{image_id}', cv2.IMREAD_COLOR).copy().astype(np.float32)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        records = self.marking[self.marking['image_name'] == image_id]
        boxes = records[['x', 'y', 'w', 'h']].values
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        labels = records['impact'].values
        return image, boxes, labels
    
    def load_mixup_image_and_boxes(self, index):
        image, boxes, labels = self.load_image_and_boxes(index)
        r_image, r_boxes, r_labels = self.load_image_and_boxes(random.randint(0, self.image_ids.shape[0] - 1))
        return (image+r_image)/2, np.vstack((boxes, r_boxes)).astype(np.int32), np.concatenate((labels, r_labels))
    

    def load_cutmix_image_and_boxes(self, index, imsize=720):
        """ 
        This implementation of cutmix author:  https://www.kaggle.com/nvnnghia 
        Refactoring and adaptation: https://www.kaggle.com/shonenkov
        """
        w, h = imsize, imsize
        s = imsize // 2
    
        xc, yc = [int(random.uniform(imsize * 0.25, imsize * 0.75)) for _ in range(2)]  # center x, y
        indexes = [index] + [random.randint(0, self.image_ids.shape[0] - 1) for _ in range(3)]

        result_image = np.full((imsize, imsize, 3), 1, dtype=np.float32)
        result_boxes = []
        result_labels = np.array([], dtype=np.int)

        for i, index in enumerate(indexes):
            image, boxes, labels = self.load_image_and_boxes(index)
            if i == 0:
                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
            elif i == 1:  # top right
                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
            elif i == 2:  # bottom left
                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
            elif i == 3:  # bottom right
                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
            result_image[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
            padw = x1a - x1b
            padh = y1a - y1b

            boxes[:, 0] += padw
            boxes[:, 1] += padh
            boxes[:, 2] += padw
            boxes[:, 3] += padh

            result_boxes.append(boxes)
            result_labels = np.concatenate((result_labels, labels))

        result_boxes = np.concatenate(result_boxes, 0)
        np.clip(result_boxes[:, 0:], 0, 2 * s, out=result_boxes[:, 0:])
        result_boxes = result_boxes.astype(np.int32)
        index_to_use = np.where((result_boxes[:,2]-result_boxes[:,0])*(result_boxes[:,3]-result_boxes[:,1]) > 0)
        result_boxes = result_boxes[index_to_use]
        result_labels = result_labels[index_to_use]
        
        return result_image, result_boxes, result_labels

## Fitter

In [17]:
class Fitter:
    
    def __init__(self, model, device, config):
        self.config = config
        self.epoch = 0

        self.base_dir = f'/home/thinh/nfl/{config.folder}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10**5

        self.model = model
        self.device = device

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ] 

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.lr)
        self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params)
        self.log(f'Fitter prepared. Device is {self.device}')

    def fit(self, train_loader, validation_loader, fold):
        scaler = torch.cuda.amp.GradScaler()
        
        for e in range(self.config.n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR: {lr}')

            t = time.time()
            summary_loss = self.train_one_epoch(train_loader, scaler)

            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, time: {(time.time() - t):.5f}')
#             self.save(f'{self.base_dir}/tito-last-checkpoint.bin')

            t = time.time()
            summary_loss = self.validation(validation_loader)
#             self.f1_validation(f1_validation_loader)

            self.log(f'[RESULT]: Val. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, time: {(time.time() - t):.5f}')
#             if summary_loss.avg < self.best_summary_loss:
            self.best_summary_loss = summary_loss.avg
            self.model.eval()
            self.save(f'{self.base_dir}/tito-512/tito-checkpoint-D6-512-A-epoch{str(self.epoch).zfill(3)}-fold{str(fold)}-{TrainGlobalConfig.gpu}.bin')
#                 for path in sorted(glob(f'{self.base_dir}/best-checkpoint-*epoch.bin'))[:-3]:
#                     os.remove(path)

            if self.config.validation_scheduler:
                self.scheduler.step(metrics=summary_loss.avg)

            self.epoch += 1

    def validation(self, val_loader):
        valid_data = {}
        self.model.eval()
        summary_loss = AverageMeter()
        t = time.time()
        
        for step, (images, targets, image_ids) in enumerate(val_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
                    
            with torch.no_grad():
                images = torch.stack(images)
                batch_size = images.shape[0]
                images = images.to(self.device).float()
                boxes = [target['boxes'].to(self.device).float() for target in targets]
                labels = [target['labels'].to(self.device).float() for target in targets]
                
                target_res = {}
                target_res['bbox'] = boxes
                target_res['cls'] = labels 
                target_res['img_scale'] = torch.tensor([1]*images.shape[0]).float().to(self.device)
                target_res['img_size'] = torch.tensor(images.shape[2:]).repeat(images.shape[0], 1).to(self.device)
                
                outputs = self.model(images, target_res)
#                 loss, _, _ = self.model(images, boxes, labels)
                loss = outputs['loss']
                det = outputs['detections']
                
                for i in range(images.shape[0]):
                    pre_boxes = det[i].detach().cpu().numpy()[:,:4]    
                    pre_scores = det[i].detach().cpu().numpy()[:,4]
                    pre_label = det[i].detach().cpu().numpy()[:,5]
                    
                    # using only label = 2
                    pre_indexes = np.where(pre_label == 2)[0]
                    
                    true_indexes = np.where(labels[i].cpu().numpy() == 2)[0]
                    valid_data[image_ids[i]] = (boxes[i][true_indexes].cpu().numpy(), pre_boxes[pre_indexes], pre_scores[pre_indexes])
                
                summary_loss.update(loss.detach().item(), batch_size)
        
#         for threshold in [0.2, 0.3, 0.4, 0.5]:
#             val_f1 = f1_calc(valid_data, threshold, val_type=1)
#             val_f1 = f1_calc(valid_data, threshold, val_type=1, post_process=True)
        
        return summary_loss
    
    def f1_validation(self, val_loader):
        valid_data = {}
        self.model.eval()
        
        for step, (images, targets, image_ids) in enumerate(val_loader):
            with torch.no_grad():
                images = torch.stack(images)
                batch_size = images.shape[0]
                images = images.to(self.device).float()
                boxes = [target['boxes'].to(self.device).float() for target in targets]
                labels = [target['labels'].to(self.device).float() for target in targets]
                
                target_res = {}
                target_res['bbox'] = boxes
                target_res['cls'] = labels 
                target_res['img_scale'] = torch.tensor([1]*images.shape[0]).float().to(self.device)
                target_res['img_size'] = torch.tensor(images.shape[2:]).repeat(images.shape[0], 1).to(self.device)
                
                outputs = self.model(images, target_res)
#                 loss, _, _ = self.model(images, boxes, labels)
                loss = outputs['loss']
                det = outputs['detections']
                
                for i in range(images.shape[0]):
                    pre_boxes = det[i].detach().cpu().numpy()[:,:4]    
                    pre_scores = det[i].detach().cpu().numpy()[:,4]
                    pre_label = det[i].detach().cpu().numpy()[:,5]
                    
                    # using only label = 2
                    pre_indexes = np.where(pre_label == 2)[0]
                    
                    true_indexes = np.where(labels[i].cpu().numpy() == 2)[0]
                    valid_data[image_ids[i]] = (boxes[i][true_indexes].cpu().numpy(), pre_boxes[pre_indexes], pre_scores[pre_indexes])
                        
        for threshold in [0.2, 0.3, 0.4, 0.5]:
            val_f1 = f1_calc(valid_data, threshold, val_type=2)
            val_f1 = f1_calc(valid_data, threshold, val_type=2, post_process=True)
            
        return val_f1
        
    def train_one_epoch(self, train_loader, scaler):
        self.model.train()
        summary_loss = AverageMeter()
        t = time.time()
        
        for step, (images, targets, image_ids) in enumerate(train_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Train Step {step}/{len(train_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            
            images = torch.stack(images)
            images = images.to(self.device).float()
            batch_size = images.shape[0]
            boxes = [target['boxes'].to(self.device).float() for target in targets]
            labels = [target['labels'].to(self.device).float() for target in targets]

            target_res = {}
            target_res['bbox'] = boxes
            target_res['cls'] = labels
            
            self.optimizer.zero_grad()
#             with torch.cuda.amp.autocast():
            outputs = self.model(images, target_res)
                
            loss = outputs['loss']
            loss.backward()
#             scaler.scale(loss).backward()

            summary_loss.update(loss.detach().item(), batch_size)

            self.optimizer.step()
#             scaler.step(self.optimizer)
#             scaler.update()

            if self.config.step_scheduler:
                self.scheduler.step()

        return summary_loss
    
    def save(self, path):
        self.model.eval()
        torch.save({
            'model_state_dict': self.model.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_summary_loss': self.best_summary_loss,
            'epoch': self.epoch,
        }, path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_summary_loss = checkpoint['best_summary_loss']
        self.epoch = checkpoint['epoch'] + 1
        
    def log(self, message):
        if self.config.verbose:
            print(message)
        with open(self.log_path, 'a+') as logger:
            logger.write(f'{message}\n')

## Training

In [18]:
def collate_fn(batch):
    return tuple(zip(*batch))
    
def run_training(num_fold):
    device = torch.device(TrainGlobalConfig.gpu)
    
    for fold_number in range(num_fold):
        print('Fold: {}'.format(fold_number + 1))

        train_dataset = DatasetRetriever(
#             image_ids=df_folds[df_folds['fold'] != fold_number].index.values,
            image_ids=images_train,
            marking=video_labels,
            transforms=get_train_transforms(),
            test=False,
        )

        validation_dataset = DatasetRetriever(
#             image_ids=df_folds[df_folds['fold'] == fold_number].index.values,
            image_ids=images_valid,
            marking=video_labels,
            transforms=get_valid_transforms(),
            test=True,
        )
        
#         f1_validation_dataset = DatasetRetriever(
#             image_ids=f1_images_valid,
#             marking=valid_video_labels,
#             transforms=get_valid_transforms(),
#             test=True,
#         )

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=TrainGlobalConfig.batch_size,
            sampler=RandomSampler(train_dataset),
            pin_memory=False,
            drop_last=True,
            num_workers=TrainGlobalConfig.num_workers,
            collate_fn=collate_fn,
        )
        val_loader = torch.utils.data.DataLoader(
            validation_dataset, 
            batch_size=TrainGlobalConfig.batch_size,
            num_workers=TrainGlobalConfig.num_workers,
            shuffle=False,
            sampler=SequentialSampler(validation_dataset),
            pin_memory=False,
            collate_fn=collate_fn,
        )
        
#         f1_val_loader = torch.utils.data.DataLoader(
#             f1_validation_dataset, 
#             batch_size=TrainGlobalConfig.batch_size,
#             num_workers=TrainGlobalConfig.num_workers,
#             shuffle=False,
#             sampler=SequentialSampler(f1_validation_dataset),
#             pin_memory=False,
#             collate_fn=collate_fn,
#         )
        
        net = get_net()
        net.to(device)

        fitter = Fitter(model=net, device=device, config=TrainGlobalConfig)
        fitter.fit(train_loader, val_loader, fold_number)

In [19]:
def get_net():
#     config = get_efficientdet_config('tf_efficientdet_d5')
#     net = EfficientDet(config, pretrained_backbone=False)
#     checkpoint = torch.load('./efficientdet_d5-ef44aea8.pth')
#     net.load_state_dict(checkpoint)
#     config.num_classes = 2
#     config.image_size = 768
#     net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
#     return DetBenchTrain(net, config)

    config = get_efficientdet_config('tf_efficientdet_d6')

    config.image_size = [IMG_H, IMG_W]
    config.norm_kwargs=dict(eps=.001, momentum=.01)

    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load('./efficientdet_d6-51cb0132.pth')
    net.load_state_dict(checkpoint)

    net.reset_head(num_classes=2)
    net.class_net = HeadNet(config, num_outputs=config.num_classes)

    return DetBenchTrain(net)

net = get_net()

In [20]:
run_training(num_fold=1)

Fold: 1
Fitter prepared. Device is cuda:1

2021-01-03T05:06:48.572167
LR: 0.0002
[RESULT]: Train. Epoch: 0, summary_loss: 14.52404, time: 2220.42307
[RESULT]: Val. Epoch: 0, summary_loss: 0.63802, time: 129.39115

2021-01-03T05:45:59.381341
LR: 0.0002
[RESULT]: Train. Epoch: 1, summary_loss: 0.59334, time: 2282.33316
[RESULT]: Val. Epoch: 1, summary_loss: 0.83075, time: 128.80048

2021-01-03T06:26:11.838923
LR: 0.0002
[RESULT]: Train. Epoch: 2, summary_loss: 0.52223, time: 2313.03370
[RESULT]: Val. Epoch: 2, summary_loss: 0.67350, time: 133.41074

2021-01-03T07:06:59.476148
LR: 0.0001
[RESULT]: Train. Epoch: 3, summary_loss: 0.43752, time: 2341.93928
[RESULT]: Val. Epoch: 3, summary_loss: 0.50064, time: 135.57372

2021-01-03T07:48:18.000709
LR: 0.0001
[RESULT]: Train. Epoch: 4, summary_loss: 0.41468, time: 2321.06744
[RESULT]: Val. Epoch: 4, summary_loss: 0.53225, time: 134.25587

2021-01-03T08:29:14.454625
LR: 0.0001
[RESULT]: Train. Epoch: 5, summary_loss: 0.39821, time: 2336.18835
[

KeyboardInterrupt: 