In [12]:
import os

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from torch.utils.data import DataLoader, Dataset

In [2]:
TEST_DIR = "../input/global-wheat-detection/test/"
TEST_CSV_DIR = "../input/global-wheat-detection/sample_submission.csv"
WEIGHTS = "../input/santi-model-2/santi_2.pth"

In [3]:
test_df = pd.read_csv(TEST_CSV_DIR)

In [4]:
class WheatTestDataset(Dataset):

    def __init__(self, root, bboxes, transforms=None):
        super().__init__()

        self.imgs = bboxes['image_id'].unique()
        self.bboxes = bboxes
        self.root = root
        self.transforms = transforms

    def __getitem__(self, idx):

        img_id = self.imgs[idx]
        bbox = self.bboxes[self.bboxes['image_id'] == img_id]
        
        img_path = os.path.join(self.root, img_id + ".jpg")
        img = np.array(Image.open(img_path).convert("RGB")) / 255

        if self.transforms:
            sample = {
                'image': img,
            }
            sample = self.transforms(**sample)
            img = sample['image']
        
        img = torch.as_tensor(img, dtype=torch.float32)

        return img, img_id

    def __len__(self):
        return self.imgs.shape[0]

In [5]:
# Albumentations
def get_test_transform():
    return A.Compose([
        # A.Resize(512, 512),
        ToTensorV2(p=1.0)
    ])

In [6]:
# load a model; pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_classes = 2  # 1 class (wheat) + background

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Load the trained weights
model.load_state_dict(torch.load(WEIGHTS))
model.eval()

x = model.to(device)

In [7]:
def collate_fn(batch):
    return tuple(zip(*batch))

test_dataset = WheatTestDataset(TEST_DIR, test_df, get_test_transform())

test_data_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    drop_last=False,
    collate_fn=collate_fn
)

In [8]:
def format_prediction_string(boxes, scores):
    pred_strings = []
    for j in zip(scores, boxes):
        pred_strings.append("{0:.4f} {1} {2} {3} {4}".format(j[0], j[1][0], j[1][1], j[1][2], j[1][3]))

    return " ".join(pred_strings)

In [13]:
detection_threshold = 0.5
results = []

for images, image_ids in test_data_loader:

    images = list(image.to(device) for image in images)
    outputs = model(images)

    for i, image in enumerate(images):

        boxes = outputs[i]['boxes'].data.cpu().numpy()
        scores = outputs[i]['scores'].data.cpu().numpy()
        
        boxes = boxes[scores >= detection_threshold].astype(np.int32)
        scores = scores[scores >= detection_threshold]
        image_id = image_ids[i]
        
        boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        
        result = {
            'image_id': image_id,
            'PredictionString': format_prediction_string(boxes, scores)
        }

        
        results.append(result)

In [14]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])
test_df.head()

Unnamed: 0,image_id,PredictionString
0,aac893a91,0.9825 614 916 80 108 0.9505 552 523 129 202 0...
1,51f1be19e,0.9773 608 77 166 190 0.9718 497 462 221 118 0...
2,f5a1f0358,0.9779 138 750 160 124 0.9776 283 453 170 112 ...
3,796707dd7,0.9630 892 332 118 95 0.7622 48 82 158 125 0.7...
4,51b3e36ab,0.9949 231 644 95 161 0.9927 0 436 106 347 0.9...


In [15]:
test_df.to_csv('submission.csv', index=False)