# Enhanced APT Example with Ideas A, B, C, D

## Improvements Overview

**Idea A: Loss-based Adaptation Control**
- Loss spike detection for domain change → optimizer reset
- Loss threshold for skipping bad updates
- Loss history tracking for statistical analysis

**Idea B: Extended Adaptation Scope**
- Conv layers before BatchNorm
- MLP layers after LayerNorm/BatchNorm
- Better feature adaptation capability

**Idea C: Gradient Scaling**
- Inverse relationship with loss magnitude
- Small losses → larger gradients (fine-tuning)
- Large losses → smaller gradients (stability)

**Idea D: BN Statistics Update**
- Running mean/var updated via backprop
- Better adaptation to distribution shift

## Imports and Configs

In [None]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch
from torchinfo import summary

from ttadapters import datasets, models, methods
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, scenarios

In [None]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

### Parse Arguments

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[0]

In [None]:
# Create argument parser
parser = ArgumentParser(description="Enhanced APT experiment script")

parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE)
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0])
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1])
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS)
parser.add_argument("--data-root", type=str, default=DATA_ROOT)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--additional_gpu", type=int, default=0)
parser.add_argument("--use-bf16", action="store_true")

if "ipykernel" in sys.modules:
    args = parser.parse_args([])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityScapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")

### Check GPU

In [None]:
!nvidia-smi

In [None]:
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

## Define Dataset

In [None]:
datasets.patch_fast_download_for_object_detection()

In [None]:
CLASSES = datasets.SHIFTClearDatasetForObjectDetection.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

## Load Base Model

In [None]:
match MODEL_TYPE:
    case "rcnn":
        base_model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "swinrcnn":
        base_model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "yolo11":
        DATA_TYPE = torch.bfloat16
        base_model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case "rtdetr":
        DATA_TYPE = torch.bfloat16
        base_model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == datasets.SHIFTDataset else base_model.Weights.CITYSCAPES), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
base_model.to(device)

## Load Enhanced APT Method

In [None]:
adaptive_config = methods.APTConfig.with_idea_B()

In [None]:
adaptive_config.adaptation_name

In [None]:
# Initialize enhanced APT engine
adaptive_model = methods.APTEngine(base_model, adaptive_config)
adaptive_model.to(device)
print(f"\nModel: {adaptive_model.model_name}")
print(f"Number of adaptable parameters: {sum(p.numel() for p in adaptive_model.online_parameters())}")

## Evaluation

In [None]:
base_model.eval()
adaptive_model.online()
summary(adaptive_model)

### Load Scenarios

In [None]:
_ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT, train=True)

In [None]:
data_preparation = base_model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        discrete_scenario = scenarios.SHIFTDiscreteScenario(
            root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
        )
    case datasets.CityScapesDataset:
        discrete_scenario = None
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

In [None]:
methods_dict = {
    #'Direct-Test': base_model,
    adaptive_model.model_name: adaptive_model
}

In [None]:
evaluator = validator.DetectionEvaluator(
    list(methods_dict.values()), 
    classes=CLASSES, 
    data_preparation=data_preparation, 
    dtype=DATA_TYPE, 
    device=device, 
    no_grad=False
)
evaluator_loader_params = dict(
    batch_size=BATCH_SIZE[2], 
    shuffle=False, 
    collate_fn=data_preparation.collate_fn
)

In [None]:
print("Starting evaluation with enhanced APT...")
print("Monitoring: domain changes, skipped updates, loss statistics\n")

results = visualizer.visualize_metrics(
    discrete_scenario(**evaluator_loader_params).play(evaluator, index=methods_dict.keys())
)

In [None]:
print("Starting evaluation with enhanced APT...")
print("Monitoring: domain changes, skipped updates, loss statistics\n")

results = visualizer.visualize_metrics(
    discrete_scenario(**evaluator_loader_params).play(evaluator, index=methods_dict.keys())
)

In [None]:
print("Starting evaluation with enhanced APT...")
print("Monitoring: domain changes, skipped updates, loss statistics\n")

results = visualizer.visualize_metrics(
    discrete_scenario(**evaluator_loader_params).play(evaluator, index=methods_dict.keys())
)

### Check Adaptation Statistics

In [None]:
stats = adaptive_model.get_adaptation_stats()
print("\nAdaptation Statistics:")
print("=" * 60)
for key, value in stats.items():
    if isinstance(value, float):
        print(f"{key:30s}: {value:10.4f}")
    else:
        print(f"{key:30s}: {value:10d}")
print("=" * 60)

## Analysis and Comparison

Key metrics to observe:
1. **Domain changes detected**: Should increase when conditions shift dramatically
2. **Skipped updates**: Shows robustness against bad gradients
3. **Loss statistics**: Mean/std/min/max show adaptation behavior
4. **mAP improvements**: Overall performance gain from adaptation