# Global Wheat Detection Faster R-CNN

In [None]:
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import os
import ast

In [None]:
from torch.utils.data import Dataset
import torchvision.transforms as T
import torch

# Dataset

In [None]:
class GlobalWheatDataset(Dataset):
    def __init__(self, path, train_or_test_path, transforms):
        self.path = path
        self.train_or_test_path = train_or_test_path
        self.transforms = transforms

        self.df = pd.read_csv(path + 'train.csv')

        self.ids = {v:k for k, v in enumerate(np.unique(self.df.image_id.values))}
        self.imgs_list = list(sorted(os.listdir(os.path.join(path, train_or_test_path))))

    def get_rectangles(self, idx):
        id = self.imgs_list[idx].split('/')[-1].split('.jpg')[0]
        rectangles = []

        for box in self.df[self.df.image_id == id]['bbox'].values:
            bbox = ast.literal_eval(box)
            x = bbox[0]
            y = bbox[1]
            w = bbox[2]
            h = bbox[3]
            rectangles.append(patches.Rectangle((x,y),w,h,linewidth=1,edgecolor='r',facecolor='none'))

        return rectangles

    def format_boxes(self, boxes):
        # replace width, height with xmax, ymax
        try:
            boxes[:, 2] =  boxes[:, 2] + boxes[:, 0]
            boxes[:, 3] =  boxes[:, 3] + boxes[:, 1]
        except:
            pass
        return boxes

    def get_image(self, idx):
        img_path = os.path.join(self.path, self.train_or_test_path, self.imgs_list[idx])
        return np.array(Image.open(img_path).convert("RGB"))

    def draw(self, idx):
        fig, ax = plt.subplots(1, figsize=(10, 10))
        ax.imshow(dataset.get_image(idx))
        for rectangle in dataset.get_rectangles(idx):
            ax.add_patch(rectangle)
        plt.show

    def __getitem__(self, idx):
        id = self.df.iloc[idx].image_id
        boxes = np.int64(np.array([ast.literal_eval(box) for box in self.df[self.df.image_id == id]['bbox'].values]))

        # format boxes width, height
        boxes = self.format_boxes(boxes)
        
        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.int64)
        target["labels"] = torch.ones((len(boxes),), dtype=torch.int64)
        target["image_id"] = torch.tensor([self.ids[id]])
        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64)
        
        img_path = os.path.join(self.path, self.train_or_test_path, self.imgs_list[idx])
        img = Image.open(img_path).convert("RGB")

        if self.transforms is not None:
            img = self.transforms(img)

        return T.ToTensor()(img), target

    def __len__(self):
        return len(self.imgs_list)

# Augmentations

In [None]:
def get_transform(train):
    transforms = []
    if train:
        # random horizontal flip with 50% probability
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
path = '../input/global-wheat-detection/'
train_path = 'train'
test_path = 'test'

In [None]:
dataset = GlobalWheatDataset(path, train_path, get_transform(train=True))
dataset.draw(8)

# Create Datasets and DataLoaders

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def my_collate(batch):
    return tuple(zip(*batch))

In [None]:
dataset = GlobalWheatDataset(path, train_path, get_transform(train=True))
dataset_test = GlobalWheatDataset(path, test_path, get_transform(train=False))

indices = torch.randperm(len(dataset)).tolist()
indices_test = torch.randperm(len(dataset_test)).tolist()

dataset = torch.utils.data.Subset(dataset, indices)
dataset_test = torch.utils.data.Subset(dataset_test, indices_test)

data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=2, shuffle=True, num_workers=4,
        collate_fn=my_collate)

data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=4,
        collate_fn=my_collate)

# Create Faster RCNN Model with Resnet50

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [None]:
model = model.to(device)

# Training

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

In [None]:
total_errors = []
for epoch in range(101):
    losses_arr = []

    for images, targets in data_loader:

        images = list(image.to(device) for image in images)
        targets = [{k: torch.as_tensor(v).detach().to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses_arr.append(losses.item())

        losses.backward()
        optimizer.step()

        # update the learning rate
        # lr_scheduler.step()
        
    total_errors.append(np.mean(np.array(losses_arr)))
    if epoch % 10 == 0:
        print("Epoch:{0:3d}, Loss:{1:1.3f}".format(epoch, total_errors[-1]))

# Evaluate Loss

In [None]:
plt.plot(total_errors)

# Prediction

Once you have a small loss error, you can try to predict the bounding boxes for the testing data.

In [None]:
def get_boxes(result):
    rectangles = []

    for box in result[0]['boxes']:
        # boxes.append([xmin, ymin, xmax, ymax])
        x = box[0]
        y = box[1]
        w = box[2] - box[0]
        h = box[3] - box[1]

        rectangles.append(patches.Rectangle((x,y),w,h,linewidth=1,edgecolor='r',facecolor='none'))

    return rectangles

In [None]:
def draw_boxes(image, boxes):
    # move the depth
    im = image[0].permute(1, 2, 0).cpu().numpy()

    fig,ax = plt.subplots(1)
    ax.imshow(im)

    for box in boxes:
        ax.add_patch(box)

    plt.show()

In [None]:
images, targets = next(iter(data_loader_test))
images = list(image.to(device).type(torch.cuda.FloatTensor) for image in images)

In [None]:
# make sure you eval so you can predict targets
model.eval
model = model.to(device)

In [None]:
result = model(images)
boxes = get_boxes(result)