* [Training Notebook](https://www.kaggle.com/ceshine/wheat-detection-training-efficientdet-public)
* [Inference Notebook](https://www.kaggle.com/ceshine/effdet-wheat-head-detection-inference-public)

In [None]:
!git clone https://github.com/ceshine/global-wheat-detection
%cd global-wheat-detection

In [None]:
!pip install omegaconf pytorch_lightning_spells
!pip install https://github.com/rwightman/efficientdet-pytorch/archive/master.zip

In [None]:
import cv2
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from wheat.dataset import WheatDataset, get_train_transforms
from wheat.model import get_train_efficientdet

In [None]:
BASE_DIR = "/kaggle/input/wheat-dataset-resized-to-512x512/512/"

In [None]:
df = pd.read_csv(BASE_DIR + "train.csv")
bboxes = np.stack(df["bbox"].apply(lambda x: np.fromstring(x[1:-1], sep=",")))
for i, col in enumerate(["x", "y", "w", "h"]):
    df[col] = bboxes[:, i]
df["x2"] = df["x"] + df["w"]
df["y2"] = df["y"] + df["h"]

## Regular Augmentations

In [None]:
dataset = WheatDataset(
    df=df, image_dir=BASE_DIR + "train/", 
    transforms=get_train_transforms(image_size=512, cutout=True)
)

In [None]:
def sample_aug(idx, dataset):
    _, ax = plt.subplots(4, 2, figsize=(14, 28))
    for row in range(4):
        for col in range(2):
            img, target = dataset[idx]
            img = img.transpose(1,2,0).copy()
            for i in range(len(target["bbox"])):
                box = target["bbox"][i, (1, 0, 3, 2)].round().astype(int) 
                _ = cv2.rectangle(
                    img,
                    (box[0], box[1]),
                    (box[2], box[3]),
                    (220, 0, 0), 2)
            ax[row][col].axis('off')
            ax[row][col].imshow(img)

In [None]:
sample_aug(0, dataset)

In [None]:
sample_aug(5, dataset)

### Mosaic
References:
1. [ultralytics/yolov5](https://github.com/ultralytics/yolov5/blob/831773f5a23926658ee76459ce37550643432123/utils/datasets.py#L529)
2. [shonenkov/training-efficientdet](https://www.kaggle.com/shonenkov/training-efficientdet)

In [None]:
dataset = WheatDataset(
    df=df, image_dir=BASE_DIR + "train/",
    transforms=get_train_transforms(image_size=512, cutout=False),
    mosaic_p=1.0
)

In [None]:
sample_aug(0, dataset)

In [None]:
sample_aug(5, dataset)