# TTA Example

## Imports and Configs

In [None]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch
from torchinfo import summary

from ttadapters import datasets, models, methods
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, DataLoaderHolder, scenarios

In [None]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

### Parse Arguments

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = False

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[0]

In [None]:
# Create argument parser
parser = ArgumentParser(description="Adaptation experiment script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")
parser.add_argument("--test-only", action="store_true", help="Run in test-only mode")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args(["--test-only"] if TEST_MODE else [])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
TEST_MODE = args.test_only
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityscapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Set test mode - {TEST_MODE} for {SOURCE_DOMAIN.dataset_name} dataset")

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

## Define Dataset

In [None]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()

In [None]:
# Benchmark dataset
match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        test_dataset = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT)
    case datasets.CityscapesDataset:
        pass
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = test_dataset.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
# Check annotation keys-values
test_dataset[999]

In [None]:
# Check data shape
test_dataset[999][0].shape  # should be (num_channels, height, width)

In [None]:
# Visualize video
visualizer.visualize_bbox_frames(test_dataset)

## Load Base Model

In [None]:
# Initialize model
match MODEL_TYPE:
    case "rcnn":
        base_model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO), strict=False)
    case "swinrcnn":
        base_model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR_NATUREYOO), strict=False)
    case "yolo11":
        base_model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR), strict=False)
    case "rtdetr":
        base_model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = base_model.load_from(**vars(base_model.Weights.SHIFT_CLEAR), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
base_model.to(device)

In [None]:
summary(base_model)

In [None]:
# Compile model
#base_model = torch.compile(base_model)

## Load Adaptation Method

### APT

#### Fitting Dataset

##### Original Dataset

In [None]:
from torchvision.datasets import utils as tv_utils
from torchvision.transforms import v2 as transforms
from tqdm.auto import tqdm

tv_utils.tqdm = tqdm

In [None]:
GOT10K_DATA_ROOT = r"F:\data"

got10k = DatasetHolder(
    train=datasets.GOT10kDatasetForObjectTracking(root=GOT10K_DATA_ROOT, force_download=False, train=True),
    valid=datasets.GOT10kDatasetForObjectTracking(root=GOT10K_DATA_ROOT, force_download=False, valid=True),
)

In [None]:
# Define image size for resizing
ORIGINAL_SIZE = got10k.train[0][0].shape[-2:]
IMG_SIZE = 800, 1280

print("INFO: Image conversion is set to resize to", IMG_SIZE, "from", ORIGINAL_SIZE)

In [None]:
got10k.train.targets

In [None]:
got10k.train[0]

In [None]:
# Create transforms
train_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Give Perturbation
    transforms.RandomPosterize(bits=2),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.))
])

default_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
])

In [None]:
# Apply transforms
got10k.train.transforms = train_transforms
got10k.valid.transforms = default_transforms

In [None]:
got10k.train[0]

In [None]:
got10k.valid[-1]

##### Paired Dataset

In [None]:
# Create paired datasets with lazy loading
train_pairset = datasets.PairedGOT10kDataset(got10k.train)
valid_pairset = train_pairset.extract_valid()

got10k_pairset = DatasetHolder(train=train_pairset, valid=valid_pairset, test=datasets.PairedGOT10kDataset(got10k.valid))
del train_pairset, valid_pairset

In [None]:
got10k_pairset.train[0]

In [None]:
got10k_pairset.test[0]

In [None]:
# Visualize Bbox Frame Pair
visualizer.visualize_bbox_frame_pair(got10k_pairset.test, bbox_key=None, bbox_class_key=None)

#### DataLoader

In [None]:
# Use Teacher Forcing
got10k_pairset.train.use_teacher_forcing = True
got10k_pairset.valid.use_teacher_forcing = True

In [None]:
from torch.utils.data import DataLoader

In [None]:
loaders = DataLoaderHolder(
    train=DataLoader(got10k_pairset.train, batch_size=BATCH_SIZE[0], shuffle=True),
    valid=DataLoader(got10k_pairset.valid, batch_size=BATCH_SIZE[1], shuffle=False),
    test=DataLoader(got10k_pairset.test, batch_size=BATCH_SIZE[2], shuffle=False)
)

#### Plugin

In [None]:
from torch import nn

In [None]:
class FeatureNormalizationLayer(nn.Module):
    def __init__(self, target_dim=256):
        super().__init__()
        self.target_dim = target_dim

        # Keep only channel dimension
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Linear compression
        self.linear_compress = nn.AdaptiveAvgPool1d(self.target_dim)

        # Feature normalization
        self.feature_norm = nn.Sequential(
            nn.LayerNorm(target_dim),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        # Apply adaptive pooling
        x = self.adaptive_pool(x)

        # Squeeze channel dimension
        x = x.squeeze(-1).squeeze(-1)

        # Linear compression
        x = self.linear_compress(x)

        # Feature normalization
        x = self.feature_norm(x)

        return x

In [None]:
class APT(nn.Module):
    """ Light-weight Autoencoder for Adaptation
    which learns how to sniff out the frame changes to predict next bounding boxes.
    """
    def __init__(self, feature_dim=256, bbox_dim=4, hidden_dim=32):
        super().__init__()

        self.feature_dim = feature_dim
        self.bbox_dim = bbox_dim
        self.hidden_dim = hidden_dim

        # Feature normalization layer for encoder-agnostic adaptation
        self.feature_norm = FeatureNormalizationLayer(target_dim=feature_dim)

        # Lightweight feature sniffer
        self.feature_sniffer = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 4 * 3)
        )

        # Previous bbox encoder
        self.bbox_encoder = nn.Sequential(
            nn.Linear(bbox_dim, hidden_dim // 4),
            nn.ReLU()
        )

        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Prediction head
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, bbox_dim),
            nn.Sigmoid()  # Normalize bbox coordinates to [0,1]
        )

    def forward(self, features, prev_bbox):
        # Normalize encoder features to be encoder-agnostic
        norm_features = self.feature_norm(features)

        # Extract relevant features from current frame
        sniffed_features = self.feature_sniffer(norm_features)

        # Encode previous bbox information
        bbox_features = self.bbox_encoder(prev_bbox)

        # Expand sniffed features to match bbox features
        sniffed_features = sniffed_features.expand(bbox_features.shape[0], -1)

        # Fuse features
        fused = self.fusion(
            torch.cat([sniffed_features, bbox_features], dim=-1)
        )

        # Predict next bbox
        next_bbox = self.predictor(fused)

        return next_bbox

In [None]:
from typing import Optional
from dataclasses import dataclass

from torchvision.ops import box_convert
from scipy.optimize import linear_sum_assignment

from ttadapters.models.base import BaseModel, ModelProvider, DataPreparation
from ttadapters.methods import AdaptationEngine, AdaptationConfig


@dataclass
class APTConfig(AdaptationConfig):
    adaptation_name="APT"
    img_size=(800, 1280)
    feature_dim=256
    bbox_dim=4
    hidden_dim=32
    bbox_conf_threshold=0.1
    bbox_topk=20


class APTPlugin(AdaptationEngine):
    model_name: str = "APT"
    model_provider: ModelProvider = ModelProvider.HuggingFace
    DataPreparation = DataPreparation
    class Trainer:
        pass

    def __init__(self, basemodel: BaseModel, config: APTConfig):
        super().__init__(basemodel, config)
        self.apt = APT(
            feature_dim=config.feature_dim, bbox_dim=config.bbox_dim,
            hidden_dim=config.hidden_dim
        )
        self.img_size = config.img_size
        img_size = config.img_size
        self.__bbox_normalize = lambda bbox: bbox / torch.tensor([img_size[1], img_size[0], img_size[1], img_size[0]], device=bbox.device)
        self.bbox_cache = None  # {'boxes': tensor, 'scores': tensor}
        self.adapt = False
        self.bbox_conf_threshold = config.bbox_conf_threshold
        self.bbox_topk = config.bbox_topk

    def online(self, mode=True):
        if mode:
            for param in self.base_model.parameters():
                param.requires_grad = False  # Freeze
        else:
            for param in base_model.parameters():
                param.requires_grad = False  # Freeze
        return super().online(mode)

    def offline(self):
        return self.online(False)

    def train(self, model=True):
        out = super().train(model)
        if model:  # train
            if self.adapting:
                pass
            else:
                pass
        else:  # eval
            if self.adapting:
                pass
            else:
                pass
        return out

    def fit(self, *args, **kwargs):
        pass

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        pixel_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[list[dict]] = None,
        bbox_cache: Optional[torch.FloatTensor] = None,
        teacher_forcing_labels: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ):
        bbox_cache = bbox_cache if bbox_cache is not None else self.bbox_cache

        # Run base model (encoder-decoder)
        output = self.model(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs
        )

        apt_loss = None
        if self.adapt and self.bbox_cache is not None:
            # 2. 이전 프레임의 '신뢰도 높은' bbox만 필터링 (교사 신호 정제)
            prev_boxes = self.bbox_cache['boxes']
            prev_scores = self.bbox_cache['scores']

            high_conf_mask = prev_scores > self.bbox_conf_threshold
            confident_prev_boxes = prev_boxes[high_conf_mask]
            confident_prev_scores = prev_scores[high_conf_mask]

            # 신뢰도 높은 이전 박스가 있을 경우에만 적응 수행
            if confident_prev_boxes.shape[0] > 0:
                # top-k selection
                scores_for_sorting = confident_prev_scores
                sorted_indices = torch.argsort(scores_for_sorting, descending=True)
                top_k_indices = sorted_indices[:self.bbox_topk]
                top_k_prev_boxes = confident_prev_boxes[top_k_indices]

                # 3. APT로 시간적 일관성에 기반한 '교사' 예측 수행
                features = output.encoder_last_hidden_state[-1]
                # APT 입력은 정규화된 cxcywh 포맷이어야 함
                normalized_prev_boxes_cxcywh = box_convert(self.__bbox_normalize(top_k_prev_boxes), "xyxy", "cxcywh")
                apt_teacher_boxes = self.apt(features, normalized_prev_boxes_cxcywh) # 출력: cxcywh

                # 4. 현재 프레임의 '신뢰도 높은' 예측 필터링 (학습 대상 선별)
                # 모델의 최종 예측(logits, pred_boxes)은 아직 후처리를 거치지 않은 상태
                # pred_boxes는 정규화된 cxcywh 포맷
                student_boxes = output.pred_boxes[0] # 배치 크기는 1이라고 가정

                # 5. '교사'와 '학생' 예측을 매칭하여 손실 계산 (Hungarian Algorithm)
                with torch.no_grad():
                    # L1 거리 비용 매트릭스 계산
                    cost_matrix = torch.cdist(apt_teacher_boxes, student_boxes, p=1.0)
                    # 최적의 매칭 쌍 찾기
                    row_indices, col_indices = linear_sum_assignment(cost_matrix.cpu())

                # 매칭된 학생 예측(student_boxes)과 교사 예측(apt_teacher_boxes) 사이의 L1 Loss 계산
                matched_student_boxes = student_boxes[col_indices]
                apt_loss = nn.functional.l1_loss(matched_student_boxes, apt_teacher_boxes[row_indices])

        # 6. 다음 프레임을 위해 현재 예측 결과를 bbox_cache에 저장
        if self.adapt:
            with torch.no_grad():
                sizes = [self.img_size for _ in range(len(output.pred_boxes))]
                # threshold를 낮게 설정하여 가능한 많은 후보를 캐시에 저장
                results = self.post_process(output, target_sizes=sizes, threshold=0.3)[0]
                self.bbox_cache = results # results는 {'scores': ..., 'labels': ..., 'boxes': ...} 형태

        # 7. 최종 손실 업데이트
        # 기존 loss (훈련 시) 또는 apt_loss (적응 시) 설정
        if labels is not None and hasattr(output, 'loss'): # 일반 훈련 시
            if apt_loss is not None:
                output.loss += apt_loss # 일반 훈련에도 apt_loss를 추가할 수 있음 (선택사항)
        else: # 적응 시
            output.loss = apt_loss if apt_loss is not None else torch.tensor(0.0, device=pixel_values.device)

        if apt_loss is not None:
            print(f"\rINFO: APT Loss - {apt_loss.item():.6f}", end="")

        return output

In [None]:
# Initialize Model
adaptive_config = APTConfig()
adaptive_model = APTPlugin(base_model, adaptive_config)
adaptive_model.to(device)

## Training

In [None]:
from matplotlib import pyplot as plt
import numpy as np

from IPython.display import display
import ipywidgets as widgets


# Interactive Loss Plot Update
def create_plot():
    train_losses, valid_losses = [], []

    # Enable Interactive Mode
    plt.ion()

    # Loss Plot Setting
    fig, ax = plt.subplots(figsize=(6, 2))
    train_line, = ax.plot(train_losses, label="Train Loss", color="purple")
    valid_line, = ax.plot(valid_losses, label="Valid Loss", color="red")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    ax.set_title("Model Loss Graph")

    # Display Plot
    plot = widgets.Output()
    display(plot)

    def update_plot(train_loss=None, valid_loss=None):
        if train_loss is not None:
            train_losses.append(train_loss)
        if valid_loss is not None:
            valid_losses.append(valid_loss)
        train_line.set_ydata(train_losses)
        train_line.set_xdata(range(len(train_losses)))
        valid_line.set_ydata(valid_losses)
        valid_line.set_xdata(range(len(valid_losses)))
        ax.relim()
        ax.autoscale_view()
        with plot:
            plot.clear_output(wait=True)
            display(fig)

    return update_plot

In [None]:
def avg(lst):
    try:
        return sum(lst) / len(lst)
    except ZeroDivisionError:
        return 0

In [None]:
def calculate_iou(box1, box2):
    """
    box shape: [x1, y1, x2, y2]
    """
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    # calculate the area of intersection rectangle
    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    
    # calculate the area of both the prediction and ground truth
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    
    union = box1_area + box2_area - intersection
    
    return intersection / union if union > 0 else 0

In [None]:
def calculate_ciou(box1, box2):
    """
    Calculate CIoU (Complete IoU) between two bounding boxes
    box format: [x, y, w, h] (normalized)
    """
    # Convert boxes to [x1, y1, x2, y2] format
    b1_x1, b1_y1 = box1[0], box1[1]
    b1_x2, b1_y2 = box1[0] + box1[2], box1[1] + box1[3]
    b2_x1, b2_y1 = box2[0], box2[1]
    b2_x2, b2_y2 = box2[0] + box2[2], box2[1] + box2[3]

    # Calculate area of boxes
    b1_area = box1[2] * box1[3]
    b2_area = box2[2] * box2[3]

    # Calculate intersection area
    inter_x1 = max(b1_x1, b2_x1)
    inter_y1 = max(b1_y1, b2_y1)
    inter_x2 = min(b1_x2, b2_x2)
    inter_y2 = min(b1_y2, b2_y2)

    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)

    # Calculate union area
    union_area = b1_area + b2_area - inter_area

    # Calculate IoU
    iou = inter_area / (union_area + 1e-7)  # Add small epsilon to avoid division by zero

    # Calculate the center distance
    center_x1 = (b1_x1 + b1_x2) / 2
    center_y1 = (b1_y1 + b1_y2) / 2
    center_x2 = (b2_x1 + b2_x2) / 2
    center_y2 = (b2_y1 + b2_y2) / 2

    center_distance = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2

    # Calculate diagonal distance of smallest enclosing box
    enclosing_x1 = min(b1_x1, b2_x1)
    enclosing_y1 = min(b1_y1, b2_y1)
    enclosing_x2 = max(b1_x2, b2_x2)
    enclosing_y2 = max(b1_y2, b2_y2)

    diagonal_distance = (enclosing_x2 - enclosing_x1) ** 2 + (enclosing_y2 - enclosing_y1) ** 2

    # Calculate aspect ratio term
    v = 4 / (np.pi ** 2) * (np.arctan(box1[2]/(box1[3] + 1e-7)) - np.arctan(box2[2]/(box2[3] + 1e-7))) ** 2

    # Calculate alpha term for CIoU
    alpha = v / (1 - iou + v + 1e-7)

    # Calculate CIoU
    ciou = iou - center_distance / (diagonal_distance + 1e-7) - alpha * v

    # Clip CIoU to [0,1] range
    return max(0.0, min(1.0, ciou))

### Default Pre-training Process
Using Teacher forcing

In [None]:
adaptive_model.offline()

In [None]:
from torch import optim

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 10
LEARNING_RATE = 5e-3, 1e-6
WEIGHT_DECAY = 0.05

#wandb.watch(model, log="all", log_freq=10)
optimizer = optim.AdamW(adaptive_model.parameters(), lr=LEARNING_RATE[0], weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LEARNING_RATE[1])

In [None]:
train_length, valid_length = map(len, (train_loader, valid_loader))
epochs = tqdm(range(EPOCHS), desc="Running Epochs")
with (tqdm(total=train_length, desc="Training") as train_progress,
        tqdm(total=valid_length, desc="Validation") as valid_progress):  # Set up Progress Bars
    update = create_plot()  # Create Loss Plot

    for epoch in epochs:
        train_progress.reset(total=train_length)
        valid_progress.reset(total=valid_length)

        train_loss, train_ciou = 0, 0

        # Training
        adaptive_model.train()
        for i, (curr_frame, prev_bbox, curr_bbox) in enumerate(train_loader):
            torch.cuda.empty_cache()  # Clear GPU memory
            optimizer.zero_grad()

            prev_bbox, curr_bbox = prev_bbox.to(device, dtype=DATA_TYPE), curr_bbox.to(device, dtype=DATA_TYPE)
            adaptive_model.cache = [prev_bbox]
            output = adaptive_model(curr_frame.to(device, dtype=DATA_TYPE), teacher_forcing_labels=curr_bbox)  # Use Teacher Forcing while training

            output.loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += output.loss.item() / train_length

            train_progress.update(1)
            #if i != train_length-1: wandb.log({'MSE Loss': output.loss.item()})
            print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{i+1:4}/{train_length}], MSE Loss: {output.loss.item():.6f}", end="")

        print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{train_length}/{train_length}], MSE Loss: {train_loss:.6f}, CIoU Loss: {train_ciou:.6f}", end="")
        val_loss, val_ciou = 0, 0

        # Validation
        adaptive_model.eval()
        with torch.no_grad():
            for curr_frame, prev_bbox, curr_bbox in valid_loader:
                prev_bbox, curr_bbox = prev_bbox.to(device, dtype=DATA_TYPE), curr_bbox.to(device, dtype=DATA_TYPE)
                adaptive_model.cache = [prev_bbox]
                output = adaptive_model(curr_frame.to(device, dtype=DATA_TYPE), teacher_forcing_labels=curr_bbox)  # Use Teacher Forcing while training

                val_loss += output.loss.item() / valid_length

        update(train_loss=train_loss, valid_loss=val_loss)
        #wandb.log({'Train MSE Loss': train_loss, 'Train CIoU Loss': train_ciou, 'Val MSE Loss': val_loss, 'Val CIoU Loss': val_ciou})
        print(f"\rEpoch [{epoch+1:2}/{EPOCHS}], Step [{train_length}/{train_length}], MSE Loss: {train_loss:.6f}, CIoU Loss: {train_ciou:.6f}, Valid MSE Loss: {val_loss:.6f}, Valid CIoU Loss: {val_ciou:.6f}", end="\n" if (epoch+1) % 5 == 0 or (epoch+1) == EPOCHS else "")

In [None]:
adaptive_model.save_pretrained("./results/apt/")

## Evaluation

### Load Scenarios

In [None]:
# Load Pretrained APT Weights & Un-Freeze Model Encoder
# Allow FPN/Encoder to adapt during online adaptation
adaptive_model.eval()
adaptive_model.online()

In [None]:
data_preparation = adaptive_model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

discrete_scenario = scenarios.SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
)
continuous_scenario = scenarios.SHIFTContinuousScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTContinuousScenario.DEFAULT, transforms=data_preparation.transforms
)

In [None]:
evaluator = validator.DetectionEvaluator(adaptive_model, classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device, no_grad=True)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
visualizer.visualize_metrics(discrete_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))

In [None]:
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))