# RT-DETR Pretraining with SHIFT-Discrete Dataset

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_NUM)
environ["CUDA_VISIBLE_DEVICES"]

## Imports

In [None]:
from os import path

import torch
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import (
    SHIFTDataset,
    SHIFTClearDatasetForObjectDetection,
    SHIFTCorruptedDatasetForObjectDetection
)
from ttadapters import datasets

from ttadapters.models.rcnn import FasterRCNNForObjectDetection, SwinRCNNForObjectDetection, collate_fn

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"INFO: Using device - {device}")

In [None]:
PROJECT_NAME = "detectron_test"
RUN_NAME = "Faster-RCNN_R50"

## Define Dataset

In [None]:
DATA_ROOT = path.join(".", "data")

dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(
        root=DATA_ROOT, train=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_train_transforms
    ),
    valid=SHIFTClearDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms
    ),
    test=SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms
    )
)

In [None]:
dataset.train[999]

In [None]:
dataset.train[999][0].shape  # should be (num_channels, height, width)

## DataLoader

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 8, 8
BATCH_SIZE = 50, 200, 200, 200  # A100 or H100
BATCH_SIZE = 40, 128, 128, 120  # Half of A100 or H100
BATCH_SIZE = 1, 1, 1, 1  # Online

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
dataloader = DataLoaderHolder(
    train=DataLoader(dataset.train, batch_size=BATCH_SIZE[0], shuffle=True, collate_fn=collate_fn),
    valid=DataLoader(dataset.valid, batch_size=BATCH_SIZE[1], shuffle=False, collate_fn=collate_fn),
    test=DataLoader(dataset.test, batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=collate_fn)
)

In [None]:
dataloader.train.__iter__().__next__()

## Load Model

In [None]:
USE_SWIN_T_BACKBONE = False

In [None]:
if USE_SWIN_T_BACKBONE:
    model = SwinRCNNForObjectDetection(dataset=SHIFTDataset)
else:
    model = FasterRCNNForObjectDetection(dataset=SHIFTDataset)

model.load_from(model.Weights.NATUREYOO, weight_key="model")
model.to(device)

### Direct Test

In [None]:
from tqdm.notebook import tqdm
import pandas as pd
from ipywidgets import Output
from IPython.display import display

import time
import gc

def evaluate_for(self, desc, loader, loader_length, threshold=0.0, dtype=torch.float32, device=torch.device("cuda")):
    torch.cuda.empty_cache()
    gc.collect()

    self.eval()

    map_metric = MeanAveragePrecision()
    predictions_list = []
    targets_list = []
    total_images = 0
    collapse_time = 0

    with torch.inference_mode():
        for batch in tqdm(loader, total=loader_length, desc=f"Evaluation for {desc}"):
            total_images += len(batch)
            with torch.autocast(device_type=device.type, dtype=dtype):
                start = time.time()
                outputs = self(batch)

                if device.type == 'cuda':
                    torch.cuda.synchronize()

                collapse_time += time.time() - start

            for output, input_data in zip(outputs, batch):
                instances = output['instances']
                mask = instances.scores > threshold

                pred_detection = Detections(
                    xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(),
                    class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                    confidence=instances.scores[mask].detach().cpu().numpy()
                )
                gt_instances = input_data['instances']
                target_detection = Detections(
                    xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                    class_id=gt_instances.gt_classes.detach().cpu().numpy()
                )

                predictions_list.append(pred_detection)
                targets_list.append(target_detection)

        map_metric.update(predictions=predictions_list, targets=targets_list)
        m_ap = map_metric.compute()

        per_class_map = {
            f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
            for idx in m_ap.matched_classes
        }
        performances = {
            "collapse_time": collapse_time,
            "fps": total_images / collapse_time
        }

        result = {
            "mAP@0.50:0.95": m_ap.map50_95.item(),
            "mAP@0.50": m_ap.map50.item(),
            "mAP@0.75": m_ap.map75.item(),
            **per_class_map,
            **performances
        }
        display(pd.DataFrame({k: [v] for k, v in result.items()}))
        return result

In [None]:
#evaluate_for(model, dataloader.valid, dataloader.valid_len)

In [None]:
#evaluate_for(model, dataloader.test, dataloader.test_len)

In [None]:
from ttadapters.datasets.scenarios import SHIFTDiscreteScenario

In [None]:
discrete_scenario = SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True,
    transform=datasets.detectron_image_transform,
    transforms=datasets.default_valid_transforms
)

In [None]:
from functools import partial

discrete_scenario.load(
    batch_size=BATCH_SIZE[1], shuffle=False, collate_fn=collate_fn
).play(partial(evaluate_for, model))

In [None]:
from ttadapters.methods import ActMADConfig, ActMAD

model = FasterRCNNForObjectDetection()
config = ActMADConfig()
adaptive_model = ActMAD(model, config)

discrete_scenario.load(
    batch_size=BATCH_SIZE[1],
    shuffle=False,
    collate_fn=collate_fn
).play(
    script=adaptive_model.evaluate_for
)

## Train

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 10
REAL_BATCH = BATCH_SIZE[-1]
LEARNING_RATE = 2e-5

training_args = dict(
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.1,
    weight_decay=0.15,
    max_grad_norm=1.0,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    gradient_accumulation_steps=REAL_BATCH//BATCH_SIZE[0],
    eval_accumulation_steps=BATCH_SIZE[1],
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=50,
    save_steps=50,
    logging_steps=50,
    save_total_limit=100,
    load_best_model_at_end=True,
    metric_for_best_model="mAP@0.50:0.95",
    greater_is_better=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    output_dir="./results/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    run_name=RUN_NAME,
    bf16=True,
)

testing_args = dict(
    per_device_eval_batch_size=BATCH_SIZE[2],
    batch_eval_metrics=True,
    remove_unused_columns=False,
)

In [None]:
from torchvision.ops import box_convert
from dataclasses import dataclass


def de_normalize_boxes(boxes, height, width):
    # 1. cxcywh → xyxy
    boxes_xyxy_norm = box_convert(boxes, 'cxcywh', 'xyxy')

    # 2. de-normalize (convert to actual pixel coordinates)
    boxes_xyxy_norm[:, [0, 2]] *= width
    boxes_xyxy_norm[:, [1, 3]] *= height
    return boxes_xyxy_norm


def map_compute_metrics(preprocessor, threshold=0.0):
    map_metric = MeanAveragePrecision()
    post_process = preprocessor.post_process_object_detection

    def calc(eval_pred, compute_result=False):
        nonlocal map_metric

        if compute_result:
            m_ap = map_metric.compute()
            map_metric.reset()

            per_class_map = {
                f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean()
                for idx in m_ap.matched_classes
            }

            return {
                "mAP@0.50:0.95": m_ap.map50_95,
                "mAP@0.50": m_ap.map50,
                "mAP@0.75": m_ap.map75,
                **per_class_map
            }
        else:
            preds = ModelOutput(*eval_pred.predictions[1:3])
            labels = eval_pred.label_ids
            sizes = [label['orig_size'].cpu().tolist() for label in labels]

            results = post_process(preds, target_sizes=sizes, threshold=threshold)
            predictions = [Detections.from_transformers(result) for result in results]
            targets = [Detections(
                xyxy=de_normalize_boxes(label['boxes'], *label['orig_size']).cpu().numpy(),
                class_id=label['class_labels'].cpu().numpy(),
            ) for label in labels]

            map_metric.update(predictions=predictions, targets=targets)
            return {}
    return calc, map_metric

In [None]:
from detectron2.utils.events import EventStorage
import torch

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for iteration, data in enumerate(dataloader.train):
    with EventStorage(iteration) as storage:
        # Forward pass
        loss_dict = model(data)

        # 모든 loss를 합산
        losses = sum(loss_dict.values())

        # Backward pass
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        # 로깅 (선택사항)
        if iteration % 20 == 0:
            print(f"Iteration {iteration}: {loss_dict}")

## Evaluate

In [None]:
targets = []
predictions = []
batch_size = 32

raw_data = DataLoader(LabelDataset(dataset.valid), batch_size=batch_size, collate_fn=naive_collate_fn)
loader = DataLoader(DatasetAdapterForTransformers(dataset.valid), batch_size=batch_size, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))
for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, loader):
    sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

    with torch.no_grad():
        outputs = model(pixel_values=inputs['pixel_values'].to(device))

    results = reference_preprocessor.post_process_object_detection(
        outputs, target_sizes=sizes, threshold=0.3
    )

    detections = [Detections.from_transformers(results[i]) for i in range(batch_size)]
    annotations = [Detections(
        xyxy=lables[i][0].cpu().numpy(),
        class_id=lables[i][1].cpu().numpy(),
    ) for i in range(batch_size)]

    targets.extend(annotations)
    predictions.extend(detections)

In [None]:
len(predictions) == len(targets), len(predictions), len(targets)

In [None]:
mean_average_precision = MeanAveragePrecision().update(
    predictions=predictions,
    targets=targets,
).compute()
per_class_map = {
    f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
    for idx in mean_average_precision.matched_classes
}

print(f"mAP@0.95: {mean_average_precision.map50_95:.2f}")
print(f"map50: {mean_average_precision.map50:.2f}")
print(f"map75: {mean_average_precision.map75:.2f}")
for key, value in per_class_map.items():
    print(f"{key}: {value:.2f}")