In [None]:
from pathlib import Path
import sys

THIS_DIR = Path.cwd().resolve()
PROJECT_ROOT = THIS_DIR.parents[1]  # -> ptta/
print(PROJECT_ROOT)

sys.path.insert(0, str(PROJECT_ROOT))

In [None]:
import sys
from pathlib import Path

# 현재 폴더: ptta/other_method/DUA/
# ptta 바로 위의 디렉토리를 sys.path에 추가
PROJECT_PARENT = Path.cwd().parents[1]  # -> ptta/ 의 부모 디렉토리
sys.path.insert(0, str(PROJECT_PARENT))

from os import path

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import SHIFTClearDatasetForObjectDetection, SHIFTCorruptedDatasetForObjectDetection, SHIFTDiscreteSubsetForObjectDetection
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from accelerate import Accelerator, notebook_launcher

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

# import wandb
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 1
ADDITIONAL_GPU = 0
DATA_TYPE = torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

In [None]:
PROJECT_NAME = "DUA test"
RUN_NAME = "RT-DETR_R50_DUA"

## Dataset

In [None]:
DATA_ROOT = path.normpath(path.join(Path.cwd(), "..", "..", "data"))
print(DATA_ROOT)
dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
    valid=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
    test=SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
)
DATA_ROOT

In [None]:
from typing import Iterable, List

def task_to_subset_types(task: str):
    T = SHIFTDiscreteSubsetForObjectDetection.SubsetType

    # weather
    if task == "cloudy":
        return T.CLOUDY_DAYTIME
    if task == "overcast":
        return T.OVERCAST_DAYTIME
    if task == "rainy":
        return T.RAINY_DAYTIME
    if task == "foggy":
        return T.FOGGY_DAYTIME

    # time
    if task == "night":
        return T.CLEAR_NIGHT
    if task in {"dawn", "dawn/dusk"}:
        return T.CLEAR_DAWN
    if task == "clear":
        return T.CLEAR_DAYTIME
    
    # simple
    if task == "normal":
        return T.NORMAL
    if task == "corrupted":
        return T.CORRUPTED

    raise ValueError(f"Unknown task: {task}")

In [None]:
from typing import Optional, Callable

class SHIFTCorruptedTaskDatasetForObjectDetection(SHIFTDiscreteSubsetForObjectDetection):
    def __init__(
            self, root: str, force_download: bool = False,
            train: bool = True, valid: bool = False,
            transform: Optional[Callable] = None, task: str = "clear", target_transform: Optional[Callable] = None
    ):
        super().__init__(
            root=root, force_download=force_download,
            train=train, valid=valid, subset_type=task_to_subset_types(task),
            transform=transform, target_transform=target_transform
        )

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 8, 8  # 4070 Ti
BATCH_SIZE = 32, 64, 64, 32  # A6000

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [None]:
def collate_fn(batch, preprocessor=None):
    images = [item['image'] for item in batch]
    if preprocessor is not None:
        target = [item['target'] for item in batch]
        return preprocessor(images=images, annotations=target, return_tensors="pt")
    else:
        # If no preprocessor is provided, just assume images are already in tensor format
        return dict(
            pixel_values=dict(pixel_values=torch.stack(images)),
            labels=[dict(
                class_labels=item['boxes2d_classes'].long(),
                boxes=item["boxes2d"].float()
            ) for item in batch]
        )

## Define Model

In [None]:
from transformers import RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig
from transformers.image_utils import AnnotationFormat
from safetensors.torch import load_file

In [None]:
IMG_SIZE = 800

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = 6

# Set the image size and preprocessor size
reference_config.image_size = 800

# Load the reference model image processor
reference_preprocessor = RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format
reference_preprocessor.size = {"height": IMG_SIZE, "width": IMG_SIZE}
reference_preprocessor.do_resize = False

In [None]:
model_pretrained = RTDetrForObjectDetection(config=reference_config)
model_states = load_file("/workspace/ptta/RT-DETR_R50vd_SHIFT_CLEAR.safetensors", device="cpu")
model_pretrained.load_state_dict(model_states, strict=False)

for param in model_pretrained.parameters():
    param.requires_grad = False  # Freeze

# Initialize Model
model_pretrained.to(device)

## NORM

In [None]:
from functools import partial
dataloader_discrete = DataLoader(DatasetAdapterForTransformers(dataset.test), batch_size=4, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))

In [None]:
class LabelDataset(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        return item['boxes2d'], item['boxes2d_classes']

In [None]:
def naive_collate_fn(batch):
    return batch

In [None]:
from functools import partial

def test(model, task, raw_data, dataloader_discrete):
    targets = []
    predictions = []

    for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, dataloader_discrete):
        sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

        with torch.no_grad():
            outputs = model(pixel_values=inputs['pixel_values'].to(device))

        results = reference_preprocessor.post_process_object_detection(
            outputs, target_sizes=sizes, threshold=0.0
        )

        detections = [Detections.from_transformers(results[i]) for i in range(len(results))]
        annotations = [Detections(
            xyxy=lables[i][0].cpu().numpy(),
            class_id=lables[i][1].cpu().numpy(),
        ) for i in range(len(lables))]

        targets.extend(annotations)
        predictions.extend(detections)
    
    mean_average_precision = MeanAveragePrecision().update(
    predictions=predictions,
    targets=targets,
    ).compute()
    per_class_map = {
        f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
        for idx in mean_average_precision.matched_classes
    }

    print(f"mAP@0.95_{task}: {mean_average_precision.map50_95:.3f}")
    print(f"mAP50_{task}: {mean_average_precision.map50:.3f}")
    print(f"mAP75_{task}: {mean_average_precision.map75:.3f}")
    for key, value in per_class_map.items():
        print(f"{key}_{task}: {value:.3f}")
    
    return {"mAP@0.95" : mean_average_precision.map50_95,
            "mAP50" : mean_average_precision.map50,
            "mAP75" : mean_average_precision.map75,
            "per_class_mAP@0.95" : per_class_map
            }

In [None]:
from collections import defaultdict

def agg_per_class(dicts):
    """dicts: per_class_map(dict)의 리스트. 예: [{"car_mAP@0.95":0.41, ...}, {...}]"""
    sums = defaultdict(float)
    counts = defaultdict(int)
    for d in dicts:
        for cls, val in d.items():
            sums[cls]  += float(val)
            counts[cls] += 1
    means = {cls: (sums[cls] / counts[cls]) for cls in sums}
    return means


def aggregate_runs(results_list):
    overall_sum = {"mAP@0.95": 0.0, "mAP50": 0.0, "mAP75": 0.0}
    n = len(results_list)

    per_class_maps = []

    for r in results_list:
        overall_sum["mAP@0.95"] += float(r["mAP@0.95"])
        overall_sum["mAP50"]    += float(r["mAP50"])

        overall_sum["mAP75"] += float(r["mAP75"])

        class_mAP = r["per_class_mAP@0.95"]
        per_class_means = agg_per_class([class_mAP])

    overall_mean = {k: (overall_sum[k] / n if n > 0 else 0.0) for k in overall_sum}

    return {
        "overall_sum": overall_sum,            # {"mAP@0.95": ..., "mAP50": ..., "map75": ...}
        "overall_mean": overall_mean,          # 위의 평균          # {"car_mAP@0.95": 합, ...}
        "per_class_mean@0.95": per_class_means,        # {"car_mAP@0.95": 평균, ...}
    }

def print_results(result):
    om = result["overall_mean"]
    print(f"mAP@0.95: {float(om['mAP@0.95']):.3f}")
    print(f"mAP50: {float(om['mAP50']):.3f}")
    print(f"mAP75: {float(om['mAP75']):.3f}")

    for k, v in result["per_class_mean@0.95"].items():
        print(f"{k}: {v:.3f}")

In [None]:
class EMABatchNorm(nn.Module):
    @staticmethod
    def reset_stats(module):
        module.reset_running_stats()
        module.momentum = None
        return module

    @staticmethod
    def find_bns(parent):
        replace_mods = []
        if parent is None:
            return []
        for name, child in parent.named_children():
            child.requires_grad_(False)
            if isinstance(child, nn.BatchNorm2d):
                module = EMABatchNorm.reset_stats(child)
                module = EMABatchNorm(module)
                replace_mods.append((parent, name, module))
            else:
                replace_mods.extend(EMABatchNorm.find_bns(child))

        return replace_mods

    @staticmethod
    def adapt_model(model):
        replace_mods = EMABatchNorm.find_bns(model)
        print(f"| Found {len(replace_mods)} modules to be replaced.")
        for parent, name, child in replace_mods:
            setattr(parent, name, child)
        return model

    def __init__(self, layer):
        super().__init__()
        self.layer = layer

    def forward(self, x):
        # store statistics, but discard result
        self.layer.train()
        self.layer(x)
        # store statistics, use the stored stats
        self.layer.eval()
        return self.layer(x)

In [None]:
def set_bn_momentum(model, momentum: float | None):
    # momentum=None 이면 CMA(누적 평균), 수치(0<α≤1)이면 EMA
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = momentum

In [None]:
import os, logging
from datetime import datetime
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm  # tqdm → logging 연동

def setup_logger(save_dir, name="dua", level=logging.INFO, mirror_to_stdout=True):
    os.makedirs(save_dir, exist_ok=True)
    log_path = os.path.join(save_dir, f"{name}_{datetime.now():%Y%m%d_%H%M%S}.log")

    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.handlers.clear()

    fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
    # 파일 핸들러(전체 기록)
    fh = logging.FileHandler(log_path, encoding="utf-8")
    fh.setFormatter(fmt)
    logger.addHandler(fh)

    # 화면 핸들러(미러링)
    if mirror_to_stdout:
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        logger.addHandler(sh)

    logger.propagate = False
    return logger, log_path

In [None]:
import io, contextlib
from tqdm.contrib.logging import logging_redirect_tqdm  # tqdm도 로깅에 정리되도록

class LoggerWriter(io.TextIOBase):
    def __init__(self, logger, level):
        self.logger = logger
        self.level = level
        self._buf = ""
    def write(self, msg):
        # 줄 단위로 로깅(개행/부분문자 처리)
        self._buf += msg
        while "\n" in self._buf:
            line, self._buf = self._buf.split("\n", 1)
            if line.strip():
                self.level(line)
    def flush(self):
        if self._buf.strip():
            self.level(self._buf.strip())
            self._buf = ""

In [None]:
from transformers.models.rt_detr.modeling_rt_detr import RTDetrFrozenBatchNorm2d
from torchvision import transforms
import os
import copy

def NORM(model, MOMENTOM, save_dir = './norm_checkpoints'):
    """
    model is a pre-trained model.
    """
    logger, log_path = setup_logger(save_dir, name="norm", mirror_to_stdout=True)
    logger.info(f"Log file: {log_path}")
    
    with contextlib.redirect_stdout(LoggerWriter(logger, logger.info)), \
        contextlib.redirect_stderr(LoggerWriter(logger, logger.error)), \
        logging_redirect_tqdm():  # tqdm 진행바도 깔끔히 로그에 남김

        os.makedirs(save_dir, exist_ok=True)

        all_results = []
        
        device = next(model.parameters()).device
        carry_state = copy.deepcopy(model.state_dict())

        for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night", "clear"]: 
            logger.info(f"start {task}")
            model.load_state_dict(carry_state)
            best_state = copy.deepcopy(model.state_dict())
            best_mAP95 = -1.0
            no_imp_streak = 0
            PATIENCE = 10
            EVAL_EVERY = 10

            dataloader_discrete = DataLoader(DatasetAdapterForTransformers(SHIFTCorruptedTaskDatasetForObjectDetection(root=DATA_ROOT, train=True, valid=False, task=task)), batch_size=1, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))

            # dataset_tta_valid
            dataset_valid = SHIFTCorruptedTaskDatasetForObjectDetection(root=DATA_ROOT, valid=True, task=task)
            
            raw_data = DataLoader(LabelDataset(dataset_valid), batch_size=16, collate_fn=naive_collate_fn)
            dataloader_discrete_valid = DataLoader(DatasetAdapterForTransformers(dataset_valid), batch_size=16, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))

            for batch_i, input in enumerate(tqdm(dataloader_discrete)):
                model.eval()
                for module in model.modules():
                    if isinstance(module, torch.nn.BatchNorm2d):
                        module.momentum = MOMENTOM
                        module.train()
                img = input['pixel_values'].to(device, non_blocking=True)

                _ = model(img)
                model.eval()

                if batch_i % EVAL_EVERY == 0:
                    current_result = test(model, task, raw_data, dataloader_discrete_valid)
                    current_mAP95 = current_result.get("mAP@0.95", current_result.get("mAP@0.95", -1.0))

                    improved = current_mAP95 >= best_mAP95

                    if improved:
                        logger.info(f"[{task}] batch {batch_i}: mAP@0.95 {best_mAP95:.4f} -> {current_mAP95:.4f} ✔")
                        best_mAP95 = current_mAP95
                        best_state = copy.deepcopy(model.state_dict())
                        no_imp_streak = 0
                    else:
                        no_imp_streak += 1
                        logger.info(f"[{task}] batch {batch_i}: mAP@0.95 {current_mAP95:.4f} (no-imp {no_imp_streak}/{PATIENCE})")
                        if no_imp_streak >= PATIENCE:
                            logger.info(f"[{task}] Early stop at batch {batch_i} (no improvement {PATIENCE} times).")
                            break
            model.load_state_dict(best_state)
            with torch.inference_mode():
                final_result = test(model, task, raw_data, dataloader_discrete_valid)

            save_path = os.path.join(save_dir, f"model_{task}.pth")
            torch.save(model.state_dict(), save_path)
            logger.info(f"Saved adapted model after task [{task}] → {save_path}")

            carry_state = copy.deepcopy(model.state_dict())

            all_results.append(final_result)
        
        each_task_mAP_list = aggregate_runs(all_results)

        print_results(each_task_mAP_list)    

In [None]:
# 
NORM(model_pretrained, MOMENTOM=0.2)