# TTA Example

## Imports and Configs

In [None]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch
from torch import nn, optim
from torchinfo import summary

from ttadapters import datasets, models, methods
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, DataLoaderHolder, scenarios

In [None]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

### Parse Arguments

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = False

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[-1]

In [None]:
# Create argument parser
parser = ArgumentParser(description="Adaptation experiment script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")
parser.add_argument("--test-only", action="store_true", help="Run in test-only mode")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args(["--test-only"] if TEST_MODE else [])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
TEST_MODE = args.test_only
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityscapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Set test mode - {TEST_MODE} for {SOURCE_DOMAIN.dataset_name} dataset")

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

## Define Dataset

In [None]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()

In [None]:
# Benchmark dataset
#match SOURCE_DOMAIN:
#    case datasets.SHIFTDataset:
#        discrete_dataset = DatasetHolder(
#            train=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
#            valid=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
#            test=datasets.SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
#        )
#        continuous_dataset = DatasetHolder(
#            train=datasets.SHIFTContinuous100DatasetForObjectDetection(root=DATA_ROOT),
#            valid=datasets.SHIFTContinuous10DatasetForObjectDetection(root=DATA_ROOT),
#            test=datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT)
#        )
#        dataset = continuous_dataset
#    case datasets.CityscapesDataset:
#        pass
#    case _:
#        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = ["pedestrian", "car", "truck", "bus", "motorcycle", "bicycle"]
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
# Check annotation keys-values
#dataset.test[999]

In [None]:
# Check data shape
#dataset.test[999][0].shape  # should be (num_channels, height, width)

In [None]:
# Visualize video
#visualizer.visualize_bbox_frames(dataset.test)

## Load Base Model

In [None]:
# Initialize model
match MODEL_TYPE:
    case "rcnn":
        base_model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO), strict=False)
    case "swinrcnn":
        base_model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO), strict=False)
    case "yolo11":
        base_model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR), strict=False)
    case "rtdetr":
        base_model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
base_model.to(device)

In [None]:
summary(base_model)

In [None]:
# Compile model
#base_model = torch.compile(base_model)

## Load Adaptation Method

In [None]:
from ttadapters.methods.regularizers.plugin.apt.filter_detr import AdaptiveKalmanFilterCXCYWH, initialize_state_cxcywh
from scipy.optimize import linear_sum_assignment

In [None]:
class APTKalmanFilteredLoss(nn.Module):
    def __init__(self, learnable_uncertainty: bool = True):
        super().__init__()
        self.kalman_filter = AdaptiveKalmanFilterCXCYWH(learnable_uncertainty=learnable_uncertainty)

    def forward(
        self,
        prev_boxes: torch.Tensor,   # [batch, num_queries, 4] normalized [cx, cy, w, h]
        prev_logits: torch.Tensor,  # [batch, num_queries, num_classes]
        curr_boxes: torch.Tensor,
        track_states: dict | None = None,
        image_sizes: torch.Tensor | None = None  # [batch, 2] for denormalization
    ):
        """
        Args:
            prev_boxes: [batch, num_queries, 4] normalized [cx, cy, w, h] in [0, 1]
            prev_logits: [batch, num_queries, num_classes]
            curr_boxes: ground truth boxes
            track_states: dict of {query_idx: (mean, covariance)}
            image_sizes: [batch, 2] tensor of [height, width] for denormalization
        """
        batch_size, num_queries = prev_boxes.shape[:2]
        device = prev_boxes.device

        if track_states is None:
            track_states = {}

        if image_sizes is None:
            image_sizes = torch.tensor([[800, 1280]] * batch_size, device=device)

        # Apply Kalman filter to each prediction
        filtered_boxes = []
        new_track_states = {}

        for b in range(batch_size):
            batch_filtered = []
            h, w = image_sizes[b]

            for q in range(num_queries):
                # Get prediction (normalized)
                pred_box_norm = prev_boxes[b, q]  # [4] [cx, cy, w, h] in [0, 1]

                # Denormalize
                cx = pred_box_norm[0] * w
                cy = pred_box_norm[1] * h
                box_w = pred_box_norm[2] * w
                box_h = pred_box_norm[3] * h
                pred_box_abs = torch.stack([cx, cy, box_w, box_h])

                # Get confidence
                pred_logit = prev_logits[b, q]  # [num_classes]
                confidence = pred_logit.softmax(-1).max()

                # Get or initialize track state
                track_key = f"{b}_{q}"
                if track_key not in track_states:
                    mean, cov = initialize_state_cxcywh(pred_box_abs)
                else:
                    mean, cov = track_states[track_key]

                # Prepare measurement
                measurement = pred_box_abs.unsqueeze(-1)  # [4, 1]

                # Apply Kalman filter
                mean_new, cov_new = self.kalman_filter(
                    mean, cov, measurement, confidence
                )

                # Extract filtered bbox
                bbox_filtered_abs = mean_new[:4, 0]  # [4] [cx, cy, w, h]

                # Normalize back
                cx_norm = bbox_filtered_abs[0] / w
                cy_norm = bbox_filtered_abs[1] / h
                w_norm = bbox_filtered_abs[2] / w
                h_norm = bbox_filtered_abs[3] / h
                bbox_filtered_norm = torch.stack([cx_norm, cy_norm, w_norm, h_norm])

                batch_filtered.append(bbox_filtered_norm)
                new_track_states[track_key] = (mean_new.detach(), cov_new.detach())

            filtered_boxes.append(torch.stack(batch_filtered))

        filtered_boxes = torch.stack(filtered_boxes)  # [batch, num_queries, 4]

        # Now compute loss with filtered boxes
        loss = self.compute_loss(student_boxes=curr_boxes, teacher_boxes=filtered_boxes)

        return loss, new_track_states

    def compute_loss(self, student_boxes, teacher_boxes):
        """
        student_boxes: [batch, num_queries, 4]
        teacher_boxes: [batch, num_queries, 4]  (Kalman filtered)
        """
        batch_size = student_boxes.shape[0]
        batch_losses = []

        for b in range(batch_size):
            teacher_b = teacher_boxes[b]  # [num_queries, 4]
            student_b = student_boxes[b]  # [num_queries, 4]

            # Remove Padding
            valid_teacher = (teacher_b.abs().sum(dim=-1) > 1e-6)
            valid_student = (student_b.abs().sum(dim=-1) > 1e-6)

            teacher_b = teacher_b[valid_teacher]
            student_b = student_b[valid_student]

            if len(teacher_b) == 0 or len(student_b) == 0:
                continue

            with torch.no_grad():
                # Cost matrix (2D!)
                cost_matrix = torch.cdist(teacher_b, student_b, p=1.0)
                # Hungarian matching (NumPy 2D array)
                row_indices, col_indices = linear_sum_assignment(cost_matrix.cpu().numpy())

            # Matched costs
            matched_costs = nn.functional.l1_loss(student_b[col_indices], teacher_b[row_indices])
            #print("matched costs", matched_costs.grad_fn)

            # Distance threshold
            THRESHOLD = 50.0
            valid_matches = matched_costs < THRESHOLD

            if valid_matches.sum() > 0:
                batch_losses.append(matched_costs[valid_matches].mean())

        # Average across batch
        if len(batch_losses) > 0:
            return torch.stack(batch_losses).mean()
        else:
            # No valid matches - return 0 with gradient
            return torch.tensor(0.0, device=student_boxes.device, requires_grad=True)

In [None]:
import time

In [None]:
class RTDetrForObjectDetectionWithAPT(models.RTDetrForObjectDetection):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._apt_loss_function = None
        self._optimizer = None
        self.adapting = False
        self.apt_lr = 1e-4

        self.bbox_cache = None
        self.logit_cache = None
        self.track_states = None
        self.adapt_step = 0
        self.skip_step = 0

    @property
    def apt_loss_function(self):
        if self._apt_loss_function is None:
            self._apt_loss_function = APTKalmanFilteredLoss().to(self.device)
        return self._apt_loss_function

    def online_parameters(self):
        return self.model.encoder.parameters()

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._optimizer = optim.AdamW(self.online_parameters(), lr=self.apt_lr)
        return self._optimizer

    def online(self, mode=True):
        self.adapting = mode
        if mode:
            self.eval()
            for param in self.parameters():
                param.requires_grad = False
            for param in self.online_parameters():
                param.requires_grad = True
        else:
            self.train()
            for param in self.parameters():
                param.requires_grad = True
        return self

    def offline(self):
        return self.online(False)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        pixel_mask: torch.LongTensor | None = None,
        encoder_outputs: torch.FloatTensor | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        decoder_inputs_embeds: torch.FloatTensor | None = None,
        labels: list[dict] | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        **kwargs,
    ):
        t_start = time.time()
        result = super().forward(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs,
        )
        t_forward = time.time() - t_start

        if self.bbox_cache is not None and self.adapting:
            t_start = time.time()
            apt_loss, self.track_states = self.apt_loss_function(
                prev_boxes=self.bbox_cache, prev_logits=self.logit_cache,
                curr_boxes=result.pred_boxes, track_states=self.track_states
            )
            t_loss_compute = time.time() - t_start
            print(f"INFO: APT Loss - {apt_loss.item():.6f}, Track States: {len(self.track_states)}, Adapt Step: {self.adapt_step}, Skip Step: {self.skip_step}")
            t_start = time.time()
            if apt_loss.item() > 0:# and apt_loss.grad_fn is not None:
                optimizer = self.optimizer
                optimizer.zero_grad()
                apt_loss.backward()
                optimizer.step()
                self.adapt_step += 1
            t_backward = time.time() - t_start
            if self.adapt_step % 100 == 0:
                print(f"\n=== Timing (step {self.adapt_step}) ===")
                print(f"Forward:       {t_forward*1000:.2f} ms")
                print(f"APT Loss:      {t_loss_compute*1000:.2f} ms")
                print(f"Backward+Step: {t_backward*1000:.2f} ms")
                print(f"Total:         {(t_forward+t_loss_compute+t_backward)*1000:.2f} ms")
        else:
            self.skip_step += 1

        with torch.no_grad():
            self.bbox_cache = result.pred_boxes.detach().clone()
            self.logit_cache = result.logits.detach().clone()
        return result

In [None]:
adaptive_model = RTDetrForObjectDetectionWithAPT(dataset=SOURCE_DOMAIN)
load_result = adaptive_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR), strict=False)

## Evaluation

### Load Scenarios

In [None]:
base_model.eval()
adaptive_model.online()

In [None]:
data_preparation = base_model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

discrete_scenario = scenarios.SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True, order=[datasets.SHIFTDiscreteSubsetForObjectDetection.SubsetType.OVERCAST_DAYTIME], transforms=data_preparation.transforms
)
# continuous_scenario = scenarios.SHIFTContinuousScenario(
#     root=DATA_ROOT, valid=True, order=scenarios.SHIFTContinuousScenario.DEFAULT, transforms=data_preparation.transforms
# )

In [None]:
evaluator = validator.DetectionEvaluator([adaptive_model], classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device, no_grad=False)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
visualizer.visualize_metrics(discrete_scenario(**evaluator_loader_params).play(evaluator, index=["APT-Kalman"]))

In [None]:
#visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test", "APT-Kalman"]))