In [None]:
!nvidia-smi

In [None]:
import sys

from os import path
import torch
from fast_baseline import DIRECT, ActMAD, NORM, DUA, MeanTeacher, WHW, FasterRCNNForObjectDetection, SwinRCNNForObjectDetection
from ttadapters.models.rcnn import FasterRCNNForObjectDetection, SwinRCNNForObjectDetection
from ttadapters.datasets import SHIFTDataset

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 6

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_NUM)
environ["CUDA_VISIBLE_DEVICES"]

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"INFO: Using device - {device}")

In [None]:
import os
os.chdir("/workspace/ptta") # os.chdir("/home/ubuntu/test-time-adapters")

In [None]:
DATA_ROOT = path.join(".", "data")

## Model Load

In [None]:
USE_SWIN_T_BACKBONE = False

In [None]:
if USE_SWIN_T_BACKBONE:
    model = SwinRCNNForObjectDetection(dataset=SHIFTDataset)
else:
    model = FasterRCNNForObjectDetection(dataset=SHIFTDataset)

model.load_from(model.Weights.NATUREYOO, weight_key="model")
model.to(device)

## Baseline

In [None]:
# Direct_method
direct = DIRECT.load(
    model, 
    data_root=DATA_ROOT, 
    batch_size=4
)

# ActMAD
actmad = ActMAD.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    learning_rate=0.001,
    clean_bn_extract_batch=8
)

# NORM
norm = NORM.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    source_sum=128  # NORM 특유의 하이퍼파라미터
)

# DUA
dua = DUA.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    decay_factor=0.94,
    mom_pre=0.01,
    min_momentum_constant=0.0001
)

# Mean-Teacher
mean_teacher = MeanTeacher.load(
    model=model,
    data_root=DATA_ROOT,
    conf_threshold=0.3,          # 낮은 threshold로 더 많은 pseudo labels
    augment_strength_n=1,        # 적은 연산
    augment_strength_m=5,        # 약한 강도
    cutout_size=8,              # 작은 cutout
    learning_rate=0.001,       # 학습률
    ema_alpha=0.99              # EMA 계수 (높을수록 안정적)
)

# WHW
whw = WHW.load(
    model=model,
    data_root=DATA_ROOT,  # 데이터 경로
    batch_size=4,
    learning_rate=0.0001,
    weight_decay=1e-4,
    momentum=0.9,
    adaptation_where="adapter",      # adapter만 학습
    adapter_bottleneck_ratio=32,     # bottleneck 비율 (r=32)

    fg_align='KL',                   # Foreground alignment (KL divergence)
    gl_align='KL',                   # Global alignment (KL divergence)
    alpha_fg=1.0,                    # Foreground loss weight
    alpha_gl=1.0,                    # Global loss weight
    ema_gamma=128,                   # EMA 계수
    source_feat_stats='path/to/source_stats.pt'  # Source statistics 경로
)

In [None]:
# Direct_method
direct_results = direct.evaluate_all_tasks()

# ActMAD
actmad_results = actmad.evaluate_all_tasks()

# NORM
norm_results = norm.evaluate_all_tasks()

# DUA
dua_results = dua.evaluate_all_tasks()

# Mean-Teacher
mean_teacher_results = mean_teacher.evaluate_all_tasks()

# WHW
whw = whw.evaluate_all_tasks()

In [None]:
results = actmad_results # norm_results, dua_results, mean_teacher_results

print("=== mAP@0.50:0.95 Summary ===")
for task, metrics in results.items():
    print(f"{task:10s}: {metrics['mAP@0.50:0.95']:.3f}")