Here is the training notebook [Training Notebook](https://www.kaggle.com/teykaihong/gwd-pytorch-fasterrcnn-training)

In [None]:
%cd ../input/myfile

In [None]:
import torch
import torchvision
import glob
import os
import cv2
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import utils
import transforms as T
from PIL import Image
from ensemble_boxes_wbf import *
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

%matplotlib inline
matplotlib.rcParams['font.size'] = 14
matplotlib.rcParams['figure.figsize'] = (27, 27)

In [None]:
%cd ..
%cd ..
%cd working

In [None]:
class WheatDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = glob.glob(os.path.join(root, '*'))
    
    def __getitem__(self, index):
        image_path = self.images[index]
        image_name = image_path.split('.')[-2].split('/')[-1]
        
        image = Image.open(image_path).convert('RGB')
        target = {}
        if self.transform is not None:
            image, _ = self.transform(image, target)
        return image, image_name
    
    def __len__(self):
        return len(self.images)

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

root = '../input/global-wheat-detection/test'

test_ds = WheatDataset(root, transform=get_transform(train=False))
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = '../input/gwd-pytorch-fasterrcnn-training/fasterrcnn_resnet50_fpn_plabel.pth'

def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

model = get_model(2).to(device)
model.load_state_dict(torch.load(weight_path))

In [None]:
def to_PredString(scores, boxes):
    result = []
    for score, box in zip(scores, boxes):
        if isinstance(scores, np.ndarray):
            scores = torch.tensor(scores, device=device, dtype=torch.float)
            
        result.append(round(score.item(),4))
        result.append(box[0].item())
        result.append(box[1].item())
        result.append(box[2].item())
        result.append(box[3].item())
    return ' '.join([str(x) for x in result])

In [None]:
class BaseWheatTTA():
    image_size = 1024
    
    def augment(self, image):
        raise NotImplementedError
    
    def batch_augment(self, images):
        raise NotImplementedError
    
    def deaugmented_boxes(self, boxes):
        raise NotImplementedError

class TTAHorizontalFlip(BaseWheatTTA):
    def augment(self, image):
        return image.flip(2)
    
    def batch_augment(self, images):
        return images.flip(3)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [0, 2]] = self.image_size - boxes[:, [2, 0]]
        return boxes

class TTAVerticalFlip(BaseWheatTTA):
    def augment(self, image):
        return image.flip(1)
    
    def batch_augment(self, images):
        return images.flip(2)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [1, 3]] = self.image_size - boxes[:, [3, 1]]
        return boxes

class TTARotate90(BaseWheatTTA):
    def augment(self, image):
        return torch.rot90(image, 1, (1, 2))

    def batch_augment(self, images):
        return torch.rot90(images, 1, (2, 3))
    
    def deaugment_boxes(self, boxes):
        res_boxes = boxes.clone()
        res_boxes[:, [0,2]] = self.image_size - boxes[:, [3,1]]
        res_boxes[:, [1,3]] = boxes[:, [0,2]]
        return res_boxes

class TTACompose(BaseWheatTTA):
    def __init__(self, transforms):
        self.transforms = transforms
    
    def augment(self, image):
        for transform in self.transforms:
            image = transform.augment(image)
        return image
    
    def batch_augment(self, images):
        for transform in self.transforms:
            images = transform.batch_augment(images)
        return images
    
    def deaugment_boxes(self, boxes):
        for transform in self.transforms[::-1]:
            boxes = transform.deaugment_boxes(boxes)
        return boxes

In [None]:
from itertools import product

tta_transforms = []
for tta_combination in product([TTAHorizontalFlip(), None],
                              [TTAVerticalFlip(), None],
                              [TTARotate90(), None]):
    tta_transforms.append(TTACompose([tta_tranform for tta_tranform in tta_combination if tta_tranform is not None]))

def make_tta_prediction(model, image, threshold=0.5):
    predictions = []
    model.eval()
    with torch.no_grad():
        for tta_transform in tta_transforms:
            result = []
            tta_transformed_image = tta_transform.augment(image[0])
            
            prediction = model([tta_transformed_image.to(device)])
            
            boxes = prediction[0]['boxes']
            scores = prediction[0]['scores']
            
            boxes = boxes[scores > threshold]
            scores = scores[scores > threshold]
            
            boxes = tta_transform.deaugment_boxes(boxes)
            
            result.append({
                'boxes': boxes,
                'scores': scores,
            })
            predictions.append(result)
    return predictions

def run_wbf(predictions, image_size, iou_thr=0.5, skip_box_thr=0.0001, weights=None):
    '''
    boxes: list
    scores: list
    labels: list
    '''
    
    boxes = [((predictions[i][0]['boxes'].clip(min=0, max=1023)) / (image_size - 1)).tolist() for i in range(len(predictions))]
    scores = [predictions[i][0]['scores'].tolist() for i in range(len(predictions))]
    labels = [torch.ones((len(box),), device=device, dtype=torch.int8).tolist() for box in boxes]
    
    boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    boxes = boxes * (image_size - 1)
    return boxes, scores, labels

In [None]:
image_ids = []
PredStrings = []

def show_result(image: torch.Tensor, boxes: np.ndarray):
    img = Image.fromarray(image[0].permute(1,2,0).mul(255).byte().cpu().numpy().astype(np.uint8)).convert('RGB')
    img = np.array(img)
    for box in boxes:
        cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (225, 0, 0), 3)
    plt.imshow(img)
    plt.axis('off')

model.eval()
with torch.no_grad(): 
    for i, (image, image_id) in enumerate(test_dl):
        prediction = make_tta_prediction(model, image)
        
        boxes, scores, labels = run_wbf(predictions=prediction, image_size=1024)
        boxes = boxes.round().astype(np.int32).clip(min=0, max=1023)
        
#         plt.subplot(5, 2, i + 1)
#         show_result(image, boxes)
    
        boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        
        pred_string = to_PredString(scores, boxes)
        
        image_ids.append(image_id[0])
        PredStrings.append(pred_string)

In [None]:
test_df = pd.DataFrame({'image_id': image_ids, 'PredictionString': PredStrings})
test_df

In [None]:
test_df.to_csv('submission.csv', index=False)