# RT-DETR Pretraining with SHIFT-Discrete Dataset

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 2
ADDITIONAL_GPU = 0

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = ",".join([f"{i+DEVICE_NUM}" for i in range(0, ADDITIONAL_GPU+1)])
environ["CUDA_VISIBLE_DEVICES"]

## Imports

In [None]:
from os import path

import torch
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import SHIFTClearDatasetForObjectDetection, SHIFTCorruptedDatasetForObjectDetection
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from accelerate import Accelerator, notebook_launcher

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

# import wandb
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda")  # torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

In [None]:
# Tqdm Test
for _ in tqdm(range(100)):
    pass

In [None]:
PROJECT_NAME = "APT_SHIFT_Pretraining"
RUN_NAME = "YOLOv11"

# # WandB Initialization
# wandb.init(project=PROJECT_NAME, name=RUN_NAME)

## Define Dataset

In [None]:
DATA_ROOT = path.join(".", "data")

dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
    valid=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
    test=SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
)

In [None]:
dataset.train[1]['front'].keys()

In [None]:
dataset.train[999]

In [None]:
dataset.train[1000]['front']['images'].shape  # should be (batch_size, num_channels, height, width)

## DataLoader

In [None]:
# Set Batch Size
BATCH_SIZE = 4, 8, 8, 16

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [None]:
def collate_fn(batch, preprocessor=None):
    images = [item['image'] for item in batch]
    if preprocessor is not None:
        target = [item['target'] for item in batch]
        return preprocessor(images=images, annotations=target, return_tensors="pt")
    else:
        # If no preprocessor is provided, just assume images are already in tensor format
        return dict(
            pixel_values=dict(pixel_values=torch.stack(images)),
            labels=[dict(
                class_labels=item['boxes2d_classes'].long(),
                boxes=item["boxes2d"].float()
            ) for item in batch]
        )

## Load Model

In [None]:
from transformers import RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig
from transformers.image_utils import AnnotationFormat

In [None]:
USE_PRETRAINED_MODEL = True

In [None]:
from transformers.utils.generic import TensorType
from typing import Union
from torchvision.ops import batched_nms
from ultralytics.utils import ops

class Custom_RTDetrImageProcessorFast(RTDetrImageProcessorFast):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.iou_thres: float = 0.45
        self.max_det: int = 300
    
    def post_process_object_detection(
        self,
        outputs,
        target_sizes
    ):
        B, N, C = outputs.logits.shape
        class_scores = torch.softmax(outputs.logits, dim=-1) # [B, N, C]
        
        prediction_bcn = torch.cat([
            outputs.pred_boxes.permute(0, 2, 1),
            class_scores.permute(0, 2, 1)
            ], dim=1)
        
        outputs = ops.non_max_suppression(
            prediction=prediction_bcn,
            conf_thres= 0.25,
            iou_thres= self.iou_thres,
            classes=None,
            agnostic= False,
            multi_label= False,
            labels=(),
            max_det= self.max_det,
            nc= C,  # number of classes (optional)
            max_time_img= 0.05,
            max_nms = 30000,
            max_wh = 7680,
            in_place = False,
            rotated = False,
            end2end = False,
            return_idxs = False,
        )
        
        final_results = [{"boxes": x[:, :4], "scores": x[:, 4], "labels": x[:, 5]} for x in outputs]
        
        if isinstance(target_sizes, list):
            ts = torch.as_tensor(target_sizes)
        else:
            ts = target_sizes
        img_h, img_w = ts.unbind(1)  # (B,), (B,)

        for i, res in enumerate(final_results):
            b = res["boxes"]  # (Ni, 4) normalized xyxy
            if b.numel() == 0:
                continue
            scale = b.new_tensor([img_w[i], img_h[i], img_w[i], img_h[i]])
            b = b * scale  # 픽셀 좌표로
            
            b[:, 0::2] = torch.clamp(b[:, 0::2], min=0, max=img_w[i])  # x1, x2
            b[:, 1::2] = torch.clamp(b[:, 1::2], min=0, max=img_h[i])  # y1, y2
            
            res["boxes"] = b
            
        return final_results
    # def post_process_object_detection(
    #     self,
    #     outputs,
    #     threshold: float = 0.45,
    #     target_sizes: Union[TensorType, list[tuple]] = None,
    #     # use_focal_loss: bool = True,
        
    # ):
    #     print("=== Outputs ===")
    #     print(outputs)

    #     print("=== Shapes ===")
    #     print(f"logits: {outputs.logits.shape}")
    #     print(f"pred_boxes: {outputs.pred_boxes.shape}")
    #     results = super().post_process_object_detection(
    #         outputs, 
    #         threshold=threshold,
    #         target_sizes=target_sizes,
    #         # use_focal_loss=use_focal_loss,
    #     )
        
    #     all_boxes, all_scores, all_labels, all_img_ids = [], [], [], []
    #     for img_id, result in enumerate(results):
    #         b = result["boxes"].to(device)
    #         s = result["scores"].to(device)
    #         l = result["labels"].to(device)
    #         N = b.size(0)

    #         all_boxes.append(b)
    #         all_scores.append(s)
    #         all_labels.append(l)
    #         all_img_ids.append(torch.full((N,), img_id, dtype=torch.int64, device=device))

    #     all_boxes  = torch.cat(all_boxes,  dim=0)
    #     all_scores = torch.cat(all_scores, dim=0)
    #     all_labels = torch.cat(all_labels, dim=0)
    #     all_img_ids= torch.cat(all_img_ids, dim=0)
        
    #     # Extract the (height, width) from target_sizes
    #     if target_sizes is None:
    #         raise ValueError("target_sizes must be provided for clamping")
    #     if isinstance(target_sizes, list):
    #         ts = torch.as_tensor(target_sizes, device=device)  # shape: (B, 2)
    #     else:
    #         ts = target_sizes.to(device)                        # shape: (B, 2)
    #     img_h, img_w = ts.unbind(1)  # each of shape (B,)

    #     # all images have the same dimensions (H, W)
    #     H, W = img_h[0].item(), img_w[0].item()
    #     all_boxes[:, [0,2]] = all_boxes[:, [0,2]].clamp(0, W)
    #     all_boxes[:, [1,3]] = all_boxes[:, [1,3]].clamp(0, H)

    #     # Perform NMS independently for each image and class
    #     C = int(all_labels.max().item()) + 1
    #     group_ids = all_img_ids * C + all_labels
        
    #     keep = batched_nms(all_boxes, all_scores, group_ids, self.iou_thres)

    #     # Reassemble results by selecting up to max_det per image.
    #     filtered = [ {"boxes": [], "scores": [], "labels": []} for _ in results ]
    #     for idx in keep:
    #         img_id = int(all_img_ids[idx])
    #         if len(filtered[img_id]["boxes"]) < self.max_det:
    #             filtered[img_id]["boxes"].append(all_boxes[idx])
    #             filtered[img_id]["scores"].append(all_scores[idx])
    #             filtered[img_id]["labels"].append(all_labels[idx])

    #     # list → tensor
    #     final_results = []
    #     for fr in filtered:
    #         if fr["boxes"]:
    #             b = torch.stack(fr["boxes"],  dim=0)
    #             s = torch.stack(fr["scores"], dim=0)
    #             l = torch.stack(fr["labels"], dim=0)
    #         else:
    #             b = torch.zeros((0, 4),  device=device)
    #             s = torch.zeros((0,),    device=device)
    #             l = torch.zeros((0,),    dtype=torch.int64, device=device)

    #         final_results.append({
    #             "boxes":  b,
    #             "scores": s,
    #             "labels": l,
    #         })

    #     return final_results

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = NUM_CLASSES

# Set the image size and preprocessor size
reference_config.image_size = 800

# Load the reference model image processor
reference_preprocessor = Custom_RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format
reference_preprocessor.size = {"height": 800, "width": 1280} # {"height": 800, "width": 800} 
reference_preprocessor.do_resize = False

In [None]:
from transformers.models.rt_detr.modeling_rt_detr import RTDetrPreTrainedModel, RTDetrObjectDetectionOutput
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import DEFAULT_CFG

from typing import List, Tuple, Union, Optional

class YOLOForObjectDetection(RTDetrForObjectDetection):
    reference_config = reference_config

    def __init__(self, config: str):
        super(RTDetrPreTrainedModel, self).__init__(self.reference_config)
        self.model = DetectionModel(config, ch=3, nc=self.reference_config.num_labels, verbose=False)
        self.args = DEFAULT_CFG
        self.model.args = self.args
        self.loss_function = self.model.init_criterion()

    def to(self, device):
        super().to(device)
        self.loss_function.device = device
        self.loss_function.bbox_loss = self.loss_function.bbox_loss.to(device)
        self.loss_function.proj = self.loss_function.proj.to(device)
        return self

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: Optional[List[dict]] = None,
        return_dict: Optional[bool] = None
    ) -> Union[Tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        nc = self.config.num_labels

        # 1. Data format conversion
        batch_idx, clss, bboxes = [], [], []
        for i, lab in enumerate(labels):
            ci = lab['class_labels']
            bi = lab['boxes']
            n  = ci.size(0)
            batch_idx.append(torch.full((n,), i, device=self.device, dtype=torch.long))
            clss.append(ci); bboxes.append(bi)

        adapted_labels = {
            'img':       pixel_values,
            'batch_idx': torch.cat(batch_idx, 0),
            'cls':       torch.cat(clss,      0),
            'bboxes':    torch.cat(bboxes,    0),
        }

        # 2. Do inference
        if self.model.training:
            outputs = self.model(pixel_values)  # Multi-scale outputs P3, P4, P5
        else:
            processed_output, outputs = self.model(pixel_values)

        # 3. Reformat outputs to match expected RT-DETR format
        # Combine outputs from all scales
        combined_outputs = []
        for output in outputs:
            # [B, C, H, W] -> [B, H*W, C]
            b, c, h, w = output.shape
            output_flat = output.permute(0, 2, 3, 1).reshape(b, h*w, c)
            combined_outputs.append(output_flat)

        # Concatenate all outputs
        all_outputs = torch.cat(combined_outputs, dim=1)  # [B, total_anchors, C]

        logits = all_outputs[..., :nc]          # [B, total_anchors, nc] - Class Predictions
        pred_boxes = all_outputs[..., nc:nc+4]  # [B, total_anchors, 4] - Bounding Box Coordinates
        
        # Constrain the bounding-box coordinates to the [0, 1] range using a sigmoid function.
        # YOLO, unlike transformer-based detectors, does not automatically output normalized boxes.
        pred_boxes = pred_boxes.sigmoid().clamp_(1e-4, 1 - 1e-4)

        # 4. Calculate loss if labels are provided
        loss, loss_dict = None, None
        if labels is not None:
            loss_raw, loss_items = self.loss_function(outputs, adapted_labels)
            if isinstance(loss_raw, torch.Tensor):
                loss = loss_raw if loss_raw.dim() == 0 else loss_raw.sum()
            else:
                loss = sum(loss_raw) if isinstance(loss_raw, (list, tuple)) else loss_raw
            loss_dict = dict(
                box_loss=loss_items[0] if len(loss_items) > 0 else torch.tensor(0.0),
                cls_loss=loss_items[1] if len(loss_items) > 1 else torch.tensor(0.0),
                dfl_loss=loss_items[2] if len(loss_items) > 2 else torch.tensor(0.0)
            )

        if not return_dict:
            output = (logits, pred_boxes) + outputs
            return ((loss, loss_dict) + output) if loss is not None else output

        result = RTDetrObjectDetectionOutput(
            loss=loss,
            loss_dict=loss_dict,
            logits=logits,
            pred_boxes=pred_boxes
        )
        #print(f"INFO: Loss: {result.loss}, loss dict: {result.loss_dict}")

        return result

In [None]:
# Initialize a new model with the reference configuration
model = YOLOForObjectDetection(config="yolo11m.yaml")
if USE_PRETRAINED_MODEL:
    # Load the pre-trained model
    state_dict = torch.hub.load_state_dict_from_url("https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11m.pt", progress=True)
    model.model.load_state_dict(state_dict, strict=False)
model.to(device)

In [None]:
test_d = DatasetAdapterForTransformers(dataset.train)[5]
test_d

In [None]:
reference_preprocessor(images=test_d['image'], annotations=test_d['target'])

In [None]:
from transformers.trainer_utils import EvalPrediction
from torchvision.ops import box_convert
from dataclasses import dataclass


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


def de_normalize_boxes(boxes, height, width):
    # 1. cxcywh → xyxy
    boxes_xyxy_norm = box_convert(boxes, 'cxcywh', 'xyxy')

    # 2. de-normalize (convert to actual pixel coordinates)
    boxes_xyxy_norm[:, [0, 2]] *= width
    boxes_xyxy_norm[:, [1, 3]] *= height
    return boxes_xyxy_norm


def map_compute_metrics(preprocessor=reference_preprocessor, threshold=0.0):
    map_metric = MeanAveragePrecision()
    post_process = preprocessor.post_process_object_detection

    def calc(eval_pred: EvalPrediction, compute_result=True):
        nonlocal map_metric

        if compute_result:
            m_ap = map_metric.compute()
            map_metric.reset()

            per_class_map = {
                f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean()
                for idx in m_ap.matched_classes
            }

            return {
                "mAP@0.50:0.95": m_ap.map50_95,
                "mAP@0.50": m_ap.map50,
                "mAP@0.75": m_ap.map75,
                **per_class_map
            }
        else:
            preds = ModelOutput(*eval_pred.predictions[1:3])
            labels = eval_pred.label_ids
            sizes = [label['orig_size'].cpu().tolist() for label in labels]

            results = post_process(preds, sizes)
            predictions = [Detections.from_transformers(result) for result in results]
            targets = [Detections(
                xyxy=de_normalize_boxes(label['boxes'], *label['orig_size']).cpu().numpy(),
                class_id=label['class_labels'].cpu().numpy(),
            ) for label in labels]

            map_metric.update(predictions=predictions, targets=targets)
            return {}
    return calc, map_metric

In [None]:
class DifferentiableLRTrainer(Trainer):
    def create_optimizer(self):
        backbone_params = []
        head_params = []

        for name, param in self.model.named_parameters():
            if 'backbone' in name:
                backbone_params.append(param)
            else:
                head_params.append(param)

        self.optimizer = torch.optim.AdamW([
            {'params': backbone_params, 'lr': self.args.backbone_lr},
            {'params': head_params, 'lr': self.args.learning_rate}
        ], weight_decay=self.args.weight_decay)

        return self.optimizer


class DifferentiableLRTrainingArguments(TrainingArguments):
    def __init__(self, *args, backbone_lr=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.backbone_lr = backbone_lr

### Train

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 20
REAL_BATCH = BATCH_SIZE[-1]
LEARNING_RATE = 1e-4

training_args = DifferentiableLRTrainingArguments(
    backbone_lr=LEARNING_RATE/10,  # Set backbone learning rate to 1/10th of the main learning rate
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.1,
    max_grad_norm=0.5,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    gradient_accumulation_steps=REAL_BATCH//BATCH_SIZE[0],
    eval_accumulation_steps=BATCH_SIZE[1],
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=100,
    save_steps=100,
    logging_steps=100,
    save_total_limit=100,
    load_best_model_at_end=True,
    metric_for_best_model="mAP@0.50:0.95",
    greater_is_better=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    # report_to="wandb",
    output_dir="./results/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    run_name=RUN_NAME,
    bf16=True,
)

testing_args = TrainingArguments(
    per_device_eval_batch_size=BATCH_SIZE[2],
    batch_eval_metrics=True,
    remove_unused_columns=False,
)

In [None]:
from functools import partial

compute_metrics, compute_results = map_compute_metrics(preprocessor=reference_preprocessor)

trainer = DifferentiableLRTrainer(
    model=model,
    args=training_args,
    train_dataset=DatasetAdapterForTransformers(dataset.train),
    eval_dataset=DatasetAdapterForTransformers(dataset.valid),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
)

tester = DifferentiableLRTrainer(
    model=model,
    args=testing_args,
    eval_dataset=DatasetAdapterForTransformers(dataset.test),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

## Train

In [None]:
def start_train():
    accelerator = Accelerator()
    while True:
        try:
            try:
                print("INFO: Trying to resume from previous checkpoint")
                compute_results.reset()
                trainer.train(resume_from_checkpoint=True)
            except Exception as e:
                if "No valid checkpoint found" in str(e):
                    print(f"ERROR: Failed to resume from checkpoint - {e}")
                    print("INFO: Starting training from scratch")
                    compute_results.reset()
                    trainer.train(resume_from_checkpoint=False)
        except Exception as e:
            if "CUDA" in str(e):
                print(f"ERROR: CUDA Error - {e}")
                trainer.train()
            else:
                raise e

In [None]:
if ADDITIONAL_GPU:
    notebook_launcher(start_train, args=(), num_processes=ADDITIONAL_GPU)
else:
    start_train()

In [None]:
compute_results.compute().plot()

## Evaluate

### Auto Evaluation

In [None]:
trainer.evaluate()

In [None]:
tester.evaluate()

### Manual Evaluation

In [None]:
checkpoint = 31100

In [None]:
try:
    model = RTDetrForObjectDetection.from_pretrained(f"{training_args.output_dir}/checkpoint-{checkpoint}/", torch_dtype=torch.float32, return_dict=True, local_files_only=True)
    model.to(device)
except Exception:
    pass

In [None]:
class LabelDataset(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        return item['boxes2d'], item['boxes2d_classes']

In [None]:
def naive_collate_fn(batch):
    return batch

In [None]:
targets = []
predictions = []
batch_size = 32

raw_data = DataLoader(LabelDataset(dataset.valid), batch_size=batch_size, collate_fn=naive_collate_fn)
loader = DataLoader(DatasetAdapterForTransformers(dataset.valid), batch_size=batch_size, collate_fn=partial(collate_fn, preprocessor=reference_preprocessor))
for idx, lables, inputs in zip(tqdm(range(len(raw_data))), raw_data, loader):
    sizes = [label['orig_size'].cpu().tolist() for label in inputs['labels']]

    with torch.no_grad():
        outputs = model(pixel_values=inputs['pixel_values'].to(device))

    results = reference_preprocessor.post_process_object_detection(
        outputs, target_sizes=sizes, threshold=0.3
    )

    detections = [Detections.from_transformers(results[i]) for i in range(batch_size)]
    annotations = [Detections(
        xyxy=lables[i][0].cpu().numpy(),
        class_id=lables[i][1].cpu().numpy(),
    ) for i in range(batch_size)]

    targets.extend(annotations)
    predictions.extend(detections)

In [None]:
len(predictions) == len(targets), len(predictions), len(targets)

In [None]:
mean_average_precision = MeanAveragePrecision().update(
    predictions=predictions,
    targets=targets,
).compute()
per_class_map = {
    f"{CLASSES[idx]}_mAP@0.95": mean_average_precision.ap_per_class[idx].mean()
    for idx, idx in enumerate(mean_average_precision.matched_classes)
}

print(f"mAP@0.95: {mean_average_precision.map50_95:.2f}")
print(f"map50: {mean_average_precision.map50:.2f}")
print(f"map75: {mean_average_precision.map75:.2f}")
for key, value in per_class_map.items():
    print(f"{key}: {value:.2f}")