## Import Libraries

In [None]:
from pathlib import Path
import sys

THIS_DIR = Path.cwd().resolve()
PROJECT_ROOT = THIS_DIR.parents[1]  # -> ptta/
print(PROJECT_ROOT)

sys.path.insert(0, str(PROJECT_ROOT))

In [None]:
import sys
from pathlib import Path

# 현재 폴더: ptta/other_method/ActMAD/
# ptta 바로 위의 디렉토리를 sys.path에 추가
PROJECT_PARENT = Path.cwd().parents[1]  # -> ptta/ 의 부모 디렉토리
sys.path.insert(0, str(PROJECT_PARENT))

from os import path

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import SHIFTClearDatasetForObjectDetection, SHIFTCorruptedDatasetForObjectDetection, SHIFTDiscreteSubsetForObjectDetection

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from accelerate import Accelerator, notebook_launcher

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

# import wandb
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 0
ADDITIONAL_GPU = 0
DATA_TYPE = torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

In [None]:
PROJECT_NAME = "ActMAD test"
RUN_NAME = "RT-DETR_R50_ActMAD"

## Dataset

In [None]:
DATA_ROOT = path.normpath(path.join(Path.cwd(), "..", "..", "data"))
print(DATA_ROOT)
dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
    valid=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
    test=SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
)

In [None]:
from typing import Iterable, List

def task_to_subset_types(task: str):
    T = SHIFTDiscreteSubsetForObjectDetection.SubsetType

    # weather
    if task == "cloudy":
        return T.CLOUDY_DAYTIME
    if task == "overcast":
        return T.OVERCAST_DAYTIME
    if task == "rainy":
        return T.RAINY_DAYTIME
    if task == "foggy":
        return T.FOGGY_DAYTIME

    # time
    if task == "night":
        return T.CLEAR_NIGHT
    if task in {"dawn", "dawn/dusk"}:
        return T.CLEAR_DAWN
    if task == "clear":
        return T.CLEAR_DAYTIME
    
    # simple
    if task == "normal":
        return T.NORMAL
    if task == "corrupted":
        return T.CORRUPTED

    raise ValueError(f"Unknown task: {task}")

In [None]:
from typing import Optional, Callable

class SHIFTCorruptedTaskDatasetForObjectDetection(SHIFTDiscreteSubsetForObjectDetection):
    def __init__(
            self, root: str, force_download: bool = False,
            train: bool = True, valid: bool = False,
            transform: Optional[Callable] = None, task: str = "clear", target_transform: Optional[Callable] = None
    ):
        super().__init__(
            root=root, force_download=force_download,
            train=train, valid=valid, subset_type=task_to_subset_types(task),
            transform=transform, target_transform=target_transform
        )

## Dataloader

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 8, 8  # 4070 Ti
BATCH_SIZE = 32, 64, 64, 32  # A6000

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [None]:
def collate_fn(batch, preprocessor=None):
    images = [item['image'] for item in batch]
    if preprocessor is not None:
        target = [item['target'] for item in batch]
        return preprocessor(images=images, annotations=target, return_tensors="pt")
    else:
        # If no preprocessor is provided, just assume images are already in tensor format
        return dict(
            pixel_values=dict(pixel_values=torch.stack(images)),
            labels=[dict(
                class_labels=item['boxes2d_classes'].long(),
                boxes=item["boxes2d"].float()
            ) for item in batch]
        )

## Define Model
### APT: Adaptive Plugin for TTA(Test-time Adaptation)

In [None]:
from transformers import RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig
from transformers.image_utils import AnnotationFormat
from safetensors.torch import load_file

In [None]:
IMG_SIZE = 800

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = 6

# Set the image size and preprocessor size
reference_config.image_size = 800

# Load the reference model image processor
reference_preprocessor = RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format
reference_preprocessor.size = {"height": IMG_SIZE, "width": IMG_SIZE}
reference_preprocessor.do_resize = False

In [None]:
model_pretrained = RTDetrForObjectDetection(config=reference_config)
model_states = load_file("/home/elicer/ptta/RT-DETR_R50vd_SHIFT_CLEAR.safetensors", device="cpu")
model_pretrained.load_state_dict(model_states, strict=False)

for param in model_pretrained.parameters():
    param.requires_grad = False  # Freeze

# Initialize Model
model_pretrained.to(device)

### Training

## Direct method

In [None]:
class LabelDataset(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        return item['boxes2d'], item['boxes2d_classes']

In [None]:
def naive_collate_fn(batch):
    return batch

In [None]:
from functools import partial

def test(model, task, batch_size):
    targets = []
    predictions = []
    dataset = SHIFTCorruptedTaskDatasetForObjectDetection(root=DATA_ROOT, train=True, valid=True, task=task)
    
    raw_data = DataLoader(LabelDataset(dataset), batch_size=batch_size, collate_fn=naive_collate_fn)
    dataloader_discrete = DataLoader(DatasetAdapterForTransformers(dataset), batch_size=batch_size, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))
    for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, dataloader_discrete):
        sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

        with torch.no_grad():
            outputs = model(pixel_values=inputs['pixel_values'].to(device))

        results = reference_preprocessor.post_process_object_detection(
            outputs, target_sizes=sizes, threshold=0.0
        )

        detections = [Detections.from_transformers(results[i]) for i in range(len(results))]
        annotations = [Detections(
            xyxy=lables[i][0].cpu().numpy(),
            class_id=lables[i][1].cpu().numpy(),
        ) for i in range(len(lables))]

        targets.extend(annotations)
        predictions.extend(detections)
    
    mean_average_precision = MeanAveragePrecision().update(
    predictions=predictions,
    targets=targets,
    ).compute()
    per_class_map = {
        f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
        for idx in mean_average_precision.matched_classes
    }

    print(f"mAP@0.95_{task}: {mean_average_precision.map50_95:.3f}")
    print(f"mAP50_{task}: {mean_average_precision.map50:.3f}")
    print(f"mAP75_{task}: {mean_average_precision.map75:.3f}")
    for key, value in per_class_map.items():
        print(f"{key}_{task}: {value:.3f}")
    
    return {"mAP@0.95" : mean_average_precision.map50_95,
            "mAP50" : mean_average_precision.map50,
            "mAP75" : mean_average_precision.map75,
            "per_class_mAP@0.95" : per_class_map
            }

In [None]:
from collections import defaultdict

def agg_per_class(dicts):
    """dicts: per_class_map(dict)의 리스트. 예: [{"car_mAP@0.95":0.41, ...}, {...}]"""
    sums = defaultdict(float)
    counts = defaultdict(int)
    for d in dicts:
        for cls, val in d.items():
            sums[cls]  += float(val)
            counts[cls] += 1
    means = {cls: (sums[cls] / counts[cls]) for cls in sums}
    return means


def aggregate_runs(results_list):
    overall_sum = {"mAP@0.95": 0.0, "mAP50": 0.0, "mAP75": 0.0}
    n = len(results_list)

    per_class_maps = []

    for r in results_list:
        overall_sum["mAP@0.95"] += float(r["mAP@0.95"])
        overall_sum["mAP50"]    += float(r["mAP50"])

        overall_sum["mAP75"] += float(r["mAP75"])

        class_mAP = r["per_class_mAP@0.95"]
        per_class_means = agg_per_class([class_mAP])

    overall_mean = {k: (overall_sum[k] / n if n > 0 else 0.0) for k in overall_sum}

    return {
        "overall_sum": overall_sum,            # {"mAP@0.95": ..., "mAP50": ..., "map75": ...}
        "overall_mean": overall_mean,          # 위의 평균          # {"car_mAP@0.95": 합, ...}
        "per_class_mean@0.95": per_class_means,        # {"car_mAP@0.95": 평균, ...}
    }

def print_results(result):
    om = result["overall_mean"]
    print(f"mAP@0.95: {float(om['mAP@0.95']):.3f}")
    print(f"mAP50: {float(om['mAP50']):.3f}")
    print(f"mAP75: {float(om['mAP75']):.3f}")

    for k, v in result["per_class_mean@0.95"].items():
        print(f"{k}: {v:.2f}")

In [None]:
import torch.backends.cudnn as cudnn

def create_model(ema=False, model):
    if ema:
        for param in model.parameters():
            param.detach_()
    return model

def main():
    model = create_model(model_pretrained)
    ema_model = create_model(ema=True, model_pretrained)

    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, 
                                momentum=0.2, weight_decay=0.05)
    
    cudnn.benchmark = True

def symmetric_mse_loss(input1, input2):
    """Like F.mse_loss but sends gradients to both directions

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to both input1 and input2.
    """
    assert input1.size() == input2.size()
    num_classes = input1.size()[1]
    return torch.sum((input1 - input2)**2) / num_classes

def train(train_loader, model, ema_model, optimizer, eopch, training_log):
    loss =""# object detection에 맞는 loss를 설정
    residual_logit_criterion = symmetric_mse_loss()

    model.train()
    ema_model.train()

    for i, () in enumerate(train_loader):
        



In [None]:
# SHIFT-Discrete와 ActMAD를 활용하여 TTA
# 이후 SHIFT-Continuous를 이용하여 최종 확인 

from transformers.models.rt_detr.modeling_rt_detr import RTDetrFrozenBatchNorm2d

save_dir = '/home/elicer/ptta/other_method/ActMAD'

def Direct_method(model, actmad_save='end_batch', half_precision=False):
    print('Direct method start')
    device = next(model.parameters()).device
    all_ap = [] 
    for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night", "clear"]:
        mAP = test(model, task, batch_size=16)

        all_ap.append(mAP)
    
    each_task_mAP_list = aggregate_runs(all_ap)

    print_results(each_task_mAP_list)


In [None]:
Direct_method(model_pretrained, half_precision=False)