# About

This is a notebook designed to follow up on [this post](https://www.kaggle.com/c/global-wheat-detection/discussion/172827).

### Who this is for
- You participated in the [Global Wheat Detection](https://www.kaggle.com/c/global-wheat-detection) competition
- It helps if you've worked with [Alex Shonenkov](https://www.kaggle.com/shonenkov)'s EfficientDet implementation. This notebook is a fork (of a fork of a fork :P) of [Alex's one](https://www.kaggle.com/shonenkov/inference-efficientdet).

Otherwise, you can just read about the general idea [here](https://www.kaggle.com/c/global-wheat-detection/discussion/172827).

### How to read
If you can tick off both points above and want to be able to get the value out of this notebook really fast, just follow my markdown prompts.


In [None]:
!pip install --no-deps '../input/timm-package/timm-0.1.26-py3-none-any.whl' > /dev/null
!pip install --no-deps '../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl' > /dev/null

Below I introduce `QUAD_SIZE` which determines the size of the quadrants I'll be analysing. I found 640 to be the Goldilocks number (not too small, not too large).

In [None]:
import sys
sys.path.insert(0, "../input/timm-efficientdet-pytorch")
sys.path.insert(0, "../input/omegaconf")
sys.path.insert(0, "../input/weightedboxesfusion")

from ensemble_boxes import *
import torch
import numpy as np
import pandas as pd
from glob import glob
from torch.utils.data import Dataset,DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import gc
from matplotlib import pyplot as plt
from effdet import get_efficientdet_config, EfficientDet, DetBenchEval
from effdet.efficientdet import HeadNet
from scipy.spatial.distance import cosine
import os

FULL_SIZE = 1024
INF_SIZE = 640
QUAD_SIZE = 640

# wbf params
SKIP_BOX_THR = 0.43 # 0.43
IOU_THR = 0.44 # 0.44

# Augmentations

Nothing new here if you've already done TTA.

In [None]:
### REVERSE AUGMENTATIONS FOR BBOXES

def hflip_predictions(predictions):
    """
    horizontal flip
    """
    for batch_ix in range(len(predictions[0])):
        boxes = predictions[0][batch_ix]['boxes']
        boxes[:,0], boxes[:,2] = INF_SIZE - boxes[:,2], INF_SIZE - boxes[:,0]
        predictions[0][batch_ix]['boxes'] = boxes
    return predictions

def vflip_predictions(predictions):
    """
    vertical flip
    """
    for batch_ix in range(len(predictions[0])):
        boxes = predictions[0][batch_ix]['boxes']
        boxes[:,1], boxes[:,3] = INF_SIZE - boxes[:,3], INF_SIZE - boxes[:,1]
        predictions[0][batch_ix]['boxes'] = boxes
    return predictions

def rotate_predictions(predictions):
    """
    rotate counter-clockwise 90 degrees
    """
    for batch_ix in range(len(predictions[0])):
        boxes = predictions[0][batch_ix]['boxes']
        x1 = boxes[:,0].copy()
        x2 = boxes[:,2].copy()
        y1 = boxes[:,1].copy()
        y2 = boxes[:,3].copy()
        h = (y2 - y1).copy()
        boxes[:,0] = INF_SIZE - y1 - h
        boxes[:,1] = x1
        boxes[:,2] = INF_SIZE - y2 + h
        boxes[:,3] = x2
        predictions[0][batch_ix]['boxes'] = boxes
    return predictions

# i = do nothing (should only be on its own)
# h = horizontal flip
# v = vertical flip
# r = rotation
# applied from left to right
AUG_TRANSFORMS = {
    's': A.Resize(height=INF_SIZE, width=INF_SIZE, p=1.0),
    't': ToTensorV2(p=1.0),
    'h': A.HorizontalFlip(p=1.0),
    'v': A.VerticalFlip(p=1.0),
    'r': A.Rotate((90,90), p=1.0),
    'c': A.OneOf([
                A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2,
                                     val_shift_limit=0.2, p=1),
                A.RandomBrightnessContrast(brightness_limit=0.2, 
                                           contrast_limit=0.2, p=1),
            ],p=1),
}

UNWRAP_FUNCS = {
    'h': hflip_predictions,
    'v': vflip_predictions,
    'r': rotate_predictions,
}

# non-trivial augs
# AUGS = ['', 'h', 'v', 'r', 'hv', 'rh', 'rv', 'rhv']
# AUGS = ['', 'h', 'v', 'r']
AUGS = ['', 'h', 'v', 'r']

def apply_augs(image):
    sample = {'image': image}
    images = []
    for aug_str in AUGS:
        # compose aug
        transform_list = [AUG_TRANSFORMS['s']]
        for aug_char in aug_str:
            transform_list += [AUG_TRANSFORMS[aug_char]]
        transform_list += [AUG_TRANSFORMS['t']]
        transforms = A.Compose(transform_list, p=1.0)
        # apply transforms
        sample_aug = transforms(**sample)
        images.append(sample_aug['image'])
    return images


def unwrap_augs(ls_predictions):
    """
    expects list of predictions as long as the AUGS list
    one set of predictions for each aug
    """
    ls_corrected_predictions = []
    assert len(ls_predictions) == len(AUGS), \
        "Warning length of predictions needs to be the same as length of AUGS"
    for aug_ix, aug_str in enumerate(AUGS):
        corrected_predictions = ls_predictions[aug_ix]
        # pass trivial aug without correction
        if aug_str in ['', 'c']:
            ls_corrected_predictions.append(corrected_predictions)
            continue
        # unwrap backwards
        for aug_char in aug_str[::-1]:
            corrected_predictions = UNWRAP_FUNCS[aug_char](corrected_predictions)    
        ls_corrected_predictions.append(corrected_predictions)
    return ls_corrected_predictions

# Data

Definitely some new stuff here. Look at how I cut out quadrants in the `__getitem__` function.

In [None]:
DATA_ROOT_PATH = '../input/global-wheat-detection/test'

class DatasetRetriever(Dataset):

    def __init__(self, image_ids, crop_size=QUAD_SIZE, transforms=None):
        super().__init__()
        self.image_ids = image_ids
        self.transforms = transforms

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        do_stitch_edges = []
        orig_image = self.load_image(image_id)
        orig_image = cv2.resize(orig_image, (FULL_SIZE, FULL_SIZE))
        # produce 4 crops
        # 0th dim will be for quadrants and full image
        quad_images = []
        quad_images.append(orig_image[:QUAD_SIZE,:QUAD_SIZE,:]) # top left
        quad_images.append(orig_image[:QUAD_SIZE,-QUAD_SIZE:,:]) # top right
        quad_images.append(orig_image[-QUAD_SIZE:,:QUAD_SIZE,:]) # bottom left
        quad_images.append(orig_image[-QUAD_SIZE:,-QUAD_SIZE:,:]) # bottom right
        quad_images += [orig_image]
        # 1st dim will be for augs
        ls_images = []
        for quad_image in quad_images:
            aug_images = apply_augs(quad_image)
            if self.transforms:
                for i, aug_image in enumerate(aug_images):
                    sample = {'image': aug_image}
                    sample = self.transforms(**sample)
                    aug_images[i] = sample['image']   
            ls_images.append(aug_images)
        return ls_images, image_id, do_stitch_edges
    
    def load_image(self, image_id):
        image = cv2.imread(f'{DATA_ROOT_PATH}/{image_id}.jpg', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        return image

    def __len__(self) -> int:
        return self.image_ids.shape[0]

In [None]:
dataset = DatasetRetriever(
    image_ids=np.array([path.split('/')[-1][:-4] for path in glob(f'{DATA_ROOT_PATH}/*.jpg')]))


def collate_fn(batch):
    return tuple(zip(*batch))

data_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    num_workers=4,
    drop_last=False,
    collate_fn=collate_fn
)

# Nets

Nothing new here.

In [None]:
def load_net(checkpoint_path):
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size=INF_SIZE
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint['model_state_dict'])

    del checkpoint
    gc.collect()

    net = DetBenchEval(net, config)
    net.eval();
    return net.cuda()

NETS = [
    load_net('../input/alex-gwd-demo-models/fold1-84_(1)024-last.bin')
]

# Helpers

New stuff here. `aug_ensemble_predictions` and `quad_ensemble_predictions` are designed to work together in a nested loop. We have 4 image orientations, and 5 images (4 quadrants + full image), so we are ensembling over 20 predictions per image.

In [None]:
def make_predictions(images, nets, score_threshold=0.22):
    images = torch.stack(images).cuda().float()
    predictions = []
    with torch.no_grad():
        dets = []
        for net in nets:
            dets.append(net(images, torch.tensor([1]*images.shape[0]).float().cuda()))
        for i in range(images.shape[0]):
            ls_boxes = []
            ls_scores = []
            for det in dets:
                boxes = det[i].detach().cpu().numpy()[:,:4]    
                scores = det[i].detach().cpu().numpy()[:,4]
                indexes = np.where(scores > score_threshold)[0]
                boxes = boxes[indexes]
                boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
                boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                ls_boxes.append(boxes)
                ls_scores.append(scores[indexes])
            predictions.append({
                'boxes': np.concatenate(ls_boxes),
                'scores': np.concatenate(ls_scores),
            })
    return [predictions]


def run_wbf(predictions, image_index, image_size=INF_SIZE, weights=None):
    boxes = [(prediction[image_index]['boxes']/image_size).tolist()  for prediction in predictions]
    scores = [prediction[image_index]['scores'].tolist()  for prediction in predictions]
    labels = [np.ones(prediction[image_index]['scores'].shape[0]).tolist() for prediction in predictions]
    boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=weights, iou_thr=IOU_THR, skip_box_thr=SKIP_BOX_THR)
    boxes = boxes*image_size
    return boxes, scores, labels

def prune_edge_predictions(predictions, quad_ix, size):
    """
    get rid of any predictions touching the edges
    or within buffer range
    MUST BE DONE BEFORE scale_predictions
    """
    boxes = predictions['boxes']
    scores = predictions['scores']
    buffer = 20
    if len(boxes) > 0:
        filt_left = boxes[:, 0] > (0 + buffer)
        filt_top = boxes[:, 1] > (0 + buffer)
        filt_right = boxes[:, 2] < (size - buffer)
        filt_bottom = boxes[:, 3] < (size - buffer)
        if quad_ix == 0: # top left
            filt = np.bitwise_and(filt_right, filt_bottom)
        elif quad_ix == 1: # top right
            filt = np.bitwise_and(filt_left, filt_bottom)
        elif quad_ix == 2: # bottom left
            filt = np.bitwise_and(filt_right, filt_top)
        elif quad_ix == 3: # bottom right
            filt = np.bitwise_and(filt_left, filt_top)
        boxes = boxes[filt]
        scores = scores[filt]
    predictions['boxes'] = boxes
    predictions['scores'] = scores
    return predictions

def scale_predictions(predictions, scaling):
    predictions['boxes'] = predictions['boxes'] * scaling
    return predictions

def shift_predictions(predictions, quad_ix, shift):
    """
    0 ==> top left
    1 ==> top right
    2 ==> bottom left
    3 ==> bottom right
    MUST BE DONE BEFORE combine_predictions
    """
    boxes = predictions['boxes']
    boxes[:, 0] += shift * (quad_ix % 2)
    boxes[:, 1] += shift * (0 if quad_ix < 2 else 1)
    boxes[:, 2] += shift * (quad_ix % 2)
    boxes[:, 3] += shift * (0 if quad_ix < 2 else 1)
    predictions['boxes'] = boxes
    return predictions

def combine_predictions(ls_predictions):
    combined_predictions = []
    # loop over batch dimension
    for batch_ix in range(len(ls_predictions[0])):
        boxes = []
        scores = []
        # loop over predictions sets
        for pred_ix in range(len(ls_predictions)):
            boxes.append(ls_predictions[pred_ix][batch_ix]['boxes'])
            scores.append(ls_predictions[pred_ix][batch_ix]['scores'])
        boxes = np.concatenate(boxes)
        scores = np.concatenate(scores)
        # sort from highest to lowest scrore
        sort_ix = np.argsort(scores)[::-1]
        boxes = np.array(boxes)[sort_ix]
        scores = np.array(scores)[sort_ix]
        combined_predictions.append({
            'boxes': boxes,
            'scores': scores
        })
    return [combined_predictions]


def aug_ensemble_predictions(aug_images, nets):
    """
    aug_images is expected to be a list of n_augs augmentations each
    with batch dim number of images
    this will return aug_predictions which will have dimensionality (n_augs, n_batches)
    """
    aug_predictions = []
    for batch_images in aug_images:
        aug_predictions.append(make_predictions(batch_images, nets))
    aug_predictions = unwrap_augs(aug_predictions)
    return aug_predictions


def quad_ensemble_predictions(ls_images, nets):
    """
    ls_images is expected to be a 2D list with 5 images in the 1st dim,
    first 4 for the quadrants, and last one for the full image
    and n augmentations in the second dim
    """
    ls_predictions = []
    for quad_ix, aug_images in enumerate(ls_images):
        aug_predictions = aug_ensemble_predictions(aug_images, nets)
        for aug_ix, batched_predictions in enumerate(aug_predictions):
            for batch_ix, predictions in enumerate(batched_predictions[0]):
                if quad_ix < 4:
                    batched_predictions[0][batch_ix] = prune_edge_predictions(predictions, quad_ix, INF_SIZE)
                    batched_predictions[0][batch_ix] = scale_predictions(batched_predictions[0][batch_ix], QUAD_SIZE/INF_SIZE)
                    batched_predictions[0][batch_ix] = shift_predictions(batched_predictions[0][batch_ix], quad_ix, FULL_SIZE-QUAD_SIZE)
                else:
                    batched_predictions[0][batch_ix] = scale_predictions(predictions, FULL_SIZE/INF_SIZE)
            ls_predictions += batched_predictions
    return ls_predictions

Check if it works by predicting on a sample test image. Notice the reconstructions of the image from the quadrants. I could have used the original image, but it was useful to check that I was reassembling everything properly.

In [None]:
import matplotlib.pyplot as plt

for j, (images_batches, image_ids, do_stitch_edges) in enumerate(data_loader):
    break

'''
images_batches comes out in the following dim structure: (batch, quad, aug)
I need it in (quad, aug, batch)
'''
ls_images = [[[images_batches[i][j][k] for i in range(len(images_batches))] for k in range(len(images_batches[0][0]))] for j in range(len(images_batches[0]))]

ls_predictions = quad_ensemble_predictions(ls_images, NETS)
combined_predictions = combine_predictions(ls_predictions)    

i = 1
reconstructed_image = np.zeros(shape=(FULL_SIZE, FULL_SIZE, 3))
reconstructed_image[:QUAD_SIZE, :QUAD_SIZE, :] = cv2.resize(ls_images[0][0][i].permute(1,2,0).cpu().numpy(), (QUAD_SIZE, QUAD_SIZE))
reconstructed_image[:QUAD_SIZE, -QUAD_SIZE:, :] = cv2.resize(ls_images[1][0][i].permute(1,2,0).cpu().numpy(), (QUAD_SIZE, QUAD_SIZE))
reconstructed_image[-QUAD_SIZE:, :QUAD_SIZE, :] = cv2.resize(ls_images[2][0][i].permute(1,2,0).cpu().numpy(), (QUAD_SIZE, QUAD_SIZE))
reconstructed_image[-QUAD_SIZE:, -QUAD_SIZE:, :] = cv2.resize(ls_images[3][0][i].permute(1,2,0).cpu().numpy(), (QUAD_SIZE, QUAD_SIZE))

boxes, scores, labels = run_wbf(combined_predictions, image_index=i)
boxes = boxes.astype(np.int32).clip(min=0, max=1024)

fig, ax = plt.subplots(1, 1, figsize=(16, 10))

for box in boxes:
    cv2.rectangle(reconstructed_image, (box[0], box[1]), (box[2], box[3]), (0, 1, 0), 2)
    
ax.set_axis_off()
ax.imshow(reconstructed_image);

In [None]:
def format_prediction_string(boxes, scores):
    pred_strings = []
    for j in zip(scores, boxes):
        pred_strings.append("{0:.4f} {1} {2} {3} {4}".format(j[0], j[1][0], j[1][1], j[1][2], j[1][3]))
    return " ".join(pred_strings)

In [None]:
results = []

for images_batches, image_ids, do_stitch_edges in data_loader:
    # unpack batch dimension and repack into list of batches
    ls_images = [[[images_batches[i][j][k] for i in range(len(images_batches))] for k in range(len(images_batches[0][0]))] for j in range(len(images_batches[0]))]
    
    ls_predictions = quad_ensemble_predictions(ls_images, NETS)
    combined_predictions = combine_predictions(ls_predictions)  
    
    for i in range(len(images_batches)):
        boxes, scores, labels = run_wbf(combined_predictions, image_index=i)
        boxes = boxes.astype(np.int32).clip(min=0, max=FULL_SIZE)
        image_id = image_ids[i]
        
        boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        boxes[:, 3] = boxes[:, 3] - boxes[:, 1]

        result = {
            'image_id': image_id,
            'PredictionString': format_prediction_string(boxes, scores)
        }
        results.append(result)

In [None]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])
test_df.to_csv('submission.csv', index=False)
test_df.head()