In [None]:
# cloning detr's github repository
!git clone https://github.com/facebookresearch/detr.git

# Importing our Dependencies

In [None]:
import os
import numpy as np 
import pandas as pd 
from datetime import datetime
import time
import random
import sys
import numba  # python and numpy code acceleartor
from tqdm.autonotebook import tqdm

#Torch
import torch
import torch. nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler

#CV
import cv2

################# DETR FUCNTIONS FOR LOSS######################## 
# losses used by detr
sys.path.append('../input/detrfiles/results/')
sys.path.append('../input/detrfiles/results/detr/')


from detr.models.matcher import HungarianMatcher
from detr.models.detr import SetCriterion
################################################################

#Albumenatations
import albumentations as A
import matplotlib.pyplot as plt
from albumentations.pytorch.transforms import ToTensorV2

#Glob
from glob import glob

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw
from IPython.display import display

In [None]:
# returns imgages in the form of a batch
def collate_fn(batch):
    return tuple(zip(*batch))

# transforms to be applied on the test images
def get_test_transforms():
    return A.Compose([A.Resize(height=512, width=512, p=1.0),
                      ToTensorV2(p=1.0)], 
                      p=1.0, 
                    )

def change_to_xmin_ymin_wh(bboxes,size):
    """A utility function to convert bounding boxes output by detr to (x_min, y_min, w, h) format"""
    
    box = np.array(bboxes)
    xy , wh = box[:,[0,1]] , box[:,[2,3]]
    xy = xy-(wh/2)
    xywh = np.concatenate((xy,wh),axis=1)
    
    # denormalise bbox
    xywh = xywh*size
    return xywh

# Creating Test Dataset

In [None]:
DIR_TEST = "../input/global-wheat-detection/test"

class WheatTestDataset(Dataset):
    def __init__(self,image_ids,dataframe,transforms=None):
        self.image_ids = image_ids
        self.df = dataframe
        self.transforms = transforms
        
        
    def __len__(self) -> int:
        return self.image_ids.shape[0]
    
    def __getitem__(self,index):
        image_id = self.image_ids[index]
        
        # reading and processing image
        image = cv2.imread(f'{DIR_TEST}/{image_id}.jpg', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0

        # applying transforms
        if self.transforms:
            sample = {
                'image': image,
            }
            sample = self.transforms(**sample)
            image = sample['image']
        
        return image, image_id

# Model

In [None]:
class DETRModel(nn.Module):
    def __init__(self, num_classes, num_queries, model_name='detr_resnet50'):
        super(DETRModel, self).__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries

        self.model = torch.hub.load('facebookresearch/detr', model_name, pretrained=True)
        self.in_features = self.model.class_embed.in_features

        self.model.class_embed = nn.Linear(in_features=self.in_features,
                                           out_features=self.num_classes)
        self.model.num_queries = self.num_queries

    def forward(self, images):
        return self.model(images)

In [None]:
test_df = pd.read_csv("../input/global-wheat-detection/sample_submission.csv")
image_ids = test_df.image_id.unique()
print(test_df.shape)

In [None]:
# path to trained weights
WEIGHTS_FILE = "../input/detrfiles/detr_best_3.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 2
num_queries = 100

# instantiating the model
model = DETRModel(num_classes=num_classes, num_queries=num_queries, model_name='detr_resnet50')
model = model.to(device)

# loading trained weights for predictions
model.load_state_dict(torch.load(WEIGHTS_FILE))
model.eval()

In [None]:
clf_gt = []

# creating the test dataset
test_dataset = WheatTestDataset(
    image_ids=image_ids,
    dataframe=test_df,
    transforms=get_test_transforms()
)

# loading the test dataset batch-wise
data_test_loader = DataLoader(
    test_dataset, 
    batch_size=5, 
    collate_fn=collate_fn,
)

# Making predictions

In [None]:
bounding_boxes = {}

confidence_thrsh = 0.5
for i, (images, image_ids) in enumerate(tqdm(data_test_loader)):
    with torch.no_grad():
        images = list(image.to(device, dtype=torch.float) for image in images)
        outputs = model(images) 
   
    outputs = {k: v.to('cpu') for k, v in outputs.items()} # dictionary of outputs
     
    for j, (image_name, bboxes, logits) in enumerate(zip(image_ids, outputs['pred_boxes'], outputs['pred_logits'])):

        # denormalizing and scaling  boxes 
        oboxes = bboxes.detach().cpu()
        
        # converting to required format
        oboxes = change_to_xmin_ymin_wh(oboxes,512)     
        
        # rescaling the bounding boxes
        oboxes = (oboxes*2).astype(np.int32).clip(min=0,max=1023)
        
        # applying softmax on logits
        prob   = logits.softmax(1).detach().cpu().numpy()[:, 0]
        
        # creating a dictionary for submission
        clf_gt.append({
                'image_id': image_name,
                'PredictionString': ' '.join(
                    str(round(confidence,4)) 
                    + ' '
                    + ' '.join(str(int(round(float(x)))) for x in box) 
                    for box, confidence in zip(oboxes, prob)
                    if confidence > confidence_thrsh
                )
                ,
            })
        
        temp = []
        
        # considering bounding boxes with confidence greater than a threshold
        for box, confidence in zip(oboxes, prob):
            if confidence > confidence_thrsh:
                temp.append(box)
                
        bounding_boxes[image_name]=temp

In [None]:
submission_df = pd.DataFrame(clf_gt)
submission_df['PredictionString'] = submission_df['PredictionString'].fillna('')
submission_df.to_csv('submission.csv', index=False)
submission_df

In [None]:
def display_bboxes(img, bboxes):
    draw = ImageDraw.Draw(box_img)
    for box in bboxes:
        x,y,w,h = box
        x0,x1 = x,x+w
        y0,y1 = y,y+h

        draw.rectangle([x0, y0, x1, y1], outline = "blue", width = 2)
    display(box_img)

# Visualizing the predicted Bounding Boxes

In [None]:
image_id = image_ids[0]
boxes = bounding_boxes[image_id]
img = Image.open("../input/global-wheat-detection/test/"+image_id+".jpg")
box_img = img.copy()
display_bboxes(box_img, boxes)

In [None]:
image_id = image_ids[1]
boxes = bounding_boxes[image_id]
img = Image.open("../input/global-wheat-detection/test/"+image_id+".jpg")
box_img = img.copy()
display_bboxes(box_img, boxes)

In [None]:
image_id = image_ids[2]
boxes = bounding_boxes[image_id]
img = Image.open("../input/global-wheat-detection/test/"+image_id+".jpg")
box_img = img.copy()
display_bboxes(box_img, boxes)

In [None]:
image_id = image_ids[3]
boxes = bounding_boxes[image_id]
img = Image.open("../input/global-wheat-detection/test/"+image_id+".jpg")
box_img = img.copy()
display_bboxes(box_img, boxes)

In [None]:
image_id = image_ids[4]
boxes = bounding_boxes[image_id]
img = Image.open("../input/global-wheat-detection/test/"+image_id+".jpg")
box_img = img.copy()
display_bboxes(box_img, boxes)