In [None]:
import os
os.chdir("/home/elicer/ptta") # os.chdir("/home/ubuntu/test-time-adapters")

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_NUM)

In [None]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch

from ttadapters import datasets, models
from ttadapters.models.base import ModelProvider
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, scenarios

from ttadapters.methods.other_method.baseline import ActMADConfig, ActMAD, NORMConfig, NORM, DUAConfig, DUA, MeanTeacherConfig, MeanTeacher, WHWConfig, WHW
from ttadapters.methods.other_method.our_method import Ours, OursConfig


In [None]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = True

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[0]

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0
ADDITIONAL_GPU = 0
DATA_TYPE = torch.float32

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

In [None]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()

In [None]:
# Basic pre-training dataset
match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        dataset = DatasetHolder(
            train=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
    case datasets.CityscapesDataset:
        pass
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
# Initialize model
match MODEL_TYPE:
    case "rcnn":
        model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR_NATUREYOO if TEST_MODE else model.Weights.IMAGENET_OFFICIAL), strict=False)
    case "swinrcnn":
        model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR_NATUREYOO if TEST_MODE else model.Weights.IMAGENET_XIAOHU2015), strict=False)
    case "rtdetr":
        model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.SHIFT_CLEAR if TEST_MODE else model.Weights.COCO_OFFICIAL), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
model.to(device)

In [None]:
# config = OursConfig(
#     # ============ Model & Device ============
#     model_type="rcnn",
#     data_root='./data',
#     device=torch.device("cuda"),
#     batch_size=1,  # Must be 1 for tracking

#     # ============ Adaptation Layers ============
#     adapt_bn=True,     # ✅ BN/LayerNorm 학습
#     adapt_conv=False,  # Conv는 BN만으로 충분
#     adapt_linear=False,

#     # ============ Optimizer ============
#     optimizer_option="AdamW",
#     lr=1e-4,  # ⬆️ 1e-6 → 1e-5로 상향 (delta loss가 강해서 안정적)
#     momentum=0.9,
#     weight_decay=1e-4,

#     # ============ Tracking Parameters ============
#     iou_threshold=0.3,
#     min_confidence=0.5,
#     max_age=30,

#     # ============ Loss Settings ============
#     bbox_loss_weight=1.0,
#     smooth_l1_beta=1.0,
#     use_delta_loss=True,  # ✅ NEW: Delta encoding 사용 (BN에 효과적)

#     # ============ Kalman Update Strategy ============
#     use_model_for_kalman_update=True,  # ✅ CHANGED: True로 (detection 사용)
#     kalman_detection_blend=1.0,  # ✅ CHANGED: 1.0 (pure detection)

#     # ============ Batch Accumulation ============
#     batch_accumulation_steps=4,  # ✅ NEW: 4 프레임 누적 (gradient 안정화)

#     # ============ Innovation Weighting ============
#     use_covariance_weighting=False,
#     use_innovation_weighting=True,
#     max_innovation=100.0,
#     min_innovation_weight=0.2,

#     # ============ Detection Confidence Gating ============
#     confidence_penalty_exponent=2.0,

#     # ============ Scene Change Detection ============
#     shift_detection_window=10,
#     shift_detection_threshold=0.3,
#     shift_detection_min_matches=2,
#     reset_tracker_on_shift=False,

#     # ============ Quality Filter ============
#     enable_quality_filter=True,
#     min_matches=4,  # ✅ CHANGED: 2 → 4 (더 안정적)

#     min_track_hits=2.0,  # 3.0 → 2.0 (약간 완화)
#     min_match_iou=0.5,
#     max_innovation_cv=0.8,
#     min_avg_innovation=5.0,
#     max_avg_innovation=80.0,
#     max_outlier_ratio=3.0,
# )

# adaptive_model = Ours(model, config)

In [None]:
# config = ActMADConfig()
# config.data_root = DATA_ROOT
# config.lr = 0.00001
# config.clear_dataset = model.DataPreparation(dataset.train, evaluation_mode=True)
# adaptive_model = ActMAD(model, config)

In [None]:
# config = NORMConfig()
# config.data_root = DATA_ROOT
# adaptive_model = NORM(model, config)

In [None]:
# config = DUAConfig()
# config.data_root = DATA_ROOT
# adaptive_model = DUA(model, config)

In [None]:
# config = MeanTeacherConfig()
# config.data_root = DATA_ROOT
# config.lr = 0.0001
# adaptive_model = MeanTeacher(model, config)

In [None]:
# # skip x
# config = WHWConfig()
# config.model_type = "rcnn"
# config.data_root = "./data"
# config.device = torch.device("cuda:0")

# # Optimizer
# config.lr = 2e-3
# config.optimizer_option = "SGD"
# config.momentum = 0.9
# config.weight_decay = 1e-4

# # Adaptation
# config.adaptation_where = "adapter"
# config.adapter_bottleneck_ratio = 24 # [16, 24, 32]

# # Skip settings
# config.skip_redundant = None # "stat+period+ema"
# config.skip_beta = 1.05      # SKIP_BETA
# config.skip_period = 10      # SKIP_PERIOD
# config.skip_tau = 1.1        # SKIP_TAU

# # Loss settings
# config.fg_align = "KL"
# config.gl_align = "KL"
# config.alpha_fg = 1.0
# config.alpha_gl = 0.5 # original : 1.0
# config.ema_gamma = 128 # [64, 96, 128]
# config.freq_weight = True

# # Dataset
# config.num_classes = 6
# config.clear_dataset = model.DataPreparation(dataset.train, evaluation_mode=True)
# config.clear_statistics_batch = 64
# config.output_path = "./whw_source_statistics_clear.pt"

# adaptive_model = WHW(model, config)

In [None]:
# skip o
config = WHWConfig()
config.model_type = "rcnn"
config.data_root = "./data"
config.device = torch.device("cuda:0")

# Optimizer
config.lr = 2e-3
config.optimizer_option = "SGD"
config.momentum = 0.9
config.weight_decay = 1e-4

# Adaptation
config.adaptation_where = "adapter"
config.adapter_bottleneck_ratio = 24 # [16, 24, 32]

# Skip settings
config.skip_redundant = "stat+period+ema" # "stat+period+ema"
config.skip_beta = 1.05      # SKIP_BETA
config.skip_period = 10      # SKIP_PERIOD
config.skip_tau = 1.1        # SKIP_TAU

# Loss settings
config.fg_align = "KL"
config.gl_align = "KL"
config.alpha_fg = 1.0
config.alpha_gl = 0.5 # original : 1.0
config.ema_gamma = 128 # [64, 96, 128]
config.freq_weight = True

# Dataset
config.num_classes = 6
config.clear_dataset = model.DataPreparation(dataset.train, evaluation_mode=True)
config.clear_statistics_batch = 64
config.output_path = "./whw_source_statistics_clear.pt"

adaptive_model = WHW(model, config)

In [None]:
# # Compile model
# adaptive_model = torch.compile(adaptive_model)

In [None]:
# Ensure split (required due to Scenario class works with coroutines)
_ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT, train=True)

In [None]:
data_preparation = model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

discrete_scenario = scenarios.SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
)
continuous_scenario = scenarios.SHIFTContinuousScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTContinuousScenario.DEFAULT, transforms=data_preparation.transforms
)

In [None]:
evaluator = validator.DetectionEvaluator(adaptive_model, no_grad=False, classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
adaptive_model.model_provider = model.model_provider

In [None]:
# 1
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 2
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 3
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 4
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 5
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 6
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 7
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 8
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 9
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# 10
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))

In [None]:
# visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=[adaptive_model.__class__.__name__]))