# TTA Example

## Imports and Configs

In [1]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch
from torchinfo import summary

from ttadapters import datasets, models, methods
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, scenarios

In [2]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

### Parse Arguments

In [3]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[0]

In [4]:
# Create argument parser
parser = ArgumentParser(description="Adaptation experiment script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args([])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityScapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")

INFO: Running in notebook mode with default arguments
INFO: Set batch size - Train: 2, Valid: 8, Test: 1


### Check GPU Availability

In [5]:
!nvidia-smi

Wed Nov  5 09:02:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94                 Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   53C    P8              1W /   78W |       0MiB /   6141MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [6]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

INFO: Using device - cuda:0
INFO: Using data precision - torch.float32


## Define Dataset

In [7]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()



In [8]:
# Dataset info
CLASSES = datasets.SHIFTClearDatasetForObjectDetection.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

INFO: Number of classes - 6 ['pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle']


## Load Base Model

In [9]:
# Initialize base_model
match MODEL_TYPE:
    case "rcnn":
        base_model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "swinrcnn":
        base_model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "yolo11":
        DATA_TYPE = torch.bfloat16  # bf16 default
        base_model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "rtdetr":
        DATA_TYPE = torch.bfloat16  # bf16 default
        base_model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
base_model.to(device)

INFO: Model state loaded - <All keys matched successfully>


FasterRCNNForObjectDetection(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
      (res2): Sequential(
        (0): Bo

In [10]:
summary(base_model)

Layer (type:depth-idx)                                  Param #
FasterRCNNForObjectDetection                            --
├─FPN: 1-1                                              --
│    └─Conv2d: 2-1                                      65,792
│    └─Conv2d: 2-2                                      590,080
│    └─Conv2d: 2-3                                      131,328
│    └─Conv2d: 2-4                                      590,080
│    └─Conv2d: 2-5                                      262,400
│    └─Conv2d: 2-6                                      590,080
│    └─Conv2d: 2-7                                      524,544
│    └─Conv2d: 2-8                                      590,080
│    └─LastLevelMaxPool: 2-9                            --
│    └─ResNet: 2-10                                     --
│    │    └─BasicStem: 3-1                              (9,408)
│    │    └─Sequential: 3-2                             (212,992)
│    │    └─Sequential: 3-3                             1,2

## Load Adaptation Method

In [11]:
# Method configuration
adaptive_config = methods.APTConfig(
    # Optimization
    optim="SGD",
    adapt_lr=1e-5,
    backbone_lr=1e-6,
    head_lr=1e-6,

    # Tracking
    max_age=3,
    min_hits=1,
    iou_threshold=0.8,

    # Loss
    loss_type="smooth_l1",
    loss_weight=1.0,
    use_confidence_weighting=True,
    conf_threshold=0.7,
    min_confidence_for_update=0.3,

    # Update strategy
    update_backbone=False,
    update_head=False,
    update_bn=True,
    update_fpn_last_layer=False,
    update_box_regressor_last_layer=False,

    # Memory & Stabilization
    buffer_size=500,
    loss_ema_decay=0.9
)

In [12]:
# Initialize method
adaptive_model = methods.APTEngine(base_model, adaptive_config)
adaptive_model.to(device)

APTEngine(
  (base_model): FasterRCNNForObjectDetection(
    (backbone): FPN(
      (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (top_block): LastLevelMaxPool()
      (bottom_up): ResNet(
        (stem): BasicStem(
          (conv1): Conv2d(
            3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
            (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
  

## Evaluation

In [13]:
# Load Pretrained APT Weights & Un-Freeze Model Encoder
# Allow FPN/Encoder to adapt during online adaptation
base_model.eval()
adaptive_model.online()
summary(adaptive_model)

Layer (type:depth-idx)                                       Param #
APTEngine                                                    --
├─FasterRCNNForObjectDetection: 1-1                          --
│    └─FPN: 2-1                                              --
│    │    └─Conv2d: 3-1                                      (65,792)
│    │    └─Conv2d: 3-2                                      (590,080)
│    │    └─Conv2d: 3-3                                      (131,328)
│    │    └─Conv2d: 3-4                                      (590,080)
│    │    └─Conv2d: 3-5                                      (262,400)
│    │    └─Conv2d: 3-6                                      (590,080)
│    │    └─Conv2d: 3-7                                      (524,544)
│    │    └─Conv2d: 3-8                                      (590,080)
│    │    └─LastLevelMaxPool: 3-9                            --
│    │    └─ResNet: 3-10                                     23,508,032
│    └─RPN: 2-2                     

### Load Scenarios

In [14]:
# Ensure split (required due to Scenario class works with coroutines)
_ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT, train=True)

[11/05/2025 09:02:36] SHIFT DevKit - INFO - Base: .\data\SHIFT\continuous\images\1x\train. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:36] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\1x\train\front\det_2d.json' ...


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\continuous/1x...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:36] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\daytime_to_night\continuous\images\1x\train\front\det_2d.json' Done.
[11/05/2025 09:02:37] SHIFT DevKit - INFO - Loading annotation takes 0.85 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0039-134e']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                  -191.55      57.56
boxes2d              torch.Size([1, 3, 4])                   112.00     494.00
boxes2d_classes      torch.Size([1, 3])                        0.00       0.00
boxes2d_track_ids    torch.Size([1, 3])                        0.00       2.00
images               torch.Size([1, 3, 800, 1280])             0.00     255.00

Video name: 0039-134e
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

In [15]:
data_preparation = base_model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        discrete_scenario = scenarios.SHIFTDiscreteScenario(
            root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
        )
    case datasets.CityScapesDataset:
        discrete_scenario = None
        continuous_scenario = None
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

[11/05/2025 09:02:37] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:37] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\cloudy_daytime\discrete\images\val\front\det_2d.json' ...


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:37] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\cloudy_daytime\discrete\images\val\front\det_2d.json' Done.
[11/05/2025 09:02:38] SHIFT DevKit - INFO - Loading annotation takes 0.39 seconds.
[11/05/2025 09:02:38] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:38] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\overcast_daytime\discrete\images\val\front\det_2d.json' ...
[11/05/2025 09:02:38] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\overcast_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:38] SHIFT DevKit - INFO - Loading annotation takes 0.44 seconds.
[11/05/2025 09:02:38] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:38] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\foggy_daytime\discrete\images\val\front\det_2d.json' ...
[11/05/2025 09:02:38] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\foggy_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:39] SHIFT DevKit - INFO - Loading annotation takes 0.54 seconds.
[11/05/2025 09:02:39] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:39] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\rainy_daytime\discrete\images\val\front\det_2d.json' ...
[11/05/2025 09:02:39] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\rainy_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation takes 0.89 seconds.
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_dawn\discrete\images\val\front\det_2d.json' ...
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_dawn\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation takes 0.24 seconds.
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_night\discrete\images\val\front\det_2d.json' ...
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_night\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation takes 0.36 seconds.
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Base: .\data\SHIFT\discrete\images\val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x000001AAFFCA7140>
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_daytime\discrete\images\val\front\det_2d.json' ...
[11/05/2025 09:02:40] SHIFT DevKit - INFO - Loading annotation from '.\data\SHIFT_SUBSET\clear_daytime\discrete\images\val\front\det_2d.json' Done.


INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to .\data\SHIFT\discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[11/05/2025 09:02:41] SHIFT DevKit - INFO - Loading annotation takes 0.59 seconds.


In [16]:
methods = {
    #'Direct-Test': base_model,
    adaptive_model.model_name: adaptive_model
}

In [17]:
evaluator = validator.DetectionEvaluator(list(methods.values()), classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device, no_grad=False)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
visualizer.visualize_metrics(discrete_scenario(**evaluator_loader_params).play(evaluator, index=methods.keys()))

Output()

SHIFT Discrete Scenario:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluation for cloudy_daytime:   0%|          | 0/2400 [00:00<?, ?it/s]

W1105 09:02:43.047000 77760 .venv\Lib\site-packages\torch\fx\_symbolic_trace.py:52] is_fx_tracing will return true for both fx.symbolic_trace and torch.export. Please use is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace or torch.compiler.is_compiling() for specifically torch.export/compile.
