# Overal Performance

In [1]:
# 导入库
import sys
import os

# 添加项目根目录到 sys.path
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))
sys.path.append(project_root)

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt


from core_model.custom_model import ClassifierWrapper, load_custom_model
from core_model.dataset import get_dataset_loader

from core_model.train_test import model_test

from configs import settings


import pandas as pd



# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")





## Cifar-10 分类

In [2]:
def get_suffix(method, step):
    # 不同的method有不同的后缀
    # 例如，contra有restore和tta的，我们在综合测评中应该只考虑tta
    # 并且根据step还有所不同。step0的情况下，所有的suffix都为worker_restore.
    # 因此，需要一个函数专门处理各种情况
    
    # Step-0
    if step == 0:
        return "worker_restore"
    
    # Step-1,2,3
    if method in ['cotta', 'plf', 'contra']:
        return "worker_tta"
    else:
        return "worker_restore"
    

# 参数设定区：
dataset_name = 'cifar-10'
model_name = 'cifar-resnet18'
mean, std = None, None
case = settings.get_case(noise_ratio=0.2, noise_type='symmetric', balanced=True)
num_classes = settings.num_classes_dict[dataset_name]
print(f'目前测试的数据集：{dataset_name}, case模式：{case}')

# 读入测试数据集
test_data, test_labels, test_dataloader = get_dataset_loader(
    dataset_name, "test", case, None, mean, std, batch_size=128, shuffle=False
)

# 读入模型架构

model = load_custom_model(model_name, num_classes, load_pretrained=False)
model = ClassifierWrapper(model, num_classes)

# steps = [f'step_{i}' for i in range(4)]

# methods = ['contra']
methods = ['raw', 'coteaching', 'coteaching_plus', 'jocor', 'cotta', 'plf', 'contra']
steps = [i for i in range(4)]

df = pd.DataFrame(index=methods, columns=steps)




for method in tqdm(methods):
    for step in steps:
        # 先确定一下模型的suffix
        model_suffix = get_suffix(method, step)

        # 按照模型名和step数读入模型参数
        model_repair_save_path = settings.get_ckpt_path(dataset_name, case, model_name, model_suffix=model_suffix, step=step, unique_name=method)
        print(model_repair_save_path)
        checkpoint = torch.load(model_repair_save_path)
        model.load_state_dict(checkpoint, strict=False)
        test_acc = model_test(test_dataloader, model, device=device)
        # print(test_acc)
        df.loc[method, step] = test_acc['global']

results_dir = './results_main'
os.makedirs(results_dir, exist_ok=True)

# results_file_name = f'{dataset_name}.xlsx'
# results_file_dir = os.path.join(results_dir, results_file_name)
# with pd.ExcelWriter(results_file_dir, engine='openpyxl') as writer:
#     df.to_excel(writer)
    # for tab, df in dfs.items():

results_file_name = f'{dataset_name}.csv'
results_file_dir = os.path.join(results_dir, results_file_name)
df.to_csv(results_file_dir)

目前测试的数据集：cifar-10, case模式：nr_0.2_nt_symmetric_balanced


  0%|          | 0/7 [00:00<?, ?it/s]

/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/raw/cifar-resnet18_worker_restore.pth


  checkpoint = torch.load(model_repair_save_path)


test_acc: 10.00
label: 0, acc: 100.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 0.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/raw/cifar-resnet18_worker_restore.pth
test_acc: 10.00
label: 0, acc: 100.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 0.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/raw/cifar-resnet18_worker_restore.pth
test_acc: 12.51
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 0.00
label: 4, acc: 73.40
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 51.70
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_3/raw/cifar-resnet18_worker_r

 14%|█▍        | 1/7 [00:02<00:14,  2.34s/it]

test_acc: 10.00
label: 0, acc: 100.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 0.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/coteaching/cifar-resnet18_worker_restore.pth
test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 100.00
label: 3, acc: 0.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/coteaching/cifar-resnet18_worker_restore.pth
test_acc: 27.33
label: 0, acc: 0.00
label: 1, acc: 40.90
label: 2, acc: 1.30
label: 3, acc: 0.20
label: 4, acc: 12.70
label: 5, acc: 8.60
label: 6, acc: 67.80
label: 7, acc: 39.40
label: 8, acc: 75.10
label: 9, acc: 27.30
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/coteachin

 29%|██▊       | 2/7 [00:03<00:09,  1.92s/it]

test_acc: 27.08
label: 0, acc: 0.00
label: 1, acc: 44.10
label: 2, acc: 0.70
label: 3, acc: 0.00
label: 4, acc: 10.10
label: 5, acc: 8.30
label: 6, acc: 69.20
label: 7, acc: 37.20
label: 8, acc: 70.40
label: 9, acc: 30.80
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/coteaching_plus/cifar-resnet18_worker_restore.pth
test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 100.00
label: 3, acc: 0.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/coteaching_plus/cifar-resnet18_worker_restore.pth
test_acc: 27.33
label: 0, acc: 0.00
label: 1, acc: 40.90
label: 2, acc: 1.30
label: 3, acc: 0.20
label: 4, acc: 12.70
label: 5, acc: 8.60
label: 6, acc: 67.80
label: 7, acc: 39.40
label: 8, acc: 75.10
label: 9, acc: 27.30
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/st

 43%|████▎     | 3/7 [00:05<00:07,  1.80s/it]

test_acc: 27.08
label: 0, acc: 0.00
label: 1, acc: 44.10
label: 2, acc: 0.70
label: 3, acc: 0.00
label: 4, acc: 10.10
label: 5, acc: 8.30
label: 6, acc: 69.20
label: 7, acc: 37.20
label: 8, acc: 70.40
label: 9, acc: 30.80
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/jocor/cifar-resnet18_worker_restore.pth
test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 100.00
label: 3, acc: 0.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/jocor/cifar-resnet18_worker_restore.pth
test_acc: 27.33
label: 0, acc: 0.00
label: 1, acc: 40.90
label: 2, acc: 1.30
label: 3, acc: 0.20
label: 4, acc: 12.70
label: 5, acc: 8.60
label: 6, acc: 67.80
label: 7, acc: 39.40
label: 8, acc: 75.10
label: 9, acc: 27.30
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/jocor/cifar-res

 57%|█████▋    | 4/7 [00:07<00:05,  1.74s/it]

test_acc: 27.08
label: 0, acc: 0.00
label: 1, acc: 44.10
label: 2, acc: 0.70
label: 3, acc: 0.00
label: 4, acc: 10.10
label: 5, acc: 8.30
label: 6, acc: 69.20
label: 7, acc: 37.20
label: 8, acc: 70.40
label: 9, acc: 30.80
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/cotta/cifar-resnet18_worker_restore.pth
test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 100.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/cotta/cifar-resnet18_worker_tta.pth
test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 100.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/cotta/cifar-resnet18_wo

 71%|███████▏  | 5/7 [00:09<00:03,  1.74s/it]

test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 100.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/plf/cifar-resnet18_worker_restore.pth
test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 100.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/plf/cifar-resnet18_worker_tta.pth
test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 100.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/plf/cifar-resnet18_worker_tta.p

 86%|████████▌ | 6/7 [00:10<00:01,  1.74s/it]

test_acc: 10.00
label: 0, acc: 0.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 100.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_0/contra/cifar-resnet18_worker_restore.pth
test_acc: 10.00
label: 0, acc: 100.00
label: 1, acc: 0.00
label: 2, acc: 0.00
label: 3, acc: 0.00
label: 4, acc: 0.00
label: 5, acc: 0.00
label: 6, acc: 0.00
label: 7, acc: 0.00
label: 8, acc: 0.00
label: 9, acc: 0.00
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_1/contra/cifar-resnet18_worker_tta.pth
test_acc: 57.31
label: 0, acc: 66.40
label: 1, acc: 83.20
label: 2, acc: 28.40
label: 3, acc: 30.80
label: 4, acc: 26.10
label: 5, acc: 58.10
label: 6, acc: 78.20
label: 7, acc: 66.80
label: 8, acc: 66.90
label: 9, acc: 68.20
/nvme/sunzekun/Projects/tta-mr/ckpt/cifar-10/nr_0.2_nt_symmetric_balanced/step_2/contra/cifar-resn

100%|██████████| 7/7 [00:12<00:00,  1.77s/it]

test_acc: 64.64
label: 0, acc: 74.90
label: 1, acc: 85.10
label: 2, acc: 39.70
label: 3, acc: 30.00
label: 4, acc: 40.20
label: 5, acc: 60.60
label: 6, acc: 85.80
label: 7, acc: 72.40
label: 8, acc: 83.50
label: 9, acc: 74.20



