# Pytorch starter - FasterRCNN Inference

- You can find the [train notebook here](https://www.kaggle.com/pestipeti/pytorch-starter-fasterrcnn-train)
- The weights are [available here](https://www.kaggle.com/dataset/7d5f1ed9454c848ecb909c109c6fa8e573ea4de299e249c79edc6f47660bf4c5)

In [None]:
!pip install --no-deps '../input/timm-package/timm-0.1.26-py3-none-any.whl' > /dev/null
!pip install --no-deps '../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl' > /dev/null

In [None]:
import sys
sys.path.insert(0, "../input/timm-efficientdet-pytorch")
sys.path.insert(0, "../input/omegaconf")
sys.path.insert(0, "../input/weightedboxesfusion")

from ensemble_boxes import *
from effdet import get_efficientdet_config, EfficientDet, DetBenchEval
from effdet.efficientdet import HeadNet

import pandas as pd
import numpy as np
import cv2
import os
import re
import gc

from PIL import Image

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler

from matplotlib import pyplot as plt


DIR_INPUT = '/kaggle/input/global-wheat-detection'
DIR_TRAIN = f'{DIR_INPUT}/train'
DIR_TEST = f'{DIR_INPUT}/test'


#'/kaggle/input/wwdadamfasterrcnn'
#'/kaggle/input/fasterrcnn-wheat-detection60epoch'
#'/kaggle/input/fasterrcnn-resnet50-fpn-wheat-detection'
#'/kaggle/input/global-wheat-detection-public'
# '/kaggle/input/gwd-fasterrcnn-30-epoch-good'

DIR_MAIN = '/kaggle/input'
DIR_WEIGHTS = '/kaggle/input/gwd-fasterrcnn-20-epoch-mosaic'

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

#fasterrcnn_resnet50_fpn.pth
#fasterrcnn_resnet50_fpn_heavy_aug.pth
#fasterrcnn_resnet50_fpn_10_epoch_mosaic.pth

In [None]:
from collections import OrderedDict
def load_efficient_net(checkpoint_path):
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size=512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(checkpoint_path)

    checkpoint2 = OrderedDict()

    for v in checkpoint:
        if v.startswith('anchors.boxes'):
            continue
            
        checkpoint2[v[6:]] = checkpoint[v]

    
    net.load_state_dict(checkpoint2)

    del checkpoint
    gc.collect()

    net = DetBenchEval(net, config)
    net.eval();
    return net.cuda()

def load_net(checkpoint_path):
    net = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    num_classes = 2  # 1 class (wheat) + background
    # get number of input features for the classifier
    in_features = net.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    net.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint)
    net = net.cuda()
    net.eval()

    del checkpoint
    gc.collect()
    return net


models = [
#         load_net(f'{DIR_MAIN}/fasterrcnn-resnet50-fpn-wheat-detection/fasterrcnn_resnet50_fpn.pth'), #TEST: 0.6643
#         load_net(f'{DIR_MAIN}/wwd-sgd-heavy-aug/fasterrcnn_resnet50_fpn_heavy_aug.pth'), #TEST: 0.6623
# #         load_net(f'{DIR_MAIN}/gwd-fasterrcnn-10-epoch-mosaic/fasterrcnn_resnet50_fpn_10_epoch_mosaic.pth'), # TEST: 0.6570
#         load_net(f'{DIR_MAIN}/gwd-fasterccn-30e-mosaic-heavy/fasterrcnn_resnet50_fpn_30e_mosaic_heavy.pth') # TEST: 0.6606
   load_net(f'{DIR_MAIN}/gwd-fasterrcnn-resnet50-fpn-fold0/fasterrcnn_resnet50_fpn_fold0.pth'),
   load_net(f'{DIR_MAIN}/gwd-fasterrcnn-resnet50-fpn-fold1/fasterrcnn_resnet50_fpn_fold1.pth'),
   load_net(f'{DIR_MAIN}/gwd-fasterrcnn-resnet50-fpn-fold2/fasterrcnn_resnet50_fpn_fold2.pth'),
   load_net(f'{DIR_MAIN}/gwd-fasterrcnn-resnet50-fpn-fold3/fasterrcnn_resnet50_fpn_fold3.pth'),
   load_net(f'{DIR_MAIN}/gwd-fasterrcnn-resnet50-fpn-fold4/fasterrcnn_resnet50_fpn_fold4.pth')
]



In [None]:
test_df = pd.read_csv(f'{DIR_INPUT}/sample_submission.csv')
test_df.shape

In [None]:
class BaseWheatTTA:
    """ author: @shonenkov """
    image_size = 512

    def augment(self, image):
        raise NotImplementedError
    
    def batch_augment(self, images):
        raise NotImplementedError
    
    def deaugment_boxes(self, boxes):
        raise NotImplementedError

class TTAHorizontalFlip(BaseWheatTTA):
    """ author: @shonenkov """

    def augment(self, image):
        return image.flip(1)
    
    def batch_augment(self, images):
        return images.flip(2)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [1,3]] = self.image_size - boxes[:, [3,1]]
        return boxes

class TTAVerticalFlip(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return image.flip(2)
    
    def batch_augment(self, images):
        return images.flip(3)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [0,2]] = self.image_size - boxes[:, [2,0]]
        return boxes
    
class TTARotate90(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return torch.rot90(image, 1, (1, 2))

    def batch_augment(self, images):
        return torch.rot90(images, 1, (2, 3))
    
    def deaugment_boxes(self, boxes):
        res_boxes = boxes.copy()
        res_boxes[:, [0,2]] = self.image_size - boxes[:, [1,3]]
        res_boxes[:, [1,3]] = boxes[:, [2,0]]
        return res_boxes

class TTACompose(BaseWheatTTA):
    """ author: @shonenkov """
    def __init__(self, transforms):
        self.transforms = transforms
        
    def augment(self, image):
        for transform in self.transforms:
            image = transform.augment(image)
        return image
    
    def batch_augment(self, images):
        for transform in self.transforms:
            images = transform.batch_augment(images)
        return images
    
    def prepare_boxes(self, boxes):
        result_boxes = boxes.copy()
        result_boxes[:,0] = np.min(boxes[:, [0,2]], axis=1)
        result_boxes[:,2] = np.max(boxes[:, [0,2]], axis=1)
        result_boxes[:,1] = np.min(boxes[:, [1,3]], axis=1)
        result_boxes[:,3] = np.max(boxes[:, [1,3]], axis=1)
        return result_boxes
    
    def deaugment_boxes(self, boxes):
        for transform in self.transforms[::-1]:
            boxes = transform.deaugment_boxes(boxes)
        return self.prepare_boxes(boxes)

In [None]:
from itertools import product

tta_transforms = []
for tta_combination in product([TTAHorizontalFlip(), None], 
                               [TTAVerticalFlip(), None],
                               [TTARotate90(), None]):
    tta_transforms.append(TTACompose([tta_transform for tta_transform in tta_combination if tta_transform]))

In [None]:
class WheatTestDataset(Dataset):

    def __init__(self, dataframe, image_dir, transforms=None):
        super().__init__()

        self.image_ids = dataframe['image_id'].unique()
        self.df = dataframe
        self.image_dir = image_dir
        self.transforms = transforms

    def __getitem__(self, index: int):
        self.image_mean = [0.43216, 0.394666, 0.37645]      
        self.image_std = [0.22803, 0.22145, 0.216989]
        
        image_id = self.image_ids[index]
        records = self.df[self.df['image_id'] == image_id]

        image = cv2.imread(f'{self.image_dir}/{image_id}.jpg', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#.astype(np.float32)
#         image /= 255.0

        if self.transforms:
            sample = {
                'image': image,
            }
            sample = self.transforms(**sample)
            image = sample['image']/255.0 #self.normalize(sample['image']/255.0)

        return image, image_id

    def __len__(self) -> int:
        return self.image_ids.shape[0]
    
    def normalize(self, image):
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

In [None]:
# Albumentations
def get_test_transform():
    return A.Compose([
        A.Resize(512, 512),
        ToTensorV2(p=1.0)
    ])

# Albumentations
def get_flip_test_transform():
    return A.Compose([
        A.VerticalFlip(p=1.0),
        ToTensorV2(p=1.0)
    ])


In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

test_dataset = WheatTestDataset(test_df, DIR_TEST, get_test_transform())

test_data_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    drop_last=False,
    collate_fn=collate_fn
)

In [None]:
def format_prediction_string(boxes, scores):
    pred_strings = []
    for j in zip(scores, boxes):
        pred_strings.append("{0:.4f} {1} {2} {3} {4}".format(j[0], j[1][0], j[1][1], j[1][2], j[1][3]))

    return " ".join(pred_strings)

In [None]:
def make_predictions(images, score_threshold=0.1):
    images = torch.stack(images).cuda().float()
    predictions = []
    for net in models:
        with torch.no_grad():
            det = net(images, torch.tensor([1]*images.shape[0]).float().cuda())
            for i in range(images.shape[0]):
                boxes = det[i].detach().cpu().numpy()[:,:4]    
                scores = det[i].detach().cpu().numpy()[:,4]
                indexes = np.where(scores > score_threshold)[0]
                boxes = boxes[indexes]
                boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
                boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                predictions.append({
                    'boxes': boxes[indexes],
                    'scores': scores[indexes],
                })
    return [predictions]

def make_tta_predictions(images, score_threshold=0.5):
    images = torch.stack(images).float().cuda()
    predictions = []
    for net in models:
        with torch.no_grad():                      
            for tta_transform in tta_transforms:
                result = []
                det = net(tta_transform.batch_augment(images.clone()))
                print(det)
                for i in range(images.shape[0]):
                    boxes = det[i].detach().cpu().numpy()[:,:4]    
                    scores = det[i].detach().cpu().numpy()[:,4]
                    indexes = np.where(scores > score_threshold)[0]
                    boxes = boxes[indexes]
                    boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
                    boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                    boxes = tta_transform.deaugment_boxes(boxes.copy())
                    result.append({
                        'boxes': boxes,
                        'scores': scores[indexes],
                    })
                predictions.append(result)
    return predictions

def make_tta_ensemble_predictions(images, score_threshold=0.5):
    images = torch.stack(images).float().cuda()
    predictions = []
    for net in models:
        with torch.no_grad():                      
            for tta_transform in tta_transforms:
                result = []
                det = net(tta_transform.batch_augment(images.clone()))

                for i in range(images.shape[0]):
                    boxes = det[i]['boxes'].detach().cpu().numpy()  
                    scores = det[i]['scores'].detach().cpu().numpy()
                    indexes = np.where(scores > score_threshold)[0]
                    boxes = boxes[indexes]
#                     boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
#                     boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                    boxes = tta_transform.deaugment_boxes(boxes.copy())
                    result.append({
                        'boxes': boxes,
                        'scores': scores[indexes],
                    })
                predictions.append(result)
    return predictions

def make_ensemble_predictions(images):
#     transform = A.VerticalFlip(p=1.0)
#     images_lr = [torch.as_tensor(transform.apply(image.numpy())) for image in images]
    
    images = list(image.to(device) for image in images)
#     images_lr = list(image.to(device) for image in images_lr)
    
    result = []
    for net in models:
        outputs = net(images)
        result.append(outputs)
        
#         images = [np.fliplr(image) for image in images]
#         outputs = net(images_lr)
#         for output in outputs:
#             outputs['boxes'] = [transform.apply_to_bbox(box) for box in output['boxes']]
#         outputs[:]['boxes'] = map(transform.apply_to_bbox, outputs[:]['boxes'])
    
#     for net in models:
#         images = [np.fliplr(image) for image in images]
#         outputs = net(images)
#         outputs['boxes'] = map(np.fliplr, outputs['boxes'])
#         result.append(outputs)
#     print(outputs[0]['boxes'], type(outputs[0]['boxes']))
    return result

# from ensemble_boxes import *
# def run_wbf(predictions, image_index, image_size=512, iou_thr=0.55, skip_box_thr=0.7, weights=None):
#     boxes = [prediction[image_index]['boxes'].data.cpu().numpy()/(image_size-1) for prediction in predictions]
#     scores = [prediction[image_index]['scores'].data.cpu().numpy() for prediction in predictions]
#     labels = [np.ones(prediction[image_index]['scores'].shape[0]) for prediction in predictions]
#     boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
#     boxes = boxes*(image_size-1)
#     return boxes, scores, labels

def run_wbf(predictions, image_index, image_size=512, iou_thr=0.44, skip_box_thr=0.43, weights=None):
    boxes = [(prediction[image_index]['boxes']/(image_size-1)).tolist()  for prediction in predictions]
    scores = [prediction[image_index]['scores'].tolist()  for prediction in predictions]
    labels = [np.ones(prediction[image_index]['scores'].shape[0]).tolist() for prediction in predictions]
    boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    boxes = boxes*(image_size-1)
    return boxes, scores, labels

In [None]:
detection_threshold = 0.5
results = []

for images, image_ids in test_data_loader:

    images = list(image.to(device) for image in images)
#     outputs = model(images)
#     outputs = make_ensemble_predictions(images)
    outputs = make_tta_ensemble_predictions(images)

    for i, image in enumerate(images):
#         boxes, scores, labels = run_wbf([outputs],image_size=1024, image_index=i)
#         boxes = (boxes*2).astype(np.int32).clip(min=0, max=1023)

        boxes, scores, labels = run_wbf(outputs, image_index=i)
    
#         boxes = outputs[i]['boxes'].data.cpu().numpy()
#         scores = outputs[i]['scores'].data.cpu().numpy()
        
#         boxes = boxes[scores >= detection_threshold].astype(np.int32)
#         scores = scores[scores >= detection_threshold]
        
        if image.shape[1] == 512:
            boxes = (boxes*2).astype(np.int32).clip(min=0, max=1023)
            
        image_id = image_ids[i]
        
        boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        
        result = {
            'image_id': image_id,
            'PredictionString': format_prediction_string(boxes, scores)
        }

        
        results.append(result)


In [None]:
results[0:2]

In [None]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])
test_df.head()

In [None]:
sample = images[1].permute(1,2,0).cpu().numpy()
boxes = outputs[0][1]['boxes']#.data.cpu().numpy()
scores = outputs[0][1]['scores']#.data.cpu().numpy()

boxes = boxes[scores >= detection_threshold].astype(np.int32)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(16, 8))

for box in boxes:
    cv2.rectangle(sample,
                  (box[0], box[1]),
                  (box[2], box[3]),
                  (220, 0, 0), 2)
    
ax.set_axis_off()
ax.imshow(sample)

In [None]:
print(sample.shape)

In [None]:
test_df.to_csv('submission.csv', index=False)