In [2]:
!nvidia-smi

Tue Jul 29 09:05:49 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   52C    P0              44W / 250W |   5778MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [3]:
# Set CUDA Device Number
DEVICE_NUM = 3
ADDITIONAL_GPU = 0

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = ",".join([f"{i+DEVICE_NUM}" for i in range(0, ADDITIONAL_GPU+1)])
environ["CUDA_VISIBLE_DEVICES"]

'3'

## Import

In [None]:
import sys, os

root = '/workspace/ptta'
if root not in sys.path:
    sys.path.insert(0, root)

from os import path
import argparse
from pathlib import Path
import torch.optim as optim
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
import logging
log = logging.getLogger('ActMAD')

from torch.utils.data import DataLoader

from ttadapters.datasets import BaseDataset, DatasetHolder, DataLoaderHolder
from ttadapters.datasets import SHIFTClearDatasetForObjectDetection, SHIFTCorruptedDatasetForObjectDetection
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from accelerate import Accelerator, notebook_launcher

## Dataset

In [5]:
# DATA_ROOT = path.join(".", "data")
DATA_ROOT = path.abspath("/workspace/ptta/data")

dataset = DatasetHolder(
    train=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, train=True),
    valid=SHIFTClearDatasetForObjectDetection(root=DATA_ROOT, valid=True),
    test=SHIFTCorruptedDatasetForObjectDetection(root=DATA_ROOT, valid=True)
)

[07/29/2025 09:06:08] SHIFT DevKit - INFO - Base: /workspace/ptta/data/SHIFT/discrete/images/train. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7f88619b1ad0>
[07/29/2025 09:06:08] SHIFT DevKit - INFO - Loading annotation from '/workspace/ptta/data/SHIFT_SUBSET/normal/discrete/images/train/front/det_2d.json' ...


INFO: Downloading 'SHIFT_SUBSET' from file server to /workspace/ptta/data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to /workspace/ptta/data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[07/29/2025 09:06:10] SHIFT DevKit - INFO - Loading annotation from '/workspace/ptta/data/SHIFT_SUBSET/normal/discrete/images/train/front/det_2d.json' Done.
[07/29/2025 09:06:25] SHIFT DevKit - INFO - Loading annotation takes 17.19 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0016-1b62']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                    -7.53     219.91
boxes2d              torch.Size([1, 26, 4])                    5.00     974.00
boxes2d_classes      torch.Size([1, 26])                       0.00       3.00
boxes2d_track_ids    torch.Size([1, 26])                       0.00      25.00
images               torch.Size([1, 1, 3, 800, 1280])          0.00     255.00



[07/29/2025 09:06:30] SHIFT DevKit - INFO - Base: /workspace/ptta/data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7f88619b1ad0>
[07/29/2025 09:06:30] SHIFT DevKit - INFO - Loading annotation from '/workspace/ptta/data/SHIFT_SUBSET/normal/discrete/images/val/front/det_2d.json' ...


Video name: 0016-1b62
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
INFO: Downloading 'SHIFT_SUBSET' from file server to /workspace/ptta/data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to /workspace/ptta/data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[07/29/2025 09:06:31] SHIFT DevKit - INFO - Loading annotation from '/workspace/ptta/data/SHIFT_SUBSET/normal/discrete/images/val/front/det_2d.json' Done.
[07/29/2025 09:06:32] SHIFT DevKit - INFO - Loading annotation takes 1.94 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['0116-4859']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                    -0.90     138.34
boxes2d              torch.Size([1, 6, 4])                   246.00     859.00
boxes2d_classes      torch.Size([1, 6])                        1.00       5.00
boxes2d_track_ids    torch.Size([1, 6])                        0.00       5.00
images               torch.Size([1, 1, 3, 800, 1280])          0.00     255.00



[07/29/2025 09:06:33] SHIFT DevKit - INFO - Base: /workspace/ptta/data/SHIFT/discrete/images/val. Backend: <shift_dev.utils.backend.ZipBackend object at 0x7f88619b1ad0>
[07/29/2025 09:06:33] SHIFT DevKit - INFO - Loading annotation from '/workspace/ptta/data/SHIFT_SUBSET/corrupted/discrete/images/val/front/det_2d.json' ...


Video name: 0116-4859
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
INFO: Downloading 'SHIFT_SUBSET' from file server to /workspace/ptta/data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.
INFO: Subset split for 'SHIFT_SUBSET' dataset is already done. Skipping...
INFO: Downloading 'SHIFT_SUBSET' from file server to /workspace/ptta/data/SHIFT/discrete...
INFO: Dataset archive found in the root directory. Skipping download.


[07/29/2025 09:06:35] SHIFT DevKit - INFO - Loading annotation from '/workspace/ptta/data/SHIFT_SUBSET/corrupted/discrete/images/val/front/det_2d.json' Done.
[07/29/2025 09:07:04] SHIFT DevKit - INFO - Loading annotation takes 31.27 seconds.


Batch 0:

Item                 Shape                               Min        Max       
--------------------------------------------------------------------------------
original_hw          [tensor([800]), tensor([1280])]
input_hw             [tensor([800]), tensor([1280])]
frame_ids            torch.Size([1])                           0.00       0.00
name                 ['00000000_img_front.jpg']
videoName            ['007b-4e72']
intrinsics           torch.Size([1, 3, 3])                     0.00     640.00
extrinsics           torch.Size([1, 4, 4])                  -311.22     226.46
boxes2d              torch.Size([1, 3, 4])                   233.00     802.00
boxes2d_classes      torch.Size([1, 3])                        0.00       1.00
boxes2d_track_ids    torch.Size([1, 3])                        0.00       2.00
images               torch.Size([1, 1, 3, 800, 1280])          0.00     255.00

Video name: 007b-4e72
Sample indices within a video: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,

## Dataloader

In [6]:
# Set Batch Size
BATCH_SIZE = 2, 32, 32

# Dataset Configs
CLASSES = dataset.train.classes
NUM_CLASSES = len(CLASSES)

print(f"INFO: Set batch size - Train: {BATCH_SIZE[0]}, Valid: {BATCH_SIZE[1]}, Test: {BATCH_SIZE[2]}")
print(f"INFO: Number of classes - {NUM_CLASSES} {CLASSES}")

INFO: Set batch size - Train: 2, Valid: 32, Test: 32
INFO: Number of classes - 6 ['pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle']


In [7]:
class DatasetAdapterForTransformers(BaseDataset):
    def __init__(self, original_dataset, camera='front'):
        self.dataset = original_dataset
        self.camera = camera

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx][self.camera]
        image = item['images'].squeeze(0)

        # Convert to COCO_Detection Format
        annotations = []
        target = dict(image_id=idx, annotations=annotations)
        for box, cls in zip(item['boxes2d'], item['boxes2d_classes']):
            x1, y1, x2, y2 = box.tolist()  # from Pascal VOC format (x1, y1, x2, y2)
            width, height = x2 - x1, y2 - y1
            annotations.append(dict(
                bbox=[x1, y1, width, height],  # to COCO format: [x, y, width, height]
                category_id=cls.item(),
                area=width * height,
                iscrowd=0
            ))

        # Following prepare_coco_detection_annotation's expected format
        # RT-DETR ImageProcessor converts the COCO bbox to center format (cx, cy, w, h) during preprocessing
        # But, eventually re-converts the bbox to Pascal VOC (x1, y1, x2, y2) format after post-processing
        return dict(image=image, target=target)

In [8]:
def collate_fn(batch, preprocessor=None):
    images = [item['image'] for item in batch]
    if preprocessor is not None:
        target = [item['target'] for item in batch]
        return preprocessor(images=images, annotations=target, return_tensors="pt")
    else:
        # If no preprocessor is provided, just assume images are already in tensor format
        return dict(
            pixel_values=dict(pixel_values=torch.stack(images)),
            labels=[dict(
                class_labels=item['boxes2d_classes'].long(),
                boxes=item["boxes2d"].float()
            ) for item in batch]
        )

### 데이터셋에 대한 논의 필요.
#### 이 방법(ActMAD)는 clear dataset을 통해 train을 미리 진행 함. 이후 오염된 dataset으로 TTA진행 -> TEST진행
#### 기존 우리의 test dataset을 분할하여 tta하고, 성능을 측정해야할 것 같음

### Load Model

In [None]:
from transformers import RTDetrForObjectDetection, RTDetrImageProcessorFast, RTDetrConfig, SwinConfig, ResNetConfig
from transformers.image_utils import AnnotationFormat

In [None]:
USE_PRETRAINED_MODEL = False
USE_SWIN_T_BACKBONE = False

In [None]:
reference_model_id = "PekingU/rtdetr_r50vd"

# Load the reference model configuration
reference_config = RTDetrConfig.from_pretrained(reference_model_id, torch_dtype=torch.float32, return_dict=True)
reference_config.num_labels = NUM_CLASSES

# Load the reference model image processor
reference_preprocessor = RTDetrImageProcessorFast.from_pretrained(reference_model_id)
reference_preprocessor.format = AnnotationFormat.COCO_DETECTION  # COCO Format / Detection BBOX Format

In [None]:
if USE_PRETRAINED_MODEL:
    # Load the pre-trained model
    model = RTDetrForObjectDetection.from_pretrained(reference_model_id, config=reference_config, torch_dtype=torch.float32, ignore_mismatched_sizes=True)
else:
    # Set the image size and preprocessor size
    reference_config.image_size = 800
    reference_preprocessor.do_resize = True
    reference_preprocessor.size = {"height": 800, "width": 800}

    # Set the backbone configuration
    if USE_SWIN_T_BACKBONE:
        backbone_url = "https://github.com/robustaim/ContinualTTA_ObjectDetection/releases/download/backbone_converted/swin_tiny_patch4_window7_shift_from_detectron2.pth"
        reference_config.backbone_config = SwinConfig(
            embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], output_hidden_states=True,
            window_size=7, out_features=["stage1", "stage2", "stage3"]  # stride 8·16·32
        ) 
    else:
        backbone_url = "https://github.com/robustaim/ContinualTTA_ObjectDetection/releases/download/backbone_converted/resnet50_shift_from_detectron2.pth"
        reference_config.backbone_config = ResNetConfig(depths=[3,4,6,3], out_indices=(2, 3, 4))

    # Initialize a new model with the reference configuration
    model = RTDetrForObjectDetection(config=reference_config)
    backbone_state = torch.hub.load_state_dict_from_url(backbone_url, map_location="cpu")
    model.model.backbone.model.load_state_dict(backbone_state, strict=False)
model.to(device)

In [None]:
from transformers.trainer_utils import EvalPrediction
from torchvision.ops import box_convert
from dataclasses import dataclass


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


def de_normalize_boxes(boxes, height, width):
    # 1. cxcywh → xyxy
    boxes_xyxy_norm = box_convert(boxes, 'cxcywh', 'xyxy')

    # 2. de-normalize (convert to actual pixel coordinates)
    boxes_xyxy_norm[:, [0, 2]] *= width
    boxes_xyxy_norm[:, [1, 3]] *= height
    return boxes_xyxy_norm


def map_compute_metrics(preprocessor=reference_preprocessor, threshold=0.0):
    map_metric = MeanAveragePrecision()
    post_process = preprocessor.post_process_object_detection

    def calc(eval_pred: EvalPrediction, compute_result=False):
        nonlocal map_metric

        if compute_result:
            m_ap = map_metric.compute()
            map_metric.reset()

            per_class_map = {
                f"{CLASSES[idx]}_mAP@0.50:0.95": m_ap.ap_per_class[idx].mean()
                for idx in m_ap.matched_classes
            }

            return {
                "mAP@0.50:0.95": m_ap.map50_95,
                "mAP@0.50": m_ap.map50,
                "mAP@0.75": m_ap.map75,
                **per_class_map
            }
        else:
            preds = ModelOutput(*eval_pred.predictions[1:3])
            labels = eval_pred.label_ids
            sizes = [label['orig_size'].cpu().tolist() for label in labels]

            results = post_process(preds, target_sizes=sizes, threshold=threshold)
            predictions = [Detections.from_transformers(result) for result in results]
            targets = [Detections(
                xyxy=de_normalize_boxes(label['boxes'], *label['orig_size']).cpu().numpy(),
                class_id=label['class_labels'].cpu().numpy(),
            ) for label in labels]

            map_metric.update(predictions=predictions, targets=targets)
            return {}
    return calc, map_metric

In [None]:
class DifferentiableLRTrainer(Trainer):
    def create_optimizer(self):
        backbone_params = []
        head_params = []

        for name, param in self.model.named_parameters():
            if 'backbone' in name:
                backbone_params.append(param)
            else:
                head_params.append(param)

        self.optimizer = torch.optim.AdamW([
            {'params': backbone_params, 'lr': self.args.backbone_lr},
            {'params': head_params, 'lr': self.args.learning_rate}
        ], weight_decay=self.args.weight_decay)

        return self.optimizer


class DifferentiableLRTrainingArguments(TrainingArguments):
    def __init__(self, *args, backbone_lr=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.backbone_lr = backbone_lr

### Trainer

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 20
REAL_BATCH = BATCH_SIZE[-1]
LEARNING_RATE = 1e-4

training_args = DifferentiableLRTrainingArguments(
    backbone_lr=LEARNING_RATE/10,  # Set backbone learning rate to 1/10th of the main learning rate
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.1,
    max_grad_norm=0.5,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE[0],
    per_device_eval_batch_size=BATCH_SIZE[1],
    gradient_accumulation_steps=REAL_BATCH//BATCH_SIZE[0],
    eval_accumulation_steps=BATCH_SIZE[1],
    batch_eval_metrics=True,
    remove_unused_columns=False,
    optim="adamw_torch",
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=100,
    save_steps=100,
    logging_steps=100,
    save_total_limit=100,
    load_best_model_at_end=True,
    metric_for_best_model="mAP@0.50:0.95",
    greater_is_better=True,
    #metric_for_best_model="eval_loss",
    #greater_is_better=False,
    report_to="wandb",
    output_dir="./results/"+RUN_NAME,
    logging_dir="./logs/"+RUN_NAME,
    run_name=RUN_NAME,
    bf16=True,
)

testing_args = TrainingArguments(
    per_device_eval_batch_size=BATCH_SIZE[2],
    batch_eval_metrics=True,
    remove_unused_columns=False,
)

In [None]:
from functools import partial

compute_metrics, compute_results = map_compute_metrics(preprocessor=reference_preprocessor)

trainer = DifferentiableLRTrainer(
    model=model,
    args=training_args,
    train_dataset=DatasetAdapterForTransformers(dataset.train),
    eval_dataset=DatasetAdapterForTransformers(dataset.valid),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
)

tester = DifferentiableLRTrainer(
    model=model,
    args=testing_args,
    eval_dataset=DatasetAdapterForTransformers(dataset.test),
    data_collator=partial(collate_fn, preprocessor=reference_preprocessor),
    compute_metrics=compute_metrics
)

In [None]:
def start_train():
    accelerator = Accelerator()
    while True:
        try:
            try:
                print("INFO: Trying to resume from previous checkpoint")
                compute_results.reset()
                trainer.train(resume_from_checkpoint=True)
            except Exception as e:
                if "No valid checkpoint found" in str(e):
                    print(f"ERROR: Failed to resume from checkpoint - {e}")
                    print("INFO: Starting training from scratch")
                    compute_results.reset()
                    trainer.train(resume_from_checkpoint=False)
        except Exception as e:
            if "CUDA" in str(e):
                print(f"ERROR: CUDA Error - {e}")
                trainer.train()
            else:
                raise e

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class SaveOutput:
    def __init__(self):
        self.outputs = []

    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out.clone())

    def clear(self):
        self.outputs = []

    def get_out_mean(self):
        out = torch.vstack(self.outputs)
        out = torch.mean(out, dim=0)
        return out

    def get_out_var(self):
        out = torch.vstack(self.outputs)
        out = torch.var(out, dim=0)
        return out

In [None]:
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments

from transformers.utils import (
    ADAPTER_CONFIG_NAME,
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
    CONFIG_NAME,
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    XLA_FSDPV2_MIN_VERSION,
    PushInProgress,
    PushToHubMixin,
    can_return_loss,
    check_torch_load_is_safe,
    find_labels,
    is_accelerate_available,
    is_apollo_torch_available,
    is_bitsandbytes_available,
    is_datasets_available,
    is_galore_torch_available,
    is_grokadamw_available,
    is_in_notebook,
    is_liger_kernel_available,
    is_lomo_available,
    is_peft_available,
    is_safetensors_available,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_schedulefree_available,
    is_torch_hpu_available,
    is_torch_mlu_available,
    is_torch_mps_available,
    is_torch_musa_available,
    is_torch_neuroncore_available,
    is_torch_npu_available,
    is_torch_xla_available,
    is_torch_xpu_available,
    is_torchao_available,
    logging,
    strtobool,
)

from transformers import TrainerCallback, TrainerControl, TrainerState
from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
from accelerate.utils import (
    AutocastKwargs,
    DistributedDataParallelKwargs,
    DistributedType,
    load_fsdp_model,
    load_fsdp_optimizer,
    save_fsdp_model,
    save_fsdp_optimizer,
)
from typing import TYPE_CHECKING, Any, Callable, Optional, Union


class CustomTrainer(Trainer):
    def __init__(self, val_dataloader, chosen_bn_layers, clean_mean_list, clean_var_list, device, half):
        self.val_dataloader    = val_dataloader
        self.chosen_bn_layers  = chosen_bn_layers
        self.clean_mean_list   = clean_mean_list
        self.clean_var_list    = clean_var_list
        self.device            = device
        self.half              = half
        
    def training_step(
        self,
        model: nn.Module,
        inputs: dict[str, Union[torch.Tensor, Any]],
        num_items_in_batch: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        # add
        for m in model.modules():
            if isinstance(m, nn.modules.batchnorm._BatchNorm):
                m.eval()
        
        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
            self.optimizer.train()

        inputs = self._prepare_inputs(inputs)
        hook_list_tta = [self.chosen_bn_layers[x].register_forward_hook(self.save_outputs_tta[x])
                            for x in range(self.n_chosen_layers)]
        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

        del inputs
        if (
            self.args.torch_empty_cache_steps is not None
            and self.state.global_step % self.args.torch_empty_cache_steps == 0
        ):
            if is_torch_xpu_available():
                torch.xpu.empty_cache()
            elif is_torch_mlu_available():
                torch.mlu.empty_cache()
            elif is_torch_musa_available():
                torch.musa.empty_cache()
            elif is_torch_npu_available():
                torch.npu.empty_cache()
            elif is_torch_mps_available():
                torch.mps.empty_cache()
            elif is_torch_hpu_available():
                logger.warning(
                    "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
                )
            else:
                torch.cuda.empty_cache()

        kwargs = {}

        # For LOMO optimizers you need to explicitly use the learnign rate
        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
            kwargs["learning_rate"] = self._get_learning_rate()

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            from apex import amp

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
            if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
                loss = loss / self.args.gradient_accumulation_steps

            # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
            # https://github.com/huggingface/transformers/pull/35808
            if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs["scale_wrt_gas"] = False

            self.accelerator.backward(loss, **kwargs)

            return loss.detach()
        
        
    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
        num_items_in_batch: Optional[torch.Tensor] = None,
    ):
        
        

## ActMAD

In [None]:
def ActMAD(model, ckpt_path, half_precision=False):
    log.info('ActMAD start')
    
    model.load_state_dict(torch.load(ckpt_path))
    log.info('model load')
    
    all_ap = []
    
    device = next(model.parameters()).device
    
    half = device.type != 'cpu' and half_precision  # half precision only supported on CUDA
    if half:
        model.half()
        
    # dataloader
    dataloader_train = DataLoader(DatasetAdapterForTransformers(dataset.train))
    dataloader_val = DataLoader(DatasetAdapterForTransformers(dataset.test))
    dataloader_test = DataLoader(DatasetAdapterForTransformers(dataset.test))
    
    l1_loss = nn.L1Loss(reduction='mean')
    
    # pre-trained model : freeze -> unfreeze
    for k, v in model.named_parameters():
            v.requires_grad = True
            
    chosen_bn_layers = []
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            chosen_bn_layers.append(m)
    chosen_bn_layers = chosen_bn_layers[26:]
    n_chosen_layers = len(chosen_bn_layers)
    save_outputs = [SaveOutput() for _ in range(n_chosen_layers)]
    clean_mean_act_list = [AverageMeter() for _ in range(n_chosen_layers)]
    clean_var_act_list = [AverageMeter() for _ in range(n_chosen_layers)]
    clean_mean_list_final = []
    clean_var_list_final = []
    # extract the activation alignment in train dataset
    with torch.no_grad():
        for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader_train)):
            img = img.to(device, non_blocking=True)
            img = img.half() if half else img.float()  # uint8 to fp16/32
            img /= 255.0  # 0 - 255 to 0.0 - 1.0
            model.eval()
            hook_list = [chosen_bn_layers[i].register_forward_hook(save_outputs[i]) for i in range(n_chosen_layers)]
            _ = model(img)

            for i in range(n_chosen_layers):
                clean_mean_act_list[i].update(save_outputs[i].get_out_mean())  # compute mean from clean data
                clean_var_act_list[i].update(save_outputs[i].get_out_var())  # compute variane from clean data

                save_outputs[i].clear()
                hook_list[i].remove()
        for i in range(n_chosen_layers):
            clean_mean_list_final.append(clean_mean_act_list[i].avg)  # [C, H, W]
            clean_var_list_final.append(clean_var_act_list[i].avg)  # [C, H, W]
            
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4, nesterov=True)

    log.info('Starting TEST TIME ADAPTATION WITH ActMAD...')
    # ap_epochs = list()

    # TTA start(여기에 trainer 사용해야할듯)
    for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader_val)):
        model.train()
        for m in model.modules():
            if isinstance(m, nn.modules.batchnorm._BatchNorm):
                m.eval()

        optimizer.zero_grad()
        save_outputs_tta = [SaveOutput() for _ in range(n_chosen_layers)]

        hook_list_tta = [chosen_bn_layers[x].register_forward_hook(save_outputs_tta[x])
                            for x in range(n_chosen_layers)]
        img = img.to(device, non_blocking=True)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        img = img.cuda()
        _ = model(img)
        batch_mean_tta = [save_outputs_tta[x].get_out_mean() for x in range(n_chosen_layers)]
        batch_var_tta = [save_outputs_tta[x].get_out_var() for x in range(n_chosen_layers)]

        loss_mean = torch.tensor(0, requires_grad=True, dtype=torch.float).float().cuda()
        loss_var = torch.tensor(0, requires_grad=True, dtype=torch.float).float().cuda()

        for i in range(n_chosen_layers):
            loss_mean += l1_loss(batch_mean_tta[i].cuda(), clean_mean_list_final[i].cuda())
            loss_var += l1_loss(batch_var_tta[i].cuda(), clean_var_list_final[i].cuda())

        loss = loss_mean + loss_var

        loss.backward()
        optimizer.step()
        
        # test
        for z in range(n_chosen_layers):
            save_outputs_tta[z].clear()
            hook_list_tta[z].remove()
            
            ap50 = test(batch_size=args.batch_size,
                            imgsz=args.img_size,
                            conf_thres=args.conf_thres,
                            iou_thres=args.iou_thres,
                            augment=args.augment,
                            verbose=False,
                            multi_label=True,
                            model=model,
                            dataloader=dataloader_test,
                            save_dir=args.save_dir)[-1]
            all_ap.append(np.mean(ap50))
            Path(f'{args.save_dir}/results_stf_ttt/{args.task}/all/').mkdir(parents=True, exist_ok=True)
            np.save(f'{args.save_dir}/results_stf_ttt/{args.task}/all/{batch_i}.npy', ap50)
            if np.mean(ap50) >= max(all_ap):
                Path(f'{args.save_dir}/results_stf_ttt/{args.task}/best/').mkdir(exist_ok=True, parents=True)
                np.save(f'{args.save_dir}/results_stf_ttt/{args.task}/best/{args.task}.npy', ap50)
                state = {
                    'net': model.state_dict()
                }
                Path(f'{args.save_dir}/results_stf_ttt/models/').mkdir(parents=True, exist_ok=True)
                torch.save(state, f'{args.save_dir}/results_stf_ttt/models/{args.task}.pt')