# RT-DETR Pretraining with SHIFT-Discrete Dataset

## Check GPU Availability

In [1]:
!nvidia-smi

Thu Sep 18 23:27:29 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          Off |   00000000:05:00.0 Off |                    0 |
| N/A   75C    P0            303W /  300W |    4509MiB /  81920MiB |     86%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off |   00

In [2]:
# Set CUDA Device Number
DEVICE_NUM = 1

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_NUM)
environ["CUDA_VISIBLE_DEVICES"]

'1'

## Imports

In [3]:
import os
os.chdir("/home/ubuntu/test-time-adapters")

In [4]:
from os import path
import math

import torch
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import (
    SHIFTDataset,
    SHIFTClearDatasetForObjectDetection,
    SHIFTCorruptedDatasetForObjectDetection,
    SHIFTDiscreteSubsetForObjectDetection
)
from ttadapters import datasets

from ttadapters.models.rcnn import FasterRCNNForObjectDetection, SwinRCNNForObjectDetection

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"INFO: Using device - {device}")

INFO: Using device - cuda


In [6]:
PROJECT_NAME = "detectron_test"
RUN_NAME = "Faster-RCNN_R50"

## Define Dataset

In [7]:
DATA_ROOT = path.join(".", "data")

## DataLoader

In [8]:
# # Set Batch Size
# BATCH_SIZE = 2, 8, 8, 8
# BATCH_SIZE = 50, 200, 200, 200  # A100 or H100
# BATCH_SIZE = 40, 120, 120, 120  # Half of A100 or H100

# # Dataset Configs
# CLASSES = dataset.train.classes
# NUM_CLASSES = len(CLASSES)

# print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
# print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [9]:
from detectron2.structures import ImageList

def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return dict(
        pixel_values=ImageList.from_tensors(images, size_divisibility=32),
        labels=[dict(
            class_labels=item['boxes2d_classes'].long(),
            boxes=item["boxes2d"].float()
        ) for item in targets]
    )

In [10]:
from detectron2.structures import Boxes, Instances
from torchvision.tv_tensors import Image, BoundingBoxes

def collate_fn(batch: list[Image, BoundingBoxes]):
    batched_inputs = []
    for image, metadata in batch:
        original_height, original_width = image.shape[-2:]
        instances = Instances(image_size=(original_height, original_width))
        instances.gt_boxes = Boxes(metadata["boxes2d"])  # xyxy
        instances.gt_classes = metadata["boxes2d_classes"]
        batched_inputs.append({
            "image": image,
            "instances": instances,
            "height": original_height,
            "width": original_width
        })
    return batched_inputs

## Load Model

In [11]:
USE_SWIN_T_BACKBONE = False

In [12]:
if USE_SWIN_T_BACKBONE:
    model = SwinRCNNForObjectDetection(dataset=SHIFTDataset)
else:
    model = FasterRCNNForObjectDetection(dataset=SHIFTDataset)

model.load_from(model.Weights.NATUREYOO, weight_key="model")
model.to(device)

FasterRCNNForObjectDetection(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
      (res2): Sequential(
        (0): Bo

### Direct Test

In [13]:
def task_to_subset_types(task: str):
    T = SHIFTDiscreteSubsetForObjectDetection.SubsetType

    # weather
    if task == "cloudy":
        return T.CLOUDY_DAYTIME
    if task == "overcast":
        return T.OVERCAST_DAYTIME
    if task == "rainy":
        return T.RAINY_DAYTIME
    if task == "foggy":
        return T.FOGGY_DAYTIME

    # time
    if task == "night":
        return T.CLEAR_NIGHT
    if task in {"dawn", "dawn/dusk"}:
        return T.CLEAR_DAWN
    if task == "clear":
        return T.CLEAR_DAYTIME
    
    # simple
    if task == "normal":
        return T.NORMAL
    if task == "corrupted":
        return T.CORRUPTED

    raise ValueError(f"Unknown task: {task}")

In [16]:
# dataset
class SHIFTCorruptedDatasetForObjectDetection(SHIFTDiscreteSubsetForObjectDetection):
    def __init__(
            self, root: str, force_download: bool = False,
            train: bool = True, valid: bool = False,
            transform= None, target_transform = None, transforms = None,
            task = "clear"
    ):
        super().__init__(
            root=root, force_download=force_download,
            train=train, valid=valid, subset_type=task_to_subset_types(task),
            transform=transform, target_transform=target_transform, transforms=transforms
        )

In [17]:
import time
import gc

def evaluate_for(self, loader, loader_length, threshold=0.0, dtype=torch.float32, device=torch.device("cuda")):
    torch.cuda.empty_cache()
    gc.collect()

    self.eval()

    map_metric = MeanAveragePrecision()
    predictions_list = []
    targets_list = []
    collapse_time = 0

    with torch.inference_mode():
        for batch in tqdm(loader, total=loader_length, desc="Evaluation"):
            with torch.autocast(device_type=device.type, dtype=dtype):
                start = time.time()
                outputs = self(batch)
                collapse_time += time.time() - start

            for i, (output, input_data) in enumerate(zip(outputs, batch)):
                instances = output['instances']
                mask = instances.scores > threshold

                pred_detection = Detections(
                    xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(),
                    class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                    confidence=instances.scores[mask].detach().cpu().numpy()
                )
                gt_instances = input_data['instances']
                target_detection = Detections(
                    xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                    class_id=gt_instances.gt_classes.detach().cpu().numpy()
                )

                predictions_list.append(pred_detection)
                targets_list.append(target_detection)

        map_metric.update(predictions=predictions_list, targets=targets_list)
        m_ap = map_metric.compute()

        per_class_map = {
            f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
            for idx in m_ap.matched_classes
        }
        performances = {
            "collapse_time": collapse_time,
            "fps": loader_length / collapse_time
        }

        return {
            "mAP@0.50:0.95": m_ap.map50_95.item(),
            "mAP@0.50": m_ap.map50.item(),
            "mAP@0.75": m_ap.map75.item(),
            **per_class_map,
            **performances
        }

In [19]:
for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night", "clear"]:
    dataset=SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms,
        task=task
    )
    print(f"start {task}")
    CLASSES = dataset
    NUM_CLASSES = len(CLASSES)
    
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    dataloader.valid_len = math.ceil(len(dataset)/4)
    result = evaluate_for(model, dataloader, dataloader.valid_len)
    print(result)

[09/18/2025 23:32:33] SHIFT DevKit - INFO - Base: ./data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7fdda882be60>
[09/18/2025 23:32:33] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/cloudy_daytime/discrete/images/val/front/det_2d.json' ...
[09/18/2025 23:32:33] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/cloudy_daytime/discrete/images/val/front/det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to ./data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to ./data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[09/18/2025 23:32:34] SHIFT DevKit - INFO - Loading annotation takes 0.62 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['075f-be61']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                    -0.74     159.04
boxes2d              torch.Size([1, 9, 4])                     0.00    1044.00
boxes2d_classes      torch.Size([1, 9])                        0.00       2.00
boxes2d_track_ids    torch.Size([1, 9])                        0.00       8.00
images               torch.Size([1, 3, 800, 1280])             0.00     255.00

Video name: 075f-be61
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

Evaluation:   0%|          | 0/600 [00:00<?, ?it/s]

[09/18/2025 23:34:21] SHIFT DevKit - INFO - Base: ./data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7fdda882be60>
[09/18/2025 23:34:21] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/overcast_daytime/discrete/images/val/front/det_2d.json' ...
[09/18/2025 23:34:21] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/overcast_daytime/discrete/images/val/front/det_2d.json' Done.


{'mAP@0.50:0.95': 0.47546814352822453, 'mAP@0.50': 0.6840113446259135, 'mAP@0.75': 0.5301818352816422, "(Image([[[178., 178., 178.,  ..., 162., 162., 162.],\n        [178., 178., 178.,  ..., 163., 162., 162.],\n        [178., 178., 178.,  ..., 163., 163., 162.],\n        ...,\n        [ 91.,  96., 100.,  ..., 184., 184., 184.],\n        [ 95.,  97.,  99.,  ..., 183., 183., 183.],\n        [ 98.,  95.,  93.,  ..., 182., 182., 182.]],\n\n       [[140., 140., 140.,  ..., 132., 132., 132.],\n        [140., 140., 140.,  ..., 133., 132., 132.],\n        [140., 140., 140.,  ..., 133., 133., 132.],\n        ...,\n        [ 90.,  95.,  99.,  ..., 185., 185., 185.],\n        [ 94.,  96.,  98.,  ..., 184., 184., 184.],\n        [ 97.,  94.,  92.,  ..., 183., 183., 183.]],\n\n       [[ 98.,  98.,  98.,  ..., 113., 113., 113.],\n        [ 98.,  98.,  98.,  ..., 114., 113., 113.],\n        [ 98.,  98.,  98.,  ..., 114., 114., 113.],\n        ...,\n        [ 92.,  97., 101.,  ..., 189., 189., 189.],\

[09/18/2025 23:34:21] SHIFT DevKit - INFO - Loading annotation takes 0.38 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0aee-69fd']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                    -5.76     162.09
boxes2d              torch.Size([1, 5, 4])                   255.00     881.00
boxes2d_classes      torch.Size([1, 5])                        1.00       5.00
boxes2d_track_ids    torch.Size([1, 5])                        0.00       4.00
images               torch.Size([1, 3, 800, 1280])             0.00     255.00

Video name: 0aee-69fd
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

Evaluation:   0%|          | 0/400 [00:00<?, ?it/s]

[09/18/2025 23:35:31] SHIFT DevKit - INFO - Base: ./data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7fdda882be60>
[09/18/2025 23:35:31] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/foggy_daytime/discrete/images/val/front/det_2d.json' ...
[09/18/2025 23:35:31] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/foggy_daytime/discrete/images/val/front/det_2d.json' Done.


{'mAP@0.50:0.95': 0.368158237161888, 'mAP@0.50': 0.5359163384473442, 'mAP@0.75': 0.40063993080367194, "(Image([[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        ...,\n        [145., 146., 146.,  ..., 107., 107., 107.],\n        [145., 145., 146.,  ..., 106., 107., 107.],\n        [145., 145., 145.,  ..., 106., 106., 106.]],\n\n       [[  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        ...,\n        [142., 143., 143.,  ..., 102., 102., 102.],\n        [142., 142., 143.,  ..., 101., 102., 102.],\n        [142., 142., 142.,  ..., 101., 101., 101.]],\n\n       [[  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],\n        ...,\n        [144., 145., 145.,  ..., 104., 104., 104.],\n

[09/18/2025 23:35:31] SHIFT DevKit - INFO - Loading annotation takes 0.67 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0188-aef6']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                  -136.69       1.57
boxes2d              torch.Size([1, 5, 4])                   275.00     661.00
boxes2d_classes      torch.Size([1, 5])                        0.00       1.00
boxes2d_track_ids    torch.Size([1, 5])                        0.00       4.00
images               torch.Size([1, 3, 800, 1280])            17.00     255.00

Video name: 0188-aef6
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

Evaluation:   0%|          | 0/663 [00:00<?, ?it/s]

[09/18/2025 23:37:24] SHIFT DevKit - INFO - Base: ./data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7fdda882be60>
[09/18/2025 23:37:24] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/rainy_daytime/discrete/images/val/front/det_2d.json' ...
[09/18/2025 23:37:24] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/rainy_daytime/discrete/images/val/front/det_2d.json' Done.


{'mAP@0.50:0.95': 0.1940205289304845, 'mAP@0.50': 0.270254085727999, 'mAP@0.75': 0.22102926175687232, "(Image([[[191., 191., 191.,  ..., 189., 189., 189.],\n        [191., 191., 191.,  ..., 190., 189., 189.],\n        [191., 191., 191.,  ..., 190., 190., 189.],\n        ...,\n        [139., 140., 140.,  ...,  83.,  81.,  79.],\n        [139., 139., 140.,  ...,  85.,  83.,  81.],\n        [139., 139., 139.,  ...,  88.,  85.,  84.]],\n\n       [[186., 186., 186.,  ..., 183., 183., 183.],\n        [186., 186., 186.,  ..., 184., 183., 183.],\n        [186., 186., 186.,  ..., 184., 184., 183.],\n        ...,\n        [133., 134., 134.,  ..., 104., 102., 100.],\n        [133., 133., 134.,  ..., 109., 107., 105.],\n        [133., 133., 133.,  ..., 112., 109., 108.]],\n\n       [[187., 187., 187.,  ..., 184., 184., 184.],\n        [187., 187., 187.,  ..., 185., 184., 184.],\n        [187., 187., 187.,  ..., 185., 185., 184.],\n        ...,\n        [134., 135., 135.,  ..., 105., 103., 101.],\n

[09/18/2025 23:37:25] SHIFT DevKit - INFO - Loading annotation takes 0.81 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['01cb-e76a']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                    -2.40      50.31
boxes2d              torch.Size([1, 19, 4])                   97.00    1248.00
boxes2d_classes      torch.Size([1, 19])                       0.00       4.00
boxes2d_track_ids    torch.Size([1, 19])                       0.00      18.00
images               torch.Size([1, 3, 800, 1280])             0.00     255.00

Video name: 01cb-e76a
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

Evaluation:   0%|          | 0/800 [00:00<?, ?it/s]

[09/18/2025 23:39:50] SHIFT DevKit - INFO - Base: ./data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7fdda882be60>
[09/18/2025 23:39:50] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/clear_dawn/discrete/images/val/front/det_2d.json' ...
[09/18/2025 23:39:50] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/clear_dawn/discrete/images/val/front/det_2d.json' Done.


{'mAP@0.50:0.95': 0.4260381618120778, 'mAP@0.50': 0.6043350308727273, 'mAP@0.75': 0.4768644026216432, "(Image([[[101., 101., 101.,  ..., 101., 101., 101.],\n        [101., 101., 102.,  ..., 101., 101., 101.],\n        [101., 102., 102.,  ..., 101., 101., 101.],\n        ...,\n        [165., 165., 165.,  ..., 180., 178., 177.],\n        [164., 164., 164.,  ..., 181., 180., 180.],\n        [163., 163., 163.,  ..., 182., 182., 182.]],\n\n       [[ 96.,  96.,  96.,  ...,  95.,  95.,  95.],\n        [ 96.,  96.,  97.,  ...,  95.,  95.,  95.],\n        [ 96.,  97.,  97.,  ...,  95.,  95.,  95.],\n        ...,\n        [168., 168., 168.,  ..., 189., 187., 186.],\n        [167., 167., 167.,  ..., 190., 189., 189.],\n        [166., 166., 166.,  ..., 191., 191., 191.]],\n\n       [[ 95.,  95.,  95.,  ...,  96.,  96.,  96.],\n        [ 95.,  95.,  96.,  ...,  96.,  96.,  96.],\n        [ 95.,  96.,  96.,  ...,  96.,  96.,  96.],\n        ...,\n        [173., 173., 173.,  ..., 202., 200., 199.],\n

[09/18/2025 23:39:51] SHIFT DevKit - INFO - Loading annotation takes 0.26 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0eda-a492']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                  -186.75      69.93
boxes2d              torch.Size([1, 1, 4])                     0.00     759.00
boxes2d_classes      torch.Size([1, 1])                        1.00       1.00
boxes2d_track_ids    torch.Size([1, 1])                        0.00       0.00
images               torch.Size([1, 3, 800, 1280])             0.00     255.00

Video name: 0eda-a492
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

Evaluation:   0%|          | 0/350 [00:00<?, ?it/s]

[09/18/2025 23:40:56] SHIFT DevKit - INFO - Base: ./data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7fdda882be60>
[09/18/2025 23:40:56] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/clear_night/discrete/images/val/front/det_2d.json' ...
[09/18/2025 23:40:56] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/clear_night/discrete/images/val/front/det_2d.json' Done.
[09/18/2025 23:40:56] SHIFT DevKit - INFO - Loading annotation takes 0.20 seconds.


{'mAP@0.50:0.95': 0.35371351489794356, 'mAP@0.50': 0.49906488862668574, 'mAP@0.75': 0.394712322738314, "(Image([[[24., 24., 25.,  ..., 27., 27., 28.],\n        [28., 28., 27.,  ..., 27., 27., 27.],\n        [28., 28., 27.,  ..., 26., 26., 25.],\n        ...,\n        [ 0.,  0.,  0.,  ...,  3.,  0.,  0.],\n        [ 1.,  1.,  0.,  ...,  1.,  0.,  0.],\n        [ 0.,  0.,  0.,  ...,  0.,  0.,  4.]],\n\n       [[21., 21., 22.,  ..., 23., 23., 24.],\n        [25., 25., 24.,  ..., 23., 23., 23.],\n        [25., 25., 24.,  ..., 22., 22., 21.],\n        ...,\n        [ 1.,  1.,  0.,  ...,  5.,  2.,  0.],\n        [ 3.,  3.,  2.,  ...,  3.,  0.,  2.],\n        [ 0.,  0.,  0.,  ...,  1.,  1.,  6.]],\n\n       [[43., 43., 44.,  ..., 42., 42., 43.],\n        [47., 47., 46.,  ..., 42., 42., 42.],\n        [47., 47., 46.,  ..., 41., 41., 40.],\n        ...,\n        [ 2.,  2.,  1.,  ...,  6.,  3.,  1.],\n        [ 4.,  4.,  3.,  ...,  4.,  1.,  3.],\n        [ 1.,  1.,  1.,  ...,  2.,  2.,  7.]]], 

Evaluation:   0%|          | 0/300 [00:00<?, ?it/s]

[09/18/2025 23:41:42] SHIFT DevKit - INFO - Base: ./data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7fdda882be60>
[09/18/2025 23:41:42] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/clear_daytime/discrete/images/val/front/det_2d.json' ...
[09/18/2025 23:41:42] SHIFT DevKit - INFO - Loading annotation from './data/SHIFT_SUBSET/clear_daytime/discrete/images/val/front/det_2d.json' Done.


{'mAP@0.50:0.95': 0.19603334026554445, 'mAP@0.50': 0.3102387417072683, 'mAP@0.75': 0.2168836471238425, "(Image([[[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        ...,\n        [33., 33., 33.,  ..., 16., 15., 15.],\n        [33., 33., 33.,  ..., 13., 12., 11.],\n        [33., 33., 33.,  ..., 12., 10.,  8.]],\n\n       [[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        ...,\n        [46., 46., 46.,  ..., 22., 21., 21.],\n        [46., 46., 46.,  ..., 19., 18., 17.],\n        [46., 46., 46.,  ..., 18., 16., 14.]],\n\n       [[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n        ...,\n        [54., 54., 54.,  ..., 27., 26., 26.],\n        [54., 54., 54.,  ..., 24., 23., 22.],\n        [54., 54., 54.,  ..., 23., 21., 19.]]], 

[09/18/2025 23:41:43] SHIFT DevKit - INFO - Loading annotation takes 0.69 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0116-4859']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                    -0.90     138.34
boxes2d              torch.Size([1, 6, 4])                   246.00     859.00
boxes2d_classes      torch.Size([1, 6])                        1.00       5.00
boxes2d_track_ids    torch.Size([1, 6])                        0.00       5.00
images               torch.Size([1, 3, 800, 1280])             0.00     255.00

Video name: 0116-4859
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

Evaluation:   0%|          | 0/700 [00:00<?, ?it/s]

{'mAP@0.50:0.95': 0.4476200365545435, 'mAP@0.50': 0.6706051661613364, 'mAP@0.75': 0.4868022294929105, "(Image([[[ 48.,  51.,  70.,  ..., 168., 168., 168.],\n        [ 52.,  45.,  54.,  ..., 169., 168., 168.],\n        [ 55.,  54.,  55.,  ..., 169., 169., 168.],\n        ...,\n        [191., 191., 191.,  ..., 188., 188., 188.],\n        [188., 189., 189.,  ..., 188., 188., 188.],\n        [187., 187., 188.,  ..., 188., 188., 188.]],\n\n       [[ 63.,  62.,  74.,  ..., 136., 136., 136.],\n        [ 72.,  63.,  63.,  ..., 137., 136., 136.],\n        [ 87.,  81.,  76.,  ..., 137., 137., 136.],\n        ...,\n        [194., 194., 194.,  ..., 195., 195., 195.],\n        [191., 192., 192.,  ..., 195., 195., 195.],\n        [190., 190., 191.,  ..., 195., 195., 195.]],\n\n       [[ 36.,  36.,  49.,  ..., 107., 107., 107.],\n        [ 43.,  34.,  36.,  ..., 108., 107., 107.],\n        [ 52.,  47.,  43.,  ..., 108., 108., 107.],\n        ...,\n        [198., 198., 198.,  ..., 198., 198., 198.],\n

In [None]:
evaluate_for(model, dataloader.valid, dataloader.valid_len)

In [None]:
evaluate_for(model, dataloader.test, dataloader.test_len)

## Train

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 10
REAL_BATCH = BATCH_SIZE[-1]
LEARNING_RATE = 2e-5

training_args = dict(
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.1,
    weight_decay=0.15,
    max_grad_norm=1.0,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    gradient_accumulation_steps=REAL_BATCH//BATCH_SIZE[0],
    eval_accumulation_steps=BATCH_SIZE[1],
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=50,
    save_steps=50,
    logging_steps=50,
    save_total_limit=100,
    load_best_model_at_end=True,
    metric_for_best_model="mAP@0.50:0.95",
    greater_is_better=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    output_dir="./results/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    run_name=RUN_NAME,
    bf16=True,
)

testing_args = dict(
    per_device_eval_batch_size=BATCH_SIZE[2],
    batch_eval_metrics=True,
    remove_unused_columns=False,
)

In [None]:
from torchvision.ops import box_convert
from dataclasses import dataclass


def de_normalize_boxes(boxes, height, width):
    # 1. cxcywh → xyxy
    boxes_xyxy_norm = box_convert(boxes, 'cxcywh', 'xyxy')

    # 2. de-normalize (convert to actual pixel coordinates)
    boxes_xyxy_norm[:, [0, 2]] *= width
    boxes_xyxy_norm[:, [1, 3]] *= height
    return boxes_xyxy_norm


def map_compute_metrics(preprocessor, threshold=0.0):
    map_metric = MeanAveragePrecision()
    post_process = preprocessor.post_process_object_detection

    def calc(eval_pred, compute_result=False):
        nonlocal map_metric

        if compute_result:
            m_ap = map_metric.compute()
            map_metric.reset()

            per_class_map = {
                f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean()
                for idx in m_ap.matched_classes
            }

            return {
                "mAP@0.50:0.95": m_ap.map50_95,
                "mAP@0.50": m_ap.map50,
                "mAP@0.75": m_ap.map75,
                **per_class_map
            }
        else:
            preds = ModelOutput(*eval_pred.predictions[1:3])
            labels = eval_pred.label_ids
            sizes = [label['orig_size'].cpu().tolist() for label in labels]

            results = post_process(preds, target_sizes=sizes, threshold=threshold)
            predictions = [Detections.from_transformers(result) for result in results]
            targets = [Detections(
                xyxy=de_normalize_boxes(label['boxes'], *label['orig_size']).cpu().numpy(),
                class_id=label['class_labels'].cpu().numpy(),
            ) for label in labels]

            map_metric.update(predictions=predictions, targets=targets)
            return {}
    return calc, map_metric

In [None]:
from detectron2.utils.events import EventStorage
import torch

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for iteration, data in enumerate(dataloader.train):
    with EventStorage(iteration) as storage:
        # Forward pass
        loss_dict = model(data)

        # 모든 loss를 합산
        losses = sum(loss_dict.values())

        # Backward pass
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        # 로깅 (선택사항)
        if iteration % 20 == 0:
            print(f"Iteration {iteration}: {loss_dict}")

## Evaluate

In [None]:
targets = []
predictions = []
batch_size = 32

raw_data = DataLoader(LabelDataset(dataset.valid), batch_size=batch_size, collate_fn=naive_collate_fn)
loader = DataLoader(DatasetAdapterForTransformers(dataset.valid), batch_size=batch_size, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))
for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, loader):
    sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

    with torch.no_grad():
        outputs = model(pixel_values=inputs['pixel_values'].to(device))

    results = reference_preprocessor.post_process_object_detection(
        outputs, target_sizes=sizes, threshold=0.3
    )

    detections = [Detections.from_transformers(results[i]) for i in range(batch_size)]
    annotations = [Detections(
        xyxy=lables[i][0].cpu().numpy(),
        class_id=lables[i][1].cpu().numpy(),
    ) for i in range(batch_size)]

    targets.extend(annotations)
    predictions.extend(detections)

In [None]:
len(predictions) == len(targets), len(predictions), len(targets)

In [None]:
mean_average_precision = MeanAveragePrecision().update(
    predictions=predictions,
    targets=targets,
).compute()
per_class_map = {
    f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
    for idx in mean_average_precision.matched_classes
}

print(f"mAP@0.95: {mean_average_precision.map50_95:.2f}")
print(f"map50: {mean_average_precision.map50:.2f}")
print(f"map75: {mean_average_precision.map75:.2f}")
for key, value in per_class_map.items():
    print(f"{key}: {value:.2f}")