# RT-DETR Pretraining with SHIFT-Discrete Dataset

## Imports

In [None]:
from os import path

import torch
from torch import nn, optim
from torch.utils.data import Dataset

from ttadapters.models import RTDetr50ForObjectDetection
from ttadapters.datasets import DatasetHolder, DataLoaderHolder
from ttadapters.datasets import SHIFTDiscreteDatasetForObjectDetection
from transformers import Trainer, TrainingArguments, DefaultDataCollator, EarlyStoppingCallback

import numpy as np
import pandas as pd

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 0
ADDITIONAL_GPU = 0

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

## Define Dataset

In [None]:
DATA_ROOT = path.join(".", "data")

dataset = DatasetHolder(
    train=SHIFTDiscreteDatasetForObjectDetection(root=DATA_ROOT, train=True),
    valid=SHIFTDiscreteDatasetForObjectDetection(root=DATA_ROOT, valid=True)
)

In [None]:
dataset.train[1]['front'].keys()

In [None]:
dataset.train[1000]

In [None]:
# for _ in range(2):
#     selected_idx = train_dataset.output_sampling(img_norm=IMG_NORM, imgsize=(IMG_SIZE, IMG_SIZE))
#     print(f"Visualized pair index: {selected_idx}")

## DataLoader

In [None]:
class DatasetAdapter(Dataset):
    preprocessor = RTDetr50ForObjectDetection.image_processor

    def __init__(self, shift_dataset, camera='front'):
        self.dataset = shift_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        orig_h, orig_w = item["images"].shape[:2]

        processed_image = self.preprocessor(
            item["images"],
            do_resize=True,
            size={"height": 640, "width": 640},
            return_tensors="pt"
        ).pixel_values.squeeze(0)

        boxes = item["boxes2d"].clone()

        boxes[:, 0] = boxes[:, 0] * (640 / orig_w)  # x_min
        boxes[:, 1] = boxes[:, 1] * (640 / orig_h)  # y_min
        boxes[:, 2] = boxes[:, 2] * (640 / orig_w)  # x_max
        boxes[:, 3] = boxes[:, 3] * (640 / orig_h)  # y_max

        return {
            "pixel_values": processed_image,
            "labels": item['boxes2d_classes'],
            "boxes": boxes
        }

In [None]:
class ObjectDetectionDataCollator:
    def __call__(self, batch):
        pixel_values = torch.stack([item["pixel_values"] for item in batch])
        targets = []
        for item in batch:
            target = {
                "class_labels": item["labels"],
                "boxes": item["boxes"]
            }
            targets.append(target)

        return {
            "pixel_values": pixel_values,
            "labels": targets
        }

In [None]:
# Set Batch Size
BATCH_SIZE = 8, 4, 1

## Load Model

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 20
LEARNING_RATE = 1e-4, 1e-6  # Initial LR, minimum LR

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    learning_rate=LEARNING_RATE[0],
    remove_unused_columns=False,
    save_strategy="epoch",
    logging_dir="./logs",
)

In [None]:
model = RTDetr50ForObjectDetection.from_pretrained(num_labels=len(dataset.train.categories))
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=DatasetAdapter(dataset.train),
    eval_dataset=DatasetAdapter(dataset.valid),
    data_collator=ObjectDetectionDataCollator,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

if ADDITIONAL_GPU:
    model = nn.DataParallel(model, device_ids=list(range(DEVICE_NUM, DEVICE_NUM+ADDITIONAL_GPU+1)))
model.to(device)

## Train

In [None]:
trainer.train()