# Overal Performance

In [1]:
# 导入库
import sys
import os

# 添加项目根目录到 sys.path
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))
sys.path.append(project_root)

cotta_root = os.path.join(project_root, 'baseline_code/cotta-main/cifar')
sys.path.append(cotta_root)

plf_root = os.path.join(project_root, 'baseline_code/PLF-main/cifar')
sys.path.append(plf_root)



import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm, trange
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

import math

from core_model.custom_model import ClassifierWrapper, load_custom_model
from core_model.dataset import get_dataset_loader

from core_model.train_test import model_test

from configs import settings


import pandas as pd


import cotta
import plf

import torch.optim as optim

from cfgs.conf_cotta import cfg as cfg_cotta
from cfgs.conf_plf import cfg as cfg_plf


from args_paser import parse_args


# 设置设备
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# sys.argv = ['', '--dataset','cifar-10', '--model', 'cifar-resnet18']
# custom_args = parse_args()
# custom_args


## 通用函数

In [10]:
def setup_cotta(model, args):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """


    def setup_optimizer(params):
        """Set up optimizer for tent adaptation.

        Tent needs an optimizer for test-time entropy minimization.
        In principle, tent could make use of any gradient optimizer.
        In practice, we advise choosing Adam or SGD+momentum.
        For optimization settings, we advise to use the settings from the end of
        trainig, if known, or start with a low learning rate (like 0.001) if not.

        For best results, try tuning the learning rate and batch size.
        """
        if cfg_cotta.OPTIM.METHOD == "Adam":
            return optim.Adam(
                params,
                lr=cfg_cotta.OPTIM.LR,
                betas=(cfg_cotta.OPTIM.BETA, 0.999),
                weight_decay=cfg_cotta.OPTIM.WD,
            )
        elif cfg_cotta.OPTIM.METHOD == "SGD":
            return optim.SGD(
                params,
                lr=cfg_cotta.OPTIM.LR,
                momentum=cfg_cotta.OPTIM.MOMENTUM,
                dampening=cfg_cotta.OPTIM.DAMPENING,
                weight_decay=cfg_cotta.OPTIM.WD,
                nesterov=cfg_cotta.OPTIM.NESTEROV,
            )
        else:
            raise NotImplementedError

    model = cotta.configure_model(model)
    params, param_names = cotta.collect_params(model)
    optimizer = setup_optimizer(params)
    cotta_model = cotta.CoTTA(
        model,
        optimizer,
        args,
        steps=cfg_cotta.OPTIM.STEPS,
        episodic=cfg_cotta.MODEL.EPISODIC,
        mt_alpha=cfg_cotta.OPTIM.MT,
        rst_m=cfg_cotta.OPTIM.RST,
        ap=cfg_cotta.OPTIM.AP,
    )
    return cotta_model



def setup_plf(model, custom_args, num_classes):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """

    

    def setup_optimizer(params):
        """Set up optimizer for tent adaptation.

        Tent needs an optimizer for test-time entropy minimization.
        In principle, tent could make use of any gradient optimizer.
        In practice, we advise choosing Adam or SGD+momentum.
        For optimization settings, we advise to use the settings from the end of
        trainig, if known, or start with a low learning rate (like 0.001) if not.

        For best results, try tuning the learning rate and batch size.
        """
        if cfg_plf.OPTIM.METHOD == "Adam":
            return optim.Adam(
                params,
                lr=cfg_plf.OPTIM.LR,
                betas=(cfg_plf.OPTIM.BETA, 0.999),
                weight_decay=cfg_plf.OPTIM.WD,
            )
        elif cfg_plf.OPTIM.METHOD == "SGD":
            return optim.SGD(
                params,
                lr=cfg_plf.OPTIM.LR,
                momentum=cfg_plf.OPTIM.MOMENTUM,
                dampening=cfg_plf.OPTIM.DAMPENING,
                weight_decay=cfg_plf.OPTIM.WD,
                nesterov=cfg_plf.OPTIM.NESTEROV,
            )
        else:
            raise NotImplementedError


    model = plf.configure_model(model)
    params, param_names = plf.collect_params(model)
    optimizer = setup_optimizer(params)
    plf_model = plf.PLF(
        model,
        optimizer,
        custom_args,
        steps=cfg_plf.OPTIM.STEPS,
        episodic=cfg_plf.MODEL.EPISODIC,
        mt_alpha=cfg_plf.OPTIM.MT,
        rst_m=cfg_plf.OPTIM.RST,
        ap=cfg_plf.OPTIM.AP,
        num_classes=num_classes,
    )
    return plf_model


def clean_accuracy(model: nn.Module,
                   x: torch.Tensor,
                   y: torch.Tensor,
                   batch_size: int = 100,
                   device: torch.device = None):
    if device is None:
        device = x.device
    acc = 0.
    n_batches = math.ceil(x.shape[0] / batch_size)
    with torch.no_grad():
        for counter in trange(n_batches):
            x_curr = x[counter * batch_size:(counter + 1) *
                       batch_size].to(device)
            y_curr = y[counter * batch_size:(counter + 1) *
                       batch_size].to(device)

            # print(f'Batch with sample num: {len(x_curr)}')
            output = model(x_curr)
            corrected_num = (output.max(1)[1] == y_curr).float().sum()
            acc += corrected_num
            
            # [2024-10-10 sunzekun] 屏蔽了结果输出，保持界面整洁
            # print('batch %d, corrected_num: %d' % (counter, corrected_num.item()))
        # save step model_tta
    return acc.item() / x.shape[0]

In [11]:
def get_test_acc(test_loader, model, device):

    # Run-Experiment代码里的评估代码。
    # 只能测试总体的test_acc
    # 放在这里只是为了检查一下错误是不是发生在eva中。实际上可能不用。
    criterion = nn.CrossEntropyLoss()
    model.eval().to(device)   
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        with tqdm(
            total=len(test_loader), desc=f"Testing"
        ) as pbar:
            for test_inputs, test_targets in test_loader:
                test_inputs, test_targets = test_inputs.to(device), test_targets.to(
                    device
                )
                test_outputs = model(test_inputs)
                loss = criterion(test_outputs, test_targets)
                _, predicted_test = torch.max(test_outputs, 1)
                total_test += test_targets.size(0)
                correct_test += (predicted_test == test_targets).sum().item()

                # 更新进度条
                pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
                pbar.update(1)

    test_accuracy = 100 * correct_test / total_test
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    return test_accuracy
        


class BaseTensorDataset(Dataset):

    # Run-Experiment代码里的自定义数据集。
    # 放在这里只是为了检查一下错误是不是发生在数据集中。实际上可能不用。

    def __init__(self, data, labels, transforms=None, device=None):
        self.data = torch.as_tensor(data, device=device)
        self.labels = torch.as_tensor(labels, device=device)
        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        if self.transforms is not None:
            self.transforms(data)

        return data, self.labels[index]
    
    

In [27]:
def get_suffix(method, step):
    # 不同的method有不同的后缀
    # 例如，contra有restore和tta的，我们在综合测评中应该只考虑tta
    # 并且根据step还有所不同。step0的情况下，所有的suffix都为worker_restore.
    # 因此，需要一个函数专门处理各种情况
    
    # Step-0
    if step == 0:
        return "worker_restore"
    
    # Step-1,2,3
    if method in ['cotta', 'plf', 'contra']:
        return "worker_tta"
    else:
        return "worker_restore"
    
def eva_test_acc(dataset_name, model_name, noise_type='symmetric'):
    """
    核心代码，用来评估指定dataset任务下的所有方法的step0~4的测试准确率
    具体测哪些组可以在methods这里更换
    """
    
    case = settings.get_case(noise_ratio=0.2, noise_type=noise_type, balanced=True)
    mean, std = None, None
    num_classes = settings.num_classes_dict[dataset_name]
    print(f'目前测试的数据集：{dataset_name}, case模式：{case}')

    # 读入测试数据集    
    # core.py中使用的数据集读入代码
    test_data, test_labels, test_dataloader = get_dataset_loader(
        dataset_name, "test", case, None, mean, std, batch_size=128, shuffle=False
    )

    # run_experiment.py中的代码, 用于对比验证效果。
    # print(f'Targeted dataset: {settings.get_dataset_path(dataset_name, case, "test_data")}')

    # D_test_data = np.load(
    #     settings.get_dataset_path(dataset_name, case, "test_data")
    # )
    # D_test_labels = np.load(
    #     settings.get_dataset_path(dataset_name, case, "test_label")
    # )
    # test_dataset = BaseTensorDataset(D_test_data, D_test_labels, device=device)
    # test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # [Debug用]-检查一下数据
    # x, y = next(iter(test_dataloader))
    # print(x[0])
        

    # steps = [f'step_{i}' for i in range(4)]
    # methods = ['contra']
    # methods = ['raw']
    # methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
    
    # methods = ['cotta']
    # methods = ['plf']
    methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'contra']
    steps = [i for i in range(4)]

    df = pd.DataFrame(index=methods, columns=steps)


    # for method in tqdm(methods):
    
    for method in methods:
        for step in steps:
            # 先确定一下模型的suffix
            model_suffix = get_suffix(method, step)
            # 读入模型架构
            # 放到循环里边，每次都要重新新建一个模型。
            # 牺牲了一些速度，但是好处是防止cotta和plf把模型架构更改了，导致repair类的模型出问题
            model = load_custom_model(model_name, num_classes, load_pretrained=False)
            model = ClassifierWrapper(model, num_classes)

            # 按照模型名和step数读入模型参数
            model_repair_save_path = settings.get_ckpt_path(dataset_name, case, model_name, model_suffix=model_suffix, step=step, unique_name=method)
            print(f'Evaluating {model_repair_save_path}')

            # checkpoint = torch.load(model_repair_save_path)
            try:
                checkpoint = torch.load(model_repair_save_path)
            except:
                print(f"Cannot find the weight file at {model_repair_save_path}. Just SKIP.")
                continue
            model.load_state_dict(checkpoint, strict=False)


            # [24-10-10 sunzekun] 有两个特殊的tta模型：cotta和plf，不仅改了参数，而且使用了特定的模型架构来进行推断。
            # (即，不仅有test-time Adaptation，还有Augmentation)
            # 所以对这两种需要另外写代码实现。
            # 注意，contra没有这个特殊过程。
            if method == 'cotta' or method == 'plf':
                # 由于测试代码的jupyternotebook 构建命令行参数很麻烦
                # 这里暂时去掉了dataset，model这两个必选参数的required = True
                # 而是直接在使用的时候复制。
                sys.argv = ['', '--dataset', dataset_name, '--model', model_name, '--uni_name', method, '--balanced']
                custom_args = parse_args()
                model.eval().to(device)
                if method == 'cotta':
                    model_aug = setup_cotta(model, custom_args)
                else:
                    model_aug = setup_plf(model, custom_args, num_classes)
                
                try:
                    model_aug.reset()
                except:
                    print(f'Failed to reset')
                
                x_test = torch.from_numpy(test_data)
                y_test = torch.from_numpy(test_labels)
                x_test, y_test = x_test.to(device), y_test.to(device)

                test_acc = clean_accuracy(model_aug, x_test, y_test, batch_size=1024, device=device)                
                # print(test_acc)
            else:
                test_acc = model_test(test_dataloader, model, device=device)
                # print(test_acc)
                test_acc = test_acc['global']
            
            print(f"测试集Acc：{test_acc}")
            df.loc[method, step] = test_acc
            # test_acc = get_test_acc(test_dataloader, model, device=device)
            # df.loc[method, step] = test_acc

    results_dir = './results_main'
    os.makedirs(results_dir, exist_ok=True)

    # 保留参考代码，万一想用excel保存结果的时候用下面的
    # results_file_name = f'{dataset_name}.xlsx'
    # results_file_dir = os.path.join(results_dir, results_file_name)
    # with pd.ExcelWriter(results_file_dir, engine='openpyxl') as writer:
    #     df.to_excel(writer)
        # for tab, df in dfs.items():

    # cls: 分类任务
    # rtv: 检索任务
    mission_type = 'cls' if noise_type == 'symmetric' else 'rtv'
    results_file_name = f'{dataset_name}_{mission_type}.csv'
    results_file_dir = os.path.join(results_dir, results_file_name)
    df.to_csv(results_file_dir)

## Cifar-10 分类

In [28]:
dataset_name = 'cifar-10'
model_name = 'cifar-resnet18'
noise_type='symmetric'

eva_test_acc(dataset_name, model_name, noise_type)

目前测试的数据集：cifar-10, case模式：nr_0.2_nt_symmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/raw/cifar-resnet18_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)


test_acc: 72.19
label: 0, acc: 75.30
label: 1, acc: 89.80
label: 2, acc: 58.90
label: 3, acc: 42.60
label: 4, acc: 65.60
label: 5, acc: 64.50
label: 6, acc: 81.30
label: 7, acc: 75.30
label: 8, acc: 87.30
label: 9, acc: 81.30
测试集Acc：0.7219
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/raw/cifar-resnet18_worker_restore.pth
test_acc: 76.56
label: 0, acc: 82.20
label: 1, acc: 91.80
label: 2, acc: 65.90
label: 3, acc: 48.00
label: 4, acc: 71.90
label: 5, acc: 69.40
label: 6, acc: 87.30
label: 7, acc: 80.60
label: 8, acc: 85.50
label: 9, acc: 83.00
测试集Acc：0.7656
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/raw/cifar-resnet18_worker_restore.pth
test_acc: 74.12
label: 0, acc: 80.80
label: 1, acc: 91.40
label: 2, acc: 67.50
label: 3, acc: 43.70
label: 4, acc: 70.60
label: 5, acc: 62.10
label: 6, acc: 84.00
label: 7, acc: 76.00
label: 8, acc: 84.20
label: 9, acc: 80.90
测试集Acc：0.7412
Evaluating /nvme/sun

## Pet-37 分类

In [None]:
dataset_name = 'pet-37'
model_name = 'wideresnet50'
noise_type='symmetric'

eva_test_acc(dataset_name, model_name, noise_type)

## Cifar-100 检索

In [None]:
dataset_name = 'cifar-100'
model_name = 'cifar-wideresnet40'
noise_type= 'asymmetric'

eva_test_acc(dataset_name, model_name, noise_type)

## Pet-37 检索

In [None]:
dataset_name = 'pet-37'
model_name = 'wideresnet50'
noise_type= 'asymmetric'

eva_test_acc(dataset_name, model_name, noise_type)