# Import dependencies

In [None]:
import logging
import os
import re
import gc
import json
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt

# Install packages

In [None]:
!pip install ../input/pytorch-16/torch-1.6.0cu101-cp37-cp37m-linux_x86_64.whl

In [None]:
!pip install ../input/pytorch-16/torchvision-0.7.0cu101-cp37-cp37m-linux_x86_64.whl

In [None]:
!pip install ../input/pretrainedmodels/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4/ > /dev/null # no output

In [None]:
!pip install ../input/wheat-pkgs/EfficientNet-PyTorch-master/EfficientNet-PyTorch-master/ > /dev/null # no output

In [None]:
!pip install ../input/wheat-pkgs/timm-0.1.20-py3-none-any.whl > /dev/null # no output

In [None]:
!pip install ../input/wheat-pkgs/segmentation_models.pytorch-master/segmentation_models.pytorch-master > /dev/null # no output

# Import more dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import segmentation_models_pytorch as smp

from wheat_infer_utils import *
from wheat_centernet_models import PoseBiFPNNet
from wheat_pseudo_train_l_helpers import (
    set_seed,
    create_logging,
    WheatDataset,
    FastDataLoader,
    collate,
    ModleWithLoss,
    CtdetLoss,
    ModelEMA,
    get_constant_schedule_with_warmup,
    train_one_epoch,
    get_train_transforms,
    freeze_bn
)

# Configuration

In [None]:
bifpn_path_0 = '../input/wheat-weights/model_centernet_effnetb5_bifpn_00099.pth'
bifpn_path_1 = '../input/wheat-weights/model_centernet_effnetb5_bifpn_fold1_00099.pth'
bifpn_path_3 = '../input/wheat-weights/model_centernet_effnetb5_bifpn_fold3_lb_ema_00099.pth'

In [None]:
class Config:
    arch = 'timm-efficientnet-b5'
    heads = {'hm': 1,
             'wh': 2,
             'reg': 2}
    head_conv = 64
    reg_offset = True
    cat_spec_wh = False
    
    # Image
    img_size = 1024
    in_scale = 1024 / img_size
    down_ratio = 4
    
    mean = [0.315290, 0.317253, 0.214556], 
    std = [0.245211, 0.238036, 0.193879]
    num_classes = 1
    
    pad = 63
    
    # Test
    
    batch_size = 8
    K = 128
    max_per_image = 128
    
    fix_res = False
    test_scales = [1]
    flip_test = False
    nms = False
    gpus = [0]
    amp = True
    
opt = Config()
device = torch.device('cuda') if opt.gpus[0] >= 0 else torch.device('cpu')

In [None]:
def change_key(d):
    for _ in range(len(d)):
        k, v = d.popitem(False)
        d['.'.join(k.split('.')[1:])] = v

# Preapre labels

In [None]:
DIR_INPUT = '../input/global-wheat-detection'
DIR_TRAIN = f'{DIR_INPUT}/train'
DIR_TEST = f'{DIR_INPUT}/test'

train_df = pd.read_csv(f'{DIR_INPUT}/train.csv')
train_df.shape

In [None]:
train_df['x'] = -1
train_df['y'] = -1
train_df['w'] = -1
train_df['h'] = -1

def expand_bbox(x):
    r = np.array(re.findall("([0-9]+[.]?[0-9]*)", x))
    if len(r) == 0:
        r = [-1, -1, -1, -1]
    return r

train_df[['x', 'y', 'w', 'h']] = np.stack(train_df['bbox'].apply(lambda x: expand_bbox(x)))
train_df.drop(columns=['bbox'], inplace=True)
train_df['x'] = train_df['x'].astype(np.float)
train_df['y'] = train_df['y'].astype(np.float)
train_df['w'] = train_df['w'].astype(np.float)
train_df['h'] = train_df['h'].astype(np.float)

train_df.head()

In [None]:
# DEBUG
# DIR_TEST = '../input/wheat-fake-test'

# Define test dataset

In [None]:
class WheatDatasetTest(torch.utils.data.Dataset):
    def __init__(self, opt, image_dir, transforms=None,
                 mean=[0.315290, 0.317253, 0.214556], 
                 std=[0.245211, 0.238036, 0.193879]):
        
        self.opt = opt
        
        self.image_dir = image_dir
        self.img_id = os.listdir(self.image_dir)
        
        self.transforms = transforms
        
        self.mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3)
        self.std = np.array(std, dtype=np.float32).reshape(1, 1, 3)
        
    def __len__(self):
        return len(self.img_id)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.img_id[idx])
        
        image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        h0, w0 = image.shape[0:2]
        
        if self.transforms is not None:
            image = self.transforms(image)
        image_resized = cv2.resize(image, (self.opt.img_size, self.opt.img_size))
        return image_resized, self.img_id[idx], image, h0, w0

In [None]:
def flip_lr(img):
    return np.ascontiguousarray(img[:, ::-1, :])

def deaug_lr(img, boxes):
    h, w = img.shape[:2]
    boxes[:, (0, 2)] = w - boxes[:, (2, 0)]
    return boxes

def flip_ud(img):
    return np.ascontiguousarray(img[::-1, :, :])

def deaug_ud(img, boxes):
    h, w = img.shape[:2]
    boxes[:, (1, 3)] = w - boxes[:, (3, 1)]
    return boxes

# Visualization helpers

In [None]:
BOX_COLOR_PRED = (255, 0, 0)
TEXT_COLOR = (255, 255, 255)


def visualize_bbox(img, bbox, score, color, thickness=2):
    x_min, y_min, x_max, y_max = bbox
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
    ((text_width, text_height), _) = cv2.getTextSize("{:.4f}".format(score), cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color, -1)
    cv2.putText(img, "{:.4f}".format(score), (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA)
    return img


def visualize(annotations):
    img = annotations['image'].copy()
    for bbox, score in zip(annotations['bboxes'], annotations['scores']):
        img = visualize_bbox(img, bbox, score, color=BOX_COLOR_PRED)
    plt.figure(figsize=(12, 12))
    plt.imshow(img)

# Inference

In [None]:
testdataset = WheatDatasetTest(opt, DIR_TEST)
print('Total number of images in test set: {}'.format(len(testdataset)))

testdataset_lr = WheatDatasetTest(opt, DIR_TEST, transforms=flip_lr)
testdataset_ud = WheatDatasetTest(opt, DIR_TEST, transforms=flip_ud)

In [None]:
# Inference helper
def do_predict(opt, model, threshold, flip_type=0, return_ids=False, return_shapes=False):
    
    if flip_type == 0:
        test_dataset = testdataset
        deaug_transform = None
    elif flip_type == 1:
        test_dataset = testdataset_lr
        deaug_transform = deaug_lr
    elif flip_type == 2:
        test_dataset = testdataset_ud
        deaug_transform = deaug_ud
        
    detector = CtdetDetector(opt, model)
    
    pred_boxes = []
    pred_scores = []
    
    height_list = []
    width_list = []
    if return_ids:
        img_ids = []
    
    for img, img_id, img0, h0, w0 in tqdm(test_dataset):
        
        ret = detector.run(img)
        results = ret['results'][1]
        results = results[results[:, 4] > threshold]
        
        pred_box = results[:, :4]
        if flip_type != 0:
            pred_box = deaug_transform(img, pred_box)
        
        # rescale & clip
        pred_box[:, 0] = np.clip(pred_box[:, 0] / opt.img_size * w0, 0, w0-1)
        pred_box[:, 1] = np.clip(pred_box[:, 1] / opt.img_size * h0, 0 ,h0-1)
        pred_box[:, 2] = np.clip(pred_box[:, 2] / opt.img_size * w0, 0, w0-1)
        pred_box[:, 3] = np.clip(pred_box[:, 3] / opt.img_size * h0, 0 ,h0-1)
            
        pred_boxes.append(pred_box)
        pred_scores.append(results[:, 4])
        if return_ids:
            img_ids.append(os.path.splitext(img_id)[0])
        
        if return_shapes:
            height_list.append(h0)
            width_list.append(w0)
    
    if return_shapes:
        if return_ids:
            return pred_boxes, pred_scores, height_list, width_list, img_ids
        else:
            return pred_boxes, pred_scores, height_list, width_list
    else:
        if return_ids:
            return pred_boxes, pred_scores, img_ids
        else:
            return pred_boxes, pred_scores

# On BiFPN model

## Load Fold 0 weights

In [None]:
bifpn_model = PoseBiFPNNet(opt.arch, opt.heads, opt.head_conv)
checkpoint = torch.load(bifpn_path_0, map_location=device)

change_key(checkpoint['model'])
bifpn_model.load_state_dict(checkpoint['model'])
bifpn_model.to(device)

del checkpoint
gc.collect()

In [None]:
opt.pad = 63

opt.test_scales = [1.1, ]
threshold = 0.30

bifpn0_pred_boxes_0   , bifpn0_pred_scores_0, h0_list, w0_list, img_ids = do_predict(opt, bifpn_model, threshold=threshold, flip_type=0, return_ids=True, return_shapes=True)
bifpn0_pred_boxes_0_lr, bifpn0_pred_scores_0_lr = do_predict(opt, bifpn_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
bifpn0_pred_boxes_0_ud, bifpn0_pred_scores_0_ud = do_predict(opt, bifpn_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)


opt.test_scales = [1.25, ]
threshold = 0.28

bifpn0_pred_boxes_l   , bifpn0_pred_scores_l    = do_predict(opt, bifpn_model, threshold=threshold, flip_type=0, return_ids=False, return_shapes=False)
bifpn0_pred_boxes_l_lr, bifpn0_pred_scores_l_lr = do_predict(opt, bifpn_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
bifpn0_pred_boxes_l_ud, bifpn0_pred_scores_l_ud = do_predict(opt, bifpn_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)

In [None]:
del bifpn_model
gc.collect()
torch.cuda.empty_cache()

## Load Fold 1 weights

In [None]:
bifpn_model = PoseBiFPNNet(opt.arch, opt.heads, opt.head_conv)
checkpoint = torch.load(bifpn_path_1, map_location=device)

change_key(checkpoint['model'])
bifpn_model.load_state_dict(checkpoint['model'])
bifpn_model.to(device)

del checkpoint
gc.collect()

In [None]:
opt.pad = 63

opt.test_scales = [1.1, ]
threshold = 0.30

bifpn1_pred_boxes_0   , bifpn1_pred_scores_0    = do_predict(opt, bifpn_model, threshold=threshold, flip_type=0, return_ids=False, return_shapes=False)
bifpn1_pred_boxes_0_lr, bifpn1_pred_scores_0_lr = do_predict(opt, bifpn_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
bifpn1_pred_boxes_0_ud, bifpn1_pred_scores_0_ud = do_predict(opt, bifpn_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)


opt.test_scales = [1.25, ]
threshold = 0.28

bifpn1_pred_boxes_l   , bifpn1_pred_scores_l    = do_predict(opt, bifpn_model, threshold=threshold, flip_type=0, return_ids=False, return_shapes=False)
bifpn1_pred_boxes_l_lr, bifpn1_pred_scores_l_lr = do_predict(opt, bifpn_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
bifpn1_pred_boxes_l_ud, bifpn1_pred_scores_l_ud = do_predict(opt, bifpn_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)

In [None]:
del bifpn_model
gc.collect()
torch.cuda.empty_cache()

## Load Fold 3 weights

In [None]:
bifpn_model = PoseBiFPNNet(opt.arch, opt.heads, opt.head_conv)
checkpoint = torch.load(bifpn_path_3, map_location=device)

change_key(checkpoint['model'])
bifpn_model.load_state_dict(checkpoint['model'])
bifpn_model.to(device)

del checkpoint
gc.collect()

In [None]:
opt.pad = 63

opt.test_scales = [1.1, ]
threshold = 0.30

bifpn3_pred_boxes_0   , bifpn3_pred_scores_0    = do_predict(opt, bifpn_model, threshold=threshold, flip_type=0, return_ids=False, return_shapes=False)
bifpn3_pred_boxes_0_lr, bifpn3_pred_scores_0_lr = do_predict(opt, bifpn_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
bifpn3_pred_boxes_0_ud, bifpn3_pred_scores_0_ud = do_predict(opt, bifpn_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)


opt.test_scales = [1.25, ]
threshold = 0.28

bifpn3_pred_boxes_l   , bifpn3_pred_scores_l    = do_predict(opt, bifpn_model, threshold=threshold, flip_type=0, return_ids=False, return_shapes=False)
bifpn3_pred_boxes_l_lr, bifpn3_pred_scores_l_lr = do_predict(opt, bifpn_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
bifpn3_pred_boxes_l_ud, bifpn3_pred_scores_l_ud = do_predict(opt, bifpn_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)

In [None]:
del bifpn_model
gc.collect()
torch.cuda.empty_cache()

## Ensemble

In [None]:
def normalize_boxes(boxes, h0, w0):
    boxes[:, 0] = boxes[:, 0] / w0
    boxes[:, 1] = boxes[:, 1] / h0
    boxes[:, 2] = boxes[:, 2] / w0
    boxes[:, 3] = boxes[:, 3] / h0
    return boxes

def denormalize_clip_boxes(boxes, h0, w0):
    boxes[:, 0] = np.clip(boxes[:, 0] * w0, 0, w0-1)
    boxes[:, 1] = np.clip(boxes[:, 1] * h0, 0, h0-1)
    boxes[:, 2] = np.clip(boxes[:, 2] * w0, 0, w0-1)
    boxes[:, 3] = np.clip(boxes[:, 3] * h0, 0, h0-1)
    return boxes

In [None]:
import sys
sys.path.insert(0, "../input/weightedboxesfusion")
import ensemble_boxes

iou_thr = 0.44
skip_box_thr = 0.00001

pred_boxes_ensemble = []
pred_scores_ensemble = []
for (b00, b01, b02, b03, b04, b05,
     b10, b11, b12, b13, b14, b15, 
     b20, b21, b22, b23, b24, b25,
     s00, s01, s02, s03, s04, s05,
     s10, s11, s12, s13, s14, s15,
     s20, s21, s22, s23, s24, s25,
     h0, w0) in zip(
    tqdm(bifpn0_pred_boxes_0), 
    bifpn0_pred_boxes_0_lr, 
    bifpn0_pred_boxes_0_ud,
    bifpn0_pred_boxes_l,
    bifpn0_pred_boxes_l_lr,
    bifpn0_pred_boxes_l_ud,
    
    bifpn1_pred_boxes_0, 
    bifpn1_pred_boxes_0_lr, 
    bifpn1_pred_boxes_0_ud,
    bifpn1_pred_boxes_l,
    bifpn1_pred_boxes_l_lr,
    bifpn1_pred_boxes_l_ud,
    
    bifpn3_pred_boxes_0, 
    bifpn3_pred_boxes_0_lr, 
    bifpn3_pred_boxes_0_ud,
    bifpn3_pred_boxes_l,
    bifpn3_pred_boxes_l_lr,
    bifpn3_pred_boxes_l_ud,
    
    
    bifpn0_pred_scores_0,
    bifpn0_pred_scores_0_lr, 
    bifpn0_pred_scores_0_ud,
    bifpn0_pred_scores_l,
    bifpn0_pred_scores_l_lr,
    bifpn0_pred_scores_l_ud,
    
    bifpn1_pred_scores_0,
    bifpn1_pred_scores_0_lr, 
    bifpn1_pred_scores_0_ud,
    bifpn1_pred_scores_l,
    bifpn1_pred_scores_l_lr,
    bifpn1_pred_scores_l_ud,
    
    bifpn3_pred_scores_0,
    bifpn3_pred_scores_0_lr, 
    bifpn3_pred_scores_0_ud,
    bifpn3_pred_scores_l,
    bifpn3_pred_scores_l_lr,
    bifpn3_pred_scores_l_ud,

    h0_list,
    w0_list):
    
    
    boxes_list = [
        normalize_boxes(b00, h0, w0).tolist(),
        normalize_boxes(b01, h0, w0).tolist(),
        normalize_boxes(b02, h0, w0).tolist(),
        normalize_boxes(b03, h0, w0).tolist(),
        normalize_boxes(b04, h0, w0).tolist(),
        normalize_boxes(b05, h0, w0).tolist(),
        normalize_boxes(b10, h0, w0).tolist(),
        normalize_boxes(b11, h0, w0).tolist(),
        normalize_boxes(b12, h0, w0).tolist(),
        normalize_boxes(b13, h0, w0).tolist(),
        normalize_boxes(b14, h0, w0).tolist(),
        normalize_boxes(b15, h0, w0).tolist(),
        normalize_boxes(b20, h0, w0).tolist(),
        normalize_boxes(b21, h0, w0).tolist(),
        normalize_boxes(b22, h0, w0).tolist(),
        normalize_boxes(b23, h0, w0).tolist(),
        normalize_boxes(b24, h0, w0).tolist(),
        normalize_boxes(b25, h0, w0).tolist()
    ]
    
    scores_list = [
        s00.tolist(),
        s01.tolist(),
        s02.tolist(),
        s03.tolist(),
        s04.tolist(),
        s05.tolist(),
        s10.tolist(),
        s11.tolist(),
        s12.tolist(),
        s13.tolist(),
        s14.tolist(),
        s15.tolist(),
        s20.tolist(),
        s21.tolist(),
        s22.tolist(),
        s23.tolist(),
        s24.tolist(),
        s25.tolist()
    ]
    
    labels_list = [
        [0] * len(b00),
        [0] * len(b01),
        [0] * len(b02),
        [0] * len(b03),
        [0] * len(b04),
        [0] * len(b05),
        [0] * len(b10),
        [0] * len(b11),
        [0] * len(b12),
        [0] * len(b13),
        [0] * len(b14),
        [0] * len(b15),
        [0] * len(b20),
        [0] * len(b21),
        [0] * len(b22),
        [0] * len(b23),
        [0] * len(b24),
        [0] * len(b25)
    ]
    
    boxes, scores, _ = ensemble_boxes.ensemble_boxes_wbf.weighted_boxes_fusion(boxes_list, scores_list, labels_list, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    pred_boxes_ensemble.append(boxes)
    pred_scores_ensemble.append(scores)

In [None]:
pred_boxes_ensemble = [denormalize_clip_boxes(a, h0, w0) for a, h0, w0 in zip(pred_boxes_ensemble, h0_list, w0_list)]
pred_scores_ensemble = [a for a in pred_scores_ensemble]

In [None]:
# visualization
idx = -7
img = testdataset[idx][2]
print(testdataset[idx][1])
visualize({'image': img, 'bboxes': (pred_boxes_ensemble[idx]).astype(int), 'scores': pred_scores_ensemble[idx]})

# Generate Pseudo Labels

In [None]:
data_dict_pseudo = []
id_generator = 20201000000
for idx, (bboxes, h0, w0) in enumerate(zip(tqdm(pred_boxes_ensemble), h0_list, w0_list)):
    img_dict = {
        'file_name': os.path.join(DIR_TEST, testdataset[idx][1]),
        'height': h0,
        'width': w0,
        'id': id_generator,
    }
    annotations = []
    for bbox in bboxes:
        xywh = np.round([bbox[0], bbox[1], bbox[2]-bbox[0], bbox[3]-bbox[1]]).astype(float).tolist()
        annotations.append({
            'area': xywh[2] * xywh[3],
            'bbox': xywh,
            'category_id': 0,
            'bbox_mode': 1
        })
    img_dict['annotations'] = annotations
    id_generator += 1
    
    data_dict_pseudo.append(img_dict)

In [None]:
data_dict_pseudo = [d for d in data_dict_pseudo if len(d['annotations']) > 0] # removal

# Training on pseudo labels

In [None]:
class TrainConfig:
    seed = 2519
    arch = 'timm-efficientnet-b5'
    heads = {
        'hm': 1,
        'wh': 2,
        'reg': 2}
    head_conv = 64
    reg_offset = True
    
    # Image
    data_root = '../input/global-wheat-detection'
    crop_size = 896
    scale = 0.
    shift = 0.
    rotate = 15.
    shear = 5.
    down_ratio = 4

    debug = False

    # loss
    hm_weight = 1
    off_weight = 1
    wh_weight = 0.1

    # train
    batch_size = 4
    base_lr = 0.25e-4
    warmup_iters = 0
    total_epochs = 9
    stage_epochs = 9
    freeze_bn = True
    accumulate = 3
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ema = False
    amp = True

    # logging
    output_dir = './'
    logs_dir = os.path.join(output_dir, 'logs')
    log_interval = 10

    # saving
    checkpoint = 1
    load_model = '../input/wheat-weights/model_centernet_effnetb5_bifpn_00099.pth'
    resume = ''
    
train_opt = TrainConfig()

In [None]:
with open('../input/wheat-splits/wheat_train_3.json', 'r') as f:
    data_dict_train = json.load(f)

with open('../input/wheat-splits/wheat_valid_3.json', 'r') as f:
    data_dict_valid = json.load(f)

In [None]:
def main(opt):
    set_seed(opt.seed)
    torch.backends.cudnn.benchmark = True
    
    create_logging(opt.logs_dir, 'w')

    train_dataset = WheatDataset(
        opt, 
        opt.data_root, 
        data_dict_train, 
        data_dict_pseudo=data_dict_pseudo, 
        img_size=1024, 
        transforms=get_train_transforms(opt.crop_size), 
        is_train=True, 
        load_to_ram=False)
    
    logging.info('{} images in training set'.format(len(train_dataset.data_dict)))
    
    train_loader = FastDataLoader(
        train_dataset, 
        opt.batch_size,
        collate_fn=collate,
        shuffle=True, 
        drop_last=True,
        pin_memory=True,
        num_workers=2)
    
    model = ModleWithLoss(PoseBiFPNNet(opt.arch, opt.heads, opt.head_conv), CtdetLoss(opt)).to(opt.device)
    if opt.freeze_bn:
        model.apply(freeze_bn) # freeze bn
    if opt.ema:
        logging.info('Training with EMA')
        ema = ModelEMA(model)
    else:
        ema = None
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.base_lr)
    num_training_steps = int(opt.total_epochs * len(train_dataset) / opt.batch_size / opt.accumulate)
    scheduler = get_constant_schedule_with_warmup(optimizer, opt.warmup_iters)
    current_epoch = 0

    if opt.load_model != '':
        checkpoint = torch.load(opt.load_model)
        model.load_state_dict(checkpoint['model'])

    elif opt.resume != '':
        # Load model weights
        checkpoint = torch.load(opt.resume)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        current_epoch = checkpoint['epoch'] + 1

    scaler = GradScaler() if opt.amp else None
    for epoch in tqdm(range(current_epoch, current_epoch + opt.stage_epochs)):

        epoch_start_time = time.time()
        train_one_epoch(opt, model, optimizer, scheduler, train_loader, epoch, ema=ema, scaler=scaler)

        logging.info('-' * 89)
        logging.info('end of epoch {:4d} | time: {:5.2f}s |'.format(epoch, (time.time() - epoch_start_time)))
        logging.info('-' * 89)

    return model.state_dict()

In [None]:
# gc.collect()
# torch.cuda.empty_cache()

# pseudo_model = PoseBiFPNNet(opt.arch, opt.heads, opt.head_conv)

# if len(os.listdir('../input/global-wheat-detection/test/')) < 11:
#     checkpoint = torch.load(bifpn_path_0, map_location=device)

#     change_key(checkpoint['model'])
#     pseudo_model.load_state_dict(checkpoint['model'])
#     pseudo_model.to(device)

#     del checkpoint
#     gc.collect()
    
# else:
#     state_dict = main(train_opt)
#     pseudo_model.load_state_dict(state_dict)
#     pseudo_model.to(device)
    
#     del state_dict
#     gc.collect()

In [None]:
gc.collect()
torch.cuda.empty_cache()

pseudo_model = PoseBiFPNNet(train_opt.arch, train_opt.heads, train_opt.head_conv)

if len(os.listdir(DIR_TEST)) < 20:
    train_opt.total_epochs = 2
    train_opt.stage_epochs = 2
    data_dict_train = data_dict_train[:300]
    
    state_dict = main(train_opt)
    change_key(state_dict)
    
    pseudo_model.load_state_dict(state_dict)
    pseudo_model.to(device)
    
    del state_dict
    gc.collect()
    
else:
    state_dict = main(train_opt)
    change_key(state_dict)
    
    pseudo_model.load_state_dict(state_dict)
    pseudo_model.to(device)
    
    del state_dict
    gc.collect()

# Inference using pseudo model

In [None]:
opt.pad = 63

opt.test_scales = [1.1, ]
threshold = 0.32

pseudo_pred_boxes_0   , pseudo_pred_scores_0, h0_list, w0_list, img_ids = do_predict(opt, pseudo_model, threshold=threshold, flip_type=0, return_ids=True, return_shapes=True)
pseudo_pred_boxes_0_lr, pseudo_pred_scores_0_lr = do_predict(opt, pseudo_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
pseudo_pred_boxes_0_ud, pseudo_pred_scores_0_ud = do_predict(opt, pseudo_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)


opt.test_scales = [1.25, ]
threshold = 0.30

pseudo_pred_boxes_l   , pseudo_pred_scores_l    = do_predict(opt, pseudo_model, threshold=threshold, flip_type=0, return_ids=False, return_shapes=False)
pseudo_pred_boxes_l_lr, pseudo_pred_scores_l_lr = do_predict(opt, pseudo_model, threshold=threshold, flip_type=1, return_ids=False, return_shapes=False)
pseudo_pred_boxes_l_ud, pseudo_pred_scores_l_ud = do_predict(opt, pseudo_model, threshold=threshold, flip_type=2, return_ids=False, return_shapes=False)

In [None]:
iou_thr = 0.44
skip_box_thr = 0.00001

pred_boxes_pseudo = []
pred_scores_pseudo = []
for (b00, b01, b02, b03, b04, b05,
     s00, s01, s02, s03, s04, s05,
     h0, w0) in zip(
    tqdm(pseudo_pred_boxes_0), 
    pseudo_pred_boxes_0_lr, 
    pseudo_pred_boxes_0_ud,
    pseudo_pred_boxes_l,
    pseudo_pred_boxes_l_lr,
    pseudo_pred_boxes_l_ud,
    
    
    pseudo_pred_scores_0,
    pseudo_pred_scores_0_lr, 
    pseudo_pred_scores_0_ud,
    pseudo_pred_scores_l,
    pseudo_pred_scores_l_lr,
    pseudo_pred_scores_l_ud,

    h0_list,
    w0_list):
    
    
    boxes_list = [
        normalize_boxes(b00, h0, w0).tolist(),
        normalize_boxes(b01, h0, w0).tolist(),
        normalize_boxes(b02, h0, w0).tolist(),
        normalize_boxes(b03, h0, w0).tolist(),
        normalize_boxes(b04, h0, w0).tolist(),
        normalize_boxes(b05, h0, w0).tolist()
    ]
    
    scores_list = [
        s00.tolist(),
        s01.tolist(),
        s02.tolist(),
        s03.tolist(),
        s04.tolist(),
        s05.tolist()
    ]
    
    labels_list = [
        [0] * len(b00),
        [0] * len(b01),
        [0] * len(b02),
        [0] * len(b03),
        [0] * len(b04),
        [0] * len(b05)
    ]
    
    boxes, scores, _ = ensemble_boxes.ensemble_boxes_wbf.weighted_boxes_fusion(boxes_list, scores_list, labels_list, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    pred_boxes_pseudo.append(boxes)
    pred_scores_pseudo.append(scores)

In [None]:
pred_boxes_pseudo = [denormalize_clip_boxes(a, h0, w0) for a, h0, w0 in zip(pred_boxes_pseudo, h0_list, w0_list)]
pred_scores_pseudo = [a for a in pred_scores_pseudo]

In [None]:
# visualization
idx = -7
img = testdataset[idx][2]
print(testdataset[idx][1])
visualize({'image': img, 'bboxes': (pred_boxes_pseudo[idx]).astype(int), 'scores': pred_scores_pseudo[idx]})

# Generate Submission file

In [None]:
def format_prediction_string(boxes, scores):
    pred_strings = []
    for s, b in zip(scores, boxes.astype(int)):
        # xmin, ymin, w, h
        pred_strings.append(f'{s:.4f} {b[0]} {b[1]} {b[2]} {b[3]}')
    #print(" ".join(pred_strings))
    return " ".join(pred_strings)

In [None]:
pred_strs = []
for bboxes, scores in zip(pred_boxes_pseudo, pred_scores_pseudo):
    
    if len(bboxes) > 0:
        
        bboxes[:, 2] -= bboxes[:, 0]
        bboxes[:, 3] -= bboxes[:, 1]
        bboxes = bboxes.round()

        pred_strs.append(format_prediction_string(bboxes, scores))
        
    else:
        pred_strs.append('')

In [None]:
test_df = pd.DataFrame({'image_id': img_ids, 'PredictionString':pred_strs})
test_df

In [None]:
test_df.to_csv('submission.csv', index=False)