# RT-DETR Pretraining with SHIFT-Discrete Dataset

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_NUM)
environ["CUDA_VISIBLE_DEVICES"]

## Imports

In [None]:
import os
os.chdir("/home/ubuntu/test-time-adapters")

In [None]:
from os import path

import torch
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder
from ttadapters.datasets import (
    SHIFTDataset,
    SHIFTClearDatasetForObjectDetection,
    SHIFTCorruptedDatasetForObjectDetection
)
from ttadapters import datasets

from ttadapters.models.rcnn import FasterRCNNForObjectDetection, SwinRCNNForObjectDetection

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"INFO: Using device - {device}")

In [None]:
PROJECT_NAME = "detectron_test"
RUN_NAME = "Faster-RCNN_R50"

## Define Dataset

In [None]:
DATA_ROOT = path.join(".", "data")

dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(
        root=DATA_ROOT, train=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_train_transforms
    ),
    valid=SHIFTClearDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms
    ),
    test=SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms
    )
)

In [None]:
dataset.train[999]

In [None]:
images, targets = dataset.train[1000]['shape'].shape  # should be (batch_size, num_channels, height, width)

## DataLoader

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 8, 8
#BATCH_SIZE = 50, 200, 200, 200  # A100 or H100
#BATCH_SIZE = 40, 120, 120, 120  # Half of A100 or H100

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [None]:
def collate_fn(batch, preprocessor=None):
    images = [item['image'] for item in batch]
    if preprocessor is not None:
        target = [item['target'] for item in batch]
        return preprocessor(images=images, annotations=target, return_tensors="pt")
    else:
        # If no preprocessor is provided, just assume images are already in tensor format
        return dict(
            pixel_values=dict(pixel_values=torch.stack(images)),
            labels=[dict(
                class_labels=item['boxes2d_classes'].long(),
                boxes=item["boxes2d"].float()
            ) for item in batch]
        )

## Load Model

In [None]:
model = SwinRCNNForObjectDetection(dataset=SHIFTDataset)
model.load_from(model.Weights.NATUREYOO, weight_key="model")
model.to(device)

In [None]:
model = FasterRCNNForObjectDetection(dataset=SHIFTDataset)
model.load_from(model.Weights.NATUREYOO, weight_key="model")
model.to(device)

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 10
REAL_BATCH = BATCH_SIZE[-1]
LEARNING_RATE = 2e-5

training_args = TrainingArguments(
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.1,
    weight_decay=0.15,
    max_grad_norm=1.0,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    gradient_accumulation_steps=REAL_BATCH//BATCH_SIZE[0],
    eval_accumulation_steps=BATCH_SIZE[1],
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=50,
    save_steps=50,
    logging_steps=50,
    save_total_limit=100,
    load_best_model_at_end=True,
    metric_for_best_model="mAP@0.50:0.95",
    greater_is_better=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    output_dir="./results/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    run_name=RUN_NAME,
    bf16=True,
)

testing_args = TrainingArguments(
    per_device_eval_batch_size=BATCH_SIZE[2],
    batch_eval_metrics=True,
    remove_unused_columns=False,
)

In [None]:
from transformers.trainer_utils import EvalPrediction
from torchvision.ops import box_convert
from dataclasses import dataclass


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


def de_normalize_boxes(boxes, height, width):
    # 1. cxcywh → xyxy
    boxes_xyxy_norm = box_convert(boxes, 'cxcywh', 'xyxy')

    # 2. de-normalize (convert to actual pixel coordinates)
    boxes_xyxy_norm[:, [0, 2]] *= width
    boxes_xyxy_norm[:, [1, 3]] *= height
    return boxes_xyxy_norm


def map_compute_metrics(preprocessor=reference_preprocessor, threshold=0.0):
    map_metric = MeanAveragePrecision()
    post_process = preprocessor.post_process_object_detection

    def calc(eval_pred: EvalPrediction, compute_result=False):
        nonlocal map_metric

        if compute_result:
            m_ap = map_metric.compute()
            map_metric.reset()

            per_class_map = {
                f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean()
                for idx in m_ap.matched_classes
            }

            return {
                "mAP@0.50:0.95": m_ap.map50_95,
                "mAP@0.50": m_ap.map50,
                "mAP@0.75": m_ap.map75,
                **per_class_map
            }
        else:
            preds = ModelOutput(*eval_pred.predictions[1:3])
            labels = eval_pred.label_ids
            sizes = [label['orig_size'].cpu().tolist() for label in labels]

            results = post_process(preds, target_sizes=sizes, threshold=threshold)
            predictions = [Detections.from_transformers(result) for result in results]
            targets = [Detections(
                xyxy=de_normalize_boxes(label['boxes'], *label['orig_size']).cpu().numpy(),
                class_id=label['class_labels'].cpu().numpy(),
            ) for label in labels]

            map_metric.update(predictions=predictions, targets=targets)
            return {}
    return calc, map_metric

In [None]:
checkpoint = 50

In [None]:
try:
    model = RTDetrForObjectDetection.from_pretrained(f"{training_args.output_dir}/checkpoint-{checkpoint}/", torch_dtype=torch.float32, return_dict=True, local_files_only=True)
    model.to(device)
except Exception:
    pass

In [None]:
from functools import partial

compute_metrics, compute_results = map_compute_metrics(preprocessor=reference_preprocessor)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=DatasetAdapterForTransformers(dataset.train),
    eval_dataset=DatasetAdapterForTransformers(dataset.valid),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
)

tester = Trainer(
    model=model,
    args=testing_args,
    eval_dataset=DatasetAdapterForTransformers(dataset.test),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

## Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
compute_results.reset()
trainer.train(resume_from_checkpoint=False)

## Evaluate

### Auto Evaluation

In [None]:
trainer.evaluate()

In [None]:
tester.evaluate()

### Manual Evaluation

In [None]:
checkpoint = 10400

In [None]:
try:
    model = RTDetrForObjectDetection.from_pretrained(f"{training_args.output_dir}/checkpoint-{checkpoint}/", torch_dtype=torch.float32, return_dict=True, local_files_only=True)
    model.to(device)
except Exception:
    pass

In [None]:
class LabelDataset(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        return item['boxes2d'], item['boxes2d_classes']

In [None]:
def naive_collate_fn(batch):
    return batch

In [None]:
targets = []
predictions = []
batch_size = 32

raw_data = DataLoader(LabelDataset(dataset.valid), batch_size=batch_size, collate_fn=naive_collate_fn)
loader = DataLoader(DatasetAdapterForTransformers(dataset.valid), batch_size=batch_size, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))
for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, loader):
    sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

    with torch.no_grad():
        outputs = model(pixel_values=inputs['pixel_values'].to(device))

    results = reference_preprocessor.post_process_object_detection(
        outputs, target_sizes=sizes, threshold=0.3
    )

    detections = [Detections.from_transformers(results[i]) for i in range(batch_size)]
    annotations = [Detections(
        xyxy=lables[i][0].cpu().numpy(),
        class_id=lables[i][1].cpu().numpy(),
    ) for i in range(batch_size)]

    targets.extend(annotations)
    predictions.extend(detections)

In [None]:
len(predictions) == len(targets), len(predictions), len(targets)

In [None]:
mean_average_precision = MeanAveragePrecision().update(
    predictions=predictions,
    targets=targets,
).compute()
per_class_map = {
    f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
    for idx in mean_average_precision.matched_classes
}

print(f"mAP@0.95: {mean_average_precision.map50_95:.2f}")
print(f"map50: {mean_average_precision.map50:.2f}")
print(f"map75: {mean_average_precision.map75:.2f}")
for key, value in per_class_map.items():
    print(f"{key}: {value:.2f}")