# Pre-training with SHIFT-Discrete Dataset (Clear-Daytime)

## Imports and Configs

In [None]:
import sys
from os import path
from argparse import ArgumentParser

import torch

from ttadapters import datasets, models
from ttadapters.utils import visualizer
from ttadapters.models.base import ModelProvider
from ttadapters.datasets import DatasetHolder, scenarios

### Parse Arguments

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 32, 60, 1  # A6000
#BATCH_SIZE = 50, 300, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = False

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "rtdetr", "hf_rtdetr", "yolo11"]
MODEL_TYPE = MODEL_ZOO[0]

In [None]:
# Create argument parser
parser = ArgumentParser(description="Training script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")
parser.add_argument("--test-only", action="store_true", help="Run in test-only mode")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args([])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
TEST_MODE = args.test_only
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityscapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

## Define Dataset

In [None]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()

In [None]:
# Basic pre-training dataset
match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        dataset = DatasetHolder(
            train=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
    case datasets.CityscapesDataset:
        pass
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
# Check annotation keys-values
dataset.train[999]

In [None]:
# Check data shape
dataset.train[999][0].shape  # should be (num_channels, height, width)

In [None]:
# Visualize video
visualizer.visualize_bbox_frames(dataset.train)

## Load Model

#### YOLO11

In [None]:
from ultralytics.models.yolo.detect import DetectionTrainer

In [None]:
from torchvision.tv_tensors import BoundingBoxFormat
from torchvision.transforms.v2.functional import convert_bounding_box_format
import torch


def collate_fn(batch):
    images = []
    batch_idx = []
    cls = []
    bboxes = []
    ori_shapes = []
    ratio_pads = []

    for idx, (image, metadata) in enumerate(batch):
        resized_height, resized_width = image.shape[-2:]
        original_height, original_width = metadata['original_hw']
        ori_shapes.append([original_height, original_width])

        boxes = metadata["boxes2d"]  # xyxy
        classes = metadata["boxes2d_classes"]
        boxes_cxcywh = convert_bounding_box_format(boxes, new_format=BoundingBoxFormat.CXCYWH)

        images.append(image)
        batch_idx_list.extend([idx] * len(boxes))
        cls_list.extend(classes.tolist())
        bboxes_list.extend(boxes_normalized.tolist())

    images_list = MaskedImageList.from_tensors(images)
    if len(bboxes_list) > 0:
        batch_idx_tensor = torch.tensor(batch_idx_list, dtype=torch.long)
        cls_tensor = torch.tensor(cls_list, dtype=torch.long)
        bboxes_tensor = torch.tensor(bboxes_list, dtype=torch.float32)
    else:  # no objects in the batch
        batch_idx_tensor = torch.zeros(0, dtype=torch.long)
        cls_tensor = torch.zeros(0, dtype=torch.long)
        bboxes_tensor = torch.zeros((0, 4), dtype=torch.float32)

    return {
        'img': images_list.tensor,              # Shape: [batch_size, 3, height, width]
        'batch_idx': batch_idx_tensor,          # Shape: [num_objects] - batch indices
        'cls': cls_tensor,                      # Shape: [num_objects] - class indices
        'bboxes': bboxes_tensor,                # Shape: [num_objects, 4] - normalized cxcywh (0~1)
        'ori_shapes': torch.tensor(ori_shapes), # Shape: [batch_size, 2] - original (height, width)
        'ratio_pads': torch.tensor(ratio_pads)  # Shape: [batch_size, 2, 2] - [[ratio, ratio], [pad_w, pad_h]]
    }

In [None]:
from ultralytics.nn.tasks import DetectionModel
from ttadapters.models.base import BaseModel, WeightsInfo
from ttadapters.datasets import BaseDataset


class YOLO11ForObjectDetection(DetectionModel, BaseModel):
    model_name = "YOLO11"
    model_config = "yolo11m.yaml"
    model_provider = ModelProvider.Ultralytics
    channel = 3

    class Weights:
        COCO = WeightsInfo("https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11m.pt", weight_key="model")
        SHIFT_CLEAR = WeightsInfo("")

    def __init__(self, dataset: BaseDataset):
        nc = len(dataset.classes)
        super().__init__(self.model_config, ch=self.channel, nc=nc)

        self.dataset_name = dataset.dataset_name
        self.num_classes = nc

#### Faster-RCNN

In [None]:
import warnings

from torch import nn
from torchvision.tv_tensors import Image, BoundingBoxFormat, BoundingBoxes

from detectron2.engine import DefaultTrainer
from detectron2.structures import Boxes, Instances
from detectron2.data.transforms import ResizeShortestEdge, ConvertRGBtoBGR

from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool
from detectron2.modeling import GeneralizedRCNN, SwinTransformer
from detectron2.modeling import build_model
from detectron2.config import get_cfg
from detectron2 import model_zoo

#from ..base import BaseModel, ModelProvider, WeightsInfo
#from ...datasets import BaseDataset, DataPreparation


class DetectronTrainer(DefaultTrainer):
    pass


class DetectronDataPreparation(DataPreparation):
    def __init__(
        self, bbox_key: str = "boxes2d", class_key: str = "boxes2d_classes", original_size_key: str = "original_hw"
    ):
        super().__init__()
        self.bbox_key = bbox_key
        self.class_key = class_key
        self.original_size_key = original_size_key

    detectron_image_transform = T.Compose([
        ConvertRGBtoBGR()
    ])

    default_train_transforms = T.Compose([
        ResizeShortestEdge([640, 672, 704, 736, 768, 800], max_size=1333, box_key='boxes2d'),  # Detectron2 Faster R-CNN default training transform
        T.RandomHorizontalFlip(p=0.5)  # Random horizontal flip with 50% probability
    ])

    default_valid_transforms = T.Compose([
        ResizeShortestEdge(800, max_size=1333, box_key='boxes2d')  # Detectron2 Faster R-CNN default validation transform
    ])

    def collate_fn(self, batch: list[Image, dict]):
        batched_inputs = []
        for image, metadata in batch:
            resized_height, resized_width = image.shape[-2:]
            original_height, original_width = metadata['original_hw']
            instances = Instances(image_size=(resized_height, resized_width))
            bboxes: BoundingBoxes = metadata["boxes2d"]  # xyxy
            if bboxes.format != BoundingBoxFormat.XYXY:
                bboxes = bboxes.convert_format(BoundingBoxFormat.XYXY, image_size=(original_height, original_width))
            instances.gt_boxes = Boxes()  # xyxy
            instances.gt_classes = metadata["boxes2d_classes"]
            batched_inputs.append({
                "image": image,
                "instances": instances,
                "height": original_height,
                "width": original_width
            })
        return batched_inputs


class FasterRCNNForObjectDetection(GeneralizedRCNN, BaseModel):
    model_name = "Faster_R-CNN-R50"
    model_config = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
    model_provider = ModelProvider.Detectron2
    DataPreparation = DataPreparation
    Trainer = None

    class Weights:
        IMAGENET_OFFICIAL = WeightsInfo("detectron2://ImageNetPretrained/MSRA/R-50.pkl")
        SHIFT_CLEAR_NATUREYOO = WeightsInfo("https://github.com/robustaim/ContinualTTA_ObjectDetection/releases/download/backbone/Faster_R-CNN_Resnet_50_SHIFT.pth", weight_key="model")

    def __init__(self, dataset: BaseDataset):
        num_classes = len(dataset.classes)

        cfg = get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file(self.model_config))
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', message='.*To copy construct from a tensor.*')
            modules = build_model(cfg)
            super().__init__(
                backbone=modules.backbone,
                proposal_generator=modules.proposal_generator,
                roi_heads=modules.roi_heads,
                pixel_mean=modules.pixel_mean,
                pixel_std=modules.pixel_std,
                input_format=modules.input_format,
                vis_period=modules.vis_period,
            )

        self.dataset_name = dataset.dataset_name
        self.num_classes = num_classes


class SwinRCNNForObjectDetection(GeneralizedRCNN, BaseModel):
    model_name = "SwinT_R-CNN-Tiny"
    model_provider = ModelProvider.Detectron2
    DataPreparation = DataPreparation
    Trainer = None
    default_params = dict(
        patch_size=4,
        in_chans=3,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        norm_layer=nn.LayerNorm,
        ape=False,
        patch_norm=True,
        frozen_stages=-1,
        out_indices=(0, 1, 2, 3)
    )

    class Weights:
        IMAGENET_XIAOHU2015 = WeightsInfo("https://github.com/xiaohu2015/SwinT_detectron2/releases/download/v1.1/faster_rcnn_swint_T.pth", weight_key="model", exclude_keys = [
                "roi_heads.box_predictor.cls_score.weight",
                "roi_heads.box_predictor.cls_score.bias",
                "roi_heads.box_predictor.bbox_pred.weight",
                "roi_heads.box_predictor.bbox_pred.bias"
            ]
        )
        SHIFT_CLEAR_NATUREYOO = WeightsInfo("https://github.com/robustaim/ContinualTTA_ObjectDetection/releases/download/backbone/Faster_R-CNN_SwinT_Tiny_SHIFT.pth", weight_key="model")

    def __init__(self, dataset: BaseDataset):
        num_classes = len(dataset.classes)

        cfg = get_cfg()
        base_config = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
        cfg.merge_from_file(model_zoo.get_config_file(base_config))

        cfg.MODEL.MASK_ON = False
        cfg.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
        cfg.MODEL.PIXEL_STD = [57.375, 57.120, 58.395]

        cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', category=FutureWarning)
            warnings.filterwarnings('ignore', category=UserWarning)
            modules = build_model(cfg)
            cfg.MODEL.FPN.IN_FEATURES = ["stage2", "stage3", "stage4", "stage5"]

            swin_backbone = SwinTransformer(**self.default_params)
            swin_backbone._out_features = ["stage{}".format(i+2) for i in swin_backbone.out_indices]
            swin_backbone._out_feature_channels = {
                "stage{}".format(i+2): swin_backbone.embed_dim * 2**i
                for i in swin_backbone.out_indices
            }
            swin_backbone._out_feature_strides = {
                "stage{}".format(i+2): 2 ** (i + 2)
                for i in swin_backbone.out_indices
            }
            original_forward = swin_backbone.forward

            def patched_forward(x):
                outs_orig = original_forward(x)
                outs = {}
                for i in swin_backbone.out_indices:
                    outs["stage{}".format(i+2)] = outs_orig["p{}".format(i)]
                return outs

            swin_backbone.forward = patched_forward

            super().__init__(
                backbone=FPN(
                    bottom_up=swin_backbone,
                    in_features=cfg.MODEL.FPN.IN_FEATURES,
                    out_channels=cfg.MODEL.FPN.OUT_CHANNELS,
                    norm=cfg.MODEL.FPN.NORM,
                    top_block=LastLevelMaxPool(),
                    fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
                ),
                proposal_generator=modules.proposal_generator,
                roi_heads=modules.roi_heads,
                pixel_mean=modules.pixel_mean,
                pixel_std=modules.pixel_std,
                input_format=modules.input_format,
                vis_period=modules.vis_period,
            )

        self.dataset_name = dataset.dataset_name
        self.num_classes = num_classes


In [None]:
model = models.HFRTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
load_result = model.load_from("./results/models/rt_detr", strict=False)
print("INFO: Model state loaded -", load_result)
model.to(device)

In [None]:
# Initialize model
match MODEL_TYPE:
    case "rcnn":
        model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR_NATUREYOO if TEST_MODE else model.Weights.IMAGENET_OFFICIAL), strict=False)
    case "swinrcnn":
        model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR_NATUREYOO if TEST_MODE else model.Weights.IMAGENET_XIAOHU2015), strict=False)
    case "rtdetr":
        model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR if TEST_MODE else model.Weights.COCO_OFFICIAL), strict=False)
    case "hf_rtdetr":
        model = models.HFRTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR if TEST_MODE else model.Weights.COCO_OFFICIAL), strict=False)
    case "yolo11":
        model = YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        #model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR if TEST_MODE else model.Weights.COCO_OFFICIAL), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
model.to(device)

## Train

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
PROJECT_NAME = "tta_model_pretraining"
RUN_NAME = model.model_name + "_" + SOURCE_DOMAIN.dataset_name + ("_test" if TEST_MODE else "_train")

# WandB Initialization
#import wandb
#wandb.init(project=PROJECT_NAME, name=RUN_NAME)

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 30
LEARNING_RATE = 3e-4

In [None]:
trainer = model.Trainer(
    model=model,
    classes=CLASSES,
    train_dataset=model.DataPreparation(dataset.train, strong_augment_threshold_epoch=20),
    eval_dataset=model.DataPreparation(dataset.valid, evaluation_mode=True),
    args=model.TrainingArguments(
        backbone_learning_rate=LEARNING_RATE/10,  # Set backbone learning rate to 1/10th of the main learning rate
        learning_rate=LEARNING_RATE,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        weight_decay=0.1,
        max_grad_norm=0.5,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE[0],
        per_device_eval_batch_size=BATCH_SIZE[1],
        gradient_accumulation_steps=ACCUMULATE_STEPS,
        eval_accumulation_steps=BATCH_SIZE[1],
        batch_eval_metrics=True,
        remove_unused_columns=False,
        optim="adamw_torch",
        eval_on_start=True,
        eval_strategy="epoch",  #"steps",
        save_strategy="epoch",  #"steps",
        logging_strategy="epoch",  #"steps",
        #eval_steps=100,
        #save_steps=100,
        #logging_steps=100,
        save_total_limit=100,
        load_best_model_at_end=True,
        metric_for_best_model="mAP@0.50:0.95",
        greater_is_better=True,
        #report_to="wandb",
        output_dir="./results/"+RUN_NAME,
        logging_dir="./logs/"+RUN_NAME,
        run_name=RUN_NAME,
    )
)

validator = model.Trainer(
    model=model,
    classes=CLASSES,
    eval_dataset=model.DataPreparation(dataset.test, evaluation_mode=True),
    args=model.TrainingArguments(
        per_device_eval_batch_size=BATCH_SIZE[1],
        batch_eval_metrics=True,
        remove_unused_columns=False
    )
)

In [None]:
try:
    trainer.train(resume_from_checkpoint=True)
except FileNotFoundError:
    trainer.train()

In [None]:
trainer.evaluate()

In [None]:
validator.evaluate()

## Evaluation

### Load Scenarios

In [None]:
data_preparation = model.DataPreparation((), evaluation_mode=True)

discrete_scenario = scenarios.SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
)
continuous_scenario = scenarios.SHIFTContinuousScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTContinuousScenario.DEFAULT, transforms=data_preparation.transforms
)

In [None]:
import copy
import time
import gc
import asyncio
import nest_asyncio
from typing import Callable

from tqdm.auto import tqdm

import torch
from torch import OutOfMemoryError
from torch.utils.data import DataLoader
from torchvision.ops import box_convert

from supervision.detection.core import Detections
from supervision.metrics.mean_average_precision import MeanAveragePrecision

#from ..models.base import BaseModel, ModelProvider
#from ..datasets import DataPreparation
from ttadapters.models.base import BaseModel, ModelProvider
from ttadapters.datasets import DataPreparation


class DetectionEvaluator:
    def __init__(
        self, model: BaseModel | list[BaseModel], classes: list[str], data_preparation: DataPreparation, required_reset: bool = False, 
        dtype=torch.float32, device=torch.device("cuda"), synchronize: bool = True, no_grad: bool = True
    ):
        self.do_parallel = isinstance(model, list)
        self.model = [m.to(device).to(dtype) for m in model] if self.do_parallel else model.to(device).to(dtype)
        self.data_preparation = data_preparation
        self.classes = classes
        self.required_reset = required_reset
        self.dtype = dtype
        self.device = device
        self.synchronize = synchronize
        self.no_grad = no_grad

    @staticmethod
    def evaluate_with_reset(
        model: BaseModel, desc: str, loader: DataLoader, loader_length: int, classes: list[str], data_preparation: DataPreparation,
        reset: bool = True, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cuda"),
        synchronize: bool = True, no_grad: bool = True, clear_tqdm_when_oom: bool = False
    ):
        torch.cuda.empty_cache(); torch.cuda.empty_cache(); torch.cuda.empty_cache()
        gc.collect(); gc.collect(); gc.collect()

        if reset:
            try:
                model.reset_adaptation()
            except NotImplementedError:
                print("WARNING: reset_adaptation() is not implemented for this model. Assuming the evaluation is running with deep-copy mode.")
                model = copy.deepcopy(model)

        model = model.to(device).to(dtype)
        model.eval()

        map_metric = MeanAveragePrecision()
        predictions_list = []
        targets_list = []
        total_images = 0
        collapse_time = 0

        if no_grad:  # use no_grad for inference
            disable_grad = torch.no_grad
        else:  # let model decide gradient requirement
            disable_grad = lambda: (yield)

        tqdm_loader = tqdm(loader, total=loader_length, desc=f"Evaluation for {desc}")
        try:
            with disable_grad():
                for batch in tqdm_loader:
                    if model.model_provider == ModelProvider.HuggingFace:
                        batch = {
                            k: v.to(device) if isinstance(v, torch.Tensor) else v 
                            for k, v in batch.items()
                        }
                        batch['labels'] = [{
                            k: v.to(device) if isinstance(v, torch.Tensor) else v
                            for k, v in label.items()
                        } for label in batch['labels']]
                        total_images += len(batch['labels'])
                    else:
                        total_images += len(batch)

                    with torch.autocast(device_type=device.type, dtype=dtype):
                        start = time.time()
                        match model.model_provider:
                            case ModelProvider.Detectron2:
                                outputs = model(batch)
                            case ModelProvider.Ultralytics:
                                outputs = model(*batch)
                            case ModelProvider.HuggingFace:
                                outputs = model(**batch)
                            case _:
                                raise ValueError(f"Unsupported model provider: {model.model_provider}")

                        if device.type == "cuda" and synchronize:
                            torch.cuda.synchronize()

                        collapse_time += time.time() - start

                    match model.model_provider:
                        case ModelProvider.Detectron2:
                            for output, input_data in zip(outputs, batch):
                                output = data_preparation.post_process(output)
                                predictions_list.append(Detections.from_detectron2(output))
                                targets_list.append(target_detection = Detections(
                                    xyxy=input_data['instances'].gt_boxes.tensor.detach().cpu().numpy(),
                                    class_id=input_data['instances'].gt_classes.detach().cpu().numpy()
                                ))
                        case ModelProvider.Ultralytics:
                            output = data_preparation.post_process(output)
                            pred_detection = Detections.from_ultralytics(output)
                            target_detection = Detections(
                                xyxy=input_data['bboxes'].detach().cpu().numpy(),
                                class_id=input_data['cls'].detach().cpu().numpy()
                            )
                            raise NotImplementedError("Ultralytics post_process is not implemented yet.")
                        case ModelProvider.HuggingFace:
                            sizes = [label['orig_size'].cpu().tolist() for label in batch['labels']]
                            outputs = data_preparation.post_process(outputs, target_sizes=sizes)
                            predictions_list.extend([Detections.from_transformers(output) for output in outputs])
                            targets_list.extend([Detections(
                                xyxy=(box_convert(label['boxes'], "cxcywh", "xyxy") * label['orig_size'].flip(0).repeat(2)).cpu().numpy(),
                                class_id=label['class_labels'].cpu().numpy(),
                            ) for label in batch['labels']])
                        case _:
                            raise ValueError(f"Unsupported model provider: {model.model_provider}")
        except OutOfMemoryError as e:  # catch OOM error to close tqdm properly
            tqdm_loader.close()
            if clear_tqdm_when_oom:
                tqdm_loader.container.close()
            raise e

        map_metric.update(predictions=predictions_list, targets=targets_list)
        m_ap = map_metric.compute()

        per_class_map = {
            f"{classes[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
            for idx in m_ap.matched_classes
        }
        performances = {
            "fps": total_images / collapse_time,
            "collapse_time": collapse_time
        }

        result = {
            "mAP@0.50:0.95": m_ap.map50_95.item(),
            "mAP@0.50": m_ap.map50.item(),
            "mAP@0.75": m_ap.map75.item(),
            **performances,
            **per_class_map
        }
        return result

    @staticmethod
    def evaluate(
        model: BaseModel, desc: str, loader: DataLoader, loader_length: int, classes: list[str], data_preparation: DataPreparation,
        dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cuda"),
        synchronize: bool = True, no_grad: bool = True, clear_tqdm_when_oom: bool = False
    ):
        return DetectionEvaluator.evaluate_with_reset(
            model, desc, loader, loader_length, classes=classes, data_preparation=data_preparation,
            reset=False, dtype=dtype, device=device,
            synchronize=synchronize, no_grad=no_grad, clear_tqdm_when_oom=clear_tqdm_when_oom
        )

    async def evaluate_recursively(self, module: BaseModel | list[BaseModel], *args, **kwargs):
        if isinstance(module, list):
            try:  # run all
                return await asyncio.gather(*[self.evaluate_recursively(m, *args, **kwargs) for m in module])
            except OutOfMemoryError:  # on OOM, try to run half
                if self.device.type == "cuda":
                    torch.cuda.synchronize()  # ensure all coroutine are finished
                results = []
                sub_modules = [module[:len(module)//2], module[len(module)//2:]]
                sub_modules[0] = sub_modules[0] if len(sub_modules[0]) else sub_modules[0][0]
                sub_modules[1] = sub_modules[1] if len(sub_modules[1]) else sub_modules[1][0]

                for sub_module in sub_modules:
                    result = await self.evaluate_recursively(sub_module, *args, **kwargs)
                    if isinstance(result, list):
                        results.extend(result)
                    else:
                        results.append(result)
            except KeyboardInterrupt:  # handle keyboard interrupt
                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                raise
            return results
        else:
            return await asyncio.to_thread(
                self.evaluate_with_reset,
                module, *args, **kwargs, reset=self.required_reset, classes=self.classes, data_preparation=self.data_preparation,
                dtype=self.dtype, device=self.device, synchronize=self.synchronize, no_grad=self.no_grad, clear_tqdm_when_oom=True
            )

    def __call__(self, *args, **kwargs):
        if self.do_parallel:
            nest_asyncio.apply()
            try:
                return asyncio.run(self.evaluate_recursively(self.model, *args, **kwargs))
            except KeyboardInterrupt:
                print("\nEvaluation interrupted by user")
                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                raise
        return self.evaluate_with_reset(
            self.model, *args, **kwargs, reset=self.required_reset, classes=self.classes, data_preparation=self.data_preparation,
            dtype=self.dtype, device=self.device, synchronize=self.synchronize, no_grad=self.no_grad
        )

In [None]:
evaluator = DetectionEvaluator(model, classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
visualizer.visualize_metrics(discrete_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))

In [None]:
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))