# RT-DETR Pretraining with SHIFT-Discrete Dataset

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0
ADDITIONAL_GPU = 0

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = ",".join([f"{i+DEVICE_NUM}" for i in range(0, ADDITIONAL_GPU+1)])
environ["CUDA_VISIBLE_DEVICES"]

## Imports

In [None]:
from os import path

import torch
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import SHIFTClearDatasetForObjectDetection, SHIFTCorruptedDatasetForObjectDetection
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from accelerate import Accelerator, notebook_launcher

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

#import wandb
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda")  # torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

In [None]:
# Tqdm Test
for _ in tqdm(range(100)):
    pass

In [None]:
PROJECT_NAME = "APT_SHIFT_Pretraining"
RUN_NAME = "RT-DETR_R50"

# WandB Initialization
#wandb.init(project=PROJECT_NAME, name=RUN_NAME)

## Define Dataset

In [None]:
DATA_ROOT = path.join(".", "data")

dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
    valid=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
    test=SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
)

In [None]:
dataset.train[1]['front'].keys()

In [None]:
dataset.train[999]

In [None]:
dataset.train[1000]['front']['images'].shape  # should be (batch_size, num_channels, height, width)

## DataLoader

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 8, 8  # 4070 Ti
BATCH_SIZE = 32, 64, 64, 32  # A6000
BATCH_SIZE = 32, 32, 32, 32  # A6000

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [None]:
def collate_fn(batch, preprocessor=None):
    images = [item['image'] for item in batch]
    if preprocessor is not None:
        target = [item['target'] for item in batch]
        return preprocessor(images=images, annotations=target, return_tensors="pt")
    else:
        # If no preprocessor is provided, just assume images are already in tensor format
        return dict(
            pixel_values=dict(pixel_values=torch.stack(images)),
            labels=[dict(
                class_labels=item['boxes2d_classes'].long(),
                boxes=item["boxes2d"].float()
            ) for item in batch]
        )

## Load Model

In [None]:
from transformers import AutoBackbone, RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig, SwinConfig, ResNetConfig
from transformers.image_utils import AnnotationFormat

In [None]:
USE_PRETRAINED_MODEL = True
LOAD_ONLY_COCO_BACKBONE = False
USE_SWIN_T_BACKBONE = False
USE_SHIFT_BACKBONE = False

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = NUM_CLASSES

# Set the image size and preprocessor size
reference_config.image_size = 800

# Load the reference model image processor
reference_preprocessor = RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format
reference_preprocessor.size = {"height": 800, "width": 800}
reference_preprocessor.do_resize = False

In [None]:
if USE_PRETRAINED_MODEL:
    # Load the pre-trained model
    model = RTDetrForObjectDetection.from_pretrained(reference_model_id, config=reference_config, torch_dtype=torch.float32, ignore_mismatched_sizes=True)
    if LOAD_ONLY_COCO_BACKBONE:
        detector_state = RTDetrForObjectDetection(config=reference_config).state_dict()
        detector_state = {k: v for k, v in model.state_dict().items() if 'backbone' not in k}
        model.load_state_dict(detector_state, strict=False)
else:
    # Set the backbone configuration
    if USE_SHIFT_BACKBONE:
        if USE_SWIN_T_BACKBONE:
            backbone_url = "https://github.com/robustaim/ContinualTTA_ObjectDetection/releases/download/backbone_converted/swin_tiny_patch4_window7_shift_from_detectron2.pth"
            reference_config.backbone_config = SwinConfig(
                embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], output_hidden_states=True,
                window_size=7, out_features=["stage1", "stage2", "stage3"]  # stride 8·16·32
            )
        else:
            backbone_url = "https://github.com/robustaim/ContinualTTA_ObjectDetection/releases/download/backbone_converted/resnet50_shift_from_detectron2.pth"
            reference_config.backbone_config = ResNetConfig(depths=[3,4,6,3], out_indices=(2, 3, 4))
    else:
        if USE_SWIN_T_BACKBONE:
            backbone_id = "microsoft/swin-tiny-patch4-window7-224"
            reference_config.backbone_config = SwinConfig.from_pretrained(
                backbone_id, output_hidden_states=True, out_features=["stage1", "stage2", "stage3"]
            )
        else:
            backbone_id = "microsoft/resnet-50"
            reference_config.backbone_config = ResNetConfig.from_pretrained(backbone_id, out_indices=(2, 3, 4))

    # Initialize a new model with the reference configuration
    model = RTDetrForObjectDetection(config=reference_config)
    if USE_SHIFT_BACKBONE:
        backbone_state = torch.hub.load_state_dict_from_url(backbone_url, map_location="cpu")
        model.model.backbone.model.load_state_dict(backbone_state, strict=False)
    else:
        backbone_state = AutoBackbone.from_pretrained(backbone_id, config=reference_config.backbone_config).state_dict()
        model.model.backbone.model.load_state_dict(backbone_state, strict=False)
        if USE_SWIN_T_BACKBONE:
            model.model.backbone.model.forward.__kwdefaults__['interpolate_pos_encoding'] = True
    del backbone_state

model.to(device)

In [None]:
test_d = DatasetAdapterForTransformers(dataset.train)[5]
test_d

In [None]:
reference_preprocessor(images=test_d['image'], annotations=test_d['target'])

In [None]:
from transformers.trainer_utils import EvalPrediction
from torchvision.ops import box_convert
from dataclasses import dataclass


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


def de_normalize_boxes(boxes, height, width):
    # 1. cxcywh → xyxy
    boxes_xyxy_norm = box_convert(boxes, 'cxcywh', 'xyxy')

    # 2. de-normalize (convert to actual pixel coordinates)
    boxes_xyxy_norm[:, [0, 2]] *= width
    boxes_xyxy_norm[:, [1, 3]] *= height
    return boxes_xyxy_norm


def map_compute_metrics(preprocessor=reference_preprocessor, threshold=0.0):
    map_metric = MeanAveragePrecision()
    post_process = preprocessor.post_process_object_detection

    def calc(eval_pred: EvalPrediction, compute_result=False):
        nonlocal map_metric

        if compute_result:
            m_ap = map_metric.compute()
            map_metric.reset()

            per_class_map = {
                f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean()
                for idx in m_ap.matched_classes
            }

            return {
                "mAP@0.50:0.95": m_ap.map50_95,
                "mAP@0.50": m_ap.map50,
                "mAP@0.75": m_ap.map75,
                **per_class_map
            }
        else:
            preds = ModelOutput(*eval_pred.predictions[1:3])
            labels = eval_pred.label_ids
            sizes = [label['orig_size'].cpu().tolist() for label in labels]

            results = post_process(preds, target_sizes=sizes, threshold=threshold)
            predictions = [Detections.from_transformers(result) for result in results]
            targets = [Detections(
                xyxy=de_normalize_boxes(label['boxes'], *label['orig_size']).cpu().numpy(),
                class_id=label['class_labels'].cpu().numpy(),
            ) for label in labels]

            map_metric.update(predictions=predictions, targets=targets)
            return {}
    return calc, map_metric

In [None]:
class DifferentiableLRTrainer(Trainer):
    def create_optimizer(self):
        backbone_params = []
        head_params = []

        for name, param in self.model.named_parameters():
            if 'backbone' in name:
                backbone_params.append(param)
            else:
                head_params.append(param)

        self.optimizer = torch.optim.AdamW([
            {'params': backbone_params, 'lr': self.args.backbone_lr},
            {'params': head_params, 'lr': self.args.learning_rate}
        ], weight_decay=self.args.weight_decay)

        return self.optimizer


class DifferentiableLRTrainingArguments(TrainingArguments):
    def __init__(self, *args, backbone_lr=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.backbone_lr = backbone_lr

## Train

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 20
REAL_BATCH = BATCH_SIZE[-1]
LEARNING_RATE = 1e-4

training_args = DifferentiableLRTrainingArguments(
    backbone_lr=LEARNING_RATE/10,  # Set backbone learning rate to 1/10th of the main learning rate
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.1,
    max_grad_norm=0.5,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    gradient_accumulation_steps=REAL_BATCH//BATCH_SIZE[0],
    eval_accumulation_steps=BATCH_SIZE[1],
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=100,
    save_steps=100,
    logging_steps=100,
    save_total_limit=100,
    load_best_model_at_end=True,
    metric_for_best_model="mAP@0.50:0.95",
    greater_is_better=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    #report_to="wandb",
    output_dir="./results/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    #run_name=RUN_NAME,
    bf16=True,
)

testing_args = TrainingArguments(
    per_device_eval_batch_size=BATCH_SIZE[2],
    batch_eval_metrics=True,
    remove_unused_columns=False,
)

In [None]:
from functools import partial

compute_metrics, compute_results = map_compute_metrics(preprocessor=reference_preprocessor)

trainer = DifferentiableLRTrainer(
    model=model,
    args=training_args,
    train_dataset=DatasetAdapterForTransformers(dataset.train),
    eval_dataset=DatasetAdapterForTransformers(dataset.valid),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
)

tester = Trainer(
    model=model,
    args=testing_args,
    eval_dataset=DatasetAdapterForTransformers(dataset.test),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

In [None]:
def start_train():
    accelerator = Accelerator()
    while True:
        try:
            try:
                print("INFO: Trying to resume from previous checkpoint")
                compute_results.reset()
                trainer.train(resume_from_checkpoint=True)
            except Exception as e:
                if "No valid checkpoint found" in str(e):
                    print(f"ERROR: Failed to resume from checkpoint - {e}")
                    print("INFO: Starting training from scratch")
                    compute_results.reset()
                    trainer.train(resume_from_checkpoint=False)
        except Exception as e:
            if "CUDA" in str(e):
                print(f"ERROR: CUDA Error - {e}")
                trainer.train()
            else:
                raise e

## Adaptation Testing

In [None]:
from safetensors.torch import load_file
from torch import nn

In [None]:
reference_config.return_dict = True  # Ensure the model returns a dictionary
model_pretrained = RTDetrForObjectDetection(config=reference_config)
model_states = load_file("RT-DETR_R50vd_SHIFT_CLEAR.safetensors", device="cpu")
model_pretrained.load_state_dict(model_states, strict=False)

for param in model_pretrained.parameters():
    param.requires_grad = False  # Freeze

model_pretrained

In [None]:
class FeatureNormalizationLayer(nn.Module):
    def __init__(self, target_dim=256):
        super().__init__()
        self.target_dim = target_dim

        # Keep only channel dimension
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Linear compression
        self.linear_compress = nn.AdaptiveAvgPool1d(self.target_dim)

        # Feature normalization
        self.feature_norm = nn.Sequential(
            nn.LayerNorm(target_dim),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        # Apply adaptive pooling
        x = self.adaptive_pool(x)

        # Squeeze channel dimension
        x = x.squeeze(-1).squeeze(-1)

        # Linear compression
        x = self.linear_compress(x)

        # Feature normalization
        x = self.feature_norm(x)

        return x

In [None]:
class APT(nn.Module):
    """
    Light-weight Sparse Autoencoder for Adaptation
    which learns how to sniff out the frame changes to predict next bounding boxes.
    """
    def __init__(self, feature_dim=256, bbox_dim=4, hidden_dim=32):
        super().__init__()

        self.feature_dim = feature_dim
        self.bbox_dim = bbox_dim
        self.hidden_dim = hidden_dim

        # Feature normalization layer for encoder-agnostic adaptation
        self.feature_norm = FeatureNormalizationLayer(target_dim=feature_dim)

        # Lightweight feature sniffer
        self.feature_sniffer = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 4 * 3)
        )

        # Previous bbox encoder
        self.bbox_encoder = nn.Sequential(
            nn.Linear(bbox_dim, hidden_dim // 4),
            nn.ReLU()
        )

        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Prediction head
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, bbox_dim),
            nn.Sigmoid()  # Normalize bbox coordinates to [0,1]
        )

    def forward(self, features, prev_bbox):
        # Normalize encoder features to be encoder-agnostic
        norm_features = self.feature_norm(features)

        # Extract relevant features from current frame
        sniffed_features = self.feature_sniffer(norm_features)

        # Encode previous bbox information
        bbox_features = self.bbox_encoder(prev_bbox)

        # Expand sniffed features to match bbox features
        sniffed_features = sniffed_features.expand(bbox_features.shape[0], -1)

        # Fuse features
        fused = self.fusion(
            torch.cat([sniffed_features, bbox_features], dim=-1)
        )

        # Predict next bbox
        next_bbox = self.predictor(fused)

        return next_bbox

In [None]:
from typing import Optional
from torchvision.ops import box_convert
from scipy.optimize import linear_sum_assignment

class TestTimeAdaptiveDETR(nn.Module):
    post_process = reference_preprocessor.post_process_object_detection
    
    def __init__(
        self, pretrained_model, img_size=(800, 1280),
        feature_dim=256, bbox_dim=4, hidden_dim=32,
        bbox_conf_threshold=0.7, bbox_topk=5
    ):
        super().__init__()
        self.model = pretrained_model
        self.apt = APT(
            feature_dim=feature_dim, bbox_dim=bbox_dim,
            hidden_dim=hidden_dim
        )
        self.img_size = img_size
        self.__bbox_normalize = lambda bbox: bbox / torch.tensor([img_size[1], img_size[0], img_size[1], img_size[0]], device=bbox.device)
        self.bbox_cache = None  # {'boxes': tensor, 'scores': tensor}
        self.adapt = False
        self.bbox_conf_threshold = bbox_conf_threshold
        self.bbox_topk = bbox_topk

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        pixel_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[list[dict]] = None,
        bbox_cache: Optional[torch.FloatTensor] = None,
        teacher_forcing_labels: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ):
        bbox_cache = bbox_cache if bbox_cache is not None else self.bbox_cache

        # Run base model (encoder-decoder)
        output = self.model(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs
        )

        apt_loss = None
        if self.adapt and self.bbox_cache is not None:
            # 2. 이전 프레임의 '신뢰도 높은' bbox만 필터링 (교사 신호 정제)
            prev_boxes = self.bbox_cache['boxes']
            prev_scores = self.bbox_cache['scores']

            high_conf_mask = prev_scores > self.bbox_conf_threshold
            confident_prev_boxes = prev_boxes[high_conf_mask]
            confident_prev_scores = prev_scores[high_conf_mask]

            # 신뢰도 높은 이전 박스가 있을 경우에만 적응 수행
            if confident_prev_boxes.shape[0] > 0:
                # top-k selection
                scores_for_sorting = confident_prev_scores
                sorted_indices = torch.argsort(scores_for_sorting, descending=True)
                top_k_indices = sorted_indices[:self.bbox_topk]
                top_k_prev_boxes = confident_prev_boxes[top_k_indices]

                # 3. APT로 시간적 일관성에 기반한 '교사' 예측 수행
                features = output.encoder_last_hidden_state[-1]
                # APT 입력은 정규화된 cxcywh 포맷이어야 함
                normalized_prev_boxes_cxcywh = box_convert(self.__bbox_normalize(top_k_prev_boxes), "xyxy", "cxcywh")
                apt_teacher_boxes = self.apt(features, normalized_prev_boxes_cxcywh) # 출력: cxcywh

                # 4. 현재 프레임의 '신뢰도 높은' 예측 필터링 (학습 대상 선별)
                # 모델의 최종 예측(logits, pred_boxes)은 아직 후처리를 거치지 않은 상태
                # pred_boxes는 정규화된 cxcywh 포맷
                student_boxes = output.pred_boxes[0] # 배치 크기는 1이라고 가정

                # 5. '교사'와 '학생' 예측을 매칭하여 손실 계산 (Hungarian Algorithm)
                with torch.no_grad():
                    # L1 거리 비용 매트릭스 계산
                    cost_matrix = torch.cdist(apt_teacher_boxes, student_boxes, p=1.0)
                    # 최적의 매칭 쌍 찾기
                    row_indices, col_indices = linear_sum_assignment(cost_matrix.cpu())

                # 매칭된 학생 예측(student_boxes)과 교사 예측(apt_teacher_boxes) 사이의 L1 Loss 계산
                matched_student_boxes = student_boxes[col_indices]
                apt_loss = nn.functional.l1_loss(matched_student_boxes, apt_teacher_boxes[row_indices])

        # 6. 다음 프레임을 위해 현재 예측 결과를 bbox_cache에 저장
        if self.adapt:
            with torch.no_grad():
                sizes = [self.img_size for _ in range(len(output.pred_boxes))]
                # threshold를 낮게 설정하여 가능한 많은 후보를 캐시에 저장
                results = self.post_process(output, target_sizes=sizes, threshold=0.3)[0]
                self.bbox_cache = results # results는 {'scores': ..., 'labels': ..., 'boxes': ...} 형태

        # 7. 최종 손실 업데이트
        # 기존 loss (훈련 시) 또는 apt_loss (적응 시) 설정
        if labels is not None and hasattr(output, 'loss'): # 일반 훈련 시
            if apt_loss is not None:
                output.loss += apt_loss # 일반 훈련에도 apt_loss를 추가할 수 있음 (선택사항)
        else: # 적응 시
            output.loss = apt_loss if apt_loss is not None else torch.tensor(0.0, device=pixel_values.device)
        
        if apt_loss is not None:
            print(f"\rINFO: APT Loss - {apt_loss.item():.6f}", end="")

        return output

In [None]:
# Initialize Model
model = TestTimeAdaptiveDETR(pretrained_model=model_pretrained)
model.to(device)

In [None]:
# Load Pretrained APT Weights
apt_weights = torch.load("./models/apt_model.pt")
model.apt.load_state_dict(apt_weights, strict=False)
for param in model.parameters():
    param.requires_grad = False  # Freeze APT weights
model.to(device)

In [None]:
# Un-Freeze Model Encoder
for param in model.model.model.encoder.parameters():
    param.requires_grad = True  # Allow encoder to adapt during online adaptation

In [None]:
class NoShuffleTrainer(Trainer):
    def get_train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            shuffle=False,  # Disable shuffling for online adaptation
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

In [None]:
from ttadapters.datasets import SHIFTContinuous100DatasetForObjectDetection

# Continuous Dataset
continuous_dataset = SHIFTContinuous100DatasetForObjectDetection(root=DATA_ROOT, train=True)

In [None]:
adapting_args = TrainingArguments(
    learning_rate=1e-6,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.1,
    max_grad_norm=0.5,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=100,
    load_best_model_at_end=False,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    output_dir="./adapts/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    #run_name=RUN_NAME,
    bf16=True,
    dataloader_drop_last=False
)

adapter = NoShuffleTrainer(
    model=model,
    args=adapting_args,
    train_dataset=DatasetAdapterForTransformers(continuous_dataset),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

tester = Trainer(
    model=model,
    args=testing_args,
    eval_dataset=DatasetAdapterForTransformers(dataset.test),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

revaluator = Trainer(
    model=model,
    args=testing_args,
    eval_dataset=DatasetAdapterForTransformers(dataset.valid),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

In [None]:
# Check model
model.adapt = False
#revaluator.evaluate()

In [None]:
# Check shift
model.adapt = False
#tester.evaluate()

In [None]:
# Online Adaptation
model.adapt = True
adapter.train()

In [None]:
# Evaluate Adapted Model for shifted domain
model.adapt = False
model.bbox_cache = None  # Reset bbox cache for evaluation
tester.evaluate()

In [None]:
model.model.save_pretrained("./models/RT-DETR_R50vd_SHIFT_ADAPT")

In [None]:
# Evaluate Adapted Model for train domain (catastrophic forgetting)
model.adapt = False
revaluator.evaluate()