In [1]:
import sys
import os

current_notebook_dir = os.path.dirname(os.path.abspath('__file__'))
project_root_dir = os.path.abspath(os.path.join(current_notebook_dir, '../../'))

# 将这个父目录添加到sys.path的最前面
if project_root_dir not in sys.path:
    sys.path.insert(0, project_root_dir)

print(sys.path)

['/home/hqdeng7/lijuyang/generalization/loss_distribution', '/home/hqdeng7/.conda/envs/ljy/lib/python311.zip', '/home/hqdeng7/.conda/envs/ljy/lib/python3.11', '/home/hqdeng7/.conda/envs/ljy/lib/python3.11/lib-dynload', '', '/home/hqdeng7/.conda/envs/ljy/lib/python3.11/site-packages']


In [2]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from pytorch_script.visual_utils import load_cifar10_data, \
	get_sorted_model_paths, evaluate_model_performance, load_model_state_dict


In [5]:
import seaborn as sns

config = {
	'model_dir': '../../model_training_results/cifar10_resnet20',  # 模型存储的文件夹路径
	'model_prefix': 'model_',        # 模型文件名的前缀
	'model_extension': '.pth',                         # 保存模型的文件扩展名
	'cifar10_data_path': '../../pytorch_script/data/cifar10',           # CIFAR-10 数据集存储路径
	'batch_size': 64,                                 # DataLoader 的批次大小
	'num_workers': 2                                  # DataLoader 的工作进程数
}

device = torch.device("cuda")

# 加载数据
train_loader, test_loader = load_cifar10_data(
	config['cifar10_data_path'], config['batch_size'], config['num_workers']
)
if train_loader is None or test_loader is None:
	print("数据加载失败，程序终止。")

# 获取排序后的模型文件路径
sorted_model_paths = get_sorted_model_paths(
	config['model_dir'], config['model_prefix'], config['model_extension']
)

if not sorted_model_paths:
	print("没有找到符合条件或能成功加载的模型文件。请检查 MODEL_DIR, MODEL_PREFIX 和 MODEL_EXTENSION 设置。")

epochs = []
train_accuracies = []
test_accuracies = []
train_losses = []
test_losses = []

print("\n开始逐个评估模型...")
for epoch, model_path in sorted_model_paths:
	print(f"\n正在评估模型：{model_path} (纪元: {epoch})")
	
	# 评估训练集性能
	model = load_model_state_dict('cifar10', 'resnet20', 10, model_path, device)
	train_acc, train_loss = evaluate_model_performance(model, train_loader, device)
	if train_acc is not None:
		train_accuracies.append(train_acc)
		train_losses.append(train_loss)
		print(f"  training acc = {train_acc:.2f}%")
		print(f"  training loss = {train_loss:.5f}")
	else:
		print("  训练集评估失败，跳过。")
		continue # 如果训练集评估失败，则整个模型跳过

	# 评估测试集性能
	test_acc, test_loss = evaluate_model_performance(model, test_loader, device)
	if test_acc is not None:
		test_accuracies.append(test_acc)
		test_losses.append(test_loss)
		print(f"  test acc = {test_acc:.2f}%")
		print(f"  test loss = {test_loss:.5f}")
	else:
		print("  测试集评估失败，跳过。")
		continue # 如果测试集评估失败，则整个模型跳过

	epochs.append(epoch)

正在加载 CIFAR-10 数据集到 ../../pytorch_script/data/cifar10...
CIFAR-10 数据集加载成功！

开始逐个评估模型...

正在评估模型：../../model_training_results/cifar10_resnet20/model_0.pth (纪元: 0)
  从字典中提取模型状态字典...
提取成功
  training acc = 48.33%
  training loss = 1.47464
  test acc = 47.84%
  test loss = 1.47464

正在评估模型：../../model_training_results/cifar10_resnet20/model_5.pth (纪元: 5)
  从字典中提取模型状态字典...
提取成功
  training acc = 73.19%
  training loss = 0.80459
  test acc = 72.31%
  test loss = 0.80459

正在评估模型：../../model_training_results/cifar10_resnet20/model_10.pth (纪元: 10)
  从字典中提取模型状态字典...
提取成功
  training acc = 58.85%
  training loss = 1.80034
  test acc = 57.67%
  test loss = 1.80034

正在评估模型：../../model_training_results/cifar10_resnet20/model_20.pth (纪元: 20)
  从字典中提取模型状态字典...
提取成功
  training acc = 72.31%
  training loss = 0.92275
  test acc = 70.61%
  test loss = 0.92275

正在评估模型：../../model_training_results/cifar10_resnet20/model_30.pth (纪元: 30)
  从字典中提取模型状态字典...
提取成功
  training acc = 70.76%
  training loss = 0.91712
  te

KeyboardInterrupt: 

In [None]:
np.save('train_accuracies.npy', np.array(train_accuracies))
np.save('test_accuracies.npy', np.array(test_accuracies))
np.save('train_losses.npy', np.array(train_losses))
np.save('test_losses.npy', np.array(test_losses))

In [22]:
def model_losses(model, data_loader):
  losses = []
  model.eval()
  for images, labels in data_loader:
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    loss = torch.nn.functional.cross_entropy(outputs, labels, reduction='none')
    losses.extend(loss.tolist())
  return losses

In [24]:
all_model_train_losses = {}
all_model_test_losses = {}

for epoch, model_path in sorted_model_paths:
	print(epoch)
	model = load_model_state_dict('cifar10', 'resnet20', 10, model_path, device)
	model.to(device)
	all_model_train_losses[epoch] = model_losses(model, train_loader)
	all_model_test_losses[epoch] = model_losses(model, test_loader)

0
  从字典中提取模型状态字典...
提取成功


5
  从字典中提取模型状态字典...
提取成功
10
  从字典中提取模型状态字典...
提取成功
20
  从字典中提取模型状态字典...
提取成功
30
  从字典中提取模型状态字典...
提取成功
40
  从字典中提取模型状态字典...
提取成功
50
  从字典中提取模型状态字典...
提取成功
60
  从字典中提取模型状态字典...
提取成功
70
  从字典中提取模型状态字典...
提取成功
80
  从字典中提取模型状态字典...
提取成功
90
  从字典中提取模型状态字典...
提取成功
100
  从字典中提取模型状态字典...
提取成功
110
  从字典中提取模型状态字典...
提取成功
120
  从字典中提取模型状态字典...
提取成功
130
  从字典中提取模型状态字典...
提取成功
140
  从字典中提取模型状态字典...
提取成功
150
  从字典中提取模型状态字典...
提取成功
160
  从字典中提取模型状态字典...
提取成功
170
  从字典中提取模型状态字典...
提取成功
180
  从字典中提取模型状态字典...
提取成功
190
  从字典中提取模型状态字典...
提取成功
200
  从字典中提取模型状态字典...
提取成功
210
  从字典中提取模型状态字典...
提取成功
220
  从字典中提取模型状态字典...
提取成功
230
  从字典中提取模型状态字典...
提取成功
240
  从字典中提取模型状态字典...
提取成功
250
  从字典中提取模型状态字典...
提取成功
260
  从字典中提取模型状态字典...
提取成功
270
  从字典中提取模型状态字典...
提取成功
280
  从字典中提取模型状态字典...
提取成功
290
  从字典中提取模型状态字典...
提取成功
300
  从字典中提取模型状态字典...
提取成功
310
  从字典中提取模型状态字典...
提取成功
320
  从字典中提取模型状态字典...
提取成功
330
  从字典中提取模型状态字典...
提取成功
340
  从字典中提取模型状态字典...
提取成功
350
  从字典中提取模型状态字典...
提取成功
360
  从字典中提取模型状态字典...
提取成功
370
  从字典中提取

In [25]:
import pickle
file_path_pickle = "all_model_train_losses.pickle"
with open(file_path_pickle, 'wb') as f: # 注意 'wb' 表示写入二进制
    pickle.dump(all_model_train_losses, f)

file_path_pickle = "all_model_test_losses.pickle"
with open(file_path_pickle, 'wb') as f: # 注意 'wb' 表示写入二进制
    pickle.dump(all_model_test_losses, f)

In [26]:
np.save('epochs.npy', epochs)