In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import cv2
import os
import re

# Imports for image transforms
# Albumentations bounding box augmentation docs: https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

# Torch imports
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler

from matplotlib import pyplot as plt

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
INPUT_DIR = '/kaggle/input/global-wheat-detection'
# PRETRAINED_DIR = '/kaggle/input/wheat-dataset-project'
PRETRAINED_DIR = '/kaggle/input/wheat-frcnn-bayesian/'
OUTPUT_DIR = '/kaggle/output/'
TRAIN_DIR = f'{INPUT_DIR}/train'
TEST_DIR = f'{INPUT_DIR}/test'
# MODEL_LOC = f'{PRETRAINED_DIR}/fasterrcnn_resnet50_fpn_TRAINED.pth'
MODEL_LOC = f'{PRETRAINED_DIR}/fasterrcnn_resnet50_fpn_BAYESOPT.pth'

# Model Inference
* Create TestDataset class (similar to WheatDataset but no bboxes, transforms only change to tensor)
* Create dataset using TestDataset clas
* Create dataloader
* Loop over images, image_ids in dataloader
    * Within each iteration, get outputs by calling model(images)
    * Loop over i, image in enum(images)
        * Get boxes and scores from outputs for element i
        * Threshold boxes and scores
        * Get boxes from \[xmin ymin xmax ymax\] form into \[x y w h\] form
        * make result dict for image id and prediction string (in competition format) and append to result list
* Sample from outputs as before (with score threshold on boxes) to display prediction

In [None]:
def test_transform():
    return A.Compose([ToTensorV2(p=1.0)])

class TestDataset(Dataset):
    
    def __init__(self, df, directory, transforms=None):
        super().__init__()
        
        self.image_ids = df['image_id'].unique()
        self.df = df
        self.dir = directory
        self.transforms = transforms
        
    def __len__(self):
        return int(self.image_ids.shape[0])
    
    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        image = cv2.imread(f'{self.dir}/{image_id}.jpg', cv2.IMREAD_COLOR)
        # cv2 reads images into BGR format, must convert to RGB for f-RCNN
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        # f-RCNN requires images in [C,W,H] form with values in [0,1]
        image /= 255.0
        
        if self.transforms:
            dataToTransform = {'image': image}
            transData = self.transforms(**dataToTransform)
            image = transData['image']
        
        return image, image_id

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
num_classes = 2
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
pretrained_state_dict = torch.load(MODEL_LOC)
model.load_state_dict(pretrained_state_dict)
model.eval()
model.to(device)

In [None]:
test_df = pd.read_csv(f'{INPUT_DIR}/sample_submission.csv')
test_df.head()

In [None]:
test_df.shape

In [None]:
test_dataset = TestDataset(test_df, TEST_DIR, test_transform())

def collate_fn(batch):
    return tuple(zip(*batch))

test_dl = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, collate_fn=collate_fn)

In [None]:
THRESHOLD = .5
res = []

def getPredString(outputTup):
    formatted_strings = []
    for tup in outputTup:
        score = tup[0]
        x, y, w, h = tup[1]
        box_string = f'{score} {x} {y} {w} {h}'
        formatted_strings.append(box_string)
    return " ".join(formatted_strings)
    
    

for imgs, img_ids in test_dl:
    imgs = list(image.to(device) for image in imgs)
    model_outputs = model(imgs)
    
    for i,img in enumerate(imgs):
        scores = model_outputs[i]['scores'].data.cpu().numpy()
        bboxes = model_outputs[i]['boxes'].data.cpu().numpy()
        bboxes = bboxes[scores >= THRESHOLD].astype(np.int32)
        scores = scores[scores >= THRESHOLD]
        
        this_id = img_ids[i]
        
        bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
        bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
        
        res.append({'image_id': this_id, 'PredictionString': getPredString(zip(scores, bboxes))})

test_df = pd.DataFrame(res, columns=['image_id', 'PredictionString'])

In [None]:
test_df.head(10)

In [None]:
sample = imgs[1].permute(1,2,0).cpu().numpy()
scores = model_outputs[1]['scores'].data.cpu().numpy()
bboxes = model_outputs[1]['boxes'].data.cpu().numpy()
bboxes = bboxes[scores >= THRESHOLD].astype(np.int32)

fig, ax = plt.subplots(1, 1, figsize=(16, 8))

for box in bboxes:
    cv2.rectangle(sample,
                  (box[0], box[1]),
                  (box[2], box[3]),
                  (220, 0, 0), 2)
    
ax.set_axis_off()
ax.imshow(sample)

In [None]:
test_df.to_csv('submission.csv', index=False)