In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np


# 配置函数
def load_config():
	"""
	加载并返回脚本的配置参数。
	"""
	config = {
		'model_dir': '../../model_training_results/cifar10_resnet18',  # 模型存储的文件夹路径
		'model_prefix': 'model_',        # 模型文件名的前缀
		'model_extension': '.pth',                         # 保存模型的文件扩展名
		'cifar10_data_path': '../../pytorch_script/data/cifar10',           # CIFAR-10 数据集存储路径
		'batch_size': 64,                                 # DataLoader 的批次大小
		'num_workers': 2                                  # DataLoader 的工作进程数
	}
	return config

# 数据加载函数
def load_cifar10_data(data_path, batch_size, num_workers):
	"""
	加载 CIFAR-10 数据集并返回训练和测试 DataLoader。
	"""
	print(f"正在加载 CIFAR-10 数据集到 {data_path}...")

	# 数据预处理
	transform = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10 的均值和标准差
	])

	try:
		train_dataset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform)
		test_dataset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform)

		train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
		test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

		print("CIFAR-10 数据集加载成功！")
		return train_loader, test_loader
	except Exception as e:
		print(f"加载 CIFAR-10 数据集时发生错误：{e}")
		print("请检查 CIFAR10_DATA_PATH 是否正确，并确保网络连接正常以便下载数据集。")
		return None, None

# 模型定义函数
def get_resnet18_cifar10():
	"""
	返回一个为 CIFAR-10 任务定制的 ResNet-18 模型实例。
	"""
	model = torchvision.models.resnet18(weights=None) # 不加载预训练权重
	# CIFAR-10 有 10 个类别，因此修改最后一层全连接层
	num_ftrs = model.fc.in_features
	model.fc = nn.Linear(num_ftrs, 10) # 10 是 CIFAR-10 的类别数
	return model

# 模型加载函数
def load_model_state_dict(model_path, device):
	# 实例化模型并加载状态字典
	model = get_resnet18_cifar10().to(device)
	
	checkpoint = torch.load(model_path, map_location=device, weights_only=False)

	# 核心修改：检查加载的内容，并提取 model 的 state_dict
	if isinstance(checkpoint, dict) and 'model' in checkpoint:
		# 如果保存的是一个包含 'model' 键的字典
		print(f"  从字典中提取模型状态字典...")
		model_state_dict = checkpoint['model']
	elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
		# 有些框架会把模型的state_dict放在'state_dict'键下
		print(f"  从字典中提取 'state_dict' 键...")
		model_state_dict = checkpoint['state_dict']
	else:
		# 如果直接保存的就是 state_dict
		print(f"  直接加载模型状态字典...")
		model_state_dict = checkpoint
	
	model.load_state_dict(model_state_dict)
	return model



# 模型评估函数
def evaluate_model_performance(model_path, data_loader, device):
	try:
		model = load_model_state_dict(model_path, device)
		model.eval() # 设置模型为评估模式

		correct = 0
		total_loss = 0
		total = 0
		with torch.no_grad(): # 在评估时不计算梯度
			for images, labels in data_loader:
				images, labels = images.to(device), labels.to(device)
				outputs = model(images)

				loss = criterion(outputs, labels)
				total_loss += loss.item() * images.size(0)

				_, predicted = torch.max(outputs.data, 1)
				correct += (predicted == labels).sum().item()

				total += labels.size(0)
		
		accuracy = 100 * correct / total
		return accuracy, total_loss
	except Exception as e:
		print(f"评估模型 {model_path} 时发生错误：{e}")
		return None

# 模型文件管理函数
def get_sorted_model_paths(model_dir, model_prefix, model_extension):
	"""
	扫描模型目录，获取所有符合命名约定的模型文件路径，并按纪元排序。
	返回一个 (epoch, model_path) 元组的列表。
	"""
	model_files = []
	if not os.path.exists(model_dir):
		print(f"模型目录 {model_dir} 不存在。请确保模型已保存到此目录。")
		os.makedirs(model_dir) # 尝试创建目录，避免后续错误
		return []

	for f_name in os.listdir(model_dir):
		if f_name.startswith(model_prefix) and f_name.endswith(model_extension):
			try:
				# 从文件名中提取 epoch 号码 (例如 'resnet18_cifar10_epoch_X.pt' -> X)
				epoch_str = f_name.replace(model_prefix, '').replace(model_extension, '')
				epoch = int(epoch_str)
				model_files.append((epoch, os.path.join(model_dir, f_name)))
			except ValueError:
				print(f"跳过无法解析的MODEL_PREFIX或MODEL_EXTENSION文件名：{f_name}")
				continue
	
	model_files.sort(key=lambda x: x[0]) # 按 epoch 排序
	return model_files

In [None]:
import seaborn as sns

config = load_config()
device = torch.device("cuda")

# 加载数据
train_loader, test_loader = load_cifar10_data(
	config['cifar10_data_path'], config['batch_size'], config['num_workers']
)
if train_loader is None or test_loader is None:
	print("数据加载失败，程序终止。")

# 获取排序后的模型文件路径
sorted_model_paths = get_sorted_model_paths(
	config['model_dir'], config['model_prefix'], config['model_extension']
)

if not sorted_model_paths:
	print("没有找到符合条件或能成功加载的模型文件。请检查 MODEL_DIR, MODEL_PREFIX 和 MODEL_EXTENSION 设置。")

epochs = []
train_accuracies = []
test_accuracies = []
train_losses = []
test_losses = []

print("\n开始逐个评估模型...")
for epoch, model_path in sorted_model_paths:
	print(f"\n正在评估模型：{model_path} (纪元: {epoch})")
	
	# 评估训练集性能
	train_acc, train_loss = evaluate_model_performance(model_path, train_loader, device)
	if train_acc is not None:
		train_accuracies.append(train_acc)
		train_losses.append(train_loss)
		print(f"  训练集准确率 = {train_acc:.2f}%")
		print(f"  训练集loss = {train_loss:.2f}%")
	else:
		print("  训练集评估失败，跳过。")
		continue # 如果训练集评估失败，则整个模型跳过

	# 评估测试集性能
	test_acc = evaluate_model_performance(model_path, test_loader, device)
	if test_acc is not None:
		test_accuracies.append(test_acc)
		print(f"  测试集准确率 = {test_acc:.2f}%")
	else:
		print("  测试集评估失败，跳过。")
		continue # 如果测试集评估失败，则整个模型跳过

	epochs.append(epoch)

正在加载 CIFAR-10 数据集到 ../pytorch_script/data/cifar10...
CIFAR-10 数据集加载成功！

开始逐个评估模型...

正在评估模型：../model_training_results/cifar10_resnet18/model_0.pth (纪元: 0)
  从字典中提取模型状态字典...
  训练集准确率 = 40.80%
  从字典中提取模型状态字典...
  测试集准确率 = 40.58%

正在评估模型：../model_training_results/cifar10_resnet18/model_1.pth (纪元: 1)
  从字典中提取模型状态字典...
  训练集准确率 = 45.50%
  从字典中提取模型状态字典...
  测试集准确率 = 45.38%

正在评估模型：../model_training_results/cifar10_resnet18/model_2.pth (纪元: 2)
  从字典中提取模型状态字典...
  训练集准确率 = 55.61%
  从字典中提取模型状态字典...
  测试集准确率 = 55.33%

正在评估模型：../model_training_results/cifar10_resnet18/model_3.pth (纪元: 3)
  从字典中提取模型状态字典...
  训练集准确率 = 62.35%
  从字典中提取模型状态字典...
  测试集准确率 = 62.08%

正在评估模型：../model_training_results/cifar10_resnet18/model_4.pth (纪元: 4)
  从字典中提取模型状态字典...
  训练集准确率 = 64.59%
  从字典中提取模型状态字典...
  测试集准确率 = 63.84%

正在评估模型：../model_training_results/cifar10_resnet18/model_5.pth (纪元: 5)
  从字典中提取模型状态字典...
  训练集准确率 = 63.31%
  从字典中提取模型状态字典...
  测试集准确率 = 62.61%

正在评估模型：../model_training_results/cifar10_resnet18/model_6

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 12,
 14,
 16,
 18,
 20,
 22,
 24,
 26,
 28,
 30,
 35,
 40,
 45,
 50,
 55,
 60,
 65,
 70,
 75,
 80,
 85,
 90,
 95,
 100,
 110,
 120,
 130,
 140,
 150,
 160,
 170,
 180,
 190,
 200,
 210,
 220,
 230,
 240,
 250,
 260,
 270,
 280,
 290,
 300]

In [13]:
np.save('train_accuracies.npy', np.array(train_accuracies))
np.save('test_accuracies.npy', np.array(test_accuracies))

In [18]:
def model_losses(model):
  losses = []
  model.eval()
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    loss = torch.nn.functional.cross_entropy(outputs, labels)
    losses.append(loss.item())
  return losses

In [19]:
all_model_losses = {}

for epoch, model_path in sorted_model_paths:
	print(epoch)
	model = load_model_state_dict(model_path, device)
	all_model_losses[epoch] = model_losses(model)

all_model_losses

0
  从字典中提取模型状态字典...
1
  从字典中提取模型状态字典...
2
  从字典中提取模型状态字典...
3
  从字典中提取模型状态字典...
4
  从字典中提取模型状态字典...
5
  从字典中提取模型状态字典...
6
  从字典中提取模型状态字典...
7
  从字典中提取模型状态字典...
8
  从字典中提取模型状态字典...
9
  从字典中提取模型状态字典...
10
  从字典中提取模型状态字典...
12
  从字典中提取模型状态字典...
14
  从字典中提取模型状态字典...
16
  从字典中提取模型状态字典...
18
  从字典中提取模型状态字典...
20
  从字典中提取模型状态字典...
22
  从字典中提取模型状态字典...
24
  从字典中提取模型状态字典...
26
  从字典中提取模型状态字典...
28
  从字典中提取模型状态字典...
30
  从字典中提取模型状态字典...
35
  从字典中提取模型状态字典...
40
  从字典中提取模型状态字典...
45
  从字典中提取模型状态字典...
50
  从字典中提取模型状态字典...
55
  从字典中提取模型状态字典...
60
  从字典中提取模型状态字典...
65
  从字典中提取模型状态字典...
70
  从字典中提取模型状态字典...
75
  从字典中提取模型状态字典...
80
  从字典中提取模型状态字典...
85
  从字典中提取模型状态字典...
90
  从字典中提取模型状态字典...
95
  从字典中提取模型状态字典...
100
  从字典中提取模型状态字典...
110
  从字典中提取模型状态字典...
120
  从字典中提取模型状态字典...
130
  从字典中提取模型状态字典...
140
  从字典中提取模型状态字典...
150
  从字典中提取模型状态字典...
160
  从字典中提取模型状态字典...
170
  从字典中提取模型状态字典...
180
  从字典中提取模型状态字典...
190
  从字典中提取模型状态字典...
200
  从字典中提取模型状态字典...
210
  从字典中提取模型状态字典...
220
  从字典中提取模型状态字典...
230
  从字典中

{0: [1.6776453256607056,
  1.481772541999817,
  1.5385147333145142,
  1.5767208337783813,
  1.5361666679382324,
  1.502784013748169,
  1.5533438920974731,
  1.677118182182312,
  1.642412781715393,
  1.7337770462036133,
  1.7029331922531128,
  1.5303983688354492,
  1.668331265449524,
  1.7342358827590942,
  1.54615318775177,
  1.5239574909210205,
  1.8629697561264038,
  1.3631869554519653,
  1.7457176446914673,
  1.4739675521850586,
  1.6581141948699951,
  1.4355956315994263,
  1.7806957960128784,
  1.534877896308899,
  1.5871309041976929,
  1.5405389070510864,
  1.706526279449463,
  1.513169288635254,
  1.5025135278701782,
  1.608475923538208,
  1.7326946258544922,
  1.5627689361572266,
  1.6026928424835205,
  1.454591989517212,
  1.6110373735427856,
  1.5992308855056763,
  1.449425220489502,
  1.6776336431503296,
  1.5569002628326416,
  1.627196192741394,
  1.526212215423584,
  1.4922181367874146,
  1.6241145133972168,
  1.683565378189087,
  1.520033359527588,
  1.5614761114120483,
  

In [21]:
import pickle
file_path_pickle = "all_model_losses.pickle"
with open(file_path_pickle, 'wb') as f: # 注意 'wb' 表示写入二进制
    pickle.dump(all_model_losses, f)

In [None]:
np.save('epochs.npy', epochs)