# Pre-training with SHIFT-Discrete Dataset (Clear-Daytime)

## Imports and Configs

In [None]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch

from ttadapters import datasets, models
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, scenarios, SHIFTDataset

In [None]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

### Parse Arguments

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = False

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[-1]

In [None]:
# Create argument parser
parser = ArgumentParser(description="Training script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")
parser.add_argument("--test-only", action="store_true", help="Run in test-only mode")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args(["--test-only"] if TEST_MODE else [])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
TEST_MODE = args.test_only
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityScapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Set test mode - {TEST_MODE} for {SOURCE_DOMAIN.dataset_name} dataset")

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

## Define Dataset

In [None]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()

In [None]:
# Basic pre-training dataset
match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        discrete_dataset = DatasetHolder(
            train=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
        continuous_dataset = DatasetHolder(
            train=datasets.SHIFTContinuous100DatasetForObjectDetection(root=DATA_ROOT),
            valid=datasets.SHIFTContinuous10DatasetForObjectDetection(root=DATA_ROOT),
            test=datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT)
        )
        dataset = discrete_dataset
    case datasets.CityScapesDataset:
        discrete_dataset = DatasetHolder(
            train=datasets.CityScapesForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.CityScapesForObjectDetection(root=DATA_ROOT, valid=True),
            test=None
        )
        continuous_dataset = None
        dataset = discrete_dataset
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
# Check annotation keys-values
dataset.train[999]

In [None]:
# Check data shape
dataset.train[999][0].shape  # should be (num_channels, height, width)

In [None]:
# Visualize video
visualizer.visualize_bbox_frames(dataset.train)

## Load Model

In [None]:
import warnings

class DummyYOLO:
    """
    Dummy YOLO model that provides helpful installation instructions.
    """

    def __init__(self, model_name: str = "yolo11n"):
        self.model_name = model_name
        self._show_install_message()

    def _show_install_message(self):
        msg = (
            f"\n{'='*70}\n"
            f"YOLO model '{self.model_name}' requires Ultralytics library.\n"
            f"{'='*70}\n\n"
            f"To use YOLO models, install Ultralytics:\n"
            f"    pip install ultralytics\n\n"
            f"Note: Ultralytics is licensed under AGPL-3.0.\n"
            f"By installing it, you agree to comply with AGPL-3.0 terms.\n"
            f"See: https://github.com/ultralytics/ultralytics\n"
            f"{'='*70}\n"
        )
        warnings.warn(msg, RuntimeWarning, stacklevel=2)

    def __call__(self, *args, **kwargs):
        raise RuntimeError(
            f"Cannot run YOLO model '{self.model_name}'. "
            f"Install ultralytics first: pip install ultralytics"
        )

    def predict(self, *args, **kwargs):
        raise RuntimeError(
            f"Cannot run YOLO model '{self.model_name}'. "
            f"Install ultralytics first: pip install ultralytics"
        )

    def __repr__(self):
        return f"DummyYOLO(model_name='{self.model_name}', installed=False)"

In [None]:
from dataclasses import dataclass

from torchvision.tv_tensors import Image, BoundingBoxFormat, BoundingBoxes
from torchvision.transforms.v2.functional import convert_bounding_box_format

from ultralytics.nn.tasks import DetectionModel
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator



#from ..base import BaseModel, ModelProvider, WeightsInfo
#from ...datasets import BaseDataset, DataPreparation

from ttadapters.models.base import BaseModel, ModelProvider, WeightsInfo
from ttadapters.datasets import BaseDataset, DataPreparation


@dataclass
class YOLOTrainerArguments:
    # Basic training params
    epochs: int = 100
    batch: int = 16

    # Optimizer params
    optimizer: str = "SGD"  # SGD, Adam, AdamW, auto
    lr0: float = 0.01  # initial learning rate
    lrf: float = 0.01  # final learning rate (lr0 * lrf)
    momentum: float = 0.937
    weight_decay: float = 0.0005
    warmup_epochs: int = 3.0
    warmup_momentum: float = 0.8
    warmup_bias_lr: float = 0.1

    # Loss params
    box: float = 7.5  # box loss gain
    cls: float = 0.5  # cls loss gain
    dfl: float = 1.5  # dfl loss gain

    # Validation params
    conf: float = 0.001  # confidence threshold
    iou: float = 0.7  # NMS IoU threshold

    # Other params
    amp: bool = True  # automatic mixed precision
    device: str = ""  # cuda device, e.g. 0 or 0,1,2,3 or cpu
    workers: int = 0  # number of worker threads
    project: str = "./results"  # project name
    name: str = "yolo11_training"  # experiment name
    exist_ok: bool = False  # overwrite existing experiment
    seed: int = 0  # random seed
    deterministic: bool = True  # deterministic mode
    single_cls: bool = False  # train as single-class dataset
    rect: bool = False  # rectangular training
    cos_lr: bool = False  # cosine learning rate scheduler
    close_mosaic: int = 10  # disable mosaic augmentation for final N epochs
    resume: bool = False  # resume training
    save: bool = True  # save checkpoints
    save_period: int = -1  # save checkpoint every N epochs (-1 = disabled)
    cache: bool = False  # cache images for faster training
    val: bool = True  # validate/test during training
    patience: int = 50  # early stopping patience (epochs without improvement)
    plots: bool = True  # save plots during training


class YOLOTrainer(DetectionTrainer):
    def __init__(
        self,
        model: BaseModel,
        classes: list[str],
        train_dataset: DataPreparation | None = None,
        eval_dataset: DataPreparation | None = None,
        args: YOLOTrainerArguments | None = None,
        **kwargs
    ):
        self.classes = classes
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.custom_args = args if args is not None else YOLOTrainerArguments()

        # Convert args to YOLO cfg format
        overrides = {k: v for k, v in vars(self.custom_args).items()}
        overrides['model'] = model

        # Initialize parent DetectionTrainer
        super().__init__(overrides=overrides)

    def get_dataset(self, dataset_path, mode="train", batch_size=None):
        if mode == 'train':
            return self.train_dataset
        else:
            return self.eval_dataset

    def build_dataset(self, img_path, mode="train", batch=None):
        if mode == 'train':
            return self.train_dataset
        else:
            return self.eval_dataset

    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
        dataset = self.train_dataset if mode == 'train' else self.eval_dataset

        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=(mode == 'train'),
            collate_fn=dataset.collate_fn,
            num_workers=self.custom_args.workers
        )

    def evaluate(self):
        """Evaluate using parent's validation"""
        metrics = self.validate()

        # Convert to our format
        results = {
            'mAP@0.50:0.95': metrics.box.map,
            'mAP@0.50': metrics.box.map50,
            'mAP@0.75': metrics.box.map75,
        }

        # Add per-class mAP if available
        if hasattr(metrics.box, 'maps'):
            for i, class_name in enumerate(self.classes):
                if i < len(metrics.box.maps):
                    results[f'{class_name}_mAP@0.50:0.95'] = metrics.box.maps[i]

        return results


class YOLODataPreparation(DataPreparation, DetectionValidator):
    def __init__(
        self,
        dataset: BaseDataset,
        dataset_key: dict = dict(bboxes="boxes2d", classes="boxes2d_classes", original_size="original_hw"),
        img_size: int = 640,
        evaluation_mode: bool = False,
        confidence_threshold: float = 0.001,
        iou_threshold: float = 0.7,
    ):
        from ultralytics.data.augment import v8_transforms

        self.dataset_name = dataset.dataset_name
        self.classes = dataset.classes

        self.dataset = dataset
        self.dataset_key = dataset_key
        self.img_size = img_size
        self.confidence_threshold = confidence_threshold
        self.iou_threshold = iou_threshold
        self.evaluation_mode = evaluation_mode

        # Use YOLO's pre-configured v8_transforms as augmentation
        self.augmentation = v8_transforms(
            dataset=dataset,
            imgsz=img_size,
            hyp=None,
            stretch=False
        )

    def __len__(self):
        return len(self.dataset)

    def transforms(self, *args, idx=None):
        """Apply transformations with bbox handling"""
        image, target = args[0] if len(args) == 1 else args

        bbox = target[self.dataset_key['bboxes']]
        bbox_classes = target[self.dataset_key['classes']]
        original_height, original_width = target[self.dataset_key['original_size']]

        # Convert to numpy for YOLO transforms (YOLO uses OpenCV/numpy internally)
        if isinstance(image, torch.Tensor):
            image = image.permute(1, 2, 0).numpy()  # CHW -> HWC

        # Convert bbox to numpy and ensure XYXY format
        if isinstance(bbox, BoundingBoxes):
            if bbox.format != BoundingBoxFormat.XYXY:
                bbox = convert_bounding_box_format(bbox, new_format=BoundingBoxFormat.XYXY)
            bbox = bbox.data.numpy() if isinstance(bbox.data, torch.Tensor) else bbox.data
        elif isinstance(bbox, torch.Tensor):
            bbox = bbox.numpy()

        # Convert bbox_classes to numpy
        if isinstance(bbox_classes, torch.Tensor):
            bbox_classes = bbox_classes.numpy()

        # Apply YOLO augmentation
        transformed = self.augmentation({
            'img': image,
            'bboxes': bbox,
            'cls': bbox_classes,
            'batch_idx': idx if idx is not None else 0
        })

        if len(args) == 1:
            return transformed
        else:
            return transformed['img'], transformed

    def __getitem__(self, idx):
        return self.transforms(self.dataset[idx], idx=idx)

    def collate_fn(self, batch):
        """Use YOLO's native collate function format"""
        images = []
        batch_idx = []
        cls = []
        bboxes = []

        for i, item in enumerate(batch):
            # Convert from numpy (HWC) to torch (CHW)
            img_tensor = torch.from_numpy(item['img']).permute(2, 0, 1)
            images.append(img_tensor)

            num_objects = len(item['bboxes'])
            batch_idx.extend([i] * num_objects)
            cls.extend(item['cls'].tolist())

            # Convert bboxes using BoundingBoxFormat
            boxes = item['bboxes']
            h, w = item['img'].shape[:2]

            if len(boxes) > 0:
                # Create BoundingBoxes with XYXY format
                boxes_tv = BoundingBoxes(
                    torch.from_numpy(boxes),
                    format=BoundingBoxFormat.XYXY,
                    canvas_size=(h, w)
                )

                # Convert to CXCYWH format
                boxes_cxcywh = convert_bounding_box_format(
                    boxes_tv,
                    new_format=BoundingBoxFormat.CXCYWH
                )

                # Normalize to [0, 1]
                boxes_normalized = boxes_cxcywh.clone()
                boxes_normalized[:, [0, 2]] /= w  # normalize cx, width
                boxes_normalized[:, [1, 3]] /= h  # normalize cy, height

                bboxes.extend(boxes_normalized.tolist())

        # Stack images and normalize
        images_tensor = torch.stack(images) / 255.0  # Normalize to [0, 1]

        # Create tensors
        if len(bboxes) > 0:
            batch_idx_tensor = torch.tensor(batch_idx, dtype=torch.long)
            cls_tensor = torch.tensor(cls, dtype=torch.long)
            bboxes_tensor = torch.tensor(bboxes, dtype=torch.float32)
        else:
            batch_idx_tensor = torch.zeros(0, dtype=torch.long)
            cls_tensor = torch.zeros(0, dtype=torch.long)
            bboxes_tensor = torch.zeros((0, 4), dtype=torch.float32)

        return {
            'img': images_tensor,
            'batch_idx': batch_idx_tensor,
            'cls': cls_tensor,
            'bboxes': bboxes_tensor,
        }

    def pre_process(self, batch):
        """Pre-process is handled by collate_fn"""
        return batch

    def post_process(self, outputs, target_sizes=None):
        """Use YOLO's native post-processing"""
        from ultralytics.utils.ops import non_max_suppression

        # Apply NMS
        predictions = non_max_suppression(
            outputs,
            conf_thres=self.confidence_threshold,
            iou_thres=self.iou_threshold,
            multi_label=False,
            max_det=300
        )

        return predictions


class YOLO11ForObjectDetection(DetectionModel, BaseModel):
    model_name = "YOLO11m"
    model_config = "yolo11m.yaml"
    model_provider = ModelProvider.Ultralytics
    DataPreparation = YOLODataPreparation
    Trainer = YOLOTrainer
    TrainingArguments = YOLOTrainerArguments
    channel = 3

    class Weights:
        COCO_OFFICIAL = WeightsInfo("https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11m.pt", weight_key="model")
        SHIFT_CLEAR = WeightsInfo("")
        CITYSCAPES = WeightsInfo("")

    def __init__(self, dataset: BaseDataset):
        nc = len(dataset.classes)
        super().__init__(self.model_config, ch=self.channel, nc=nc)

        self.dataset_name = dataset.dataset_name
        self.num_classes = nc

In [None]:
#TEST_MODE = False
#MODEL_TYPE = "yolo11"

In [None]:
# Initialize model
match MODEL_TYPE:
    case "rcnn":
        model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_OFFICIAL if not TEST_MODE else model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case "swinrcnn":
        model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_XIAOHU2015 if not TEST_MODE else model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case "yolo11":
        model = YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        #model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_OFFICIAL if not TEST_MODE else model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case "rtdetr":
        DATA_TYPE = torch.bfloat16
        model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_OFFICIAL if not TEST_MODE else model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
model.to(device)

## Train

In [None]:
# Project Setup
PROJECT_NAME = "tta_model_pretraining"
RUN_NAME = model.model_name + "_" + SOURCE_DOMAIN.dataset_name + ("_test" if TEST_MODE else "_train")

In [None]:
# WandB Initialization
import wandb
wandb.init(project=PROJECT_NAME, name=RUN_NAME)

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 20
LEARNING_RATE = 1e-4

In [None]:
evaluate_source = lambda: None
evaluate_target = lambda: None

### Detectron Trainer

In [None]:
# Define Trainer & Validator
if not TEST_MODE and MODEL_TYPE in ("rcnn", "swinrcnn"):
    ALL_DEVICE_BATCH = BATCH_SIZE[0]*(ADDITIONAL_GPU+1), BATCH_SIZE[1]*(ADDITIONAL_GPU+1)
    trainer = model.Trainer(
        model=model,
        classes=CLASSES,
        train_dataset=model.DataPreparation(dataset.train),
        eval_dataset=model.DataPreparation(dataset.valid, evaluation_mode=True),
        args=model.TrainingArguments(
            learning_rate=LEARNING_RATE,
            total_steps=EPOCHS*10*len(dataset.train)//ALL_DEVICE_BATCH[0],
            eval_period=100,
            save_period=100,
            train_batch_for_total=ALL_DEVICE_BATCH[0],
            eval_batch_for_total=ALL_DEVICE_BATCH[1],
            multiple_gpu_world_size=ADDITIONAL_GPU+1 if ADDITIONAL_GPU > 0 else ADDITIONAL_GPU,  # Set 0 to disable multi-GPU reference
            momentum=0.9,
            weight_decay=1e-4,
            lr_scheduler_type="WarmupCosineLR",  # WarmupMultiStepLR, WarmupStepWithFixedGammaLR
            cosine_lr_final=LEARNING_RATE/10,
            lr_warmup_method="linear",
            lr_warmup_iters=500,
            use_amp=False,
            output_dir="./results/"+RUN_NAME
        )
    )

    evaluator = model.Trainer(
        model=model,
        classes=CLASSES,
        eval_dataset=model.DataPreparation(dataset.test, evaluation_mode=True),
        args=model.TrainingArguments(
            learning_rate=LEARNING_RATE,
            total_steps=1,
            eval_batch_for_total=BATCH_SIZE[1]*(ADDITIONAL_GPU+1),
            multiple_gpu_world_size=ADDITIONAL_GPU+1 if ADDITIONAL_GPU > 0 else ADDITIONAL_GPU,  # Set 0 to disable multi-GPU reference
            use_amp=False,
            output_dir="./results/"+RUN_NAME
        )
    )

    evaluate_source = trainer.test
    evaluate_target = evaluator.test

### Ultralytics Trainer

### Transformers Trainer

In [None]:
# Define Trainer & Validator
if not TEST_MODE and MODEL_TYPE == "rtdetr":
    trainer = model.Trainer(
        model=model,
        classes=CLASSES,
        train_dataset=model.DataPreparation(dataset.train),
        eval_dataset=model.DataPreparation(dataset.valid, evaluation_mode=True),
        args=model.TrainingArguments(
            backbone_learning_rate=LEARNING_RATE/10,  # Set backbone learning rate to 1/10th of the main learning rate
            learning_rate=LEARNING_RATE,
            lr_scheduler_type="cosine",
            warmup_ratio=0.1,
            weight_decay=0.1,
            max_grad_norm=0.5,
            num_train_epochs=EPOCHS,
            per_device_train_batch_size=BATCH_SIZE[0],
            per_device_eval_batch_size=BATCH_SIZE[1],
            gradient_accumulation_steps=ACCUMULATE_STEPS,
            eval_accumulation_steps=BATCH_SIZE[1],
            batch_eval_metrics=True,
            remove_unused_columns=False,
            optim="adamw_torch",
            eval_on_start=True,
            eval_strategy="steps",
            save_strategy="steps",
            logging_strategy="steps",
            eval_steps=100,
            save_steps=100,
            logging_steps=100,
            save_total_limit=100,
            load_best_model_at_end=True,
            metric_for_best_model="mAP@0.50:0.95",
            greater_is_better=True,
            report_to="wandb",
            output_dir="./results/"+RUN_NAME,
            logging_dir="./logs/"+RUN_NAME,
            run_name=RUN_NAME,
            bf16=True if DATA_TYPE == torch.bfloat16 else False
        )
    )

    evaluator = model.Trainer(
        model=model,
        classes=CLASSES,
        eval_dataset=model.DataPreparation(dataset.test, evaluation_mode=True),
        args=model.TrainingArguments(
            per_device_eval_batch_size=BATCH_SIZE[1],
            batch_eval_metrics=True,
            remove_unused_columns=False,
            bf16=True if DATA_TYPE == torch.bfloat16 else False
        )
    )

    evaluate_source = trainer.evaluate
    evaluate_target = evaluator.evaluate

### Run Jobs

In [None]:
# Do train for source domain
if not TEST_MODE:
    if MODEL_TYPE in ("rcnn", "swinrcnn"):
        trainer.resume_or_load(resume=True)
        trainer.train()
    elif MODEL_TYPE == "yolo11":
        pass
    elif MODEL_TYPE == "rtdetr":
        try:
            trainer.train(resume_from_checkpoint=True)
        except (FileNotFoundError, ValueError):
            trainer.train()

In [None]:
# Do eval for source domain
evaluate_source()

In [None]:
# Do eval for target domain
evaluate_target()

In [None]:
# Model save
if not TEST_MODE:
    model.save_to(version=RUN_NAME)

## Evaluation

In [None]:
# Set model eval mode
model.eval()

### Load Scenarios

|Model|Dataset|Metric|
|---|---|---|
|rcnn|shift|mAP@0.50:0.95|
|swinrcnn|shift|mAP@0.50:0.95|
|yolo11|shift|mAP@0.50:0.95|
|rtdetr|shift|mAP@0.50:0.95|

In [None]:
# Ensure split (required due to Scenario class works with coroutines)
_ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT, train=True)

In [None]:
data_preparation = model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

discrete_scenario = scenarios.SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
)
continuous_scenario = scenarios.SHIFTContinuousScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTContinuousScenario.DEFAULT, transforms=data_preparation.transforms
)

In [None]:
evaluator = validator.DetectionEvaluator(model, classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device, no_grad=True)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
visualizer.visualize_metrics(discrete_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))

In [None]:
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))