# TTA Example

## Imports and Configs

In [None]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch
from torch import nn, optim
from torchinfo import summary

from ttadapters import datasets, models, methods
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, scenarios

In [None]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

### Parse Arguments

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = False

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[-1]

In [None]:
# Create argument parser
parser = ArgumentParser(description="Adaptation experiment script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")
parser.add_argument("--test-only", action="store_true", help="Run in test-only mode")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args(["--test-only"] if TEST_MODE else [])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
TEST_MODE = args.test_only
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityScapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Set test mode - {TEST_MODE} for {SOURCE_DOMAIN.dataset_name} dataset")

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

## Define Dataset

In [None]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()

In [None]:
# Basic pre-training dataset
match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        # discrete
        dataset = DatasetHolder(
            train=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
        # continuous
        _ = datasets.SHIFTContinuous100DatasetForObjectDetection(root=DATA_ROOT)  # 100
        _ = datasets.SHIFTContinuous10DatasetForObjectDetection(root=DATA_ROOT)  # 10
        _ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT)  # 1 + split
    case datasets.CityScapesDataset:
        dataset = DatasetHolder(
            train=datasets.CityScapesDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.CityScapesDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.CityScapesCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = dataset.test.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
# Check annotation keys-values
dataset.test[999]

In [None]:
# Check data shape
dataset.test[999][0].shape  # should be (num_channels, height, width)

In [None]:
# Visualize video
visualizer.visualize_bbox_frames(dataset.test)

## Load Base Model

In [None]:
# Initialize base_model
match MODEL_TYPE:
    case "rcnn":
        base_model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.COCO_OFFICIAL if not TEST_MODE else base_model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "swinrcnn":
        base_model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.COCO_XIAOHU2015 if not TEST_MODE else base_model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "yolo11":
        DATA_TYPE = torch.bfloat16  # bf16 default
        base_model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.COCO_OFFICIAL if not TEST_MODE else base_model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "rtdetr":
        DATA_TYPE = torch.bfloat16  # bf16 default
        base_model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.COCO_OFFICIAL if not TEST_MODE else base_model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
base_model.to(device)

In [None]:
summary(base_model)

## Load Adaptation Method

In [None]:
# Initialize Model
adaptive_config = APTConfig()
adaptive_model = APTPlugin(base_model, adaptive_config)
adaptive_model.to(device)

## Evaluation

In [None]:
# Load Pretrained APT Weights & Un-Freeze Model Encoder
# Allow FPN/Encoder to adapt during online adaptation
base_model.eval()
adaptive_model.online()

### Load Scenarios

In [None]:
# Ensure split (required due to Scenario class works with coroutines)
_ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT, train=True)

In [None]:
data_preparation = base_model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        discrete_scenario = scenarios.SHIFTDiscreteScenario(
            root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
        )
        continuous_scenario = scenarios.SHIFTContinuousScenario(
            root=DATA_ROOT, valid=True, order=scenarios.SHIFTContinuousScenario.DEFAULT, transforms=data_preparation.transforms
        )
    case datasets.CityScapesDataset:
        discrete_scenario = None
        continuous_scenario = None
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

In [None]:
methods = {
    'Direct-Test': base_model,
    'Adaptive-Test': adaptive_model
}

In [None]:
evaluator = validator.DetectionEvaluator(list(methods.values()), classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device, no_grad=False)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
visualizer.visualize_metrics(discrete_scenario(**evaluator_loader_params).play(evaluator, index=methods.keys()))

In [None]:
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=methods.keys()))