In [17]:
# 1. 导入必要的库

import numpy as np
import pandas as pd
import scipy.stats as stats
import joblib
import pickle
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import sys
from typing import Dict, Any

# 导入自定义模块
import MyNet
from MyNet import NetList, DNN

In [18]:
def load_and_analyze_model(file_path):
    try:
        data = joblib.load(file_path)
        print("File loaded successfully")
        
        print("\nData Structure Analysis:")
        for key, value in data.items():
            print(f"\n{key}:")
            if isinstance(value, NetList):
                print(f"  Type: NetList containing {len(value)} models")
                print(f"  First model type: {type(value.models[0])}")
            elif isinstance(value, list):
                print(f"  Type: list of length {len(value)}")
                if value:
                    print(f"  First item type: {type(value[0])}")
            elif isinstance(value, np.ndarray):
                print(f"  Type: numpy array of shape {value.shape}")
            elif isinstance(value, (int, float)):
                print(f"  Type: {type(value)}, Value: {value}")
            else:
                print(f"  Type: {type(value)}")
        
        if 'train_losses' in data and 'main_losses' in data:
            plt.figure(figsize=(12, 6))
            plt.plot(data['train_losses'], label='Training Loss')
            plt.plot(np.arange(0, len(data['train_losses']), len(data['train_losses'])//len(data['main_losses'])), 
                     data['main_losses'], label='Validation Loss')
            plt.xlabel('Training Steps')
            plt.ylabel('Loss')
            plt.title('Training and Validation Losses')
            plt.legend()
            plt.grid(True)
            plt.savefig('loss_plot.png')
            plt.close()
            print("\nLoss plot saved as 'loss_plot.png'")
        
        return data
    except Exception as e:
        print(f"Error during loading or analysis: {e}")
        return None

def generate_report(res: Dict[str, Any]) -> str:
    report = "实验数据分析报告\n" + "================\n\n"
    # 1. 训练信息
    report += "1. 训练信息\n"
    report += "-------------\n"
    epochs = len(res['info']) / (200 / len(res['info'][0]['idx']))
    report += f"总训练轮数 (epochs): {epochs:.1f}\n"
    report += f"每轮步数: {200 / len(res['info'][0]['idx']):.0f}\n"
    report += f"批次大小: {len(res['info'][0]['idx'])}\n"
    report += f"总训练步数: {len(res['info'])}\n\n"

    # 2. 模型信息
    report += "2. 模型信息\n"
    report += "-------------\n"
    report += f"模型类型: {type(res['models'])}\n"
    report += f"保存的模型数量: {len(res['models'])}\n\n"

    # 3. 损失信息
    report += "3. 损失信息\n"
    report += "-------------\n"
    report += "主模型损失:\n"
    report += f"  类型: {type(res['main_losses'])}\n"
    report += f"  记录次数: {len(res['main_losses'])}\n"
    report += f"  最终损失: {res['main_losses'][-1]:.4f}\n\n"
    report += "训练损失:\n"
    report += f"  类型: {type(res['train_losses'])}\n"
    report += f"  形状: {res['train_losses'].shape}\n"
    report += f"  最终损失: {res['train_losses'][-1]:.4f}\n\n"

    # 4. 反事实分析
    report += "4. 反事实分析\n"
    report += "-------------\n"
    report += f"分析的样本数: {len(res['counterfactual'])}\n"
    report += f"每个样本的模型状态数: {len(res['counterfactual'][0])}\n\n"

    # 5. 其他信息
    report += "5. 其他信息\n"
    report += "-------------\n"
    report += f"Alpha 值: {res['alpha']}\n"

    return report

def generate_combined_loss_plot(res: Dict[str, Any], save_path: str):
    fig, ax = plt.subplots(figsize=(12, 8))

    # Calculate epochs and steps
    epochs = len(res['info']) / (200 / len(res['info'][0]['idx']))
    steps_per_epoch = 200 / len(res['info'][0]['idx'])

    # Training loss plot
    x_train = np.arange(len(res['train_losses'])) / steps_per_epoch
    ax.plot(x_train, res['train_losses'], label='Training Loss', alpha=0.7)

    # Main model loss plot
    x_main = np.arange(len(res['main_losses']))
    ax.plot(x_main, res['main_losses'], label='Main Model Loss', alpha=0.7)

    # Set labels and title
    ax.set_title('Training and Main Model Losses vs. Epochs', fontsize=16)
    ax.set_xlabel('Epochs', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.legend(fontsize=10)

    # Add text with additional information
    info_text = f"Total Epochs: {epochs:.1f}\n" \
                f"Steps per Epoch: {steps_per_epoch:.0f}\n" \
                f"Batch Size: {len(res['info'][0]['idx'])}\n" \
                f"Alpha: {res['alpha']}"
    ax.text(0.95, 0.95, info_text, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Loss plot saved to: {save_path}")

In [19]:
sys.modules['MyNet'] = sys.modules[__name__]

# 定义文件路径
path = '/home/zihan/codes/sgd-influence/experiment/Sec71/mnist_dnn/sgd000.dat'

# 尝试加载数据
try:
    res = joblib.load(path, mmap_mode='r')
    print("数据加载成功")
    print("数据键:", res.keys())
except Exception as e:
    print(f"Joblib 加载失败: {e}")
    try:
        with open(path, 'rb') as f:
            res = pickle.load(f)
        print("Pickle 加载成功")
        print("数据键:", res.keys())
    except Exception as e:
        print(f"Pickle 加载失败: {e}")
        res = None

if res is None:
    print("无法加载数据，请检查文件路径和格式")

# 4. 数据分析

if res is not None:
    # 生成报告
    report = generate_report(res)
    print(report)

    # 生成损失图
    generate_combined_loss_plot(res, path.replace('.dat', '.png'))

# 5. 额外分析（可选）

# 这里可以添加任何额外的分析代码
# 例如，探索模型结构，分析特定参数等

# 示例：打印模型结构（如果适用）
if 'models' in res and isinstance(res['models'], NetList):
    print("\n模型结构:")
    print(res['models'].models[0])  # 假设第一个模型代表整体结构

# 6. 结论和下一步

print("\n分析结论:")
print("1. 模型训练过程中的损失变化已可视化，可以观察训练和验证损失的趋势。")
print("2. 生成的报告提供了关于模型训练的关键信息，包括训练轮数、批次大小等。")
print("3. 反事实分析部分显示了每个样本的模型状态数，这可能对理解模型的鲁棒性很有帮助。")

print("\n下一步计划:")
print("1. 深入分析反事实模型，比较不同样本对模型训练的影响。")
print("2. 考虑添加更多的可视化，例如参数分布图或学习率变化图。")
print("3. 如果可能，进行交叉验证或在不同数据集上测试模型的泛化能力。")

数据加载成功
数据键: dict_keys(['models', 'info', 'counterfactual', 'alpha', 'main_losses', 'train_losses'])
实验数据分析报告

1. 训练信息
-------------
总训练轮数 (epochs): 12.0
每轮步数: 10
批次大小: 20
总训练步数: 120

2. 模型信息
-------------
模型类型: <class 'MyNet.NetList'>
保存的模型数量: 121

3. 损失信息
-------------
主模型损失:
  类型: <class 'list'>
  记录次数: 13
  最终损失: 0.1119

训练损失:
  类型: <class 'numpy.memmap'>
  形状: (121,)
  最终损失: 0.0090

4. 反事实分析
-------------
分析的样本数: 200
每个样本的模型状态数: 121

5. 其他信息
-------------
Alpha 值: 0.001

Loss plot saved to: /home/zihan/codes/sgd-influence/experiment/Sec71/mnist_dnn/sgd000.png

模型结构:
DNN(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=8, bias=True)
    (3): ReLU()
    (4): Linear(in_features=8, out_features=1, bias=True)
  )
)

分析结论:
1. 模型训练过程中的损失变化已可视化，可以观察训练和验证损失的趋势。
2. 生成的报告提供了关于模型训练的关键信息，包括训练轮数、批次大小等。
3. 反事实分析部分显示了每个样本的模型状态数，这可能对理解模型的鲁棒性很有帮助。

下一步计划:
1. 深入分析反事实模型，比较不同样本对模型训练的影响。
2. 考虑添加更多的可视化，例如参数分布图