# Pre-training with SHIFT-Discrete Dataset (Clear-Daytime)

## Imports and Configs

In [1]:
import sys
from os import path
from argparse import ArgumentParser

import torch
from torch.utils.data import DataLoader

from ttadapters import datasets, models
from ttadapters.utils import visualizer
from ttadapters.models.base import ModelProvider
from ttadapters.datasets import DatasetHolder, DataLoaderHolder

### Parse Arguments

In [2]:
# Set Batch Size
BATCH_SIZE = 2, 4, 1  # Local
#BATCH_SIZE = 32, 60, 1  # A6000
#BATCH_SIZE = 64, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = False

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "rtdetr", "hf_rtdetr", "yolo11"]
MODEL_TYPE = MODEL_ZOO[0]

In [3]:
# Create argument parser
parser = ArgumentParser(description="Training script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")
parser.add_argument("--test-only", action="store_true", help="Run in test-only mode")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args([])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
TEST_MODE = args.test_only
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityscapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")

INFO: Running in notebook mode with default arguments
INFO: Set batch size - Train: 2, Valid: 4, Test: 1


### Check GPU Availability

In [4]:
!nvidia-smi

Tue Oct  7 03:06:32 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94                 Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   48C    P8              1W /   78W |       0MiB /   6141MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

INFO: Using device - cuda:0


## Define Dataset

In [6]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()



In [7]:
# Basic pre-training dataset
match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        dataset = DatasetHolder(
            train=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
    case datasets.CityscapesDataset:
        pass
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

[10/07/2025 03:06:33] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\train. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:06:33] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\normal\discrete\images\train\front\det_2d.json' ...


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:06:33] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\normal\discrete\images\train\front\det_2d.json' Done.
[10/07/2025 03:06:38] SHIFT DevKit - INFO - Loading annotation takes 5.58 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0016-1b62']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                    -7.53     219.91
boxes2d              torch.Size([1, 26, 4])                    5.00     974.00
boxes2d_classes      torch.Size([1, 26])                       0.00       3.00
boxes2d_track_ids    torch.Size([1, 26])                       0.00      25.00
images               torch.Size([1, 3, 800, 1280])             0.00     255.00



[10/07/2025 03:06:40] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:06:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\normal\discrete\images\val\front\det_2d.json' ...
[10/07/2025 03:06:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\normal\discrete\images\val\front\det_2d.json' Done.


Video name: 0016-1b62
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:06:41] SHIFT DevKit - INFO - Loading annotation takes 0.87 seconds.
[10/07/2025 03:06:41] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:06:41] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\corrupted\discrete\images\val\front\det_2d.json' ...


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:06:42] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\corrupted\discrete\images\val\front\det_2d.json' Done.
[10/07/2025 03:06:47] SHIFT DevKit - INFO - Loading annotation takes 6.24 seconds.


INFO: Dataset loaded successfully. Number of samples - Train: 20800, Valid: 2800, Test: 22200

INFO: Number of classes - 6 ['pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle']


In [8]:
# Check annotation keys-values
dataset.train[999]

(Image([[[133., 133., 133.,  ..., 125., 124., 123.],
         [135., 135., 135.,  ..., 125., 124., 123.],
         [140., 140., 139.,  ..., 125., 123., 122.],
         ...,
         [199., 202., 205.,  ..., 208., 211., 213.],
         [204., 205., 206.,  ..., 206., 208., 209.],
         [208., 208., 206.,  ..., 205., 205., 204.]],
 
        [[152., 152., 152.,  ..., 153., 152., 151.],
         [154., 154., 154.,  ..., 153., 152., 151.],
         [157., 157., 156.,  ..., 152., 151., 150.],
         ...,
         [190., 193., 196.,  ..., 194., 196., 198.],
         [195., 196., 197.,  ..., 192., 193., 194.],
         [200., 200., 197.,  ..., 191., 190., 189.]],
 
        [[169., 169., 169.,  ..., 175., 174., 173.],
         [171., 171., 171.,  ..., 175., 174., 173.],
         [173., 173., 172.,  ..., 173., 173., 172.],
         ...,
         [175., 178., 181.,  ..., 185., 191., 193.],
         [178., 179., 180.,  ..., 183., 188., 189.],
         [181., 181., 180.,  ..., 182., 185., 184.]

In [9]:
# Check data shape
dataset.train[999][0].shape  # should be (num_channels, height, width)

torch.Size([3, 800, 1280])

In [10]:
# Visualize video
visualizer.visualize_bbox_frames(dataset.train)

Output()

## Load Model

In [11]:
from ultralytics.nn.tasks import DetectionModel
from ttadapters.models.base import BaseModel, WeightsInfo
from ttadapters.datasets import BaseDataset


class YOLO11ForObjectDetection(DetectionModel, BaseModel):
    model_name = "YOLO11"
    model_config = "yolo11m.yaml"
    model_provider = ModelProvider.Ultralytics
    channel = 3

    class Weights:
        COCO = WeightsInfo("https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11m.pt", weight_key="model")
        SHIFT_CLEAR = WeightsInfo("")

    def __init__(self, dataset: BaseDataset):
        nc = len(dataset.classes)
        super().__init__(self.model_config, ch=self.channel, nc=nc)

        self.dataset_name = dataset.dataset_name
        self.num_classes = nc

In [12]:
#MODEL_TYPE = "yolo11"

In [22]:
# Initialize model
match MODEL_TYPE:
    case "rcnn":
        model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR_NATUREYOO if TEST_MODE else model.Weights.IMAGENET_OFFICIAL), strict=False)
    case "swinrcnn":
        model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR_NATUREYOO if TEST_MODE else model.Weights.IMAGENET_XIAOHU2015), strict=False)
    case "rtdetr":
        model = models.RTDETRForObjectDetection(dataset=SOURCE_DOMAIN)
    case "hf_rtdetr":
        model = models.HFRTDETRForObjectDetection(dataset=SOURCE_DOMAIN)
    case "yolo11":
        model = YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        #model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR if TEST_MODE else model.Weights.COCO), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
model.to(device)

INFO: Model state loaded - <All keys matched successfully>


FasterRCNNForObjectDetection(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
      (res2): Sequential(
        (0): Bo

In [16]:
from ttadapters.datasets.transform import MaskedImageList
from torchvision.tv_tensors import BoundingBoxFormat
from torchvision.transforms.v2.functional import convert_bounding_box_format
import torch


def collate_fn(batch: list[Image, dict]):
    images = []
    batch_idx = []
    cls = []
    bboxes = []
    ori_shapes = []
    ratio_pads = []

    for idx, (image, metadata) in enumerate(batch):
        resized_height, resized_width = image.shape[-2:]
        original_height, original_width = metadata['original_hw']
        ori_shapes.append([original_height, original_width])

        boxes = metadata["boxes2d"]  # xyxy
        classes = metadata["boxes2d_classes"]
        boxes_cxcywh = convert_bounding_box_format(boxes, new_format=BoundingBoxFormat.CXCYWH)

        images.append(image)
        batch_idx_list.extend([idx] * len(boxes))
        cls_list.extend(classes.tolist())
        bboxes_list.extend(boxes_normalized.tolist())

    images_list = MaskedImageList.from_tensors(images)
    if len(bboxes_list) > 0:
        batch_idx_tensor = torch.tensor(batch_idx_list, dtype=torch.long)
        cls_tensor = torch.tensor(cls_list, dtype=torch.long)
        bboxes_tensor = torch.tensor(bboxes_list, dtype=torch.float32)
    else:  # no objects in the batch
        batch_idx_tensor = torch.zeros(0, dtype=torch.long)
        cls_tensor = torch.zeros(0, dtype=torch.long)
        bboxes_tensor = torch.zeros((0, 4), dtype=torch.float32)

    return {
        'img': images_list.tensor,              # Shape: [batch_size, 3, height, width]
        'batch_idx': batch_idx_tensor,          # Shape: [num_objects] - batch indices
        'cls': cls_tensor,                      # Shape: [num_objects] - class indices
        'bboxes': bboxes_tensor,                # Shape: [num_objects, 4] - normalized cxcywh (0~1)
        'ori_shapes': torch.tensor(ori_shapes), # Shape: [batch_size, 2] - original (height, width)
        'ratio_pads': torch.tensor(ratio_pads)  # Shape: [batch_size, 2, 2] - [[ratio, ratio], [pad_w, pad_h]]
    }

NameError: name 'Image' is not defined

In [23]:
# Image transform and collate function
if model.model_provider == ModelProvider.Detectron2:
    image_transform = datasets.detectron_image_transform
    from ttadapters.models.rcnn import collate_fn
else:
    image_transform = datasets.default_image_transform
    from ttadapters.models.rt_detr import collate_fn

dataset.train.transform = image_transform
dataset.train.transforms = datasets.default_train_transforms
dataset.valid.transform = image_transform
dataset.valid.transforms = datasets.default_valid_transforms
dataset.test.transform = image_transform
dataset.test.transforms = datasets.default_valid_transforms

dataloader = DataLoaderHolder(
    train=DataLoader(dataset.train, batch_size=BATCH_SIZE[0], shuffle=True, collate_fn=collate_fn),
    valid=DataLoader(dataset.valid, batch_size=BATCH_SIZE[1], shuffle=False, collate_fn=collate_fn),
    test=DataLoader(dataset.test, batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=collate_fn)
)

INFO: Loader length - Train: 0, Valid: 0, Test: 0



In [None]:
# Check dataloader
dataloader.train.__iter__().__next__()

In [None]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [None]:
from transformers import AutoBackbone, RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig, SwinConfig, ResNetConfig
from transformers.image_utils import AnnotationFormat

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = NUM_CLASSES

# Set the image size and preprocessor size
reference_config.image_size = 800

# Load the reference model image processor
reference_preprocessor = RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format
reference_preprocessor.size = {"height": 800, "width": 800}
reference_preprocessor.do_resize = False

In [None]:
if USE_PRETRAINED_MODEL:
    # Load the pre-trained model
    model = RTDetrForObjectDetection.from_pretrained(reference_model_id, config=reference_config, torch_dtype=torch.float32, ignore_mismatched_sizes=True)
    if LOAD_ONLY_COCO_BACKBONE:
        detector_state = RTDetrForObjectDetection(config=reference_config).state_dict()
        detector_state = {k: v for k, v in model.state_dict().items() if 'backbone' not in k}
        model.load_state_dict(detector_state, strict=False)

    # Initialize a new model with the reference configuration
    model = RTDetrForObjectDetection(config=reference_config)
    if USE_SHIFT_BACKBONE:
        backbone_state = torch.hub.load_state_dict_from_url(backbone_url, map_location="cpu")
        model.model.backbone.model.load_state_dict(backbone_state, strict=False)
    else:
        backbone_state = AutoBackbone.from_pretrained(backbone_id, config=reference_config.backbone_config).state_dict()
        model.model.backbone.model.load_state_dict(backbone_state, strict=False)
        if USE_SWIN_T_BACKBONE:
            model.model.backbone.model.forward.__kwdefaults__['interpolate_pos_encoding'] = True
    del backbone_state

model.to(device)

In [None]:
test_d = DatasetAdapterForTransformers(dataset.train)[5]
test_d

In [None]:
from transformers.trainer_utils import EvalPrediction
from torchvision.ops import box_convert
from dataclasses import dataclass


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


def de_normalize_boxes(boxes, height, width):
    # 1. cxcywh â†’ xyxy
    boxes_xyxy_norm = box_convert(boxes, 'cxcywh', 'xyxy')

    # 2. de-normalize (convert to actual pixel coordinates)
    boxes_xyxy_norm[:, [0, 2]] *= width
    boxes_xyxy_norm[:, [1, 3]] *= height
    return boxes_xyxy_norm


def map_compute_metrics(preprocessor=reference_preprocessor, threshold=0.0):
    map_metric = MeanAveragePrecision()
    post_process = preprocessor.post_process_object_detection

    def calc(eval_pred: EvalPrediction, compute_result=False):
        nonlocal map_metric

        if compute_result:
            m_ap = map_metric.compute()
            map_metric.reset()

            per_class_map = {
                f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean()
                for idx in m_ap.matched_classes
            }

            return {
                "mAP@0.50:0.95": m_ap.map50_95,
                "mAP@0.50": m_ap.map50,
                "mAP@0.75": m_ap.map75,
                **per_class_map
            }
        else:
            preds = ModelOutput(*eval_pred.predictions[1:3])
            labels = eval_pred.label_ids
            sizes = [label['orig_size'].cpu().tolist() for label in labels]

            results = post_process(preds, target_sizes=sizes, threshold=threshold)
            predictions = [Detections.from_transformers(result) for result in results]
            targets = [Detections(
                xyxy=de_normalize_boxes(label['boxes'], *label['orig_size']).cpu().numpy(),
                class_id=label['class_labels'].cpu().numpy(),
            ) for label in labels]

            map_metric.update(predictions=predictions, targets=targets)
            return {}
    return calc, map_metric

In [None]:
class DifferentiableLRTrainer(Trainer):
    def create_optimizer(self):
        backbone_params = []
        head_params = []

        for name, param in self.model.named_parameters():
            if 'backbone' in name:
                backbone_params.append(param)
            else:
                head_params.append(param)

        self.optimizer = torch.optim.AdamW([
            {'params': backbone_params, 'lr': self.args.backbone_lr},
            {'params': head_params, 'lr': self.args.learning_rate}
        ], weight_decay=self.args.weight_decay)

        return self.optimizer


class DifferentiableLRTrainingArguments(TrainingArguments):
    def __init__(self, *args, backbone_lr=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.backbone_lr = backbone_lr

## Train

In [None]:
# WandB Initialization
#wandb.init(project=PROJECT_NAME, name=RUN_NAME)

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 20
REAL_BATCH = BATCH_SIZE[-1]
LEARNING_RATE = 1e-4

training_args = DifferentiableLRTrainingArguments(
    backbone_lr=LEARNING_RATE/10,  # Set backbone learning rate to 1/10th of the main learning rate
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.1,
    max_grad_norm=0.5,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    gradient_accumulation_steps=REAL_BATCH//BATCH_SIZE[0],
    eval_accumulation_steps=BATCH_SIZE[1],
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=100,
    save_steps=100,
    logging_steps=100,
    save_total_limit=100,
    load_best_model_at_end=True,
    metric_for_best_model="mAP@0.50:0.95",
    greater_is_better=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    #report_to="wandb",
    output_dir="./results/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    #run_name=RUN_NAME,
    bf16=True,
)

testing_args = TrainingArguments(
    per_device_eval_batch_size=BATCH_SIZE[2],
    batch_eval_metrics=True,
    remove_unused_columns=False,
)

In [None]:
from functools import partial

compute_metrics, compute_results = map_compute_metrics(preprocessor=reference_preprocessor)

trainer = DifferentiableLRTrainer(
    model=model,
    args=training_args,
    train_dataset=DatasetAdapterForTransformers(dataset.train),
    eval_dataset=DatasetAdapterForTransformers(dataset.valid),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
)

tester = Trainer(
    model=model,
    args=testing_args,
    eval_dataset=DatasetAdapterForTransformers(dataset.test),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

## Evaluation

### Load Scenarios

In [24]:
discrete_scenario = datasets.scenarios.SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True, order=datasets.scenarios.SHIFTDiscreteScenario.WHWPAPER,
    transform=image_transform, transforms=datasets.default_valid_transforms
)
continuous_scenario = datasets.scenarios.SHIFTContinuousScenario(
    root=DATA_ROOT, valid=True, order=datasets.scenarios.SHIFTContinuousScenario.DEFAULT,
    transform=image_transform, transforms=datasets.default_valid_transforms
)

[10/07/2025 03:09:57] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:09:57] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\cloudy_daytime\discrete\images\val\front\det_2d.json' ...
[10/07/2025 03:09:57] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\cloudy_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:09:58] SHIFT DevKit - INFO - Loading annotation takes 1.11 seconds.
[10/07/2025 03:09:58] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:09:58] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\overcast_daytime\discrete\images\val\front\det_2d.json' ...
[10/07/2025 03:09:59] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\overcast_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:09:59] SHIFT DevKit - INFO - Loading annotation takes 0.74 seconds.
[10/07/2025 03:09:59] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:09:59] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\foggy_daytime\discrete\images\val\front\det_2d.json' ...
[10/07/2025 03:09:59] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\foggy_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:00] SHIFT DevKit - INFO - Loading annotation takes 0.93 seconds.
[10/07/2025 03:10:00] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:00] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\rainy_daytime\discrete\images\val\front\det_2d.json' ...
[10/07/2025 03:10:00] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\rainy_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:01] SHIFT DevKit - INFO - Loading annotation takes 0.95 seconds.
[10/07/2025 03:10:01] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:01] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_dawn\discrete\images\val\front\det_2d.json' ...


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:01] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_dawn\discrete\images\val\front\det_2d.json' Done.
[10/07/2025 03:10:02] SHIFT DevKit - INFO - Loading annotation takes 0.48 seconds.
[10/07/2025 03:10:02] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:02] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_night\discrete\images\val\front\det_2d.json' ...
[10/07/2025 03:10:02] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_night\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:02] SHIFT DevKit - INFO - Loading annotation takes 0.57 seconds.
[10/07/2025 03:10:02] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:02] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_daytime\discrete\images\val\front\det_2d.json' ...
[10/07/2025 03:10:02] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:03] SHIFT DevKit - INFO - Loading annotation takes 0.84 seconds.
[10/07/2025 03:10:03] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\1x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:03] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\1x\val\front\det_2d.json' ...
[10/07/2025 03:10:03] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\1x\val\front\det_2d.json' Done.


ContinuousSubsetType.DAYTIME_TO_NIGHT
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:03] SHIFT DevKit - INFO - Loading annotation takes 0.21 seconds.
[10/07/2025 03:10:03] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\10x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:03] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\10x\val\front\det_2d.json' ...
[10/07/2025 03:10:03] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\10x\val\front\det_2d.json' Done.


ContinuousSubsetType.DAYTIME_TO_NIGHT
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/10x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/10x...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:04] SHIFT DevKit - INFO - Loading annotation takes 0.93 seconds.
[10/07/2025 03:10:04] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\100x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:04] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\100x\val\front\det_2d.json' ...
[10/07/2025 03:10:04] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\100x\val\front\det_2d.json' Done.


ContinuousSubsetType.DAYTIME_TO_NIGHT
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/100x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/100x...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:05] SHIFT DevKit - INFO - Loading annotation takes 0.67 seconds.
[10/07/2025 03:10:05] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\1x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:05] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_foggy\continuous\images\1x\val\front\det_2d.json' ...
[10/07/2025 03:10:05] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_foggy\continuous\images\1x\val\front\det_2d.json' Done.
[10/07/2025 03:10:05] SHIFT DevKit - INFO - Loading annotation takes 0.15 seconds.
[10/07/2025 03:10:05] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\10x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:05] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_foggy\continuous\images\10x\val\front\det_2d.json' ...


ContinuousSubsetType.CLEAR_TO_FOGGY
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.
ContinuousSubsetType.CLEAR_TO_FOGGY
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/10x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/10x...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:05] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_foggy\continuous\images\10x\val\front\det_2d.json' Done.
[10/07/2025 03:10:06] SHIFT DevKit - INFO - Loading annotation takes 1.15 seconds.
[10/07/2025 03:10:06] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\100x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:06] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_foggy\continuous\images\100x\val\front\det_2d.json' ...
[10/07/2025 03:10:06] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_foggy\continuous\images\100x\val\front\det_2d.json' Done.


ContinuousSubsetType.CLEAR_TO_FOGGY
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/100x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/100x...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:08] SHIFT DevKit - INFO - Loading annotation takes 1.72 seconds.
[10/07/2025 03:10:08] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\1x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:08] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_rainy\continuous\images\1x\val\front\det_2d.json' ...
[10/07/2025 03:10:08] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_rainy\continuous\images\1x\val\front\det_2d.json' Done.
[10/07/2025 03:10:08] SHIFT DevKit - INFO - Loading annotation takes 0.15 seconds.
[10/07/2025 03:10:08] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\10x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:08] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_rainy\continuous\images\10x\val\front\det_2d.json' ...


ContinuousSubsetType.CLEAR_TO_RAINY
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.
ContinuousSubsetType.CLEAR_TO_RAINY
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/10x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/10x...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:08] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_rainy\continuous\images\10x\val\front\det_2d.json' Done.
[10/07/2025 03:10:09] SHIFT DevKit - INFO - Loading annotation takes 1.23 seconds.
[10/07/2025 03:10:10] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\100x\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x00000167FFCE8560>
[10/07/2025 03:10:10] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_rainy\continuous\images\100x\val\front\det_2d.json' ...


ContinuousSubsetType.CLEAR_TO_RAINY
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/100x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/100x...
INFO: Dataset archive found in the root directory. Skipping download.


[10/07/2025 03:10:10] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_to_rainy\continuous\images\100x\val\front\det_2d.json' Done.
[10/07/2025 03:10:12] SHIFT DevKit - INFO - Loading annotation takes 2.34 seconds.


In [18]:
import copy
import time
import gc
import asyncio
import nest_asyncio

from tqdm.auto import tqdm

from torch import nn, OutOfMemoryError

from supervision.detection.core import Detections
from supervision.metrics.mean_average_precision import MeanAveragePrecision


class DetectionEvaluator:
    def __init__(
        self, model: nn.Module | list[nn.Module], threshold: float = 0.0, required_reset: bool = False, 
        dtype=torch.float32, device=torch.device("cuda"), synchronize: bool = True, no_grad: bool = True
    ):
        self.do_parallel = isinstance(model, list)
        self.model = [m.to(device).to(dtype) for m in model] if self.do_parallel else model.to(device).to(dtype)
        self.required_reset = required_reset
        self.dtype = dtype
        self.device = device
        self.threshold = threshold
        self.synchronize = synchronize
        self.no_grad = no_grad

    @staticmethod
    def evaluate_with_reset(
        model: nn.Module, desc: str, loader: DataLoader, loader_length: int, threshold: float = 0.0, reset: bool = True,
        dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cuda"),
        synchronize: bool = True, no_grad: bool = True, clear_tqdm_when_oom: bool = False
    ):
        torch.cuda.empty_cache(); torch.cuda.empty_cache(); torch.cuda.empty_cache()
        gc.collect(); gc.collect(); gc.collect()

        if reset:
            try:
                model.reset_adaptation()
            except NotImplementedError:
                print("WARNING: reset_adaptation() is not implemented for this model. Assuming the evaluation is running with deep-copy mode.")
                model = copy.deepcopy(model)

        model = model.to(device).to(dtype)
        model.eval()

        map_metric = MeanAveragePrecision()
        predictions_list = []
        targets_list = []
        total_images = 0
        collapse_time = 0
        
        if no_grad:  # use no_grad for inference
            disable_grad = torch.no_grad
        else:  # let model decide gradient requirement
            disable_grad = lambda: (yield)

        tqdm_loader = tqdm(loader, total=loader_length, desc=f"Evaluation for {desc}")
        try:
            with disable_grad():
                for batch in tqdm_loader:
                    total_images += len(batch)
                    with torch.autocast(device_type=device.type, dtype=dtype):
                        start = time.time()
                        outputs = model(batch)

                        if device.type == "cuda" and synchronize:
                            torch.cuda.synchronize()

                        collapse_time += time.time() - start

                    for output, input_data in zip(outputs, batch):
                        instances = output['instances']
                        mask = instances.scores > threshold

                        pred_detection = Detections(
                            xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(),
                            class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                            confidence=instances.scores[mask].detach().cpu().numpy()
                        )
                        gt_instances = input_data['instances']
                        target_detection = Detections(
                            xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                            class_id=gt_instances.gt_classes.detach().cpu().numpy()
                        )

                        predictions_list.append(pred_detection)
                        targets_list.append(target_detection)
        except OutOfMemoryError as e:  # catch OOM error to close tqdm properly
            tqdm_loader.close()
            if clear_tqdm_when_oom:
                tqdm_loader.container.close()
            raise e

        map_metric.update(predictions=predictions_list, targets=targets_list)
        m_ap = map_metric.compute()

        per_class_map = {
            f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
            for idx in m_ap.matched_classes
        }
        performances = {
            "fps": total_images / collapse_time,
            "collapse_time": collapse_time
        }

        result = {
            "mAP@0.50:0.95": m_ap.map50_95.item(),
            "mAP@0.50": m_ap.map50.item(),
            "mAP@0.75": m_ap.map75.item(),
            **performances,
            **per_class_map
        }
        return result

    @staticmethod
    def evaluate(
        model: nn.Module, desc: str, loader: DataLoader, loader_length: int, threshold: float = 0.0,
        dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cuda"),
        synchronize: bool = True, no_grad: bool = True, clear_tqdm_when_oom: bool = False
    ):
        return DetectionEvaluator.evaluate_with_reset(
            model, desc, loader, loader_length, threshold, reset=False, dtype=dtype, device=device,
            synchronize=synchronize, no_grad=no_grad, clear_tqdm_when_oom=clear_tqdm_when_oom
        )

    async def evaluate_recursively(self, module: nn.Module | list[nn.Module], *args, **kwargs):
        if isinstance(module, list):
            try:  # run all
                return await asyncio.gather(*[self.evaluate_recursively(m, *args, **kwargs) for m in module])
            except OutOfMemoryError:  # on OOM, try to run half
                if self.device.type == "cuda":
                    torch.cuda.synchronize()  # ensure all coroutine are finished
                results = []
                sub_modules = [module[:len(module)//2], module[len(module)//2:]]
                sub_modules[0] = sub_modules[0] if len(sub_modules[0]) else sub_modules[0][0]
                sub_modules[1] = sub_modules[1] if len(sub_modules[1]) else sub_modules[1][0]

                for sub_module in sub_modules:
                    result = await self.evaluate_recursively(sub_module, *args, **kwargs)
                    if isinstance(result, list):
                        results.extend(result)
                    else:
                        results.append(result)
            except KeyboardInterrupt:  # handle keyboard interrupt
                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                raise
            return results
        else:
            return await asyncio.to_thread(
                self.evaluate_with_reset,
                module, *args, **kwargs, threshold=self.threshold, reset=self.required_reset,
                dtype=self.dtype, device=self.device, synchronize=self.synchronize, no_grad=self.no_grad, clear_tqdm_when_oom=True
            )

    def __call__(self, *args, **kwargs):
        if self.do_parallel:
            nest_asyncio.apply()
            try:
                return asyncio.run(self.evaluate_recursively(self.model, *args, **kwargs))
            except KeyboardInterrupt:
                print("\nEvaluation interrupted by user")
                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                raise
        return self.evaluate_with_reset(
            self.model, *args, **kwargs, threshold=self.threshold, reset=self.required_reset,
            dtype=self.dtype, device=self.device, synchronize=self.synchronize, no_grad=self.no_grad
        )

In [25]:
evaluator = DetectionEvaluator(model, dtype=DATA_TYPE, device=device)
del model  # let evaluator handle the memory of the model

In [26]:
visualizer.visualize_metrics(discrete_scenario(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=collate_fn).play(evaluator, index=["Direct-Test"]))

Output()

SHIFT Discrete Scenario:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluation for cloudy_daytime:   0%|          | 0/2400 [00:00<?, ?it/s]



Evaluation for overcast_daytime:   0%|          | 0/1600 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
visualizer.visualize_metrics(continuous_scenario(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=collate_fn).play(evaluator, index=["Direct-Test"]))