In [1]:
!git clone -q https://github.com/facebookresearch/detr.git

[https://www.kaggle.com/code/tanulsingh077/end-to-end-object-detection-with-transformers-detr/notebook](http://)

In [2]:
import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np 
import pandas as pd 
from datetime import datetime
import time
import random
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold

import cv2

import sys
sys.path.append('./detr/')

from detr.models.matcher import HungarianMatcher
from detr.models.detr import SetCriterion

import albumentations as A # advanced augmentation framework with PyTorch interface
import matplotlib.pyplot as plt
from albumentations.pytorch.transforms import ToTensorV2

from glob import glob

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using "{device}" device')

In [5]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [5]:
# load validation dataframe just to inspect some images, bboxes and classes
df = pd.read_csv('../input/self-driving-cars/labels_val.csv')
df.head()

In [6]:
n_folds = 5
seed = 42
num_classes = 6 # 5 unique classes + background class
num_queries = 50 # max number of objects to detect per one image, default in detr = 100, strictly recommended by 
# developers to change this parameter only when training from scratch
null_class_coef = 0.5  # used as default in original repository. Set 0.5 if detecting 2 classes: object and background
BATCH_SIZE = 16
IMAGE_SIZE = 224 # 512
LR = 1e-3 # 2e-5 
EPOCHS = 4

In [7]:
labels_to_ids = {'car': 1, 'truck': 2, 'pedestrian': 3, 'bicyclist': 4, 'light': 5}
ids_to_labels = {1: 'car', 2: 'truck', 3: 'pedestrian', 4: 'bicyclist', 5: 'light'}

In [9]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(seed)

In [10]:
# base bbox format
random_image = df['frame'].sample(1).iloc[0]
df_random = df[df['frame'] == random_image]
sample_image = cv2.imread('../input/self-driving-cars/images/' + random_image, cv2.IMREAD_COLOR)
sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(8,8))
plt.imshow(sample_image)
ax = plt.gca()

for idx, row in df_random.iterrows():
    xmin, xmax, ymin, ymax = row[['xmin', 'xmax', 'ymin', 'ymax']]
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                fill=False, color='red', linewidth=3))
    text = f'Class_id: {row["class_id"]}'
    ax.text(xmin, ymin, text, fontsize=15,
            bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.show()

In [10]:
def xyxy_to_xywh(xyxy):
    """Convert [x1 y1 x2 y2] box format to [x1 y1 w h] format."""
    if isinstance(xyxy, (list, tuple)):
        # Single box given as a list of coordinates
        assert len(xyxy) == 4
        x1, y1 = xyxy[0], xyxy[1]
        w = xyxy[2] - x1 + 1
        h = xyxy[3] - y1 + 1
        return (x1, y1, w, h)
    elif isinstance(xyxy, np.ndarray):
        # Multiple boxes given as a 2D ndarray
        return np.hstack((xyxy[:, 0:2], xyxy[:, 2:4] - xyxy[:, 0:2] + 1))
    else:
        raise TypeError('Argument xyxy must be a list, tuple, or numpy array.')
        
def xyxy_to_xcycwh(xyxy):
    if isinstance(xyxy, (list, tuple)):
        assert len(xyxy) == 4
        x1, y1 = (xyxy[0] + xyxy[2]) / 2, (xyxy[1] + xyxy[3]) / 2
        w, h = xyxy[2] - xyxy[0], xyxy[3] - xyxy[1]
        return (x1, y1, w, h)
    elif isinstance(xyxy, np.ndarray):
        return np.hstack(((xyxy[:, 0:1] + xyxy[:, 2:3]) / 2, (xyxy[:, 1:2] + xyxy[:, -1:]) / 2, xyxy[:, 2:3] - xyxy[:, 0:1], xyxy[:, -1:] - xyxy[:, 1:2]))
    else:
        raise TypeError('Argument xyxy must be a list, tuple, or numpy array.')

In [11]:
train = pd.read_csv('../input/self-driving-cars/labels_train.csv')
test = pd.read_csv('../input/self-driving-cars/labels_val.csv')

outliers = train[train['ymax'] == 0].index # outliers
train.drop(outliers, inplace=True)

#train['h'] = train['ymax'] - train['ymin'] + 1
#train['w'] = train['xmax'] - train['xmin'] + 1
#train[['xc', 'yc', 'w', 'h']] = xyxy_to_xcycwh(train[['xmin', 'ymin', 'xmax', 'ymax']].values)

#test['h'] = test['ymax'] - test['ymin'] + 1
#test['w'] = test['xmax'] - test['xmin'] + 1
#test[['xc', 'yc', 'w', 'h']] = xyxy_to_xcycwh(test[['xmin', 'ymin', 'xmax', 'ymax']].values)

train.head()

In coco, a bounding box is defined by four values in pixels [x_min, y_min, width, height]. They are coordinates of the top-left corner along with the width and height of the bounding box. Yolo format [x_center, y_center, width, height]

In [16]:
# coco bbox format
random_image = train['frame'].sample(1).iloc[0]
df_random = train[train['frame'] == random_image]
sample_image = cv2.imread('../input/self-driving-cars/images/' + random_image, cv2.IMREAD_COLOR)
sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(8,8))
plt.imshow(sample_image)
ax = plt.gca()

for idx, row in df_random.iterrows():
    x, y, w, h = row[['x', 'y', 'w', 'h']]
    ax.add_patch(plt.Rectangle((x, y), w, h,
                                fill=False, color='red', linewidth=3))
    text = f'Class_id: {row["class_id"]}'
    ax.text(x, y, text, fontsize=15,
            bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.show()

In [12]:
# get n_folded dataframe, stratified by number of bboxes, truing to preserve target-value counts
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)

df_folds = train[['frame']].copy()
df_folds.loc[:, 'bbox_count'] = 1
df_folds = df_folds.groupby('frame').count()
df_folds.loc[:, 'class_id'] = train[['frame', 'class_id']].groupby('frame').max()['class_id']  # min
df_folds.loc[:, 'stratify_group'] = np.char.add(
    df_folds['class_id'].values.astype(str),
    df_folds['bbox_count'].apply(lambda x: f'_{x // ((num_classes-1)*2 + 1)}').values.astype(str)
)
df_folds.loc[:, 'fold'] = 0

for fold_number, (train_index, val_index) in enumerate(skf.split(X=df_folds.index, y=df_folds['stratify_group'])):
    df_folds.loc[df_folds.iloc[val_index].index, 'fold'] = fold_number

In [11]:
# check class balances in each fold
for i in  range(5):
    print(df_folds[df_folds['fold'] != i].class_id.value_counts(normalize=False))
    print(df_folds[df_folds['fold'] != i].index.nunique())
    print(df_folds[df_folds['fold'] == i].index.nunique())

In [12]:
train['class_id'].value_counts(normalize=True)

In [45]:
# albumentations transforms for PyTorch
def get_train_transforms():
    return A.Compose([A.OneOf([A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5), # 0.9      
                      A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9)],p=0.5), # 0.9
                      A.ToGray(p=0.01),
                      A.HorizontalFlip(p=0.1), # 0.5
                      A.VerticalFlip(p=0.1), # 0.5
                      A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1),
                      A.Cutout(num_holes=4, max_h_size=32, max_w_size=32, fill_value=0, p=0.1), # 8, 64, 64, 0.5
                      A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=1.0),
                      ToTensorV2(p=1.0),
                      ],p=1.0,
                      bbox_params=A.BboxParams(format='yolo', min_area=0, min_visibility=0, label_fields=['labels']), # coco
                      )

def get_valid_transforms():
    return A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1.0),
                      A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=1.0),
                      ToTensorV2(p=1.0),
                     ], 
                      p=1.0, 
                      bbox_params=A.BboxParams(format='yolo', min_area=0, min_visibility=0, label_fields=['labels']), # coco
                      )

As classes are extremely imbalanced, it is good practice to sample the data in dataloader. Below is an example 
of WeightedRandomSampler, that shows a way of dealing with imbalanced classes using weighted oversampling technique.
```
# alt.: gives weights > 1
class_counts = y_train.value_counts().to_list()
num_samples = sum(class_counts)
labels = y_train.map(labels_to_int).values

class_weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
weights = [class_weights[labels[i]] for i in range(int(num_samples))]
sampler = torch.utils.data.WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
```

In [14]:
# example of sampler, which we can add when defining train dataloader
target_labels = df_folds.loc[df_folds['fold'] == 0, 'class_id']

def get_sampler(target_labels):
    class_sample_count = np.unique(target_labels, return_counts=True)[1]
    weight = 1./class_sample_count
    samples_weight = weight[target_labels.values-1]
    samples_weight = torch.from_numpy(samples_weight)
    sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

sampler = get_sampler(target_labels)  # pay attention at the background class

In [87]:
train_path = '../input/self-driving-cars/images'

class CarDataset(Dataset):
    """ Define custom dataset class that returns an image tensor with corresponded target and image name"""
    def __init__(self, image_ids, df, transforms=None):
        self.image_ids = image_ids
        self.df = df
        self.transforms = transforms
        
    def __len__(self) -> int:
        return self.image_ids.shape[0]
    
    def __getitem__(self, index):
        image_id = self.image_ids[index]
        records = self.df[self.df['frame'] == image_id]
        
        image = cv2.imread(f'{train_path}/{image_id}', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        h,w,_ = image.shape
        image /= 255.0
        
        boxes = records[['xmin', 'ymin', 'xmax', 'ymax']].values
        boxes = A.augmentations.bbox_utils.normalize_bboxes(boxes, h, w)
        boxes = np.array([xyxy_to_xcycwh(box) for box in boxes]) # yolo
        
        area = boxes[:,2] * boxes[:,3]
        area = torch.as_tensor(area, dtype=torch.float32)
        
        # set all labels to 0 if our task is only to detect every object on image without label
        #labels =  np.zeros(len(boxes), dtype=np.int32)
        labels = records['class_id'].values.astype(np.int32) - 1

        if self.transforms:
            sample = {
                'image': image,
                'bboxes': boxes,
                'labels': labels
            }
            sample = self.transforms(**sample)
            image = sample['image']
            boxes = sample['bboxes']
            labels = sample['labels']
                                
        target = {}
        target['boxes'] = torch.as_tensor(boxes,dtype=torch.float32)
        target['labels'] = torch.as_tensor(labels,dtype=torch.long)
        target['image_id'] = torch.tensor([index])
        target['area'] = area
        
        return image, target, image_id
    
    def collate_fn(self, batch):
        return tuple(zip(*batch))

#### Define DETR model taken from facebook research

In [15]:
from detr.models.detr import MLP

# continue training model weights
class DETRModel(nn.Module):
    def __init__(self, num_classes, num_queries):
        super(DETRModel,self).__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries
        
        self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
        self.in_features = self.model.class_embed.in_features
        
        #for param in self.model.parameters():#
        #    param.requires_grad = False#
        
        self.model.class_embed = nn.Linear(in_features=self.in_features, out_features=self.num_classes)
        self.model.num_queries = self.num_queries
        
        #self.model.query_embed = nn.Embedding(self.num_queries, 256)
        #self.model.bbox_embed = MLP(256, 256, 4, 3)
        
    def forward(self,images):
        return self.model(images)

In [88]:
# second way
from detr.models.detr import MLP

class DETRModel(nn.Module):
    def __init__(self, num_classes, num_queries):
        super(DETRModel,self).__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries
        
        self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=False, num_classes=50)
        checkpoint = torch.hub.load_state_dict_from_url(
                                    url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth',
                                    map_location=device,
                                    check_hash=True)
        del checkpoint["model"]["class_embed.weight"]
        del checkpoint["model"]["class_embed.bias"]
        self.model.load_state_dict(checkpoint["model"], strict=False)
        for param in self.model.parameters():
            param.requires_grad = False
        
        self.in_features = self.model.class_embed.in_features
        self.model.class_embed = nn.Linear(in_features=self.in_features, out_features=self.num_classes)
        self.model.bbox_embed = MLP(256, 256, 4, 3) # multilayer perceptron
        
    def forward(self,images):
        return self.model(images)

In [17]:
# cross entropy loss for classification, bbox-loss for regression, IoU loss for background
matcher = HungarianMatcher()

weight_dict = weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1}

losses = ['labels', 'boxes', 'cardinality']

#### Train and eval functions. We train criterion also.
MAP was commented due to issue in importing torchmetrics in kaggle

In [18]:
def train_fn(dataloader, model, criterion, optimizer, scheduler, epoch):
    model.train()
    criterion.train()
    
    summary_loss = AverageMeter()
    
    tk0 = tqdm(dataloader, total=len(dataloader), leave=True)
    
    for step, (images, targets, image_ids) in enumerate(tk0):
        
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        

        output = model(images)

        loss_dict = criterion(output, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        
        optimizer.zero_grad()

        losses.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()    
        
        summary_loss.update(losses.item(), BATCH_SIZE)
        tk0.set_postfix(loss=summary_loss.avg) # print out average losses after each epoch
    
    return summary_loss

In [19]:
@torch.no_grad()
def eval_fn(dataloader, model, criterion):
    model.eval()
    criterion.eval()
    summary_loss = AverageMeter()
            
    tk0 = tqdm(dataloader, total=len(dataloader), leave=True)
    for step, (images, targets, image_ids) in enumerate(tk0):
            
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        output = model(images)

        loss_dict = criterion(output, targets)
        weight_dict = criterion.weight_dict
        
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        summary_loss.update(losses.item(),BATCH_SIZE)
        tk0.set_postfix(loss=summary_loss.avg)
        
    return summary_loss

### Run learning process on n_folds

In [38]:
def run(fold, sample=False):
    
    df_train = df_folds[df_folds['fold'] != fold]
    df_valid = df_folds[df_folds['fold'] == fold]
    
    sampler = get_sampler(df_train['class_id']) if sample else None
    
    train_dataset = CarDataset(
                               image_ids=df_train.index.values,
                               df=train,
                               transforms=get_train_transforms())

    valid_dataset = CarDataset(
                               image_ids=df_valid.index.values,
                               df=train,
                               transforms=get_valid_transforms())
    
    train_data_loader = DataLoader(
                                   train_dataset,
                                   batch_size=BATCH_SIZE,
                                   shuffle=False,
                                   num_workers=4,
                                   sampler=sampler,
                                   collate_fn=train_dataset.collate_fn)

    valid_data_loader = DataLoader(
                                   valid_dataset,
                                   batch_size=BATCH_SIZE,
                                   shuffle=False,
                                   num_workers=4,
                                   collate_fn=valid_dataset.collate_fn)
    
    model = DETRModel(num_classes=num_classes, num_queries=num_queries).to(device)
    criterion = SetCriterion(num_classes-1, matcher, weight_dict, eos_coef=1/num_classes, losses=losses).to(device)    

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
    scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
    
    best_loss = 10**5
    for epoch in range(EPOCHS):
        train_loss = train_fn(train_data_loader, model, criterion, optimizer,scheduler=scheduler, epoch=epoch)
        valid_loss = eval_fn(valid_data_loader, model, criterion)
        
        print('|EPOCH {}| TRAIN_LOSS {}| VALID_LOSS {}|'.format(epoch+1, train_loss.avg, valid_loss.avg))
        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            print('Best model found for Fold {} in Epoch {}........Saving Model'.format(fold, epoch+1))
            torch.save(model.state_dict(), f'detr_best_{fold}.pth')

In [90]:
# run
for fold in range(5):
    run(fold, sample=True)

* fine-tune learning rate, number of queries, augmentations, losses weights
* try only binary labels with eos_coef=0.5: solve seaprate detection tasks: 1 model to detect cars, 1 model to detect buses, etc. this would be more accurate, but will increase common model size significantly

In [26]:
def _concat(x, y):
    """ Concat by the last dimension """
    if isinstance(x, np.ndarray):
        return np.concatenate((x, y), axis=-1)
    elif isinstance(x, torch.Tensor):
        return torch.cat([x, y], dim=-1)
    else:
        raise TypeError("unknown type '{}'".format(type(x)))


def xcycwh_to_xywh(xcycwh):
    """Convert [x_c y_c w h] box format to [x1, y1, w, h] format."""
    if isinstance(xcycwh, (list, tuple)):
        # Single box given as a list of coordinates
        assert not isinstance(xcycwh[0], (list, tuple))
        xc, yc = xcycwh[0], xcycwh[1]
        w = xcycwh[2]
        h = xcycwh[3]
        x1 = xc - w / 2.
        y1 = yc - h / 2.
        return [x1, y1, w, h]
    elif isinstance(xcycwh, (np.ndarray, torch.Tensor)):
        wh = xcycwh[..., 2:4]
        x1y1 = xcycwh[..., 0:2] - wh / 2.
        return _concat(x1y1, wh)
    else:
        raise TypeError('Argument xcycwh must be a list, tuple, or numpy array.')

def xcycwh_to_xyxy(xcycwh):
    """Convert [x_c y_c w h] box format to [x1, y1, x2, y2] format."""
    if isinstance(xcycwh, (list, tuple)):
        # Single box given as a list of coordinates
        assert not isinstance(xcycwh[0], (list, tuple))
        xc, yc = xcycwh[0], xcycwh[1]
        w = xcycwh[2]
        h = xcycwh[3]
        x1 = xc - w / 2.
        y1 = yc - h / 2.
        x2 = xc + w / 2.
        y2 = yc + h / 2.
        return [x1, y1, x2, y2]
    elif isinstance(xcycwh, (np.ndarray, torch.Tensor)):
        wh = xcycwh[..., 2:4]
        x1y1 = xcycwh[..., 0:2] - wh / 2.
        x2y2 = xcycwh[..., 0:2] + wh / 2.
        return _concat(x1y1, x2y2)
    else:
        raise TypeError('Argument xcycwh must be a list, tuple, or numpy array.')
        
def xywh_to_xyxy(xywh):
    """Convert [x1 y1 w h] box format to [x1 y1 x2 y2] format."""
    if isinstance(xywh, (list, tuple)):
        # Single box given as a list of coordinates
        assert len(xywh) == 4
        x1, y1 = xywh[0], xywh[1]
        x2 = x1 + np.maximum(0., xywh[2] - 1.)
        y2 = y1 + np.maximum(0., xywh[3] - 1.)
        return (x1, y1, x2, y2)
    elif isinstance(xywh, np.ndarray):
        # Multiple boxes given as a 2D ndarray
        return np.hstack(
            (xywh[:, 0:2], xywh[:, 0:2] + np.maximum(0, xywh[:, 2:4] - 1))
        )
    else:
        raise TypeError('Argument xywh must be a list, tuple, or numpy array.')

In [59]:
def view_sample(test, model, device, threshold=0.7):

    test_dataset = CarDataset(image_ids=test.frame.values,
                              df=test,
                              transforms=get_valid_transforms())
    
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=4,
                                  collate_fn=test_dataset.collate_fn)
    
    images, targets, image_ids = next(iter(test_data_loader))
    
    img_to_show = cv2.imread(train_path + '/' + image_ids[0], cv2.IMREAD_COLOR)
    img_to_show = cv2.cvtColor(img_to_show, cv2.COLOR_BGR2RGB)
    h,w,_ = img_to_show.shape
    
    images = list(img.to(device) for img in images)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    
    boxes = targets[0]['boxes'].cpu().numpy()
    boxes = xcycwh_to_xyxy(boxes)
    boxes = [np.array(box).astype(np.int32) for box in A.augmentations.bbox_utils.denormalize_bboxes(boxes,h,w)]
    
    model.eval()
    model.to(device)
    cpu_device = torch.device("cpu")
    
    with torch.no_grad():
        outputs = model(images)
        
    outputs = [{k: v.to(cpu_device) for k, v in outputs.items()}]
    
    plt.figure(figsize=(16,8))
    ax = plt.gca()

    for box in boxes:
        ax.add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, color='red', linewidth=2))
        
    probs = outputs[0]['pred_logits'].softmax(-1).detach().cpu().numpy()[0, :, :-1] # discard background class
    keep = probs.max(-1) > threshold
    probs = probs[keep]

    oboxes = outputs[0]['pred_boxes'].detach().cpu().numpy()[0, keep]
    oboxes = xcycwh_to_xyxy(oboxes)
    oboxes = [np.array(box).astype(np.int32) for box in A.augmentations.bbox_utils.denormalize_bboxes(oboxes,h,w)]

    labels = outputs[0]['pred_logits'][...,:-1].max(-1)[1].cpu().numpy()[0, keep]

    for box, prob, label in zip(oboxes, probs, labels):
        ax.add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], fill=False, color='blue', linewidth=2))
        text = f'Class_id: {ids_to_labels.get(label+1)}'
        ax.text(box[0], box[1], text, fontsize=10, bbox=dict(facecolor='yellow', alpha=0.5))

    ax.set_axis_off()
    ax.imshow(img_to_show)
    return outputs

In [73]:
model = DETRModel(num_classes=num_classes,num_queries=num_queries)
model.load_state_dict(torch.load("./detr_best_0.pth"))

In [74]:
view = view_sample(test, model, device, threshold=0.3)