# Pre-training with SHIFT-Discrete Dataset (Clear-Daytime)

## Imports and Configs

In [None]:
import sys
from os import path, environ
from argparse import ArgumentParser

import torch

from ttadapters import datasets, models
from ttadapters.utils import visualizer, validator
from ttadapters.datasets import DatasetHolder, scenarios, SHIFTDataset

In [None]:
environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
environ["TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS"] = "1"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

### Parse Arguments

In [None]:
# Set Batch Size
BATCH_SIZE = 2, 8, 1  # Local
#BATCH_SIZE = 40, 200, 1  # A100 or H100
ACCUMULATE_STEPS = 1

# Set Data Root
DATA_ROOT = path.join(".", "data")

# Set Target Dataset
SOURCE_DOMAIN = datasets.SHIFTDataset

# Set Run Mode
TEST_MODE = False

# Set Model List
MODEL_ZOO = ["rcnn", "swinrcnn", "yolo11", "rtdetr"]
MODEL_TYPE = MODEL_ZOO[-1]

In [None]:
# Create argument parser
parser = ArgumentParser(description="Training script for Test-Time Adapters")

# Add model arguments
parser.add_argument("--dataset", type=str, choices=["shift", "city"], default="shift", help="Training dataset")
parser.add_argument("--model", type=str, choices=MODEL_ZOO, default=MODEL_TYPE, help="Model architecture")

# Add training arguments
parser.add_argument("--train-batch", type=int, default=BATCH_SIZE[0], help="Training batch size")
parser.add_argument("--valid-batch", type=int, default=BATCH_SIZE[1], help="Validation batch size")
parser.add_argument("--accum-step", type=int, default=ACCUMULATE_STEPS, help="Gradient accumulation steps")
parser.add_argument("--data-root", type=str, default=DATA_ROOT, help="Root directory for datasets")
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument("--additional_gpu", type=int, default=0, help="Additional CUDA device count")
parser.add_argument("--use-bf16", action="store_true", help="Use bfloat16 precision")
parser.add_argument("--test-only", action="store_true", help="Run in test-only mode")

# Parsing arguments
if "ipykernel" in sys.modules:
    args = parser.parse_args(["--test-only"] if TEST_MODE else [])
    print("INFO: Running in notebook mode with default arguments")
else:
    args = parser.parse_args()

# Update global variables based on parsed arguments
BATCH_SIZE = args.train_batch, args.valid_batch, BATCH_SIZE[2]
ACCUMULATE_STEPS = args.accum_step
DATA_ROOT = args.data_root
TEST_MODE = args.test_only
MODEL_TYPE = args.model
match args.dataset:
    case "shift":
        SOURCE_DOMAIN = datasets.SHIFTDataset
    case "city":
        SOURCE_DOMAIN = datasets.CityScapesDataset
    case _:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Set test mode - {TEST_MODE} for {SOURCE_DOMAIN.dataset_name} dataset")

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0 if not args.device else args.device
ADDITIONAL_GPU = 0 if not args.additional_gpu else args.additional_gpu
DATA_TYPE = torch.float32 if not args.use_bf16 else torch.bfloat16

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))
print(f"INFO: Using data precision - {DATA_TYPE}")

## Define Dataset

In [None]:
# Fast download patch
datasets.patch_fast_download_for_object_detection()

In [None]:
# Basic pre-training dataset
match SOURCE_DOMAIN:
    case datasets.SHIFTDataset:
        # discrete
        dataset = DatasetHolder(
            train=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
        # continuous
        _ = datasets.SHIFTContinuous100DatasetForObjectDetection(root=DATA_ROOT)  # 100
        _ = datasets.SHIFTContinuous10DatasetForObjectDetection(root=DATA_ROOT)  # 10
        _ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT)  # 1 + split
    case datasets.CityScapesDataset:
        dataset = DatasetHolder(
            train=datasets.CityScapesDatasetForObjectDetection(root=DATA_ROOT, train=True),
            valid=datasets.CityScapesDatasetForObjectDetection(root=DATA_ROOT, valid=True),
            test=datasets.CityScapesCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
        )
    case _:
        raise ValueError(f"Unsupported dataset: {SOURCE_DOMAIN}")

# Dataset info
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

In [None]:
# Check annotation keys-values
dataset.train[999]

In [None]:
# Check data shape
dataset.train[999][0].shape  # should be (num_channels, height, width)

In [None]:
# Visualize video
visualizer.visualize_bbox_frames(dataset.train)

## Load Model

In [None]:
# Initialize model
match MODEL_TYPE:
    case "rcnn":
        model = models.FasterRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_OFFICIAL if not TEST_MODE else model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case "swinrcnn":
        model = models.SwinRCNNForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_XIAOHU2015 if not TEST_MODE else model.Weights.SHIFT_CLEAR_NATUREYOO if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case "yolo11":
        DATA_TYPE = torch.bfloat16  # bf16 default
        model = models.YOLO11ForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_OFFICIAL if not TEST_MODE else model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case "rtdetr":
        DATA_TYPE = torch.bfloat16  # bf16 default
        model = models.RTDetrForObjectDetection(dataset=SOURCE_DOMAIN)
        load_result = model.load_from(**vars(model.Weights.COCO_OFFICIAL if not TEST_MODE else model.Weights.SHIFT_CLEAR if SOURCE_DOMAIN == SHIFTDataset else model.Weights.CITYSCAPES), strict=False)
    case _:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")

print("INFO: Model state loaded -", load_result)
model.to(device)

## Train

In [None]:
# Project Setup
PROJECT_NAME = "tta_model_pretraining"
RUN_NAME = model.model_name + "_" + SOURCE_DOMAIN.dataset_name + ("_test" if TEST_MODE else "_train")

In [None]:
# WandB Initialization
import wandb
wandb.init(project=PROJECT_NAME, name=RUN_NAME)

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 20
LEARNING_RATE = 1e-4

In [None]:
evaluate_source = lambda: None
evaluate_target = lambda: None

### Detectron Trainer

In [None]:
# Define Trainer & Validator
if not TEST_MODE and MODEL_TYPE in ("rcnn", "swinrcnn"):
    ALL_DEVICE_BATCH = BATCH_SIZE[0]*(ADDITIONAL_GPU+1), BATCH_SIZE[1]*(ADDITIONAL_GPU+1)
    trainer = model.Trainer(
        model=model,
        classes=CLASSES,
        train_dataset=model.DataPreparation(dataset.train),
        eval_dataset=model.DataPreparation(dataset.valid, evaluation_mode=True),
        args=model.TrainingArguments(
            learning_rate=LEARNING_RATE,
            total_steps=EPOCHS*10*len(dataset.train)//ALL_DEVICE_BATCH[0],
            eval_period=100,
            save_period=100,
            train_batch_for_total=ALL_DEVICE_BATCH[0],
            eval_batch_for_total=ALL_DEVICE_BATCH[1],
            multiple_gpu_world_size=ADDITIONAL_GPU+1 if ADDITIONAL_GPU > 0 else ADDITIONAL_GPU,  # Set 0 to disable multi-GPU reference
            momentum=0.9,
            weight_decay=1e-4,
            lr_scheduler_type="WarmupCosineLR",  # WarmupMultiStepLR, WarmupStepWithFixedGammaLR
            cosine_lr_final=LEARNING_RATE/10,
            lr_warmup_method="linear",
            lr_warmup_iters=500,
            use_amp=False,
            output_dir="./results/"+RUN_NAME
        )
    )

    evaluator = model.Trainer(
        model=model,
        classes=CLASSES,
        eval_dataset=model.DataPreparation(dataset.test, evaluation_mode=True),
        args=model.TrainingArguments(
            learning_rate=LEARNING_RATE,
            total_steps=1,
            eval_batch_for_total=BATCH_SIZE[1]*(ADDITIONAL_GPU+1),
            multiple_gpu_world_size=ADDITIONAL_GPU+1 if ADDITIONAL_GPU > 0 else ADDITIONAL_GPU,  # Set 0 to disable multi-GPU reference
            use_amp=False,
            output_dir="./results/"+RUN_NAME
        )
    )

    evaluate_source = trainer.test
    evaluate_target = evaluator.test

### Ultralytics Trainer

In [None]:
if not TEST_MODE and MODEL_TYPE == "yolo11":
    trainer = model.Trainer(
        model=model,
        classes=CLASSES,
        train_dataset=model.DataPreparation(dataset.train),
        eval_dataset=model.DataPreparation(dataset.valid, evaluation_mode=True),
        args=model.TrainingArguments(
            lr0=LEARNING_RATE,
            epochs=EPOCHS,
            batch=BATCH_SIZE[0],
            val_batch=BATCH_SIZE[1],
            optimizer="SGD",
            lrf=LEARNING_RATE/10,
            momentum=0.937,
            weight_decay=0.0005,
            warmup_epochs=3,
            close_mosaic=10,  # strong augmentations
            cos_lr=False,
            save_period=1,
            project="./results",
            name=RUN_NAME,
            plots=True,
            workers=0,
            device=str(device.index),
            amp=True if DATA_TYPE == torch.bfloat16 else False
        )
    )

    evaluator = model.Trainer(
        model=model,
        classes=CLASSES,
        eval_dataset=model.DataPreparation(dataset.test, evaluation_mode=True),
        args=model.TrainingArguments(
            val_batch=BATCH_SIZE[1],
            workers=0,
            device=str(device.index),
            amp=True if DATA_TYPE == torch.bfloat16 else False
        )
    )

    evaluate_source = trainer.validate
    evaluate_target = evaluator.validate

### Transformers Trainer

In [None]:
# Define Trainer & Validator
if not TEST_MODE and MODEL_TYPE == "rtdetr":
    trainer = model.Trainer(
        model=model,
        classes=CLASSES,
        train_dataset=model.DataPreparation(dataset.train),
        eval_dataset=model.DataPreparation(dataset.valid, evaluation_mode=True),
        args=model.TrainingArguments(
            backbone_learning_rate=LEARNING_RATE/10,  # Set backbone learning rate to 1/10th of the main learning rate
            learning_rate=LEARNING_RATE,
            lr_scheduler_type="cosine",
            warmup_ratio=0.1,
            weight_decay=0.1,
            max_grad_norm=0.5,
            num_train_epochs=EPOCHS,
            per_device_train_batch_size=BATCH_SIZE[0],
            per_device_eval_batch_size=BATCH_SIZE[1],
            gradient_accumulation_steps=ACCUMULATE_STEPS,
            eval_accumulation_steps=BATCH_SIZE[1],
            batch_eval_metrics=True,
            remove_unused_columns=False,
            optim="adamw_torch",
            eval_on_start=True,
            eval_strategy="steps",
            save_strategy="steps",
            logging_strategy="steps",
            eval_steps=100,
            save_steps=100,
            logging_steps=100,
            save_total_limit=100,
            load_best_model_at_end=True,
            metric_for_best_model="mAP@0.50:0.95",
            greater_is_better=True,
            report_to="wandb",
            output_dir="./results/"+RUN_NAME,
            logging_dir="./logs/"+RUN_NAME,
            run_name=RUN_NAME,
            bf16=True if DATA_TYPE == torch.bfloat16 else False
        )
    )

    evaluator = model.Trainer(
        model=model,
        classes=CLASSES,
        eval_dataset=model.DataPreparation(dataset.test, evaluation_mode=True),
        args=model.TrainingArguments(
            per_device_eval_batch_size=BATCH_SIZE[1],
            batch_eval_metrics=True,
            remove_unused_columns=False,
            bf16=True if DATA_TYPE == torch.bfloat16 else False
        )
    )

    evaluate_source = trainer.evaluate
    evaluate_target = evaluator.evaluate

### Run Jobs

In [None]:
# Do train for source domain
if not TEST_MODE:
    if MODEL_TYPE in ("rcnn", "swinrcnn"):
        trainer.resume_or_load(resume=True)
        trainer.train()
    elif MODEL_TYPE == "yolo11":
        trainer.args.plots = False
        try:
            trainer.resume_from_checkpoint()
        except FileNotFoundError:
            print("INFO: No checkpoint found, starting training from scratch.")
        trainer.train()
    elif MODEL_TYPE == "rtdetr":
        try:
            trainer.train(resume_from_checkpoint=True)
        except (FileNotFoundError, ValueError):
            trainer.train()

In [None]:
# Do eval for source domain
evaluate_source()

In [None]:
# Do eval for target domain
evaluate_target()

In [None]:
# Model save
if not TEST_MODE:
    model.save_to(version=RUN_NAME)

## Evaluation

In [None]:
# Set model eval mode
model.eval()

### Load Scenarios

|Model|Dataset|Metric|
|---|---|---|
|rcnn|shift|mAP@0.50:0.95|
|swinrcnn|shift|mAP@0.50:0.95|
|yolo11|shift|mAP@0.50:0.95|
|rtdetr|shift|mAP@0.50:0.95|

In [None]:
# Ensure split (required due to Scenario class works with coroutines)
_ = datasets.SHIFTContinuousSubsetForObjectDetection(root=DATA_ROOT, train=True)

In [None]:
data_preparation = model.DataPreparation(datasets.base.BaseDataset(), evaluation_mode=True)

discrete_scenario = scenarios.SHIFTDiscreteScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTDiscreteScenario.WHWPAPER, transforms=data_preparation.transforms
)
continuous_scenario = scenarios.SHIFTContinuousScenario(
    root=DATA_ROOT, valid=True, order=scenarios.SHIFTContinuousScenario.DEFAULT, transforms=data_preparation.transforms
)

In [None]:
evaluator = validator.DetectionEvaluator(model, classes=CLASSES, data_preparation=data_preparation, dtype=DATA_TYPE, device=device, no_grad=True)
evaluator_loader_params = dict(batch_size=BATCH_SIZE[2], shuffle=False, collate_fn=data_preparation.collate_fn)

In [None]:
visualizer.visualize_metrics(discrete_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))

In [None]:
visualizer.visualize_metrics(continuous_scenario(**evaluator_loader_params).play(evaluator, index=["Direct-Test"]))