# RT-DETR Pretraining with SHIFT-Discrete Dataset

## Check GPU Availability

In [1]:
!nvidia-smi

Mon Sep 22 07:00:32 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   78C    P0             247W / 250W |   7930MiB / 16384MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 6

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_NUM)
environ["CUDA_VISIBLE_DEVICES"]

## Imports

In [None]:
import os
os.chdir("/workspace/ptta") # os.chdir("/home/ubuntu/test-time-adapters")

In [None]:
from os import path
import math

import torch
from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import (
    SHIFTDataset,
    SHIFTClearDatasetForObjectDetection,
    SHIFTCorruptedDatasetForObjectDetection,
    SHIFTDiscreteSubsetForObjectDetection
)
from ttadapters import datasets

from ttadapters.models.rcnn import FasterRCNNForObjectDetection, SwinRCNNForObjectDetection

from supervision.metrics.mean_average_precision import MeanAveragePrecision
from supervision.detection.core import Detections

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
import torch.optim as optim
import torch.nn as nn
from pathlib import Path
from detectron2.layers import FrozenBatchNorm2d
from detectron2.utils.events import EventStorage

In [None]:
from ttadapters.methods.other_method import utils

In [None]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import math
from tqdm import tqdm
from detectron2.structures import Instances
import torchvision.transforms as T

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"INFO: Using device - {device}")

In [None]:
PROJECT_NAME = "detectron_test"
RUN_NAME = "Faster-RCNN_R50"

## Define Dataset

In [None]:
DATA_ROOT = path.join(".", "data")

## DataLoader

In [None]:
from detectron2.structures import ImageList

def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return dict(
        pixel_values=ImageList.from_tensors(images, size_divisibility=32),
        labels=[dict(
            class_labels=item['boxes2d_classes'].long(),
            boxes=item["boxes2d"].float()
        ) for item in targets]
    )

In [None]:
from detectron2.structures import Boxes, Instances
from torchvision.tv_tensors import Image, BoundingBoxes

def collate_fn(batch: list[Image, BoundingBoxes]):
    batched_inputs = []
    for image, metadata in batch:
        original_height, original_width = image.shape[-2:]
        instances = Instances(image_size=(original_height, original_width))
        instances.gt_boxes = Boxes(metadata["boxes2d"])  # xyxy
        instances.gt_classes = metadata["boxes2d_classes"]
        batched_inputs.append({
            "image": image,
            "instances": instances,
            "height": original_height,
            "width": original_width
        })
    return batched_inputs

## Load Model

In [None]:
USE_SWIN_T_BACKBONE = False

In [None]:
if USE_SWIN_T_BACKBONE:
    model = SwinRCNNForObjectDetection(dataset=SHIFTDataset)
else:
    model = FasterRCNNForObjectDetection(dataset=SHIFTDataset)

model.load_from(model.Weights.NATUREYOO, weight_key="model")
model.to(device)

### Direct Test

In [None]:
def task_to_subset_types(task: str):
    T = SHIFTDiscreteSubsetForObjectDetection.SubsetType

    # weather
    if task == "cloudy":
        return T.CLOUDY_DAYTIME
    if task == "overcast":
        return T.OVERCAST_DAYTIME
    if task == "rainy":
        return T.RAINY_DAYTIME
    if task == "foggy":
        return T.FOGGY_DAYTIME

    # time
    if task == "night":
        return T.CLEAR_NIGHT
    if task in {"dawn", "dawn/dusk"}:
        return T.CLEAR_DAWN
    if task == "clear":
        return T.CLEAR_DAYTIME
    
    # simple
    if task == "normal":
        return T.NORMAL
    if task == "corrupted":
        return T.CORRUPTED

    raise ValueError(f"Unknown task: {task}")

In [None]:
# dataset
class SHIFTCorruptedDatasetForObjectDetection(SHIFTDiscreteSubsetForObjectDetection):
    def __init__(
            self, root: str, force_download: bool = False,
            train: bool = True, valid: bool = False,
            transform= None, target_transform = None, transforms = None,
            task = "clear"
    ):
        super().__init__(
            root=root, force_download=force_download,
            train=train, valid=valid, subset_type=task_to_subset_types(task),
            transform=transform, target_transform=target_transform, transforms=transforms
        )

In [None]:
import time
import gc

def evaluate_for(self, loader, loader_length, threshold=0.0, dtype=torch.float32, device=torch.device("cuda")):
    torch.cuda.empty_cache()
    gc.collect()

    self.eval()

    map_metric = MeanAveragePrecision()
    predictions_list = []
    targets_list = []
    collapse_time = 0

    with torch.inference_mode():
        for batch in tqdm(loader, total=loader_length, desc="Evaluation"):
            with torch.autocast(device_type=device.type, dtype=dtype):
                start = time.time()
                outputs = self(batch)
                collapse_time += time.time() - start

            for i, (output, input_data) in enumerate(zip(outputs, batch)):
                instances = output['instances']
                mask = instances.scores > threshold

                pred_detection = Detections(
                    xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(),
                    class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                    confidence=instances.scores[mask].detach().cpu().numpy()
                )
                gt_instances = input_data['instances']
                target_detection = Detections(
                    xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                    class_id=gt_instances.gt_classes.detach().cpu().numpy()
                )

                predictions_list.append(pred_detection)
                targets_list.append(target_detection)

        map_metric.update(predictions=predictions_list, targets=targets_list)
        m_ap = map_metric.compute()

        per_class_map = {
            f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
            for idx in m_ap.matched_classes
        }
        performances = {
            "collapse_time": collapse_time,
            "fps": loader_length / collapse_time
        }

        return {
            "mAP@0.50:0.95": m_ap.map50_95.item(),
            "mAP@0.50": m_ap.map50.item(),
            "mAP@0.75": m_ap.map75.item(),
            **per_class_map,
            **performances
        }

## Direct Method

In [None]:
# direct_method
for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night", "clear"]:
    dataset=SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms,
        task=task
    )
    print(f"start {task}")
    CLASSES = dataset
    NUM_CLASSES = len(CLASSES)
    
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    dataloader.valid_len = math.ceil(len(dataset)/4)
    result = evaluate_for(model, dataloader, dataloader.valid_len)
    print(result)

## ActMAD

In [None]:
def extract_activation_alignment(model, method, data_root, batch_size=16):
    dataset = SHIFTClearDatasetForObjectDetection(
        root=data_root, train=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms
    )

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    loader.train_len = math.ceil(len(dataset)/batch_size)

    # model unfreeze
    for k, v in model.named_parameters():
        v.requires_grad = True

    chosen_bn_info = []
    if method == "actmad": 
        for name, m in model.named_modules():
            if isinstance(m, (FrozenBatchNorm2d)):
                chosen_bn_info.append((name, m))

    # chosen_bn_layers
    """
    Since high-level representations are more sensitive to domain shift,
    only the later BN layers are selected. 
    The cutoff point is determined empirically.
    """
    cutoff = len(chosen_bn_info) // 2
    chosen_bn_info = chosen_bn_info[cutoff:]
    chosen_bn_layers = [module for name, module in chosen_bn_info]
    layer_names = [name for name, module in chosen_bn_info]

    n_chosen_layers = len(chosen_bn_layers)

    save_outputs = [utils.SaveOutput() for _ in range(n_chosen_layers)]

    clean_mean_act_list = [utils.AverageMeter() for _ in range(n_chosen_layers)]
    clean_var_act_list = [utils.AverageMeter() for _ in range(n_chosen_layers)]

    clean_mean_list_final = []
    clean_var_list_final = []
    
    # extract the activation alignment in train dataset
    print("Start extracting BN statistics from the training dataset")
    
    with torch.no_grad():
        for batch in tqdm(loader, total=loader.train_len, desc="Evaluation"):
            model.eval()
            hook_list = [chosen_bn_layers[i].register_forward_hook(save_outputs[i]) for i in range(n_chosen_layers)]
            _ = model(batch)

            for i in range(n_chosen_layers):
                clean_mean_act_list[i].update(save_outputs[i].get_out_mean())  # compute mean from clean data
                clean_var_act_list[i].update(save_outputs[i].get_out_var())  # compute variane from clean data

                save_outputs[i].clear()
                hook_list[i].remove()

        for i in range(n_chosen_layers):
            clean_mean_list_final.append(clean_mean_act_list[i].avg)  # [C, H, W]
            clean_var_list_final.append(clean_var_act_list[i].avg)  # [C, H, W]

        return clean_mean_list_final, clean_var_list_final, layer_names

In [None]:
# actmad | extract clear data bn

# hyperparameter 
CLEAN_BN_EXTRACT_BATCH = 8
stats_save_path = Path("/workspace/ptta/ttadapters/methods/other_method") / f"actmad_clean_statistics_faster_rcnn.pt"

statistics = {}

# 저장된 statistics가 있는지 확인
if stats_save_path.exists():
    print(f"Loading saved ActMAD statistics from {stats_save_path}")
    saved_stats = torch.load(stats_save_path)
    statistics["clean_mean_list_final"] = saved_stats["clean_mean_list_final"]
    statistics["clean_var_list_final"] = saved_stats["clean_var_list_final"]
    statistics["layer_names"] = saved_stats["layer_names"]
else:
    print("Extracting ActMAD statistics from clean data...")
    (
        statistics["clean_mean_list_final"],
        statistics["clean_var_list_final"],
        statistics["layer_names"]
    ) = extract_activation_alignment(
        model=model, method="actmad",
        data_root=DATA_ROOT, 
        batch_size=CLEAN_BN_EXTRACT_BATCH
        )

    # Statistics만 저장 (chosen_bn_layers는 저장하지 않음)
    print(f"Saving ActMAD statistics to {stats_save_path}")
    torch.save({
        "clean_mean_list_final": statistics["clean_mean_list_final"],
        "clean_var_list_final": statistics["clean_var_list_final"],
        "layer_names": statistics["layer_names"]
    }, stats_save_path)

In [None]:
clean_mean_list_final = statistics["clean_mean_list_final"]
clean_var_list_final = statistics["clean_var_list_final"]
layer_names = statistics["layer_names"]

current_bn_dict = {name: module for name, module in model.named_modules()
                    if isinstance(module, FrozenBatchNorm2d)}

chosen_bn_layers = []
for layer_name in layer_names:
    if layer_name in current_bn_dict:
        chosen_bn_layers.append(current_bn_dict[layer_name])
    else:
        print(f"Warning: Layer {layer_name} not found!")

optimizer = optim.SGD(
                model.parameters(),
                lr=0.001,  
            )
# Unfreeze model parameters for ActMAD
for k, v in model.named_parameters():
    v.requires_grad = True

for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night", "clear"]:
    map_metric = MeanAveragePrecision()
    predictions_list = []
    targets_list = []
    threshold = 0.0

    # data load
    dataset=SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms,
        task=task
    )
    print(f"start {task}")
    CLASSES = dataset
    NUM_CLASSES = len(CLASSES)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    dataloader.valid_len = math.ceil(len(dataset)/4)

    # Unfreeze model parameters for ActMAD
    for param in model.parameters():
        param.requires_grad = True
    
    n_chosen_layers = len(chosen_bn_layers)

    l1_loss = nn.L1Loss(reduction="mean")

    for batch in tqdm(dataloader, total=dataloader.valid_len, desc="Evaluation"):
        model.eval()
        # for m in model.modules():
        #     if isinstance(m, (FrozenBatchNorm2d)):
        #         m.eval()
        optimizer.zero_grad()
        save_outputs_tta = [utils.SaveOutput() for _ in range(n_chosen_layers)]

        hook_list_tta = [chosen_bn_layers[x].register_forward_hook(save_outputs_tta[x])
                        for x in range(n_chosen_layers)]
        
        # forward pass
        outputs = model(batch)

        # Extract current batch statistics
        batch_mean_tta = [save_outputs_tta[x].get_out_mean() for x in range(n_chosen_layers)]
        batch_var_tta = [save_outputs_tta[x].get_out_var() for x in range(n_chosen_layers)]

        # Compute ActMAD loss
        loss_mean = torch.tensor(0, requires_grad=True, dtype=torch.float).float().to(device)
        loss_var = torch.tensor(0, requires_grad=True, dtype=torch.float).float().to(device)

        for i in range(n_chosen_layers):
            loss_mean += l1_loss(batch_mean_tta[i].to(device), clean_mean_list_final[i].to(device))
            loss_var += l1_loss(batch_var_tta[i].to(device), clean_var_list_final[i].to(device))
            
        loss =  loss_mean +  loss_var

        # Backward and update
        loss.backward()
        optimizer.step()

        # Clean up hooks 
        for z in range(n_chosen_layers):
            save_outputs_tta[z].clear()
            hook_list_tta[z].remove()
        
        for i, (output, input_data) in enumerate(zip(outputs, batch)):
                instances = output['instances']
                mask = instances.scores > threshold

                pred_detection = Detections(
                    xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(),
                    class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                    confidence=instances.scores[mask].detach().cpu().numpy()
                )
                gt_instances = input_data['instances']
                target_detection = Detections(
                    xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                    class_id=gt_instances.gt_classes.detach().cpu().numpy()
                )

                predictions_list.append(pred_detection)
                targets_list.append(target_detection)

    map_metric.update(predictions=predictions_list, targets=targets_list)
    print(f"start {task} mAP computation")
    m_ap = map_metric.compute()

    per_class_map = {
        f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
        for idx in m_ap.matched_classes
    }

    print({
        "mAP@0.50:0.95": m_ap.map50_95.item(),
        "mAP@0.50": m_ap.map50.item(),
        "mAP@0.75": m_ap.map75.item(),
        **per_class_map,
    })


## NORM

In [None]:
def apply_norm_adaptation(model, source_sum=128):
    for name, module in model.named_modules():
        if isinstance(module, (nn.BatchNorm2d, FrozenBatchNorm2d)):
            module.adapt_type = "NORM"
            module.source_sum = source_sum

            # ContinualTTA NORM forward method 추가
            def norm_forward(self, x):
                if hasattr(self, 'adapt_type') and self.adapt_type == "NORM":
                    # NORM adaptation logic from ContinualTTA
                    alpha = x.shape[0] / (self.source_sum + x.shape[0])
                    running_mean = (1 - alpha) * self.running_mean + alpha * x.mean(dim=[0,2,3])
                    running_var = (1 - alpha) * self.running_var + alpha * x.var(dim=[0,2,3])
                    scale = self.weight * (running_var + self.eps).rsqrt()
                    bias = self.bias - running_mean * scale
                else:
                    # Original forward
                    scale = self.weight * (self.running_var + self.eps).rsqrt()
                    bias = self.bias - self.running_mean * scale

                scale = scale.reshape(1, -1, 1, 1)
                bias = bias.reshape(1, -1, 1, 1)
                out_dtype = x.dtype
                out = x * scale.to(out_dtype) + bias.to(out_dtype)
                return out

            # Replace forward method
            module.forward = norm_forward.__get__(module, module.__class__)
            print(f"Applied NORM adaptation to {name}")

    return model

In [None]:
model = apply_norm_adaptation(model, source_sum=128)

for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night", "clear"]:
    map_metric = MeanAveragePrecision()
    predictions_list = []
    targets_list = []
    threshold = 0.0

    # data load
    dataset=SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms,
        task=task
    )
    print(f"start {task}")
    CLASSES = dataset
    NUM_CLASSES = len(CLASSES)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    dataloader.valid_len = math.ceil(len(dataset)/4)

    for batch in tqdm(dataloader, total=dataloader.valid_len, desc="Evaluation"):
        model.eval()
        outputs = model(batch)
     
        for i, (output, input_data) in enumerate(zip(outputs, batch)):
                instances = output['instances']
                mask = instances.scores > threshold

                pred_detection = Detections(
                    xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(),
                    class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                    confidence=instances.scores[mask].detach().cpu().numpy()
                )
                gt_instances = input_data['instances']
                target_detection = Detections(
                    xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                    class_id=gt_instances.gt_classes.detach().cpu().numpy()
                )

                predictions_list.append(pred_detection)
                targets_list.append(target_detection)

    map_metric.update(predictions=predictions_list, targets=targets_list)
    print(f"start {task} mAP computation")
    m_ap = map_metric.compute()

    per_class_map = {
        f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
        for idx in m_ap.matched_classes
    }

    print({
        "mAP@0.50:0.95": m_ap.map50_95.item(),
        "mAP@0.50": m_ap.map50.item(),
        "mAP@0.75": m_ap.map75.item(),
        **per_class_map,
    })

## DUA

In [None]:
def apply_dua_adaptation(model, decay_factor=0.94, mom_pre=0.01, min_momentum_constant=0.0001):
    for name, module in model.named_modules():
        for name, module in model.named_modules():
            if isinstance(module, (nn.BatchNorm2d, FrozenBatchNorm2d)):
                module.adapt_type = "DUA"
                module.min_momentum_constant = min_momentum_constant
                module.decay_factor = decay_factor
                module.mom_pre = mom_pre

            if not hasattr(module, 'original_running_mean'):
                module.original_running_mean = module.running_mean.clone()
                module.original_running_var = module.running_var.clone()

            def dua_forward(self, x):
                if hasattr(self, 'adapt_type') and self.adapt_type == "DUA":
                    with torch.no_grad():
                        current_momentum = self.mom_pre + self.min_momentum_constant
                        batch_mean = x.mean(dim=[0, 2, 3])
                        batch_var = x.var(dim=[0, 2, 3],unbiased=True)

                        # running statistics 업데이트 (gradient 없이)
                        self.running_mean.mul_(1 - current_momentum).add_(batch_mean, alpha=current_momentum)
                        self.running_var.mul_(1 - current_momentum).add_(batch_var, alpha=current_momentum)
                        self.mom_pre *= self.decay_factor
                        self.mom_pre *= self.decay_factor
                    scale = self.weight * (self.running_var + self.eps).rsqrt()
                    bias = self.bias - self.running_mean * scale
                else:
                    scale = self.weight * (self.running_var + self.eps).rsqrt()
                    bias = self.bias - self.running_mean * scale

                scale = scale.reshape(1, -1, 1, 1)
                bias = bias.reshape(1, -1, 1, 1)
                out_dtype = x.dtype
                out = x * scale.to(out_dtype) + bias.to(out_dtype)

                return out
            module.forward = dua_forward.__get__(module, module.__class__)
            print(f"Applied DUA adaptation to {name}")
    return model

In [None]:
def reset_dua_momentum(model, mom_pre=0.01):
    for name, module in model.named_modules():
        if isinstance(module, (nn.BatchNorm2d, FrozenBatchNorm2d)) and hasattr(module, 'adapt_type'):
            if module.adapt_type == "DUA":
                module.mom_pre = mom_pre
                if hasattr(module, 'original_running_mean'):
                    module.running_mean = module.original_running_mean.clone()
                    module.running_var = module.original_running_var.clone()

In [None]:
model = apply_dua_adaptation(model, decay_factor=0.94, mom_pre=0.0)

for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night", "clear"]:
    map_metric = MeanAveragePrecision()
    predictions_list = []
    targets_list = []
    threshold = 0.0

    # data load
    dataset=SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms,
        task=task
    )
    print(f"start {task}")
    CLASSES = dataset
    NUM_CLASSES = len(CLASSES)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    dataloader.valid_len = math.ceil(len(dataset)/4)

    for batch in tqdm(dataloader, total=dataloader.valid_len, desc="Evaluation"):
        model.eval()
        outputs = model(batch)
     
        for i, (output, input_data) in enumerate(zip(outputs, batch)):
                instances = output['instances']
                mask = instances.scores > threshold

                pred_detection = Detections(
                    xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(),
                    class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                    confidence=instances.scores[mask].detach().cpu().numpy()
                )
                gt_instances = input_data['instances']
                target_detection = Detections(
                    xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                    class_id=gt_instances.gt_classes.detach().cpu().numpy()
                )

                predictions_list.append(pred_detection)
                targets_list.append(target_detection)

    map_metric.update(predictions=predictions_list, targets=targets_list)
    print(f"start {task} mAP computation")
    m_ap = map_metric.compute()

    per_class_map = {
        f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
        for idx in m_ap.matched_classes
    }

    print({
        "mAP@0.50:0.95": m_ap.map50_95.item(),
        "mAP@0.50": m_ap.map50.item(),
        "mAP@0.75": m_ap.map75.item(),
        **per_class_map,
    })

## Mean-Teacher

In [None]:
def setup_mean_teacher_model(model):
    """
    Setup Mean-Teacher model following ContinualTTA pattern
    """
    # Create teacher model as deep copy
    teacher_model = copy.deepcopy(model)
    teacher_model.eval()
    teacher_model.requires_grad_(False)

    # Disable online adaptation for teacher
    if hasattr(teacher_model, 'online_adapt'):
        teacher_model.online_adapt = False

    # Setup student model for training
    if hasattr(model, 'online_adapt'):
        model.online_adapt = False
    model.training = True

    # Set training mode for specific components
    if hasattr(model, 'proposal_generator'):
        model.proposal_generator.training = True
    if hasattr(model, 'roi_heads'):
        model.roi_heads.training = True

    return teacher_model

def setup_optimizer_for_adaptation(model, lr=1e-4):
    """
    Setup optimizer for normalization layers adaptation
    """
    params = []

    # Only adapt normalization layers (following ContinualTTA pattern)
    for name, module in model.named_modules():
        if isinstance(module, (nn.BatchNorm2d, nn.LayerNorm)) or 'norm' in name.lower():
            if hasattr(module, 'weight') and module.weight is not None:
                module.weight.requires_grad_(True)
                params.append(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.requires_grad_(True)
                params.append(module.bias)

    if len(params) > 0:
        optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=1e-4)
        return optimizer
    else:
        return None

def create_strong_augmentation():
    """
    Create strong augmentation for pseudo-labeling
    """
    return T.Compose([
        T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
        T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    ])

def set_pseudo_labels(inputs, outputs, conf_th=0.7):
    """
    Generate pseudo labels from teacher predictions
    Following ContinualTTA pattern
    """
    new_inputs = []

    for inp, oup in zip(inputs, outputs):
        # Get instances with confidence filtering
        instances = oup['instances']
        mask = instances.scores > conf_th
        filtered_instances = instances[mask]

        if len(filtered_instances) == 0:
            continue

        # Create new input for pseudo-labeling
        new_inp = {k: inp[k] for k in inp if k not in ['instances']}

        # Use the same image for simplicity (in practice, you might want strong augmentation)
        new_inp['image'] = inp['image']

        # Create pseudo ground truth
        pseudo_instances = Instances(inp['instances'].image_size)
        pseudo_instances.gt_classes = filtered_instances.pred_classes
        pseudo_instances.gt_boxes = filtered_instances.pred_boxes

        new_inp['instances'] = pseudo_instances
        new_inputs.append(new_inp)

    return new_inputs

def update_teacher_with_ema(teacher_model, student_model,
ema_beta=0.999):
    """
    Update teacher model using Exponential Moving Average
    Following ContinualTTA pattern
    """
    with torch.no_grad():
        for teacher_param, student_param in zip(teacher_model.parameters(), student_model.parameters()):
            if student_param.requires_grad:
                teacher_param.data = ema_beta * teacher_param.data + (1 - ema_beta) * student_param.data

# Main evaluation code with Mean-Teacher adaptation
for task in ["cloudy", "overcast", "foggy", "rainy", "dawn", "night",
"clear"]:
    map_metric = MeanAveragePrecision()
    predictions_list = []
    targets_list = []
    threshold = 0.0

    # Mean-Teacher hyperparameters
    ema_beta = 0.999
    conf_threshold = 0.7
    learning_rate = 1e-4

    # Data loading
    dataset = SHIFTCorruptedDatasetForObjectDetection(
        root=DATA_ROOT, valid=True,
        transform=datasets.detectron_image_transform,
        transforms=datasets.default_valid_transforms,
        task=task
    )
    print(f"Starting evaluation on {task} with Mean-Teacher adaptation")
    CLASSES = dataset
    NUM_CLASSES = len(CLASSES)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    dataloader.valid_len = math.ceil(len(dataset)/4)

    # Reset model to original state before each task
    # model = load_original_model()  # You need to implement this function

    # Setup Mean-Teacher
    teacher_model = setup_mean_teacher_model(model)
    optimizer = setup_optimizer_for_adaptation(model, lr=learning_rate)

    print(f"Teacher model created, optimizer setup with {len(optimizer.param_groups[0]['params']) if optimizer else 0} parameters")

    for batch_idx, batch in enumerate(tqdm(dataloader, total=dataloader.valid_len, desc=f"Adapting {task}")):

        # Step 1: Get teacher predictions for pseudo-labeling
        teacher_model.eval()
        with torch.no_grad():
            teacher_outputs = teacher_model(batch)

        # Step 2: Generate pseudo labels
        pseudo_inputs = set_pseudo_labels(batch, teacher_outputs, conf_th=conf_threshold)

        # Step 3: Train student model with pseudo labels
        if len(pseudo_inputs) > 0 and optimizer is not None:
            model.train()
            if hasattr(model, 'proposal_generator'):
                model.proposal_generator.training = True
            if hasattr(model, 'roi_heads'):
                model.roi_heads.training = True

            optimizer.zero_grad()

            try:
                # Forward pass with pseudo labels
                losses = model(pseudo_inputs)

                if isinstance(losses, dict):
                    total_loss = sum(losses.values())
                else:
                    total_loss = losses

                if total_loss > 0:
                    total_loss.backward()

                    # Gradient clipping (optional)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    optimizer.step()

                    # Step 4: Update teacher with EMA
                    update_teacher_with_ema(teacher_model, model, ema_beta)

            except Exception as e:
                print(f"Training step failed at batch {batch_idx}: {e}")

        # Step 5: Get final predictions for evaluation (use teacher)
        teacher_model.eval()
        with torch.no_grad():
            outputs = teacher_model(batch)

        # Process outputs for mAP calculation
        for i, (output, input_data) in enumerate(zip(outputs, batch)):
            instances = output['instances']
            mask = instances.scores > threshold

            pred_detection = Detections(xyxy=instances.pred_boxes.tensor[mask].detach().cpu().numpy(), 
                                        class_id=instances.pred_classes[mask].detach().cpu().numpy(),
                                        confidence=instances.scores[mask].detach().cpu().numpy()
                                        )
            gt_instances = input_data['instances']
            target_detection = Detections(
                xyxy=gt_instances.gt_boxes.tensor.detach().cpu().numpy(),
                class_id=gt_instances.gt_classes.detach().cpu().numpy()
                )

            predictions_list.append(pred_detection)
            targets_list.append(target_detection)

    # Compute mAP
    map_metric.update(predictions=predictions_list, targets=targets_list)
    print(f"Computing mAP for {task}")
    m_ap = map_metric.compute()

    per_class_map = {
        f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean().item()
        for idx in m_ap.matched_classes
    }

    print(f"Results for {task} with Mean-Teacher:")
    print({
        "mAP@0.50:0.95": m_ap.map50_95.item(),
        "mAP@0.50": m_ap.map50.item(),
        "mAP@0.75": m_ap.map75.item(),
        **per_class_map,
    })
    print("-" * 50)