In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from generalization.loss_distribution.pytorch_script.visual_utils import load_config, load_cifar10_data, \
	get_sorted_model_paths, evaluate_model_performance, load_model_state_dict


In [3]:
import seaborn as sns

config = load_config()
device = torch.device("cuda")

# 加载数据
train_loader, test_loader = load_cifar10_data(
	config['cifar10_data_path'], config['batch_size'], config['num_workers']
)
if train_loader is None or test_loader is None:
	print("数据加载失败，程序终止。")

# 获取排序后的模型文件路径
sorted_model_paths = get_sorted_model_paths(
	config['model_dir'], config['model_prefix'], config['model_extension']
)

if not sorted_model_paths:
	print("没有找到符合条件或能成功加载的模型文件。请检查 MODEL_DIR, MODEL_PREFIX 和 MODEL_EXTENSION 设置。")

epochs = []
train_accuracies = []
test_accuracies = []
train_losses = []
test_losses = []

print("\n开始逐个评估模型...")
for epoch, model_path in sorted_model_paths:
	print(f"\n正在评估模型：{model_path} (纪元: {epoch})")
	
	# 评估训练集性能
	train_acc, train_loss = evaluate_model_performance(model_path, train_loader, device)
	if train_acc is not None:
		train_accuracies.append(train_acc)
		train_losses.append(train_loss)
		print(f"  training acc = {train_acc:.2f}%")
		print(f"  training loss = {train_loss:.2f}")
	else:
		print("  训练集评估失败，跳过。")
		continue # 如果训练集评估失败，则整个模型跳过

	# 评估测试集性能
	test_acc, test_loss = evaluate_model_performance(model_path, test_loader, device)
	if test_acc is not None:
		test_accuracies.append(test_acc)
		test_losses.append(test_loss)
		print(f"  test acc = {test_acc:.2f}%")
		print(f"  test loss = {train_loss:.2f}")
	else:
		print("  测试集评估失败，跳过。")
		continue # 如果测试集评估失败，则整个模型跳过

	epochs.append(epoch)

正在加载 CIFAR-10 数据集到 ../../pytorch_script/data/cifar10...
CIFAR-10 数据集加载成功！

开始逐个评估模型...

正在评估模型：../../model_training_results/cifar10_resnet18/model_0.pth (纪元: 0)
  从字典中提取模型状态字典...
  training acc = 40.80%
  training loss = 1.61
  从字典中提取模型状态字典...
  test acc = 40.58%
  test loss = 1.61

正在评估模型：../../model_training_results/cifar10_resnet18/model_1.pth (纪元: 1)
  从字典中提取模型状态字典...
  training acc = 45.50%
  training loss = 1.51
  从字典中提取模型状态字典...
  test acc = 45.38%
  test loss = 1.51

正在评估模型：../../model_training_results/cifar10_resnet18/model_2.pth (纪元: 2)
  从字典中提取模型状态字典...
  training acc = 55.61%
  training loss = 1.24
  从字典中提取模型状态字典...
  test acc = 55.33%
  test loss = 1.24

正在评估模型：../../model_training_results/cifar10_resnet18/model_3.pth (纪元: 3)
  从字典中提取模型状态字典...
  training acc = 62.35%
  training loss = 1.07
  从字典中提取模型状态字典...
  test acc = 62.08%
  test loss = 1.07

正在评估模型：../../model_training_results/cifar10_resnet18/model_4.pth (纪元: 4)
  从字典中提取模型状态字典...
  training acc = 64.59%
  training lo

In [4]:
np.save('train_accuracies.npy', np.array(train_accuracies))
np.save('test_accuracies.npy', np.array(test_accuracies))
np.save('train_losses.npy', np.array(train_losses))
np.save('test_losses.npy', np.array(test_losses))

In [5]:
def model_losses(model, data_loader):
  losses = []
  model.eval()
  for images, labels in data_loader:
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    loss = torch.nn.functional.cross_entropy(outputs, labels)
    losses.append(loss.item())
  return losses

In [6]:
all_model_train_losses = {}
all_model_test_losses = {}

for epoch, model_path in sorted_model_paths:
	print(epoch)
	model = load_model_state_dict(model_path, device)
	all_model_train_losses[epoch] = model_losses(model, train_loader)
	all_model_test_losses[epoch] = model_losses(model, test_loader)

0
  从字典中提取模型状态字典...
1
  从字典中提取模型状态字典...
2
  从字典中提取模型状态字典...
3
  从字典中提取模型状态字典...
4
  从字典中提取模型状态字典...
5
  从字典中提取模型状态字典...
6
  从字典中提取模型状态字典...
7
  从字典中提取模型状态字典...
8
  从字典中提取模型状态字典...
9
  从字典中提取模型状态字典...
10
  从字典中提取模型状态字典...
12
  从字典中提取模型状态字典...
14
  从字典中提取模型状态字典...
16
  从字典中提取模型状态字典...
18
  从字典中提取模型状态字典...
20
  从字典中提取模型状态字典...
22
  从字典中提取模型状态字典...
24
  从字典中提取模型状态字典...
26
  从字典中提取模型状态字典...
28
  从字典中提取模型状态字典...
30
  从字典中提取模型状态字典...
35
  从字典中提取模型状态字典...
40
  从字典中提取模型状态字典...
45
  从字典中提取模型状态字典...
50
  从字典中提取模型状态字典...
55
  从字典中提取模型状态字典...
60
  从字典中提取模型状态字典...
65
  从字典中提取模型状态字典...
70
  从字典中提取模型状态字典...
75
  从字典中提取模型状态字典...
80
  从字典中提取模型状态字典...
85
  从字典中提取模型状态字典...
90
  从字典中提取模型状态字典...
95
  从字典中提取模型状态字典...
100
  从字典中提取模型状态字典...
110
  从字典中提取模型状态字典...
120
  从字典中提取模型状态字典...
130
  从字典中提取模型状态字典...
140
  从字典中提取模型状态字典...
150
  从字典中提取模型状态字典...
160
  从字典中提取模型状态字典...
170
  从字典中提取模型状态字典...
180
  从字典中提取模型状态字典...
190
  从字典中提取模型状态字典...
200
  从字典中提取模型状态字典...
210
  从字典中提取模型状态字典...
220
  从字典中提取模型状态字典...
230
  从字典中

In [None]:
import pickle
file_path_pickle = "all_model_train_losses.pickle"
with open(file_path_pickle, 'wb') as f: # 注意 'wb' 表示写入二进制
    pickle.dump(all_model_train_losses, f)

file_path_pickle = "all_model_test_losses.pickle"
with open(file_path_pickle, 'wb') as f: # 注意 'wb' 表示写入二进制
    pickle.dump(all_model_test_losses, f)

In [8]:
np.save('epochs.npy', epochs)