In this kernel used github repos [efficientdet-pytorch](https://github.com/rwightman/efficientdet-pytorch) and [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) by [@rwightman](https://www.kaggle.com/rwightman). Don't forget add stars ;)

In [None]:
!pip install --no-deps '../input/timm-package/timm-0.1.26-py3-none-any.whl' > /dev/null
!pip install --no-deps '../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl' > /dev/null

In [None]:
import sys
sys.path.insert(0, "../input/timm-efficientdet-pytorch")
sys.path.insert(0, "../input/omegaconf")
sys.path.insert(0, "../input/weightedboxesfusion")

from ensemble_boxes import *
import torch
import numpy as np
import pandas as pd
from glob import glob
from torch.utils.data import Dataset,DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import gc
from matplotlib import pyplot as plt
from effdet import get_efficientdet_config, EfficientDet, DetBenchEval
from effdet.efficientdet import HeadNet
import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

In [None]:
def get_valid_transforms():
    return A.Compose([
            A.Resize(height=512, width=512, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)

In [None]:
DATA_ROOT_PATH = '../input/global-wheat-detection/test'

class DatasetRetriever(Dataset):

    def __init__(self, image_ids, transforms=None):
        super().__init__()
        self.image_ids = image_ids
        self.transforms = transforms

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        image = cv2.imread(f'{DATA_ROOT_PATH}/{image_id}.jpg', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        if self.transforms:
            sample = {'image': image}
            sample = self.transforms(**sample)
            image = sample['image']#/255.0
        
        return image, image_id

    def __len__(self) -> int:
        return self.image_ids.shape[0]

In [None]:
dataset = DatasetRetriever(
    image_ids=np.array([path.split('/')[-1][:-4] for path in glob(f'{DATA_ROOT_PATH}/*.jpg')]),
    transforms=get_valid_transforms()
)

def collate_fn(batch):
    return tuple(zip(*batch))

data_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    drop_last=False,
    collate_fn=collate_fn
)

In [None]:
from collections import OrderedDict
def load_net(checkpoint_path, version='d5'):
    if version == 'd5':
        config_name = 'tf_efficientdet_d5'
    elif version == 'd7':
        config_name = 'tf_efficientdet_d7'
    else:
        raise NotImplemented
        
    config = get_efficientdet_config(config_name)
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size=512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(checkpoint_path)['model_state_dict']
# #     print(checkpoint.keys())
#     checkpoint2 = OrderedDict()
    
#     for v in checkpoint:
#         if 'anchor_labeler.anchors.boxes' in v:
#             checkpoint2['anchors.boxes'] = checkpoint[v]
#             continue
#         checkpoint2[v[6:]] = checkpoint[v]
# #     checkpoint['anchors.boxes'] = checkpoint.pop('anchor_labeler.anchors.boxes')
#     net.load_state_dict(checkpoint2)
# #     net.load_state_dict(checkpoint['model_state_dict'])

#     del checkpoint
#     gc.collect()

    net = DetBenchEval(net, config)
    checkpoint['anchors.boxes'] = checkpoint.pop('anchor_labeler.anchors.boxes')
    net.load_state_dict(checkpoint)
    net.eval();
    return net.cuda()


# 
# net = load_net('../input/wheat-effdet5-fold0-best-checkpoint/fold0-best-all-states.bin')
# net = load_net('../input/efficient-det-fold4/efficient_det_fold4.pth')
# net = load_net('../input/gwd-efficient-det-fold0-30e/efficient_det_fold0.pth')

USE_YXYX_TO_XYXY = False

models = [
#     load_net('../input/gwd-efficient-det7-fold0-30e/efficient_det7_fold0_30e.pth', 'd7'),
#     load_net('../input/gwd-efficient-det-fold0-30e/efficient_det_fold0.pth'),
#     load_net('../input/gwd-efficient-det-fold1-17e/efficient_det_fold1.pth'),
#     load_net('../input/gwd-efficient-det-fold2-19e/EfficientDetModel_fold2_19e.pth'),
#     load_net('../input/gwd-efficient-det-fold3-25e/efficient_det_fold3_25e.pth'),
#     load_net('../input/gwd-efficient-det-fold4-30e/efficient_det_fold4.pth'),
#     load_net('../input/gwdefficientnetfold013e/EfficientDetModel_fold0_13e.pth')
#     load_net('../input/gwdeffnetfold210ev5/EfficientDetModel_fold2_010e.pth')
#     load_net('../input/gwdeffnetfold041ev4/EfficientDetModel_fold0_41e.pth'),
#     load_net('../input/gwdeffnetfold110ev5/EfficientDetModel_fold1_10e.pth'),
#     load_net('../input/gwdeffnetfold210ev5/EfficientDetModel_fold2_010e.pth'),
#     load_net('../input/gwdeffnetfold310ev5/EfficientDetModel_fold3_010e.pth'),
#     load_net('../input/gwdeffnetd7fold024ev5/EfficientDetModel_fold0_35e.pth', 'd7'),
#     load_net('../input/gwd-efficient-det-fold0-30e/efficient_det_fold0.pth'),
#     load_net('../input/gwdeffnetfold041ev4/EfficientDetModel_fold0_41e.pth'),
    
#     load_net('../input/gwd-efficient-det-fold3-25e/efficient_det_fold3_25e.pth'),
#     load_net('../input/gwdeffnet-fold422ev5/EfficientDetModel_fold4_034e.pth'),
#     load_net('../input/gwdeffnet-fold422ev5/EfficientDetModel_fold4_034e.pth'),
#     load_net('../input/gwdeffnetfold110ev5/EfficientDetModel_fold1_10e.pth')
#     load_net('../input/gwdeffnetfold013ev3/EfficientDetModel_fold0_13e.pth')
#     load_net('../input/gwdeffnetfold210ev5/EfficientDetModel_fold2_015e.pth')
#     load_net('../input/gwd-efficient-det-fold0-30e/EfficientDetModel_fold0_038e.pth'),
#     load_net('../input/gwdeffnetfold210ev5/EfficientDetModel_fold2_52e.pth')
    
    
    load_net('../input/gwd-efficient-det-fold0-30e/EfficientDetModel_fold0_035e.pth'),
    load_net('../input/gwdeffnetfold110ev5/EfficientDetModel_fold1_030e.pth'),
    load_net('../input/gwdeffnetfold210ev5/EfficientDetModel_fold2_058e.pth'),
    load_net('../input/gwd-efficient-det-fold3-25e/EfficientDetModel_fold3_037e.pth'),
    load_net('../input/gwdeffnet-fold422ev5/EfficientDetModel_fold4_034e.pth'),
    
    
#     load_net('../input/gwdeffnet1024fold08e/EfficientDetModel_fold0_019e.pth'),
    
    
    
]


In [None]:
class BaseWheatTTA:
    """ author: @shonenkov """
    image_size = 512

    def augment(self, image):
        raise NotImplementedError
    
    def batch_augment(self, images):
        raise NotImplementedError
    
    def deaugment_boxes(self, boxes):
        raise NotImplementedError

class TTAHorizontalFlip(BaseWheatTTA):
    """ author: @shonenkov """

    def augment(self, image):
        return image.flip(1)
    
    def batch_augment(self, images):
        return images.flip(2)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [1,3]] = self.image_size - boxes[:, [3,1]]
        return boxes

class TTAVerticalFlip(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return image.flip(2)
    
    def batch_augment(self, images):
        return images.flip(3)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [0,2]] = self.image_size - boxes[:, [2,0]]
        return boxes
    
class TTARotate90(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return torch.rot90(image, 1, (1, 2))

    def batch_augment(self, images):
        return torch.rot90(images, 1, (2, 3))
    
    def deaugment_boxes(self, boxes):
        res_boxes = boxes.copy()
        res_boxes[:, [0,2]] = self.image_size - boxes[:, [1,3]]
        res_boxes[:, [1,3]] = boxes[:, [2,0]]
        return res_boxes

class TTACompose(BaseWheatTTA):
    """ author: @shonenkov """
    def __init__(self, transforms):
        self.transforms = transforms
        
    def augment(self, image):
        for transform in self.transforms:
            image = transform.augment(image)
        return image
    
    def batch_augment(self, images):
        for transform in self.transforms:
            images = transform.batch_augment(images)
        return images
    
    def prepare_boxes(self, boxes):
        result_boxes = boxes.copy()
        result_boxes[:,0] = np.min(boxes[:, [0,2]], axis=1)
        result_boxes[:,2] = np.max(boxes[:, [0,2]], axis=1)
        result_boxes[:,1] = np.min(boxes[:, [1,3]], axis=1)
        result_boxes[:,3] = np.max(boxes[:, [1,3]], axis=1)
        return result_boxes
    
    def deaugment_boxes(self, boxes):
        for transform in self.transforms[::-1]:
            boxes = transform.deaugment_boxes(boxes)
        return self.prepare_boxes(boxes)
    
class TTAChannelShuffle(BaseWheatTTA):
    def augment(self, image):
        return image[[2,1,0]]

    def batch_augment(self, images):
        return images[:,[2,1,0]]
    
    def deaugment_boxes(self, boxes):
        return boxes
    
class TTAGrayScale(BaseWheatTTA):
    def augment(self, image):
        pil_image = torchvision.transforms.ToPILImage()(image)
        new_image = torchvision.transforms.functional.to_grayscale(pil_image, 3)
        return  torchvision.transforms.ToTensor()(new_image)

    def batch_augment(self, images):
        new_images = []
        for image in images:
            pil_image = torchvision.transforms.ToPILImage()(image.cpu())
            new_image = torchvision.transforms.functional.to_grayscale(pil_image, 3)
            new_images.append(torchvision.transforms.ToTensor()(new_image).cuda())
        return torch.stack(new_images)
    
    def deaugment_boxes(self, boxes):
        return boxes
    
class TTACrop(BaseWheatTTA):
    def __init__(self, x_min=None, x_max=None, y_min=None, y_max=None):
        self.x_min = x_min
        self.x_max = x_max
        self.y_min = y_min
        self.y_max = y_max
        
    def augment(self, image):
        crop_image = image[:, self.y_min:self.y_max, self.x_min:self.x_max] 
        new_image = torch.nn.functional.interpolate(crop_image.unsqueeze(0), size=(512,512), mode="nearest")[0]
        return new_image

    def batch_augment(self, images):
        crop_images = images[:, :, self.y_min:self.y_max, self.x_min:self.x_max] 
        new_images = torch.nn.functional.interpolate(crop_images, size=(512,512), mode="nearest")
        return new_images
    
    def deaugment_boxes(self, boxes):
        boxes = (boxes/2).astype(np.int32).clip(min=0, max=511)
        
        if self.y_min == 0:
            boxes[:, 0] +=  self.x_min
            boxes[:, 2] +=  self.x_min
        elif self.x_min == 0:
            boxes[:, 1] +=  self.y_min
            boxes[:, 3] +=  self.y_min
        else:
            boxes[:, 0] +=  self.y_min
            boxes[:, 1] +=  self.y_min
            boxes[:, 2] +=  self.y_min
            boxes[:, 3] +=  self.y_min
            
        return boxes

In [None]:
# you can try own combinations:
transform = TTACompose([
#     TTARotate90(),
#     TTAVerticalFlip(),
#     TTAChannelShuffle()
    TTACrop(0,256,0,256),
#     TTACrop(256,512,0,256),
#     TTACrop(0,256,256,512),
#     TTACrop(256,512,256,512),
])

fig, ax = plt.subplots(1, 3, figsize=(16, 6))

image, image_id = dataset[5]

numpy_image = image.permute(1,2,0).cpu().numpy().copy()

ax[0].imshow(numpy_image);
ax[0].set_title('original')

tta_image = transform.augment(image)
tta_image_numpy = tta_image.permute(1,2,0).cpu().numpy().copy()

net = models[0]
det = net(tta_image.unsqueeze(0).float().cuda(), torch.tensor([1]).float().cuda())
# boxes, scores = det[0,:,:4].int().detach().cpu().numpy(), det[0,:,4].detach().cpu().numpy() #process_det(0, det)
boxes = det[0].int().detach().cpu().numpy()[:,:4]    
scores = det[0].detach().cpu().numpy()[:,4]
indexes = np.where(scores > 0.35)[0]
boxes = boxes[indexes]

boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
if USE_YXYX_TO_XYXY:
    boxes[:,[0,1,2,3]] = boxes[:,[1,0,3,2]]
    
for box in boxes:
    cv2.rectangle(tta_image_numpy, (box[0], box[1]), (box[2],  box[3]), (0, 1, 0), 2)

ax[1].imshow(tta_image_numpy);
ax[1].set_title('tta')
    
boxes = transform.deaugment_boxes(boxes)

for box in boxes:
    cv2.rectangle(numpy_image, (box[0], box[1]), (box[2],  box[3]), (0, 1, 0), 2)
    
ax[2].imshow(numpy_image);
ax[2].set_title('deaugment predictions');

In [None]:
from itertools import product

tta_transforms = []

transform_list = product([TTAHorizontalFlip(), None], 
                            [TTAVerticalFlip(), None],
                            [TTARotate90(), None],
                            [TTAChannelShuffle(), None])
for tta_combination in transform_list:
    tta_transforms.append(TTACompose([tta_transform for tta_transform in tta_combination if tta_transform]))
    

In [None]:
def make_predictions(images, score_threshold=0.35):
    images = torch.stack(images).float().cuda()
    predictions = []
    
    for net in models:
        with torch.no_grad():
            det = net(images, torch.tensor([1]*images.shape[0]).float().cuda())
            for i in range(images.shape[0]):
                boxes = det[i].detach().cpu().numpy()[:,:4]    
                scores = det[i].detach().cpu().numpy()[:,4]
                indexes = np.where(scores > score_threshold)[0]
                boxes = boxes[indexes]
                boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
                boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                if USE_YXYX_TO_XYXY:
                    boxes[:,[0,1,2,3]] = boxes[:,[1,0,3,2]]
                predictions.append({
                    'boxes': boxes,
                    'scores': scores[indexes],
                })
    return [predictions]

def make_tta_predictions(images, score_threshold=0.35):
    images = torch.stack(images).float().cuda()
    predictions = []
    for net in models:
        with torch.no_grad():                      
            for tta_transform in tta_transforms:
                result = []
                det = net(tta_transform.batch_augment(images.clone()), torch.tensor([1]*images.shape[0]).float().cuda())

                for i in range(images.shape[0]):
                    boxes = det[i].detach().cpu().numpy()[:,:4]    
                    scores = det[i].detach().cpu().numpy()[:,4]
                    indexes = np.where(scores > score_threshold)[0]
                    boxes = boxes[indexes]
                    boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
                    boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                    if USE_YXYX_TO_XYXY:
                      boxes[:,[0,1,2,3]] = boxes[:,[1,0,3,2]]
                    
                    boxes = tta_transform.deaugment_boxes(boxes.copy())
                    result.append({
                        'boxes': boxes,
                        'scores': scores[indexes],
                    })
                predictions.append(result)
    return predictions


# fold0 fold3
# -------------WBF--------------
# [Best Iou Thr]: 0.473
# [Best Skip Box Thr]: 0.408
# [Best Score]: 0.7386
# ------------------------------

#  fold 0 3 4 
# -------------WBF--------------
# [Best Iou Thr]: 0.472
# [Best Skip Box Thr]: 0.448
# [Best Score]: 0.7369
# ------------------------------

# v89
# -------------WBF--------------
# [Best Iou Thr]: 0.479
# [Best Skip Box Thr]: 0.430
# [Best Score]: 0.7347
# ------------------------------
def run_wbf(predictions, image_index, image_size=512, iou_thr=0.479, skip_box_thr=0.430, weights=None):
    boxes = [(prediction[image_index]['boxes']/(image_size-1)).tolist()  for prediction in predictions]
    scores = [prediction[image_index]['scores'].tolist()  for prediction in predictions]
    labels = [np.ones(prediction[image_index]['scores'].shape[0]).tolist() for prediction in predictions]
    boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    boxes = boxes*(image_size-1)
    return boxes, scores, labels

In [None]:
import matplotlib.pyplot as plt

for j, (images, image_ids) in enumerate(data_loader):
    break

predictions = make_tta_predictions(images)

i = 0
sample = images[i].permute(1,2,0).cpu().numpy()

boxes, scores, labels = run_wbf(predictions, image_index=i)

if images[i].size()[1] == 512:
    boxes = (boxes*2).astype(np.int32).clip(min=0, max=1023)
else:
    boxes = boxes.astype(np.int32)

# boxes = det[i].detach().cpu().numpy()[:,:4]    
# scores = det[i].detach().cpu().numpy()[:,4]
# indexes = np.where(scores > 0.35)[0]
# boxes = boxes[indexes]

fig, ax = plt.subplots(1, 1, figsize=(16, 8))

for box in boxes:
    cv2.rectangle(sample, (box[0], box[1]), (box[2], box[3]), (1, 0, 0), 1)
    
ax.set_axis_off()
ax.imshow(sample);

In [None]:
def format_prediction_string(boxes, scores):
    pred_strings = []
    for j in zip(scores, boxes):
        pred_strings.append("{0:.4f} {1} {2} {3} {4}".format(j[0], j[1][0], j[1][1], j[1][2], j[1][3]))
    return " ".join(pred_strings)

In [None]:
results = []
boxes_ = {}
scores_ = {}

for images, image_ids in data_loader:
    predictions = make_tta_predictions(images)
    for i, image in enumerate(images):
        boxes, scores, labels = run_wbf(predictions, image_index=i)
        
        if image.size()[1] == 512:
            boxes = (boxes*2).astype(np.int32).clip(min=0, max=1023)
        else:
            boxes = boxes.astype(np.int32)
            
        image_id = image_ids[i]
        
        boxes_[image_id] = boxes.copy()
        scores_[image_id] = scores.copy()
        
#         indexes = np.where(scores > 0.35)[0]
#         boxes = boxes[indexes]
        boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        
        boxes = boxes[scores >= 0.05].astype(np.int32)
        scores = scores[scores >=float(0.05)]

        result = {
            'image_id': image_id,
            'PredictionString': format_prediction_string(boxes, scores)
        }
        results.append(result)

In [None]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])
test_df.to_csv('submission.csv', index=False)
test_df.head()

In [None]:
# images, targets, ids = next(iter(data_loader))

itr = 0
itr2=0
img_count = len(images)
fig, ax = plt.subplots(5, 2, figsize=(30, 70))

for j, (images, image_ids) in enumerate(data_loader):
   
    for i,image in enumerate(images):  
        boxes = boxes_[image_ids[i]].copy() #targets[i]['boxes']
#         print(boxes)
        if image.size()[1] == 512:
            boxes = (boxes/2).astype(np.int32)
        else:
            boxes = boxes.astype(np.int32)
        sample = image.permute(1,2,0).cpu().numpy()

        for box,score in zip(boxes, scores_[image_ids[i]]):
            cv2.rectangle(sample,
                      (int(box[0]), int(box[1])),
                      (int(box[2]), int(box[3])),
                      (220, 0, 0), 2)
            cv2.putText(sample, '%.2f'%(score), (box[2], box[3]), cv2.FONT_HERSHEY_SIMPLEX , 0.5, (255,255,255), 2, cv2.LINE_AA)
        ax[j][i].set_title(f"{image_ids[i]}")  
        ax[j][i].imshow(sample)#.astype(np.float32))
        itr+=1
    if itr == 9:    
        break
        
plt.show()