In [None]:
!nvidia-smi

In [None]:
from os import path
import torch
from pathlib import Path
from rtdetr_baseline import DirectMethod, ActMAD, NORM, DUA, MeanTeacher
from transformers import (
    RTDetrForObjectDetection,
    RTDetrImageProcessorFast,
    RTDetrConfig,
)
from safetensors.torch import load_file

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 2  # Change to your desired GPU number

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_NUM)
environ["CUDA_VISIBLE_DEVICES"]

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"INFO: Using device - {device}")

In [None]:
import os
os.chdir("/workspace/ptta")  # Change to your working directory

In [None]:
DATA_ROOT = path.join(".", "data")
print(f"Data root: {DATA_ROOT}")

## RT-DETR Model Load

In [None]:
# RT-DETR Configuration
REFERENCE_MODEL_ID = "PekingU/rtdetr_r50vd"
IMAGE_SIZE = 800
CLASS_NUM = 6
MODEL_STATES_PATH = "/workspace/ptta/RT-DETR_R50vd_SHIFT_CLEAR.safetensors"  # Change to your model path

# Load RT-DETR model
reference_config = RTDetrConfig.from_pretrained(REFERENCE_MODEL_ID, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = CLASS_NUM
reference_config.image_size = IMAGE_SIZE

model = RTDetrForObjectDetection(config=reference_config)

# Load pretrained weights if available
if os.path.exists(MODEL_STATES_PATH):
    model_states = load_file(MODEL_STATES_PATH)
    model.load_state_dict(model_states, strict=False)
    print(f"Loaded model weights from {MODEL_STATES_PATH}")
else:
    print(f"Model weights not found at {MODEL_STATES_PATH}, using default weights")

model.to(device)
print(f"Model loaded and moved to {device}")

## RT-DETR Baseline Methods

Here we demonstrate how to use different TTA (Test-Time Adaptation) methods with RT-DETR:

### 1. Direct Method (No Adaptation)

In [None]:
# Direct Method - No adaptation, just evaluation
direct_method = DirectMethod.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    image_size=IMAGE_SIZE,
    reference_model_id=REFERENCE_MODEL_ID,
    class_num=CLASS_NUM
)

print("Direct Method initialized successfully!")

### 2. ActMAD (Activation Mean Alignment and Discrepancy)

In [None]:
# ActMAD - Activation-based adaptation
actmad = ActMAD.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    learning_rate=0.001,
    clean_bn_extract_batch=8,
    image_size=IMAGE_SIZE,
    reference_model_id=REFERENCE_MODEL_ID,
    class_num=CLASS_NUM
)

print("ActMAD initialized successfully!")

### 3. NORM (Normalization Adaptation)

In [None]:
# NORM - Normalization layer adaptation
norm = NORM.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    source_sum=128,  # NORM-specific hyperparameter
    image_size=IMAGE_SIZE,
    reference_model_id=REFERENCE_MODEL_ID,
    class_num=CLASS_NUM
)

print("NORM initialized successfully!")

### 4. DUA (Dynamic Update Adaptation)

In [None]:
# DUA - Dynamic momentum-based adaptation
dua = DUA.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    decay_factor=0.94,
    mom_pre=0.01,
    min_momentum_constant=0.0001,
    image_size=IMAGE_SIZE,
    reference_model_id=REFERENCE_MODEL_ID,
    class_num=CLASS_NUM
)

print("DUA initialized successfully!")

### 5. Mean-Teacher

In [None]:
# Mean-Teacher - Student-teacher framework with pseudo-labeling
mean_teacher = MeanTeacher.load(
    model=model,
    data_root=DATA_ROOT,
    device=device,
    batch_size=4,
    learning_rate=0.0001,
    ema_alpha=0.99,          # Teacher model EMA coefficient
    conf_threshold=0.5,      # Pseudo-label confidence threshold
    image_size=IMAGE_SIZE,
    reference_model_id=REFERENCE_MODEL_ID,
    class_num=CLASS_NUM
)

print("Mean-Teacher initialized successfully!")

## Evaluation on All Tasks

Now let's evaluate all methods on all corruption tasks:

### Direct Method Evaluation

In [None]:
print("Running Direct Method evaluation...")
direct_results = direct_method.evaluate_all_tasks()
print("Direct Method evaluation completed!")

### ActMAD Evaluation

In [None]:
print("Running ActMAD evaluation...")
actmad_results = actmad.evaluate_all_tasks()
print("ActMAD evaluation completed!")

### NORM Evaluation

In [None]:
print("Running NORM evaluation...")
norm_results = norm.evaluate_all_tasks()
print("NORM evaluation completed!")

### DUA Evaluation

In [None]:
print("Running DUA evaluation...")
dua_results = dua.evaluate_all_tasks()
print("DUA evaluation completed!")

### Mean-Teacher Evaluation

In [None]:
print("Running Mean-Teacher evaluation...")
mean_teacher_results = mean_teacher.evaluate_all_tasks()
print("Mean-Teacher evaluation completed!")