# Imports 


In [None]:
import pandas as pd
import numpy as np
import torch, torchvision, os, math, re, wandb,pdb
from ast import literal_eval as lv
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
!git clone https://github.com/facebookresearch/detr.git -q
import sys
sys.path.append('./detr/')
from detr.models.matcher import HungarianMatcher
from detr.models.detr import SetCriterion

* [Initial df generation](https://www.kaggle.com/code/pestipeti/pytorch-starter-fasterrcnn-train)

* [best repo for small object detection resources](https://github.com/kuanhungchen/awesome-tiny-object-detection#tiny-object-detection)

* [detr base](https://www.kaggle.com/code/tanulsingh077/end-to-end-object-detection-with-transformers-detr/notebook)

# Helpers

In [None]:
def read_img(path):
    img=torchvision.io.read_image(path)
    return img

def display_(path_df,base_dir,disp=False):
    df=pd.read_csv(path_df)
    df['path']=base_dir+'/'+df.image_id+'.jpg'
    df['x_min']=df.bbox.apply(lambda x: lv(x)[0])
    df['y_min']=df.bbox.apply(lambda x: lv(x)[1])
    df['x_max']=df.bbox.apply(lambda x: lv(x)[0]+lv(x)[2])
    df['y_max']=df.bbox.apply(lambda x: lv(x)[1]+lv(x)[3])
    df['area']=df.bbox.apply(lambda x: lv(x)[2]*lv(x)[3])
    
    if disp:

        display(df)
        print(df.info())
        for col in df.columns:
            print(f'unique values in {col} = {df[col].nunique()}')

    return df


In [None]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

def get_train_transform(format_='coco'):
    return A.Compose([A.Resize(height=512, width=512, p=1.0),
                      ToTensorV2(p=1.0)], 
                      p=1.0, 
                      bbox_params=A.BboxParams(format=format_,min_area=0, min_visibility=0,label_fields=['labels'])
                      )

def get_valid_transforms():
    return A.Compose([A.Resize(height=512, width=512, p=1.0),
                      ToTensorV2(p=1.0)], 
                      p=1.0, 
                      bbox_params=A.BboxParams(format='coco',min_area=0, min_visibility=0,label_fields=['labels'])
                      )

# Base pipeline

In [None]:
class basic_pipeline(Dataset):
    
    def __init__(self ,
                 df ,
                 apply_transform=True,
                 apply_normalise=True,
                 format_='coco',
                 transforms=None
                ):
        
        super().__init__()
        self.df=df
        self.image_id=df['image_id'].unique()
        self.apply_transform=apply_transform
        self.apply_normalise=apply_normalise
        self.format_=format_
        self.transforms=transforms
        
    def __len__(self):
        return len(self.image_id)
    
    def read_img(self,path):
        img=torchvision.io.read_image(path)
        return img
    
    def transform_(self,image,bbox,labels):
        sample = {
                'image': image.numpy(),
                'bboxes': bbox,
                'labels': labels
                 }
        pdb.set_trace()
        sample = self.transforms(**sample)
        return  sample
    
    def normalise_(self,sample,image):
        _,h,w = image.shape
        bbox = A.augmentations.bbox_utils.normalize_bboxes(sample['bboxes'],rows=h,cols=w)
        return bbox
        
    def get_image_bboxes(self,img_unique_df,format_='coco'):
        bbox=[]
        for row in img_unique_df.iterrows():
            row=row[1]
            if format_=='coco':
                bbox.append([row.x_min,row.y_min,row.x_max-row.x_min,row.y_max-row.y_min])
            else:
                bbox.append([row.x_min,row.y_min,row.x_max,row.y_max])
        image=img=self.read_img(row.path)
        
        return image,bbox
    
    def __getitem__(self,idx):
        img_id=self.image_id[idx]
        img_unique_df=self.df[self.df.image_id==img_id]
        target={}
        area=img_unique_df.area
        image,bbox=self.get_image_bboxes(img_unique_df,self.format_)
        labels = np.zeros(len(bbox), dtype=np.int32)
        
        if self.apply_transform:
            sample=self.transform_(image,bbox,labels)
            image,bbox,labels=sample['image'],sample['bboxes'],sample['labels']
            
        if self.apply_normalise:
            bbox=self.normalise_(sample,image)
            
        target['bbox']=torch.tensor(bbox)
        target['area']=area
        target['label']=torch.tensor(labels)
        return image , target

def display_ds(ds,r=4,c=4,size=40):
    _,axs=plt.subplots(r,c,figsize=(size,size))
    axs=axs.flatten()

    for n,ax in enumerate(axs):
        img,target=ds.__getitem__(n)
        boxes=target['bbox']
        img=draw_bounding_boxes(img,boxes)
        ax.imshow(to_pil_image(img))
        ax.axis('off')
        
    plt.tight_layout() 
    plt.show()

In [None]:
train_df_path='../input/global-wheat-detection/train.csv'
submission_df_path='../input/global-wheat-detection/sample_submission.csv'
train_base_dir='../input/global-wheat-detection/train'
submission__base_dir='../input/global-wheat-detection/test'
train_df=display_(train_df_path,train_base_dir)
ds=basic_pipeline(train_df,transforms=get_valid_transforms())
i,t=ds.__getitem__(2)

In [None]:
i,t=ds.__getitem__(2)

# Model

In [None]:
class DETRModel(pl.LightningModule):
    def __init__(self,num_classes,num_queries):
        super(DETRModel,self).__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries
        self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
        self.in_features = self.model.class_embed.in_features
        self.model.class_embed = torch.nn.Linear(in_features=self.in_features,out_features=self.num_classes)
        self.model.num_queries = self.num_queries
        
    def forward(self,images):
        return self.model(images)

# Clssifier Logics


In [None]:
class classifier(pl.LightningModule):
    def __init__(
        self,
        ds,
        df,
        model,
        LR=2e-5,
        null_class_coef = 0.5,
        num_classes = 2,
        num_queries = 100,
        matcher = HungarianMatcher(),
        weight_dict  = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1},
        losses = ['labels', 'boxes', 'cardinality'],
        *args
    ):
        super().__init__()
        self.ds=ds
        self.df=df
        self.train_df,self.val_df=train_test_split(self.df)
        self.model=model
        self.LR=LR
        self.losses=losses
        self.matcher=matcher
        self.weight_dict=weight_dict
        self.num_classes=num_classes
        self.num_queries=num_queries
        self.null_class_coef=null_class_coef
        self.criterion=SetCriterion(self.num_classes-1, self.matcher, self.weight_dict, eos_coef = self.null_class_coef, losses=self.losses),
        
    def train_dataloader(self):
        train_ds=self.ds(self.train_df,transforms=get_train_transform())
        train_loader=DataLoader(train_ds,batch_size=32)
        return train_loader
      
    def val_dataloader(self):
        val_ds=self.ds(self.val_df,transforms=get_valid_transforms())
        val_loader=DataLoader(val_ds,batch_size=32)
        return val_loader
    
    def training_step(self,batch,batch_idx):
        images,targets=batch
        outputs=self.model(images)
        loss_dict = self.criterion(outputs, targets)
        weight_dict = self.criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        return losses
    
    def validation_step(self,batch,batch_idx):
        images,targets=batch
        outputs=self.model(images)
        loss_dict = self.criterion(outputs, targets)
        weight_dict = self.criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        return losses
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(),lr=self.LR)

# Play

In [None]:
def main():
    train_df_path='../input/global-wheat-detection/train.csv'
    submission_df_path='../input/global-wheat-detection/sample_submission.csv'
    train_base_dir='../input/global-wheat-detection/train'
    submission__base_dir='../input/global-wheat-detection/test'
    train_df=display_(train_df_path,train_base_dir)
    ds=basic_pipeline(train_df)
    
    Classifier=classifier(
        basic_pipeline,
        train_df,
        DETRModel(num_classes=2,num_queries=100),
    )
    
    Trainer=pl.Trainer(accelerator='cpu')
    Trainer.fit(Classifier)
    
main()