In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import cv2
import os

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory


#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
import torchvision
import torchvision.transforms as T
import math

In [None]:
!cp -r /kaggle/input/detr-wheat/ /root/.cache/torch

In [None]:
ls /root/.cache/torch

In [None]:
class DETRModel(nn.Module):
    def __init__(self,num_classes,num_queries):
        super(DETRModel,self).__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries
        
        self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True,force_reload=False)
        #self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True)
        #self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet101_dc5', pretrained=True)
        self.in_features = self.model.class_embed.in_features
        self.hidden_dim = self.model.transformer.d_model
        
        self.model.class_embed = nn.Linear(in_features=self.in_features,out_features=self.num_classes+1)
        self.model.num_queries = self.num_queries
        self.model.query_embed = nn.Embedding(self.num_queries, self.hidden_dim)
        
    def forward(self,images):
        return self.model(images)

In [None]:
model = DETRModel(num_classes=1,num_queries=150)

In [None]:
model.load_state_dict(torch.load("../input/detr-wheat/tf_detr_best_0_round22_best_class.pth"))

In [None]:
model.eval()
#model.to('cpu')

In [None]:
inf_transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
img = Image.open('../input/global-wheat-detection/test/f5a1f0358.jpg').resize((480,360))
imw, imh = img.size
imw, imh

In [None]:
ratio = imh/imw

In [None]:
img_inf = img.resize((1024,int(1024 * ratio))).convert('RGB')
img_inf

In [None]:
img_tens = inf_transform(img_inf).unsqueeze(0)

In [None]:
with torch.no_grad():
    outputs = model(img_tens)

In [None]:
im2 = img.copy()
drw = ImageDraw.Draw(im2)

oboxes = outputs['pred_boxes'][0] * torch.Tensor([imw, imh, imw, imh])
prob   = outputs['pred_logits'][0].softmax(1).detach().cpu().numpy()
classes = np.argmax(prob, axis=1)
prob   = np.max(prob, axis=1)

#print(classes)
#print(prob)


for box,p,c in zip(oboxes,prob,classes):

    if (p > 0.5) & (c!=1):
        box = box.cpu()
        print(p, box)        
        x, y, w, h = box
        x0, x1 = x-w//2, x+w//2
        y0, y1 = y-h//2, y+h//2
        drw.rectangle([x0, y0, x1, y1], outline='red', width=2)
        drw.text((x0+4, y0+4), '%.1f'%(p*100), fill='white')

In [None]:
im2

In [None]:
test_files = os.listdir('../input/global-wheat-detection/test/')

In [None]:
test_files

In [None]:
for file in test_files:
    img = Image.open('../input/global-wheat-detection/test/'+file)
    break
img

In [None]:
def format_prediction_string(boxes, scores):
    pred_strings = []
    for j in zip(scores, boxes):
        pred_strings.append("{0:.4f} {1} {2} {3} {4}".format(j[0], j[1][0], j[1][1], j[1][2], j[1][3]))
    return " ".join(pred_strings)

In [None]:
results =[]
fig, ax = plt.subplots(5, 2, figsize=(30, 70))
count = 0

for file in test_files:
    img = Image.open('../input/global-wheat-detection/test/'+file)
    im_w, im_h = img.size
    ratio = im_h / im_w
    img = img.resize((608,int(608*ratio)))
    img = img.resize((800,int(800*ratio))).convert('RGB')
    img_tens = inf_transform(img).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img_tens)
        
    try:
        oboxes = outputs['pred_boxes'][0]
        prob   = outputs['pred_logits'][0].softmax(1).detach().cpu().numpy()
        classes = torch.from_numpy(np.argmax(prob, axis=1))
        probs   = torch.from_numpy(np.max(prob, axis=1)).squeeze(0).squeeze(0)

        pred_strings = []

        mask_classes = classes !=1
        idx_classes = torch.nonzero(mask_classes)

        oboxes=oboxes[idx_classes].permute(1,0,2).squeeze(0)
        probs=probs[idx_classes].permute(1,0).squeeze(0)

        mask_probs = probs > 0.8
        idx_probs = torch.nonzero(mask_probs)

        oboxes=oboxes[idx_probs].permute(1,0,2).squeeze(0)#.detach().cpu().numpy()
        probs=probs[idx_probs].permute(1,0).squeeze(0).detach().cpu().numpy()

        oboxes=oboxes * torch.Tensor([im_w, im_h, im_w, im_h])
        oboxes=oboxes.detach().cpu().numpy()#.astype(np.int32)
        
        num_boxes = oboxes.shape[0]
        
        transf = np.concatenate((oboxes[:, [2, 3]]/2, np.zeros((num_boxes, 2))), axis=1)     

        #print(oboxes)
        
        oboxes = (oboxes[:, [0, 1, 2, 3]] - transf).astype(np.int32)#.clip(min=0, max=1024)
        
        oboxes = np.clip(oboxes, a_min = 0, a_max = [im_w, im_h, im_w, im_h]) 
        
        #oboxes = oboxes.astype(np.int32)
        #print(oboxes)
        
        #transf = oboxes[:, [2, 3]]//2
        

        print(file, len(oboxes), len(probs))

    except Exception as e:
        oboxes=np.array([])
        probs=np.array([])
        #print('erro')
        #print(str(e))
        
       
    image_id = file.split(".")[0]
    
    result = {'image_id': image_id,'PredictionString': format_prediction_string(oboxes, probs)}
    results.append(result)    
      
    #if file == '2fd875eaa.jpg':
    if count < 10:
        img_ = cv2.imread('../input/global-wheat-detection/test/'+file)
        img_ = cv2.cvtColor(img_, cv2.COLOR_BGR2RGB)
        #print('ok')
        #img = Image.open('../input/global-wheat-detection/test/'+file)
        #im2 = img.copy()
        #drw = ImageDraw.Draw(im2)
        
        for box,p in zip(oboxes,probs):
            
            x, y, w, h = box
            #x0, x1 = x-w//2, x+w//2
            #y0, y1 = y-h//2, y+h//2
            
            x0, x1 = x, x+w
            y0, y1 = y, y+h
            
            #print(x0, y0, x1, y1)
            
            #drw.rectangle([x0, y0, x1, y1], outline='red', width=2)
            #drw.text((x0+4, y0+4), '%.1f'%(p*100), fill='white')
            cv2.rectangle(img_, (x0, y0), (x1, y1), (220, 0, 0), 2)
            cv2.putText(img_, '%.1f'%(p*100), (x0, y0), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2, cv2.LINE_AA)
            ax[count%5][count//5].imshow(img_)
        count+=1

In [None]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])

In [None]:
test_df.sort_values(by='image_id', inplace=True)
test_df

In [None]:
!rm -rf ./*

In [None]:
test_df.to_csv('submission.csv', index=False)
test_df.head(10)