## Import Libraries

In [None]:
from pathlib import Path
import sys

THIS_DIR = Path.cwd().resolve()
PROJECT_ROOT = THIS_DIR.parents[1]  # -> ptta/
print(PROJECT_ROOT)

sys.path.insert(0, str(PROJECT_ROOT))

In [None]:
import sys
from pathlib import Path

# 현재 폴더: ptta/other_method/ActMAD/
# ptta 바로 위의 디렉토리를 sys.path에 추가
PROJECT_PARENT = Path.cwd().parents[1]  # -> ptta/ 의 부모 디렉토리
sys.path.insert(0, str(PROJECT_PARENT))

from os import path

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import SHIFTClearDatasetForObjectDetection, SHIFTCorruptedDatasetForObjectDetection, SHIFTDiscreteSubsetForObjectDetection

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from accelerate import Accelerator, notebook_launcher

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

# import wandb
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 2
ADDITIONAL_GPU = 0
DATA_TYPE = torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

In [None]:
PROJECT_NAME = "ActMAD test"
RUN_NAME = "RT-DETR_R50_ActMAD"

## Dataset

In [None]:
DATA_ROOT = path.normpath(path.join(Path.cwd(), "..", "..", "data"))
print(DATA_ROOT)
dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
    valid=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
    test=SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
)

In [None]:
from typing import Iterable, List

def task_to_subset_types(task: str):
    T = SHIFTDiscreteSubsetForObjectDetection.SubsetType

    # weather
    if task == "cloudy":
        return T.CLOUDY_DAYTIME
    if task == "overcast":
        return T.OVERCAST_DAYTIME
    if task == "rainy":
        return T.RAINY_DAYTIME
    if task == "foggy":
        return T.FOGGY_DAYTIME

    # time
    if task == "night":
        return T.CLEAR_NIGHT
    if task in {"dawn", "dawn/dusk"}:
        return T.CLEAR_DAWN
    if task == "clear":
        return T.CLEAR_DAYTIME
    
    # simple
    if task == "normal":
        return T.NORMAL
    if task == "corrupted":
        return T.CORRUPTED

    raise ValueError(f"Unknown task: {task}")

In [None]:
from typing import Optional, Callable

class SHIFTCorruptedTaskDatasetForObjectDetection(SHIFTDiscreteSubsetForObjectDetection):
    def __init__(
            self, root: str, force_download: bool = False,
            train: bool = True, valid: bool = False,
            transform: Optional[Callable] = None, task: str = "clear", target_transform: Optional[Callable] = None
    ):
        super().__init__(
            root=root, force_download=force_download,
            train=train, valid=valid, subset_type=task_to_subset_types(task),
            transform=transform, target_transform=target_transform
        )

## Dataloader

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 8, 8  # 4070 Ti
BATCH_SIZE = 32, 64, 64, 32  # A6000

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [None]:
def collate_fn(batch, preprocessor=None):
    images = [item['image'] for item in batch]
    if preprocessor is not None:
        target = [item['target'] for item in batch]
        return preprocessor(images=images, annotations=target, return_tensors="pt")
    else:
        # If no preprocessor is provided, just assume images are already in tensor format
        return dict(
            pixel_values=dict(pixel_values=torch.stack(images)),
            labels=[dict(
                class_labels=item['boxes2d_classes'].long(),
                boxes=item["boxes2d"].float()
            ) for item in batch]
        )

## Define Model
### APT: Adaptive Plugin for TTA(Test-time Adaptation)

In [None]:
from transformers import RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig
from transformers.image_utils import AnnotationFormat
from safetensors.torch import load_file

In [None]:
IMG_SIZE = 800

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = 6

# Set the image size and preprocessor size
reference_config.image_size = 800

# Load the reference model image processor
reference_preprocessor = RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format
reference_preprocessor.size = {"height": IMG_SIZE, "width": IMG_SIZE}
reference_preprocessor.do_resize = False

In [None]:
model_pretrained = RTDetrForObjectDetection(config=reference_config)
model_states = load_file("/workspace/ptta/RT-DETR_R50vd_SHIFT_CLEAR_42.42.safetensors", device="cpu")
model_pretrained.load_state_dict(model_states, strict=False)

for param in model_pretrained.parameters():
    param.requires_grad = False  # Freeze

# Initialize Model
model_pretrained.to(device)

### Training

## ActMAD

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class SaveOutput:
    def __init__(self):
        self.outputs = []

    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out.clone())

    def clear(self):
        self.outputs = []

    def get_out_mean(self):
        out = torch.vstack(self.outputs)
        out = torch.mean(out, dim=0)
        return out

    def get_out_var(self):
        out = torch.vstack(self.outputs)
        out = torch.var(out, dim=0)
        return out

In [None]:
class LabelDataset(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        return item['boxes2d'], item['boxes2d_classes']

In [None]:
def naive_collate_fn(batch):
    return batch

In [None]:
from functools import partial

def test(model, task, batch_size):
    targets = []
    predictions = []
    dataset = SHIFTCorruptedTaskDatasetForObjectDetection(root=DATA_ROOT, train=True, valid=True, task=task)
    
    raw_data = DataLoader(LabelDataset(dataset), batch_size=batch_size, collate_fn=naive_collate_fn)
    dataloader_discrete = DataLoader(DatasetAdapterForTransformers(dataset), batch_size=batch_size, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))
    for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, dataloader_discrete):
        sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

        with torch.no_grad():
            outputs = model(pixel_values=inputs['pixel_values'].to(device))

        results = reference_preprocessor.post_process_object_detection(
            outputs, target_sizes=sizes, threshold=0.0
        )

        detections = [Detections.from_transformers(results[i]) for i in range(batch_size)]
        annotations = [Detections(
            xyxy=lables[i][0].cpu().numpy(),
            class_id=lables[i][1].cpu().numpy(),
        ) for i in range(batch_size)]

        targets.extend(annotations)
        predictions.extend(detections)
    
    mean_average_precision = MeanAveragePrecision().update(
    predictions=predictions,
    targets=targets,
    ).compute()
    per_class_map = {
        f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
        for idx in mean_average_precision.matched_classes
    }

    print(f"mAP@0.95_{task}: {mean_average_precision.map50_95:.2f}")
    print(f"mAP50_{task}: {mean_average_precision.map50:.2f}")
    print(f"mAP75_{task}: {mean_average_precision.map75:.2f}")
    for key, value in per_class_map.items():
        print(f"{key}_{task}: {value:.2f}")
    
    return {"mAP@0.95" : mean_average_precision.map50_95,
            "mAP50" : mean_average_precision.map50,
            "mAP75" : mean_average_precision.map75,
            "per_class_mAP@0.95" : per_class_map
            }

In [None]:
from collections import defaultdict

def agg_per_class(dicts):
    """dicts: per_class_map(dict)의 리스트. 예: [{"car_mAP@0.95":0.41, ...}, {...}]"""
    sums = defaultdict(float)
    counts = defaultdict(int)
    for d in dicts:
        for cls, val in d.items():
            sums[cls]  += float(val)
            counts[cls] += 1
    means = {cls: (sums[cls] / counts[cls]) for cls in sums}
    return means


def aggregate_runs(results_list):
    overall_sum = {"mAP@0.95": 0.0, "mAP50": 0.0, "mAP75": 0.0}
    n = len(results_list)

    per_class_maps = []

    for r in results_list:
        overall_sum["mAP@0.95"] += float(r["mAP@0.95"])
        overall_sum["mAP50"]    += float(r["mAP50"])

        overall_sum["mAP75"] += float(r["mAP75"])

        class_mAP = r["per_class_mAP@0.95"]
        per_class_means = agg_per_class(class_mAP)

    overall_mean = {k: (overall_sum[k] / n if n > 0 else 0.0) for k in overall_sum}

    return {
        "overall_sum": overall_sum,            # {"mAP@0.95": ..., "mAP50": ..., "map75": ...}
        "overall_mean": overall_mean,          # 위의 평균          # {"car_mAP@0.95": 합, ...}
        "per_class_mean@0.95": per_class_means,        # {"car_mAP@0.95": 평균, ...}
    }

def print_results(result):
    print(f"mAP@0.95: {float(result['mAP@0.95']):.2f}")
    print(f"mAP50: {float(result['mAP50']):.2f}")
    print(f"mAP75: {float(result['mAP75']):.2f}")

    for k, v in result["per_class_mean@0.95"].item():
        print(f"{k}: {v:.2f}")

In [None]:
# SHIFT-Discrete와 ActMAD를 활용하여 TTA
# 이후 SHIFT-Continuous를 이용하여 최종 확인 

from transformers.models.rt_detr.modeling_rt_detr import RTDetrFrozenBatchNorm2d

save_dir = '/workspace/ptta/other_method/ActMAD'

def ActMAD(model, actmad_save='end_batch', half_precision=False):
    print('ActMAD start')
    
    all_ap = []
    
    device = next(model.parameters()).device
    
    half = device.type != 'cpu' and half_precision  # half precision only supported on CUDA
    if half:
        model.half()
    
    # dataloader
    dataloader_train = DataLoader(DatasetAdapterForTransformers(dataset.train), batch_size=4, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))

    l1_loss = nn.L1Loss(reduction='mean')

    # model unfreeze
    for k, v in model.named_parameters():
        v.requires_grad = True

    chosen_bn_layers = []
    for m in model.modules():
        if isinstance(m, nn.LayerNorm):
            chosen_bn_layers.append(m)
    # chosen_bn_layers
    """
    Since high-level representations are more sensitive to domain shift,
    only the later BN layers are selected. 
    The cutoff point is determined empirically.
    """
    n_chosen_layers = len(chosen_bn_layers)

    save_outputs = [SaveOutput() for _ in range(n_chosen_layers)]

    clean_mean_act_list = [AverageMeter() for _ in range(n_chosen_layers)]
    clean_var_act_list = [AverageMeter() for _ in range(n_chosen_layers)]

    clean_mean_list_final = []
    clean_var_list_final = []
    # extract the activation alignment in train dataset
    print("Start extracting BN statistics from the training dataset")
    
    with torch.no_grad():
        for batch_i, input in enumerate(tqdm(dataloader_train)):
            img = input['pixel_values'].to(device, non_blocking=True)
            img = img.half() if half else img.float()  # uint8 to fp16/32
            model.eval()
            hook_list = [chosen_bn_layers[i].register_forward_hook(save_outputs[i]) for i in range(n_chosen_layers)]
            _ = model(img)

            for i in range(n_chosen_layers):
                clean_mean_act_list[i].update(save_outputs[i].get_out_mean())  # compute mean from clean data
                clean_var_act_list[i].update(save_outputs[i].get_out_var())  # compute variane from clean data

                save_outputs[i].clear()
                hook_list[i].remove()

        for i in range(n_chosen_layers):
            clean_mean_list_final.append(clean_mean_act_list[i].avg)  # [C, H, W]
            clean_var_list_final.append(clean_var_act_list[i].avg)  # [C, H, W]
            
    for task in ["cloudy", "overcast", "foggy", "rain", "dawn", "night", "clear"]:
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4, nesterov=True)

        # dataset
        dataloader_discrete = DataLoader(DatasetAdapterForTransformers(SHIFTCorruptedTaskDatasetForObjectDetection(root=DATA_ROOT, train=True, valid=False, task=task)), batch_size=4, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))


        print('Starting TEST TIME ADAPTATION WITH ActMAD...')
        # ap_epochs = list()

        # TTA start
        for batch_i, input in enumerate(tqdm(dataloader_discrete)):
            model.train()
            for m in model.modules():
                if isinstance(m, nn.LayerNorm):
                    m.eval()

            optimizer.zero_grad()
            save_outputs_tta = [SaveOutput() for _ in range(n_chosen_layers)]

            hook_list_tta = [chosen_bn_layers[x].register_forward_hook(save_outputs_tta[x])
                                for x in range(n_chosen_layers)]
            img = input['pixel_values'].to(device, non_blocking=True)
            img = img.half() if half else img.float()  # uint8 to fp16/32
            _ = model(img)
            batch_mean_tta = [save_outputs_tta[x].get_out_mean() for x in range(n_chosen_layers)]
            batch_var_tta = [save_outputs_tta[x].get_out_var() for x in range(n_chosen_layers)]

            loss_mean = torch.tensor(0, requires_grad=True, dtype=torch.float).float().to(device)
            loss_var = torch.tensor(0, requires_grad=True, dtype=torch.float).float().to(device)

            for i in range(n_chosen_layers):
                loss_mean += l1_loss(batch_mean_tta[i].to(device), clean_mean_list_final[i].to(device))
                loss_var += l1_loss(batch_var_tta[i].to(device), clean_var_list_final[i].to(device))
                
            loss = loss_mean + loss_var

            loss.backward()
            optimizer.step()
            
            # test
            for z in range(n_chosen_layers):
                save_outputs_tta[z].clear()
                hook_list_tta[z].remove()

            # mAP 계산
            if actmad_save == 'each_batch':
                # load best ckpt
                best_ckpt = torch.load(f'{save_dir}/results_stf_ttt/models/{task}.pt')
                model.load_state_dict(best_ckpt['net'])
            else:
                # save last ckpt
                Path(f'{save_dir}/results_stf_ttt/models/').mkdir(parents=True, exist_ok=True)
                state = {
                    'net': model.state_dict()
                }
                torch.save(state, f'{save_dir}/results_stf_ttt/models/{task}.pt')
        
        mAP = test(model, task, batch_size=8)
        all_ap.append(mAP)
    
    each_task_mAP_list = aggregate_runs(all_ap)

    print_results(each_task_mAP_list)


In [None]:
ActMAD(model_pretrained, half_precision=False)

In [None]:
# targets = []
# predictions = []
# batch_size = 8
# from functools import partial
# model_pretrained = RTDetrForObjectDetection(config=reference_config)
# ckpt = torch.load("/workspace/ptta/other_method/ActMAD/best_actmad.pt", map_location=device)

# model_pretrained.load_state_dict(ckpt["model"], strict=True)
# model_pretrained.to(device)

# model_pretrained.eval()
# raw_data = DataLoader(LabelDataset(dataset.test), batch_size=batch_size, collate_fn=naive_collate_fn)
# dataloader_discrete = DataLoader(DatasetAdapterForTransformers(dataset.test), batch_size=batch_size, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))
# for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, dataloader_discrete):
#     sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

#     with torch.no_grad():
#         outputs = model_pretrained(pixel_values=inputs['pixel_values'].to(device))

#     results = reference_preprocessor.post_process_object_detection(
#         outputs, target_sizes=sizes, threshold=0.0
#     )

#     detections = [Detections.from_transformers(results[i]) for i in range(batch_size)]
#     annotations = [Detections(
#         xyxy=lables[i][0].cpu().numpy(),
#         class_id=lables[i][1].cpu().numpy(),
#     ) for i in range(batch_size)]

#     targets.extend(annotations)
#     predictions.extend(detections)

In [None]:
# len(predictions) == len(targets), len(predictions), len(targets)

In [None]:
# mean_average_precision = MeanAveragePrecision().update(
#     predictions=predictions,
#     targets=targets,
# ).compute()
# per_class_map = {
#     f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
#     for idx in mean_average_precision.matched_classes
# }

# print(f"mAP@0.95: {mean_average_precision.map50_95:.2f}")
# print(f"map50: {mean_average_precision.map50:.2f}")
# print(f"map75: {mean_average_precision.map75:.2f}")
# for key, value in per_class_map.items():
#     print(f"{key}: {value:.2f}")