# Pluggable TTA Implementation example

## Import Libraries

In [None]:
from os import path, mkdir

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from torchvision import datasets, utils, transforms
from torchinfo import summary

from ttadapters.datasets import GOT10kDatasetForObjectTracking, PairedGOT10kDataset

import numpy as np
import pandas as pd

import wandb
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

datasets.utils.tqdm = tqdm

In [None]:
PROJECT_NAME = "APT_PLUGIN"
RUN_NAME = "RT-DETR_R50_APT"

# WandB Initialization
#wandb.init(project=PROJECT_NAME, name=RUN_NAME)

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 0
ADDITIONAL_GPU = 0
DATA_TYPE = torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

## Load Dataset

### Original Dataset

In [None]:
DATA_ROOT = path.join(".", "data")

train_dataset = GOT10kDatasetForObjectTracking(root=DATA_ROOT, force_download=False, train=True)
valid_dataset = GOT10kDatasetForObjectTracking(root=DATA_ROOT, force_download=False, valid=True)
test_dataset = GOT10kDatasetForObjectTracking(root=DATA_ROOT, force_download=False, train=False)

print(f"INFO: Dataset loaded successfully. Number of samples - Train({len(train_dataset)}), Valid({len(valid_dataset)}), Test({len(test_dataset)})")

In [None]:
train_dataset.targets

In [None]:
train_dataset[0]

In [None]:
train_dataset[1]

In [None]:
test_dataset[-1]

### Paired Dataset

In [None]:
# Define image size for resizing
ORIGINAL_SIZE = train_dataset[0][0].size
IMG_SIZE = 800

# Define image normalization parameters
IMG_NORM = dict(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

print("INFO: Image conversion is set to resze to", (IMG_SIZE, IMG_SIZE), "from", ORIGINAL_SIZE)

In [None]:
from torchvision.ops import box_convert

# Create transforms
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Adjust brightness/contrast
    transforms.ToTensor(),
    transforms.Normalize(**IMG_NORM)
])

test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(**IMG_NORM)
])

def target_transform(targets):
    """GOT10k bbox conversion from xywh to cxcywh format."""
    targets = [targets[0] / ORIGINAL_SIZE[0], targets[1] / ORIGINAL_SIZE[1], targets[2] / ORIGINAL_SIZE[0], targets[3] / ORIGINAL_SIZE[1]]
    tensors = torch.tensor(targets)
    return box_convert(tensors, 'xywh', 'cxcywh')

In [None]:
# Create paired datasets with lazy loading
train_pairset = PairedGOT10kDataset(
    base_dataset=train_dataset, transform=train_transform, target_transform=target_transform
)
valid_pairset = train_pairset.extract_valid()
test_pairset = PairedGOT10kDataset(
    base_dataset=valid_dataset, transform=test_transform, target_transform=target_transform
)

print(f"INFO: PairedDataset initialized. Total sequences - Train({len(train_pairset)}), Valid({len(valid_pairset)}), Test({len(test_pairset)})")

In [None]:
train_pairset[0]

In [None]:
from torchvision.ops import box_convert

def visualize_frame_pair(pairset, idx=None, figsize=(7, 5)):
    """
    Visualize a pair of consecutive frames with their bounding boxes. (cxcy -> xyxy)

    Args:
        pairset: PairedGOT10kDataset instance
        idx: Index of the pair to visualize. If None, picks a random index
        figsize: Size of the figure as (width, height)
    """
    # Get random index if not provided
    if idx is None:
        idx = np.random.randint(len(pairset))

    # Get frame pair
    prev_img, curr_img, prev_gt, curr_gt = pairset[idx]

    # Convert tensors to numpy arrays and denormalize
    def denormalize(img_tensor):
        # Move channels to last dimension
        img = img_tensor.permute(1, 2, 0).numpy()
        # Denormalize
        img = img * np.array(IMG_NORM['std']) + np.array(IMG_NORM['mean'])
        # Clip values to valid range
        img = np.clip(img, 0, 1)
        return img

    prev_img = denormalize(prev_img)
    curr_img = denormalize(curr_img)

    def draw_bbox(ax, bbox, color='red'):
        """Helper function to draw bounding box"""
        x1, y1, x2, y2 = (_*IMG_SIZE for _ in box_convert(bbox, 'cxcywh', 'xyxy'))
        ax.plot([x1, x2], [y1, y1], color=color, linewidth=2)
        ax.plot([x1, x1], [y1, y2], color=color, linewidth=2)
        ax.plot([x2, x2], [y1, y2], color=color, linewidth=2)
        ax.plot([x1, x2], [y2, y2], color=color, linewidth=2)

    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

    # Plot previous frame
    ax1.imshow(prev_img)
    draw_bbox(ax1, prev_gt)
    ax1.set_title('Previous Frame')
    ax1.axis('off')

    # Plot current frame
    ax2.imshow(curr_img)
    draw_bbox(ax2, curr_gt)
    ax2.set_title('Current Frame')
    ax2.axis('off')

    plt.tight_layout()
    plt.show()

    return idx

In [None]:
for _ in range(2):
    selected_idx = visualize_frame_pair(train_pairset, _)
    print(f"Visualized pair index: {selected_idx}")

## DataLoader

In [None]:
# Set Batch Size
BATCH_SIZE = 400, 400, 1

In [None]:
# Use Teacher Forcing
train_pairset.use_teacher_forcing = True
valid_pairset.use_teacher_forcing = True

In [None]:
MULTI_PROCESSING = True  # Set False if DataLoader is causing issues

from platform import system
if MULTI_PROCESSING and system() != "Windows":  # Multiprocess data loading is not supported on Windows
    import multiprocessing
    cpu_cores = multiprocessing.cpu_count()
    print(f"INFO: Number of CPU cores - {cpu_cores}")
else:
    cpu_cores = 0
    print("INFO: Using DataLoader without multi-processing.")

train_loader = DataLoader(train_pairset, batch_size=BATCH_SIZE[0], shuffle=True, num_workers=cpu_cores)
valid_loader = DataLoader(valid_pairset, batch_size=BATCH_SIZE[1], shuffle=False, num_workers=cpu_cores)
test_loader = DataLoader(test_pairset, batch_size=BATCH_SIZE[2], shuffle=False, num_workers=cpu_cores)

## Define Model
### APT: Adaptive Plugin for TTA (Test-time Adaptation)

In [None]:
from transformers import RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig
from transformers.image_utils import AnnotationFormat
from safetensors.torch import load_file

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = 6

# Set the image size and preprocessor size
reference_config.image_size = 800

# Load the reference model image processor
reference_preprocessor = RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format
reference_preprocessor.size = {"height": IMG_SIZE, "width": IMG_SIZE}
reference_preprocessor.do_resize = False

In [None]:
model_pretrained = RTDetrForObjectDetection(config=reference_config)
model_states = load_file("RT-DETR_R50vd_SHIFT_CLEAR.safetensors", device="cpu")
model_pretrained.load_state_dict(model_states, strict=False)

for param in model_pretrained.parameters():
    param.requires_grad = False  # Freeze

model_pretrained

In [None]:
class FeatureNormalizationLayer(nn.Module):
    def __init__(self, target_dim=256):
        super().__init__()
        self.target_dim = target_dim

        # Keep only channel dimension
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Linear compression
        self.linear_compress = nn.AdaptiveAvgPool1d(self.target_dim)

        # Feature normalization
        self.feature_norm = nn.Sequential(
            nn.LayerNorm(target_dim),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        # Apply adaptive pooling
        x = self.adaptive_pool(x)

        # Squeeze channel dimension
        x = x.squeeze(-1).squeeze(-1)

        # Linear compression
        x = self.linear_compress(x)

        # Feature normalization
        x = self.feature_norm(x)

        return x

In [None]:
class APT(nn.Module):
    """
    Light-weight Sparse Autoencoder for Adaptation
    which learns how to sniff out the frame changes to predict next bounding boxes.
    """
    def __init__(self, feature_dim=256, bbox_dim=4, hidden_dim=32, sparsity_param=0.1):
        super().__init__()

        self.feature_dim = feature_dim
        self.bbox_dim = bbox_dim
        self.hidden_dim = hidden_dim
        self.sparsity_param = sparsity_param

        # Feature normalization layer for encoder-agnostic adaptation
        self.feature_norm = FeatureNormalizationLayer(target_dim=feature_dim)

        # Lightweight feature sniffer
        self.feature_sniffer = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 4 * 3)
        )

        # Previous bbox encoder
        self.bbox_encoder = nn.Sequential(
            nn.Linear(bbox_dim, hidden_dim // 4),
            nn.ReLU()
        )

        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Prediction head
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, bbox_dim),
            nn.Sigmoid()  # Normalize bbox coordinates to [0,1]
        )

        # Optional: Sparsity regularization
        self.activation = {}

    def forward(self, features, prev_bbox):
        # Normalize encoder features to be encoder-agnostic
        norm_features = self.feature_norm(features)

        # Extract relevant features from current frame
        sniffed_features = self.feature_sniffer(norm_features)

        # Encode previous bbox information
        bbox_features = self.bbox_encoder(prev_bbox)

        # Fuse features
        fused = self.fusion(
            torch.cat([sniffed_features, bbox_features], dim=-1)
        )

        # Predict next bbox
        next_bbox = self.predictor(fused)

        # Store activation for sparsity regularization if needed
        self.activation['hidden'] = fused

        return next_bbox

    def get_sparsity_loss(self):
        """Calculate sparsity regularization loss"""
        if 'hidden' not in self.activation:
            return 0

        rho_hat = torch.mean(self.activation['hidden'], dim=0)
        rho = torch.full_like(rho_hat, self.sparsity_param)

        # KL divergence for sparsity regularization
        sparsity_loss = torch.sum(
            rho * torch.log(rho/rho_hat) +
            (1-rho) * torch.log((1-rho)/(1-rho_hat))
        )

        return sparsity_loss

In [None]:
from typing import Optional

class TestTimeAdaptiveDETR(nn.Module):
    def __init__(
        self, pretrained_model,
        feature_dim=256, bbox_dim=4, hidden_dim=32, sparsity_param=0.1
    ):
        super().__init__()
        self.model = pretrained_model
        self.apt = APT(
            feature_dim=feature_dim, bbox_dim=bbox_dim,
            hidden_dim=hidden_dim, sparsity_param=sparsity_param
        )
        self.cache = None

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        pixel_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[list[dict]] = None,
        teacher_forcing_labels: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ):
        output = self.model(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs
        )

        apt_loss = 0
        if self.cache:
            features = output.encoder_last_hidden_state[-1]
            apt_output = torch.stack([self.apt(features, prev_bbox) for prev_bbox in self.cache])

            criterion = nn.MSELoss()
            if teacher_forcing_labels is not None:
                apt_loss = criterion(apt_output, teacher_forcing_labels)
            elif apt_output is None:
                apt_loss = criterion(apt_output, output.pred_boxes)

        sparsity_loss = torch.mean(self.apt.get_sparsity_loss()) * 0.01

        # Update loss
        if hasattr(output, 'loss') and output.loss is not None:
            output.loss += apt_loss + sparsity_loss
        else:
            output.loss = apt_loss + sparsity_loss

        # Update cache
        self.cache = output.pred_boxes
        return output

In [None]:
# Initialize Model
model = TestTimeAdaptiveDETR(pretrained_model=model_pretrained).to(dtype=DATA_TYPE)
model.to(device)

## Training

In [None]:
from IPython.display import display
import ipywidgets as widgets

# Interactive Loss Plot Update
def create_plot():
    train_losses, valid_losses = [], []

    # Enable Interactive Mode
    plt.ion()

    # Loss Plot Setting
    fig, ax = plt.subplots(figsize=(6, 2))
    train_line, = ax.plot(train_losses, label="Train Loss", color="purple")
    valid_line, = ax.plot(valid_losses, label="Valid Loss", color="red")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    ax.set_title("Model Loss Graph")

    # Display Plot
    plot = widgets.Output()
    display(plot)

    def update_plot(train_loss=None, valid_loss=None):
        if train_loss is not None:
            train_losses.append(train_loss)
        if valid_loss is not None:
            valid_losses.append(valid_loss)
        train_line.set_ydata(train_losses)
        train_line.set_xdata(range(len(train_losses)))
        valid_line.set_ydata(valid_losses)
        valid_line.set_xdata(range(len(valid_losses)))
        ax.relim()
        ax.autoscale_view()
        with plot:
            plot.clear_output(wait=True)
            display(fig)

    return update_plot

In [None]:
def avg(lst):
    try:
        return sum(lst) / len(lst)
    except ZeroDivisionError:
        return 0

In [None]:
def calculate_iou(box1, box2):
    """
    box shape: [x1, y1, x2, y2]
    """
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    # calculate the area of intersection rectangle
    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    
    # calculate the area of both the prediction and ground truth
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    
    union = box1_area + box2_area - intersection
    
    return intersection / union if union > 0 else 0

In [None]:
def calculate_ciou(box1, box2):
    """
    Calculate CIoU (Complete IoU) between two bounding boxes
    box format: [x, y, w, h] (normalized)
    """
    # Convert boxes to [x1, y1, x2, y2] format
    b1_x1, b1_y1 = box1[0], box1[1]
    b1_x2, b1_y2 = box1[0] + box1[2], box1[1] + box1[3]
    b2_x1, b2_y1 = box2[0], box2[1]
    b2_x2, b2_y2 = box2[0] + box2[2], box2[1] + box2[3]

    # Calculate area of boxes
    b1_area = box1[2] * box1[3]
    b2_area = box2[2] * box2[3]

    # Calculate intersection area
    inter_x1 = max(b1_x1, b2_x1)
    inter_y1 = max(b1_y1, b2_y1)
    inter_x2 = min(b1_x2, b2_x2)
    inter_y2 = min(b1_y2, b2_y2)

    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)

    # Calculate union area
    union_area = b1_area + b2_area - inter_area

    # Calculate IoU
    iou = inter_area / (union_area + 1e-7)  # Add small epsilon to avoid division by zero

    # Calculate the center distance
    center_x1 = (b1_x1 + b1_x2) / 2
    center_y1 = (b1_y1 + b1_y2) / 2
    center_x2 = (b2_x1 + b2_x2) / 2
    center_y2 = (b2_y1 + b2_y2) / 2

    center_distance = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2

    # Calculate diagonal distance of smallest enclosing box
    enclosing_x1 = min(b1_x1, b2_x1)
    enclosing_y1 = min(b1_y1, b2_y1)
    enclosing_x2 = max(b1_x2, b2_x2)
    enclosing_y2 = max(b1_y2, b2_y2)

    diagonal_distance = (enclosing_x2 - enclosing_x1) ** 2 + (enclosing_y2 - enclosing_y1) ** 2

    # Calculate aspect ratio term
    v = 4 / (np.pi ** 2) * (np.arctan(box1[2]/(box1[3] + 1e-7)) - np.arctan(box2[2]/(box2[3] + 1e-7))) ** 2

    # Calculate alpha term for CIoU
    alpha = v / (1 - iou + v + 1e-7)

    # Calculate CIoU
    ciou = iou - center_distance / (diagonal_distance + 1e-7) - alpha * v

    # Clip CIoU to [0,1] range
    return max(0.0, min(1.0, ciou))

### Default Pre-training Process
Using Teacher forcing

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 10
LEARNING_RATE = 5e-3, 1e-6
WEIGHT_DECAY = 0.05

#wandb.watch(model, log="all", log_freq=10)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE[0], weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LEARNING_RATE[1])

In [None]:
train_length, valid_length = map(len, (train_loader, valid_loader))
epochs = tqdm(range(EPOCHS), desc="Running Epochs")
with (tqdm(total=train_length, desc="Training") as train_progress,
        tqdm(total=valid_length, desc="Validation") as valid_progress):  # Set up Progress Bars
    update = create_plot()  # Create Loss Plot

    for epoch in epochs:
        train_progress.reset(total=train_length)
        valid_progress.reset(total=valid_length)

        train_loss, train_ciou = 0, 0

        # Training
        model.train()
        for i, (curr_frame, prev_bbox, curr_bbox) in enumerate(train_loader):
            torch.cuda.empty_cache()  # Clear GPU memory
            optimizer.zero_grad()

            prev_bbox, curr_bbox = prev_bbox.to(device, dtype=DATA_TYPE), curr_bbox.to(device, dtype=DATA_TYPE)
            model.cache = [prev_bbox]
            output = model(curr_frame.to(device, dtype=DATA_TYPE), teacher_forcing_labels=curr_bbox)  # Use Teacher Forcing while training

            output.loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += output.loss.item() / train_length

            train_progress.update(1)
            #if i != train_length-1: wandb.log({'MSE Loss': output.loss.item()})
            print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{i+1:4}/{train_length}], MSE Loss: {output.loss.item():.6f}", end="")

        print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{train_length}/{train_length}], MSE Loss: {train_loss:.6f}, CIoU Loss: {train_ciou:.6f}", end="")
        val_loss, val_ciou = 0, 0

        # Validation
        model.eval()
        with torch.no_grad():
            for curr_frame, prev_bbox, curr_bbox in valid_loader:
                prev_bbox, curr_bbox = prev_bbox.to(device, dtype=DATA_TYPE), curr_bbox.to(device, dtype=DATA_TYPE)
                model.cache = [prev_bbox]
                output = model(curr_frame.to(device, dtype=DATA_TYPE), teacher_forcing_labels=curr_bbox)  # Use Teacher Forcing while training

                val_loss += output.loss.item() / valid_length

        update(train_loss=train_loss, valid_loss=val_loss)
        #wandb.log({'Train MSE Loss': train_loss, 'Train CIoU Loss': train_ciou, 'Val MSE Loss': val_loss, 'Val CIoU Loss': val_ciou})
        print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{train_length}/{train_length}], MSE Loss: {train_loss:.6f}, CIoU Loss: {train_ciou:.6f}, Valid MSE Loss: {val_loss:.6f}, Valid CIoU Loss: {val_ciou:.6f}", end="\n" if (epoch+1) % 5 == 0 or (epoch+1) == EPOCHS else "")

In [None]:
if not path.isdir(path.join(".", "models")):
    mkdir(path.join(".", "models"))

# Model Save
save_path = path.join(".", "models", "apt_model.pt")
torch.save(model.apt.state_dict(), save_path)
torch.save(model.state_dict(), save_path.replace(".pt", ".full.pt"))
print(f"Model saved to {save_path}")