# 环境导入

In [1]:
# 导入库
import sys
import os

# 添加项目根目录到 sys.path
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))
sys.path.append(project_root)

cotta_root = os.path.join(project_root, 'baseline_code/cotta-main/cifar')
sys.path.append(cotta_root)

plf_root = os.path.join(project_root, 'baseline_code/PLF-main/cifar')
sys.path.append(plf_root)



import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm, trange
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

import math

from core_model.custom_model import ClassifierWrapper, load_custom_model
from core_model.dataset import get_dataset_loader

from core_model.train_test import model_test

from configs import settings


import pandas as pd


import cotta
import plf

import torch.optim as optim

from cfgs.conf_cotta import cfg as cfg_cotta
from cfgs.conf_plf import cfg as cfg_plf


from args_paser import parse_args


# 设置设备
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# sys.argv = ['', '--dataset','cifar-10', '--model', 'cifar-resnet18']
# custom_args = parse_args()
# custom_args


## 通用函数

In [2]:
def setup_cotta(model, args):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """


    def setup_optimizer(params):
        """Set up optimizer for tent adaptation.

        Tent needs an optimizer for test-time entropy minimization.
        In principle, tent could make use of any gradient optimizer.
        In practice, we advise choosing Adam or SGD+momentum.
        For optimization settings, we advise to use the settings from the end of
        trainig, if known, or start with a low learning rate (like 0.001) if not.

        For best results, try tuning the learning rate and batch size.
        """
        if cfg_cotta.OPTIM.METHOD == "Adam":
            return optim.Adam(
                params,
                lr=cfg_cotta.OPTIM.LR,
                betas=(cfg_cotta.OPTIM.BETA, 0.999),
                weight_decay=cfg_cotta.OPTIM.WD,
            )
        elif cfg_cotta.OPTIM.METHOD == "SGD":
            return optim.SGD(
                params,
                lr=cfg_cotta.OPTIM.LR,
                momentum=cfg_cotta.OPTIM.MOMENTUM,
                dampening=cfg_cotta.OPTIM.DAMPENING,
                weight_decay=cfg_cotta.OPTIM.WD,
                nesterov=cfg_cotta.OPTIM.NESTEROV,
            )
        else:
            raise NotImplementedError

    model = cotta.configure_model(model)
    params, param_names = cotta.collect_params(model)
    optimizer = setup_optimizer(params)
    cotta_model = cotta.CoTTA(
        model,
        optimizer,
        args,
        steps=cfg_cotta.OPTIM.STEPS,
        episodic=cfg_cotta.MODEL.EPISODIC,
        mt_alpha=cfg_cotta.OPTIM.MT,
        rst_m=cfg_cotta.OPTIM.RST,
        ap=cfg_cotta.OPTIM.AP,
    )
    return cotta_model



def setup_plf(model, custom_args, num_classes):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """

    

    def setup_optimizer(params):
        """Set up optimizer for tent adaptation.

        Tent needs an optimizer for test-time entropy minimization.
        In principle, tent could make use of any gradient optimizer.
        In practice, we advise choosing Adam or SGD+momentum.
        For optimization settings, we advise to use the settings from the end of
        trainig, if known, or start with a low learning rate (like 0.001) if not.

        For best results, try tuning the learning rate and batch size.
        """
        if cfg_plf.OPTIM.METHOD == "Adam":
            return optim.Adam(
                params,
                lr=cfg_plf.OPTIM.LR,
                betas=(cfg_plf.OPTIM.BETA, 0.999),
                weight_decay=cfg_plf.OPTIM.WD,
            )
        elif cfg_plf.OPTIM.METHOD == "SGD":
            return optim.SGD(
                params,
                lr=cfg_plf.OPTIM.LR,
                momentum=cfg_plf.OPTIM.MOMENTUM,
                dampening=cfg_plf.OPTIM.DAMPENING,
                weight_decay=cfg_plf.OPTIM.WD,
                nesterov=cfg_plf.OPTIM.NESTEROV,
            )
        else:
            raise NotImplementedError


    model = plf.configure_model(model)
    params, param_names = plf.collect_params(model)
    optimizer = setup_optimizer(params)
    plf_model = plf.PLF(
        model,
        optimizer,
        custom_args,
        steps=cfg_plf.OPTIM.STEPS,
        episodic=cfg_plf.MODEL.EPISODIC,
        mt_alpha=cfg_plf.OPTIM.MT,
        rst_m=cfg_plf.OPTIM.RST,
        ap=cfg_plf.OPTIM.AP,
        num_classes=num_classes,
    )
    return plf_model


def clean_accuracy(model: nn.Module,
                   x: torch.Tensor,
                   y: torch.Tensor,
                   batch_size: int = 100,
                   device: torch.device = None):
    if device is None:
        device = x.device
    acc = 0.
    n_batches = math.ceil(x.shape[0] / batch_size)
    with torch.no_grad():
        for counter in trange(n_batches):
            x_curr = x[counter * batch_size:(counter + 1) *
                       batch_size].to(device)
            y_curr = y[counter * batch_size:(counter + 1) *
                       batch_size].to(device)

            # print(f'Batch with sample num: {len(x_curr)}')
            output = model(x_curr)
            corrected_num = (output.max(1)[1] == y_curr).float().sum()
            acc += corrected_num
            
            # [2024-10-10 sunzekun] 屏蔽了结果输出，保持界面整洁
            # print('batch %d, corrected_num: %d' % (counter, corrected_num.item()))
        # save step model_tta
    return acc.item() / x.shape[0]

In [3]:
def get_test_acc(test_loader, model, device):

    # Run-Experiment代码里的评估代码。
    # 只能测试总体的test_acc
    # 放在这里只是为了检查一下错误是不是发生在eva中。实际上可能不用。
    criterion = nn.CrossEntropyLoss()
    model.eval().to(device)   
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        with tqdm(
            total=len(test_loader), desc=f"Testing"
        ) as pbar:
            for test_inputs, test_targets in test_loader:
                test_inputs, test_targets = test_inputs.to(device), test_targets.to(
                    device
                )
                test_outputs = model(test_inputs)
                loss = criterion(test_outputs, test_targets)
                _, predicted_test = torch.max(test_outputs, 1)
                total_test += test_targets.size(0)
                correct_test += (predicted_test == test_targets).sum().item()

                # 更新进度条
                pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
                pbar.update(1)

    test_accuracy = 100 * correct_test / total_test
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    return test_accuracy
        


class BaseTensorDataset(Dataset):

    # Run-Experiment代码里的自定义数据集。
    # 放在这里只是为了检查一下错误是不是发生在数据集中。实际上可能不用。

    def __init__(self, data, labels, transforms=None, device=None):
        self.data = torch.as_tensor(data, device=device)
        self.labels = torch.as_tensor(labels, device=device)
        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        if self.transforms is not None:
            self.transforms(data)

        return data, self.labels[index]
    
    

In [17]:
def get_suffix(method, step):
    # 不同的method有不同的后缀
    # 例如，contra有restore和tta的，我们在综合测评中应该只考虑tta
    # 并且根据step还有所不同。step0的情况下，所有的suffix都为worker_restore.
    # 因此，需要一个函数专门处理各种情况
    
    # Step-0
    if step == 0:
        return "worker_restore"
    
    # Step-1,2,3
    if method in ['cotta', 'plf', 'contra']:
        return "worker_tta"
    else:
        return "worker_restore"
    

In [25]:
def eva_test_acc(dataset_name, model_name, noise_type='symmetric', noise_ratio=0.2, methods=None, steps=None, results_dir=None):
    """
    核心代码，用来评估指定dataset任务下的所有方法的step0~4的测试准确率
    """
    
    case = settings.get_case(noise_ratio=noise_ratio, noise_type=noise_type, balanced=True)
    mean, std = None, None
    num_classes = settings.num_classes_dict[dataset_name]
    print(f'目前测试的数据集：{dataset_name}, case模式：{case}')

    # 读入测试数据集    
    # core.py中使用的数据集读入代码
    test_data, test_labels, test_dataloader = get_dataset_loader(
        dataset_name, "test", case, None, mean, std, batch_size=128, shuffle=False
    )

    # run_experiment.py中的代码, 用于对比验证效果。
    # print(f'Targeted dataset: {settings.get_dataset_path(dataset_name, case, "test_data")}')

    # D_test_data = np.load(
    #     settings.get_dataset_path(dataset_name, case, "test_data")
    # )
    # D_test_labels = np.load(
    #     settings.get_dataset_path(dataset_name, case, "test_label")
    # )
    # test_dataset = BaseTensorDataset(D_test_data, D_test_labels, device=device)
    # test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # [Debug用]-检查一下数据
    # x, y = next(iter(test_dataloader))
    # print(x[0])

    # 
    # steps = [f'step_{i}' for i in range(4)]

    methods = methods
    steps = steps
    assert methods is not None, "请指定要评估的方法"
    assert steps is not None, "请指定要评估的step(实验组)"

    results_dir = results_dir
    assert methods is not None, "请指定评估结果保存的目录"
    os.makedirs(results_dir, exist_ok=True)

    df = pd.DataFrame(index=methods, columns=steps)


    # for method in tqdm(methods):
    
    for method in methods:
        for step in steps:
            # 先确定一下模型的suffix
            model_suffix = get_suffix(method, step)
            # 读入模型架构
            # 放到循环里边，每次都要重新新建一个模型。
            # 牺牲了一些速度，但是好处是防止cotta和plf把模型架构更改了，导致repair类的模型出问题
            model = load_custom_model(model_name, num_classes, load_pretrained=False)
            model = ClassifierWrapper(model, num_classes)

            # 按照模型名和step数读入模型参数
            model_repair_save_path = settings.get_ckpt_path(dataset_name, case, model_name, model_suffix=model_suffix, step=step, unique_name=method)
            print(f'Evaluating {model_repair_save_path}')

            # checkpoint = torch.load(model_repair_save_path)
            try:
                checkpoint = torch.load(model_repair_save_path)
            except:
                print(f"Cannot find the weight file at {model_repair_save_path}. Just SKIP.")
                continue
            model.load_state_dict(checkpoint, strict=False)


            # [24-10-10 sunzekun] 有两个特殊的tta模型：cotta和plf，不仅改了参数，而且使用了特定的模型架构来进行推断。
            # (即，不仅有test-time Adaptation，还有Augmentation)
            # 所以对这两种需要另外写代码实现。
            # 注意，contra没有这个特殊过程。
            if method == 'cotta' or method == 'plf':
                # 由于测试代码的jupyternotebook 构建命令行参数很麻烦
                # 这里暂时去掉了dataset，model这两个必选参数的required = True
                # 而是直接在使用的时候复制。
                sys.argv = ['', '--dataset', dataset_name, '--model', model_name, '--uni_name', method, '--balanced']
                custom_args = parse_args()
                model.eval().to(device)
                if method == 'cotta':
                    model_aug = setup_cotta(model, custom_args)
                else:
                    model_aug = setup_plf(model, custom_args, num_classes)
                
                try:
                    model_aug.reset()
                except:
                    print(f'Failed to reset')
                
                x_test = torch.from_numpy(test_data)
                y_test = torch.from_numpy(test_labels)
                x_test, y_test = x_test.to(device), y_test.to(device)

                test_acc = clean_accuracy(model_aug, x_test, y_test, batch_size=1024, device=device)                
                # print(test_acc)
            else:
                test_acc = model_test(test_dataloader, model, device=device)
                # print(test_acc)
                test_acc = test_acc['global']
            
            print(f"测试集Acc：{test_acc}")
            df.loc[method, step] = test_acc
            # test_acc = get_test_acc(test_dataloader, model, device=device)
            # df.loc[method, step] = test_acc



    # 保留参考代码，万一想用excel保存结果的时候用下面的
    # results_file_name = f'{dataset_name}.xlsx'
    # results_file_dir = os.path.join(results_dir, results_file_name)
    # with pd.ExcelWriter(results_file_dir, engine='openpyxl') as writer:
    #     df.to_excel(writer)
        # for tab, df in dfs.items():

    # cls: 分类任务
    # rtv: 检索任务
    mission_type = 'cls' if noise_type == 'symmetric' else 'rtv'
    results_file_name = f'{dataset_name}_{mission_type}.csv'
    results_file_dir = os.path.join(results_dir, results_file_name)
    df.to_csv(results_file_dir)

In [26]:
def eva_map(dataset_name, model_name, noise_type='asymmetric', noise_ratio=0.2, top_k=10, methods=None, steps=None, results_dir=None):
    """
    核心代码，用来评估指定dataset任务下的所有方法的step0~4的map，检索任务专用
    """
    
    case = settings.get_case(noise_ratio=noise_ratio, noise_type=noise_type, balanced=True)
    mean, std = None, None
    num_classes = settings.num_classes_dict[dataset_name]
    print(f'目前测试的数据集：{dataset_name}, case模式：{case}')

    # 读入测试数据集    
    # core.py中使用的数据集读入代码
    train_data, train_labels, train_dataloader = get_dataset_loader(
        dataset_name, "train", case, None, mean, std, batch_size=128, shuffle=False
    )

    test_data, test_labels, test_dataloader = get_dataset_loader(
        dataset_name, "test", case, None, mean, std, batch_size=128, shuffle=False
    )

    # run_experiment.py中的代码, 用于对比验证效果。
    # print(f'Targeted dataset: {settings.get_dataset_path(dataset_name, case, "test_data")}')

    # D_test_data = np.load(
    #     settings.get_dataset_path(dataset_name, case, "test_data")
    # )
    # D_test_labels = np.load(
    #     settings.get_dataset_path(dataset_name, case, "test_label")
    # )
    # test_dataset = BaseTensorDataset(D_test_data, D_test_labels, device=device)
    # test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # [Debug用]-检查一下数据
    # x, y = next(iter(test_dataloader))
    # print(x[0])
    
    methods = methods
    steps = steps
    assert methods is not None, "请指定要评估的方法"
    assert steps is not None, "请指定要评估的step(实验组)"

    results_dir = results_dir
    assert methods is not None, "请指定评估结果保存的目录"
    os.makedirs(results_dir, exist_ok=True)

    df = pd.DataFrame(index=methods, columns=steps)


    # for method in tqdm(methods):
    
    for method in methods:
        for step in steps:
            # 先确定一下模型的suffix
            model_suffix = get_suffix(method, step)
            # 读入模型架构
            # 放到循环里边，每次都要重新新建一个模型。
            # 牺牲了一些速度，但是好处是防止cotta和plf把模型架构更改了，导致repair类的模型出问题
            model = load_custom_model(model_name, num_classes, load_pretrained=False)
            model = ClassifierWrapper(model, num_classes)

            # 按照模型名和step数读入模型参数
            model_repair_save_path = settings.get_ckpt_path(dataset_name, case, model_name, model_suffix=model_suffix, step=step, unique_name=method)
            print(f'Evaluating {model_repair_save_path}')

            # checkpoint = torch.load(model_repair_save_path)
            try:
                checkpoint = torch.load(model_repair_save_path)
            except:
                print(f"Cannot find the weight file at {model_repair_save_path}. Just SKIP.")
                continue
            model.load_state_dict(checkpoint, strict=False)


            # [24-10-10 sunzekun] 有两个特殊的tta模型：cotta和plf，不仅改了参数，而且使用了特定的模型架构来进行推断。
            # (即，不仅有test-time Adaptation，还有Augmentation)
            # 所以对这两种需要另外写代码实现。
            # 注意，contra没有这个特殊过程。

            # [24-10-11 sunzekun] 对于两个带tta的模型，这里不再做单独的增强了，因为增强和检索比较难适配。
            # 直接 

            mAP = get_map(model, train_dataloader, test_dataloader, top_k)            
            print(f"测试mAP：{mAP}")

            
            df.loc[method, step] = mAP
            # test_acc = get_test_acc(test_dataloader, model, device=device)
            # df.loc[method, step] = test_acc

    # 保留参考代码，万一想用excel保存结果的时候用下面的
    # results_file_name = f'{dataset_name}.xlsx'
    # results_file_dir = os.path.join(results_dir, results_file_name)
    # with pd.ExcelWriter(results_file_dir, engine='openpyxl') as writer:
    #     df.to_excel(writer)
        # for tab, df in dfs.items():

    # cls: 分类任务
    # rtv: 检索任务
    mission_type = 'cls' if noise_type == 'symmetric' else 'rtv'
    results_file_name = f'{dataset_name}_{mission_type}.csv'
    results_file_dir = os.path.join(results_dir, results_file_name)
    df.to_csv(results_file_dir)



# 特征提取函数
def extract_features(feature_extractor, data_loader):
    features = []
    labels = []
    with torch.no_grad():
        for images, targets in tqdm(data_loader):
            images = images.to(device)
            outputs = feature_extractor(images)
            outputs = outputs.view(outputs.size(0), -1)
            features.append(outputs.cpu().numpy())
            labels.append(targets.numpy())
    features = np.concatenate(features)
    labels = np.concatenate(labels)
    return features, labels

def retrieve(gallery_feats, query_feats, top_k=10):
    sims = cosine_similarity(query_feats, gallery_feats)
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    return indices



def calculate_map(indices, gallery_labels, query_labels):
    num_queries = query_labels.shape[0]
    ap_list = []
    for i in range(num_queries):
        query_label = query_labels[i]
        retrieved_labels = gallery_labels[indices[i]]
        relevant = (retrieved_labels == query_label).astype(int)
        num_relevant = relevant.sum()
        if num_relevant == 0:
            ap_list.append(0)
            continue
        cumulative_precision = np.cumsum(relevant) / (np.arange(len(relevant)) + 1)
        ap = (cumulative_precision * relevant).sum() / num_relevant
        ap_list.append(ap)
    mAP = np.mean(ap_list)
    return mAP

def get_map(model, train_loader, test_loader, top_k):
    model.to(device)
    model.eval()    
    feature_extractor = nn.Sequential(*list(model.children())[:-1])
    feature_extractor = feature_extractor.to(device)
    feature_extractor.eval()

    # 提取特征
    g_features, g_labels = extract_features(feature_extractor, train_loader)
    q_features, q_labels = extract_features(feature_extractor, test_loader)
    indices = retrieve(g_features, q_features, top_k=top_k)
    mAP = calculate_map(indices, g_labels, q_labels)

    return mAP
    




## 主实验
### cifar-10-分类

In [32]:
dataset_name = 'cifar-10'
model_name = 'cifar-resnet18'
noise_type='symmetric'

# ------------------------------------------------------------- #
"""
实验区：手动指定要评估的组
methods: ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
steps: [0, 1, 2, 3]
"""
# methods = ['contra']
# methods = ['raw']
methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
# methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'contra']
steps = [i for i in range(4)]
# ------------------------------------------------------------- #
"""
指定结果存储的路径
"""
results_dir = './results_main'
# ------------------------------------------------------------- #

eva_test_acc(dataset_name, model_name, noise_type, 
             methods=methods, steps=steps,
             results_dir=results_dir)

目前测试的数据集：cifar-10, case模式：nr_0.2_nt_symmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/contra/cifar-resnet18_worker_restore.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/contra/cifar-resnet18_worker_restore.pth. Just SKIP.
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/contra/cifar-resnet18_worker_tta.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/contra/cifar-resnet18_worker_tta.pth. Just SKIP.
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/contra/cifar-resnet18_worker_tta.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/contra/cifar-resnet18_worker_tta.pth. Just SKIP.
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_s

  checkpoint = torch.load(model_repair_save_path)


### Pet-37 分类
> 评估测试集acc

In [10]:
dataset_name = 'pet-37'
model_name = 'wideresnet50'
noise_type='symmetric'

# ------------------------------------------------------------- #
"""
实验区：手动指定要评估的组
methods: ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
steps: [0, 1, 2, 3]
"""
# methods = ['contra']
# methods = ['raw']
# methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'contra']
steps = [i for i in range(4)]
# ------------------------------------------------------------- #
"""
指定结果存储的路径
"""
results_dir = './results_main'
# ------------------------------------------------------------- #

eva_test_acc(dataset_name, model_name, noise_type, 
             methods=methods, steps=steps, results_dir=results_dir)

目前测试的数据集：pet-37, case模式：nr_0.2_nt_symmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.2_nt_symmetric_balanced/step_0/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)


test_acc: 92.34
label: 0, acc: 88.78
label: 1, acc: 93.00
label: 2, acc: 52.00
label: 3, acc: 94.00
label: 4, acc: 94.00
label: 5, acc: 92.00
label: 6, acc: 91.00
label: 7, acc: 100.00
label: 8, acc: 92.93
label: 9, acc: 84.00
label: 10, acc: 90.00
label: 11, acc: 87.63
label: 12, acc: 95.00
label: 13, acc: 97.00
label: 14, acc: 100.00
label: 15, acc: 98.00
label: 16, acc: 98.00
label: 17, acc: 100.00
label: 18, acc: 100.00
label: 19, acc: 100.00
label: 20, acc: 84.00
label: 21, acc: 94.00
label: 22, acc: 100.00
label: 23, acc: 92.00
label: 24, acc: 94.00
label: 25, acc: 99.00
label: 26, acc: 75.00
label: 27, acc: 78.00
label: 28, acc: 99.00
label: 29, acc: 100.00
label: 30, acc: 97.98
label: 31, acc: 98.00
label: 32, acc: 92.00
label: 33, acc: 96.00
label: 34, acc: 75.28
label: 35, acc: 96.00
label: 36, acc: 98.00
测试集Acc：0.9234123739438539
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.2_nt_symmetric_balanced/step_1/raw/wideresnet50_worker_restore.pth
test_acc: 87.03
label

### Cifar-100 检索
> 计算mAP（test_acc仅供参考）

In [11]:
dataset_name = 'cifar-100'
model_name = 'cifar-wideresnet40'
noise_type= 'asymmetric'

# ------------------------------------------------------------- #
"""
实验区：手动指定要评估的组
methods: ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
steps: [0, 1, 2, 3]
"""
# methods = ['contra']
# methods = ['raw']
# methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'contra']
steps = [i for i in range(4)]
# ------------------------------------------------------------- #
"""
指定结果存储的路径
"""
results_dir = './results_main'
# ------------------------------------------------------------- #

# eva_test_acc(dataset_name, model_name, noise_type, methods=methods, steps=steps)
eva_map(dataset_name, model_name, noise_type, 
        methods=methods, steps=steps, results_dir=results_dir)

目前测试的数据集：cifar-100, case模式：nr_0.2_nt_asymmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_0/raw/cifar-wideresnet40_worker_restore.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_0/raw/cifar-wideresnet40_worker_restore.pth. Just SKIP.


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_1/raw/cifar-wideresnet40_worker_restore.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_1/raw/cifar-wideresnet40_worker_restore.pth. Just SKIP.
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_2/raw/cifar-wideresnet40_worker_restore.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_2/raw/cifar-wideresnet40_worker_restore.pth. Just SKIP.
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_3/raw/cifar-wideresnet40_worker_restore.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/step_3/raw/cifar-wideresnet40_worker_restore.pth. Just SKIP.
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/cifar-100/nr_0.2_nt_asymmetric_balanced/s

### Pet-37 检索
> 计算mAP（test_acc仅供参考）

In [23]:
dataset_name = 'pet-37'
model_name = 'wideresnet50'
noise_type= 'asymmetric'

# ------------------------------------------------------------- #
"""
实验区：手动指定要评估的组
methods: ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
steps: [0, 1, 2, 3]
"""
methods = ['contra']
# methods = ['raw']
# methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
# methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'contra']
steps = [i for i in range(4)]
# ------------------------------------------------------------- #
"""
指定结果存储的路径
"""
results_dir = './results_main'
# ------------------------------------------------------------- #
# eva_test_acc(dataset_name, model_name, noise_type, 
#              methods=methods, steps=steps, results_dir=results_dir)
eva_map(dataset_name, model_name, noise_type, 
        methods=methods, steps=steps, results_dir=results_dir)

目前测试的数据集：pet-37, case模式：nr_0.2_nt_asymmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.2_nt_asymmetric_balanced/step_0/contra/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 10.18it/s]
100%|██████████| 29/29 [00:02<00:00, 10.80it/s]


测试mAP：0.9304353316940275


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.2_nt_asymmetric_balanced/step_1/contra/wideresnet50_worker_tta.pth


100%|██████████| 29/29 [00:03<00:00,  8.98it/s]
100%|██████████| 29/29 [00:02<00:00,  9.84it/s]


测试mAP：0.8714756712573686


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.2_nt_asymmetric_balanced/step_2/contra/wideresnet50_worker_tta.pth


100%|██████████| 29/29 [00:03<00:00,  9.46it/s]
100%|██████████| 29/29 [00:03<00:00,  9.06it/s]


测试mAP：0.8638970780361649
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.2_nt_asymmetric_balanced/step_3/contra/wideresnet50_worker_tta.pth
Cannot find the weight file at /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.2_nt_asymmetric_balanced/step_3/contra/wideresnet50_worker_tta.pth. Just SKIP.


  checkpoint = torch.load(model_repair_save_path)


# 敏感度实验

In [30]:
dataset_name = 'pet-37'
model_name = 'wideresnet50'

# ------------------------------------------------------------- #
"""
实验区：手动指定要评估的组
methods: ['raw', 'coteaching', 'contra']
steps: [0, 1, 2, 3]
"""
methods = ['raw', 'coteaching', 'contra']
steps = [i for i in range(4)]
# ------------------------------------------------------------- #
"""
指定结果存储的路径
"""
results_dir = './results_sensitivity'
# ------------------------------------------------------------- #



for noise_ratio in [0.1, 0.3, 0.5]:
        # 分类任务评估
        results_dir = f'./results_sensitivity/nr_{noise_ratio}'
        eva_test_acc(dataset_name, model_name,
                     noise_ratio=noise_ratio, 
                     noise_type='symmetric',                     
                     methods=methods, steps=steps, results_dir=results_dir)

        # 检测任务评估
        eva_map(dataset_name, model_name,
                noise_ratio=noise_ratio,
                noise_type='asymmetric',
                methods=methods, steps=steps, results_dir=results_dir)

目前测试的数据集：pet-37, case模式：nr_0.1_nt_symmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_symmetric_balanced/step_0/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)


test_acc: 92.37
label: 0, acc: 87.76
label: 1, acc: 97.00
label: 2, acc: 40.00
label: 3, acc: 95.00
label: 4, acc: 98.00
label: 5, acc: 87.00
label: 6, acc: 91.00
label: 7, acc: 100.00
label: 8, acc: 91.92
label: 9, acc: 83.00
label: 10, acc: 93.00
label: 11, acc: 89.69
label: 12, acc: 96.00
label: 13, acc: 97.00
label: 14, acc: 100.00
label: 15, acc: 99.00
label: 16, acc: 97.00
label: 17, acc: 100.00
label: 18, acc: 97.98
label: 19, acc: 99.00
label: 20, acc: 83.00
label: 21, acc: 93.00
label: 22, acc: 100.00
label: 23, acc: 93.00
label: 24, acc: 98.00
label: 25, acc: 99.00
label: 26, acc: 73.00
label: 27, acc: 85.00
label: 28, acc: 99.00
label: 29, acc: 99.00
label: 30, acc: 100.00
label: 31, acc: 99.00
label: 32, acc: 93.00
label: 33, acc: 95.00
label: 34, acc: 73.03
label: 35, acc: 95.00
label: 36, acc: 100.00
测试集Acc：0.9236849277732352
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_symmetric_balanced/step_1/raw/wideresnet50_worker_restore.pth
test_acc: 88.25
label:

  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_0/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.18it/s]
100%|██████████| 29/29 [00:02<00:00, 12.37it/s]


测试mAP：0.9284697437695287


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_1/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.28it/s]
100%|██████████| 29/29 [00:02<00:00, 12.36it/s]


测试mAP：0.9246496351510937


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_2/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.18it/s]
100%|██████████| 29/29 [00:02<00:00, 11.80it/s]


测试mAP：0.9224543508817071
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_3/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.72it/s]
100%|██████████| 29/29 [00:02<00:00, 12.03it/s]


测试mAP：0.9123392368036117


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_0/coteaching/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 11.85it/s]
100%|██████████| 29/29 [00:02<00:00, 12.05it/s]


测试mAP：0.9284697437695287


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_1/coteaching/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.00it/s]
100%|██████████| 29/29 [00:02<00:00, 12.10it/s]


测试mAP：0.9300400975679533
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_2/coteaching/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.98it/s]
100%|██████████| 29/29 [00:02<00:00, 12.07it/s]


测试mAP：0.9273417615660783
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_3/coteaching/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.97it/s]
100%|██████████| 29/29 [00:02<00:00, 12.08it/s]


测试mAP：0.9240376218449181


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_0/contra/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.01it/s]
100%|██████████| 29/29 [00:02<00:00, 12.02it/s]


测试mAP：0.9284697437695287
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_1/contra/wideresnet50_worker_tta.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.81it/s]
100%|██████████| 29/29 [00:02<00:00, 12.18it/s]


测试mAP：0.8765983173474916
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_2/contra/wideresnet50_worker_tta.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 12.00it/s]
100%|██████████| 29/29 [00:02<00:00, 12.33it/s]


测试mAP：0.8538941357990356


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.1_nt_asymmetric_balanced/step_3/contra/wideresnet50_worker_tta.pth


100%|██████████| 29/29 [00:02<00:00, 11.95it/s]
100%|██████████| 29/29 [00:02<00:00, 12.28it/s]


测试mAP：0.8513735928867777
目前测试的数据集：pet-37, case模式：nr_0.3_nt_symmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_symmetric_balanced/step_0/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)


test_acc: 93.27
label: 0, acc: 86.73
label: 1, acc: 94.00
label: 2, acc: 52.00
label: 3, acc: 97.00
label: 4, acc: 97.00
label: 5, acc: 93.00
label: 6, acc: 93.00
label: 7, acc: 100.00
label: 8, acc: 92.93
label: 9, acc: 82.00
label: 10, acc: 92.00
label: 11, acc: 90.72
label: 12, acc: 96.00
label: 13, acc: 98.00
label: 14, acc: 100.00
label: 15, acc: 98.00
label: 16, acc: 99.00
label: 17, acc: 100.00
label: 18, acc: 100.00
label: 19, acc: 100.00
label: 20, acc: 83.00
label: 21, acc: 94.00
label: 22, acc: 100.00
label: 23, acc: 91.00
label: 24, acc: 97.00
label: 25, acc: 99.00
label: 26, acc: 78.00
label: 27, acc: 86.00
label: 28, acc: 99.00
label: 29, acc: 100.00
label: 30, acc: 98.99
label: 31, acc: 99.00
label: 32, acc: 95.00
label: 33, acc: 98.00
label: 34, acc: 76.40
label: 35, acc: 96.00
label: 36, acc: 98.00
测试集Acc：0.9326792041428182
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_symmetric_balanced/step_1/raw/wideresnet50_worker_restore.pth
test_acc: 89.45
label

  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_0/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.00it/s]
100%|██████████| 29/29 [00:02<00:00, 11.52it/s]


测试mAP：0.9281733177457814
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_1/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.89it/s]
100%|██████████| 29/29 [00:02<00:00, 10.64it/s]


测试mAP：0.9224619096752902


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_2/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 11.55it/s]
100%|██████████| 29/29 [00:02<00:00, 11.90it/s]


测试mAP：0.9180141116135744
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_3/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.90it/s]
100%|██████████| 29/29 [00:02<00:00, 12.13it/s]


测试mAP：0.916256604400237


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_0/coteaching/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 11.79it/s]
100%|██████████| 29/29 [00:02<00:00, 12.03it/s]


测试mAP：0.9281733177457814
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_1/coteaching/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.95it/s]
100%|██████████| 29/29 [00:02<00:00, 10.75it/s]


测试mAP：0.9290672783402248
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_2/coteaching/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.62it/s]
100%|██████████| 29/29 [00:02<00:00, 11.74it/s]


测试mAP：0.9282619881280799
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_3/coteaching/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.79it/s]
100%|██████████| 29/29 [00:02<00:00, 11.85it/s]


测试mAP：0.923239185393326


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_0/contra/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 11.73it/s]
100%|██████████| 29/29 [00:02<00:00, 11.92it/s]


测试mAP：0.9281733177457814
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_1/contra/wideresnet50_worker_tta.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.82it/s]
100%|██████████| 29/29 [00:02<00:00, 11.92it/s]


测试mAP：0.8770389754808681


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_2/contra/wideresnet50_worker_tta.pth


100%|██████████| 29/29 [00:02<00:00, 11.82it/s]
100%|██████████| 29/29 [00:02<00:00, 12.03it/s]


测试mAP：0.8585111982267039
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.3_nt_asymmetric_balanced/step_3/contra/wideresnet50_worker_tta.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 12.07it/s]
100%|██████████| 29/29 [00:02<00:00, 12.06it/s]


测试mAP：0.8438005808730527
目前测试的数据集：pet-37, case模式：nr_0.5_nt_symmetric_balanced
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_symmetric_balanced/step_0/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)


test_acc: 92.83
label: 0, acc: 89.80
label: 1, acc: 94.00
label: 2, acc: 50.00
label: 3, acc: 95.00
label: 4, acc: 95.00
label: 5, acc: 89.00
label: 6, acc: 90.00
label: 7, acc: 100.00
label: 8, acc: 91.92
label: 9, acc: 85.00
label: 10, acc: 94.00
label: 11, acc: 91.75
label: 12, acc: 95.00
label: 13, acc: 97.00
label: 14, acc: 100.00
label: 15, acc: 99.00
label: 16, acc: 97.00
label: 17, acc: 100.00
label: 18, acc: 97.98
label: 19, acc: 99.00
label: 20, acc: 82.00
label: 21, acc: 96.00
label: 22, acc: 100.00
label: 23, acc: 94.00
label: 24, acc: 98.00
label: 25, acc: 99.00
label: 26, acc: 80.00
label: 27, acc: 85.00
label: 28, acc: 99.00
label: 29, acc: 100.00
label: 30, acc: 96.97
label: 31, acc: 99.00
label: 32, acc: 90.00
label: 33, acc: 95.00
label: 34, acc: 76.40
label: 35, acc: 95.00
label: 36, acc: 98.00
测试集Acc：0.9283183428727174
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_symmetric_balanced/step_1/raw/wideresnet50_worker_restore.pth
test_acc: 87.65
label: 

  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_0/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 11.99it/s]
100%|██████████| 29/29 [00:02<00:00, 11.99it/s]


测试mAP：0.9293015216266898


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_1/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.02it/s]
100%|██████████| 29/29 [00:02<00:00, 12.21it/s]


测试mAP：0.9242831573767974


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_2/raw/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 11.56it/s]
100%|██████████| 29/29 [00:02<00:00, 12.43it/s]


测试mAP：0.9127190599471469
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_3/raw/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 12.35it/s]
100%|██████████| 29/29 [00:02<00:00, 12.48it/s]


测试mAP：0.9094722337846438
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_0/coteaching/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 12.28it/s]
100%|██████████| 29/29 [00:02<00:00, 12.51it/s]


测试mAP：0.9293015216266898
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_1/coteaching/wideresnet50_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 11.62it/s]
100%|██████████| 29/29 [00:02<00:00, 12.26it/s]


测试mAP：0.9285702208262852


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_2/coteaching/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.28it/s]
100%|██████████| 29/29 [00:02<00:00, 12.13it/s]


测试mAP：0.926886692981605


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_3/coteaching/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.11it/s]
100%|██████████| 29/29 [00:02<00:00, 12.34it/s]


测试mAP：0.9226217378348732


  checkpoint = torch.load(model_repair_save_path)


Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_0/contra/wideresnet50_worker_restore.pth


100%|██████████| 29/29 [00:02<00:00, 12.14it/s]
100%|██████████| 29/29 [00:02<00:00, 12.29it/s]


测试mAP：0.9293015216266898
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_1/contra/wideresnet50_worker_tta.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 12.02it/s]
100%|██████████| 29/29 [00:02<00:00, 12.30it/s]


测试mAP：0.8536171762321111
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_2/contra/wideresnet50_worker_tta.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 12.12it/s]
100%|██████████| 29/29 [00:02<00:00, 12.03it/s]


测试mAP：0.8464472612632064
Evaluating /nvme/sunzekun/Projects/tta-mr/ckpt/pet-37/nr_0.5_nt_asymmetric_balanced/step_3/contra/wideresnet50_worker_tta.pth


  checkpoint = torch.load(model_repair_save_path)
100%|██████████| 29/29 [00:02<00:00, 12.03it/s]
100%|██████████| 29/29 [00:02<00:00, 12.38it/s]


测试mAP：0.8491992120505876


# Ablation Study 评估


In [None]:
dataset_name = 'pet-37'
model_name = 'wideresnet50'
noise_type='symmetric'

# ------------------------------------------------------------- #
"""
实验区：手动指定要评估的组
methods: ['contra_repair_only', 'contra_tta_only']
steps: [0, 1, 2, 3]
"""
methods = ['contra_repair_only', 'contra_tta_only']
steps = [i for i in range(4)]
# ------------------------------------------------------------- #
"""
指定结果存储的路径
"""
results_dir = './results_ablation'
# ------------------------------------------------------------- #

eva_test_acc(dataset_name, model_name, noise_type, 
             methods=methods, steps=steps, results_dir=results_dir)