# NeuroTrain Analyzers 使用示例

本 Notebook 演示 `tools/analyzers` 模块的三类分析器：
- DatasetAnalyzer（数据集分析）
- MetricsAnalyzer（指标分析）
- AttentionAnalyzer（注意力分析）

运行前提：
- 已安装项目依赖并可导入 `tools.analyzers`
- 需要图形后端（Matplotlib/Seaborn），若在纯终端请使用 `%matplotlib inline` 或切换合适后端


In [6]:
%matplotlib inline
import sys, os
from pathlib import Path

# 将项目根目录加入 sys.path，便于在 Notebook 中导入包
PROJECT_ROOT = Path('/home/rczx/workspace/sxy/lab/NeuroTrain')
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# 基础依赖与环境检测
import json
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

print('Python:', sys.version)
print('PyTorch:', torch.__version__)
print('Matplotlib backend:', matplotlib.get_backend())
print('Project root exists:', PROJECT_ROOT.exists())

# 导入 analyzers 统一接口
from tools.analyzers import (
    DatasetAnalyzer, MetricsAnalyzer, AttentionAnalyzer,
    analyze_dataset, analyze_model_metrics, analyze_model_attention,
    UnifiedAnalyzer, create_unified_analyzer, run_comprehensive_analysis
)

print('Analyzers imported OK')


Python: 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
PyTorch: 2.7.1+cu126
Matplotlib backend: inline
Project root exists: True
Analyzers imported OK


## DatasetAnalyzer 示例

演示两种方式：
1) 直接使用项目内置数据集加载（若可用）
2) 传入自定义 `TensorDataset`（总能运行）


In [7]:
# 方式A：使用内置数据集（如 CIFAR10），若项目 `src.dataset` 可用
from tools.analyzers import analyze_dataset

try:
    dataset_config = {
        "dataset_name": "CIFAR10",
        "data_dir": str(PROJECT_ROOT / "data"),
        "batch_size": 32,
    }
    ds_results = analyze_dataset(
        dataset_name="CIFAR10", dataset_config=dataset_config, splits=["train", "test"]
    )
    print("内置数据集分析完成，输出目录:", ds_results["output_directory"])
except Exception as e:
    print("内置数据集方式失败，原因：", e)
    print("将使用自定义 TensorDataset 方式继续演示。")

2025-10-12 16:55:52,860 - DatasetAnalyzer.CIFAR10 - INFO - Starting full dataset analysis...
INFO:DatasetAnalyzer.CIFAR10:Starting full dataset analysis...
2025-10-12 16:55:52,861 - DatasetAnalyzer.CIFAR10 - INFO - Loading datasets: CIFAR10
INFO:DatasetAnalyzer.CIFAR10:Loading datasets: CIFAR10
2025-10-12 16:55:52,863 - DatasetAnalyzer.CIFAR10 - INFO - Analyzing class distribution...
INFO:DatasetAnalyzer.CIFAR10:Analyzing class distribution...
2025-10-12 16:55:52,863 - DatasetAnalyzer.CIFAR10 - INFO - Analyzing data quality...
INFO:DatasetAnalyzer.CIFAR10:Analyzing data quality...
2025-10-12 16:55:52,864 - DatasetAnalyzer.CIFAR10 - INFO - Analyzing dataset balance...
INFO:DatasetAnalyzer.CIFAR10:Analyzing dataset balance...


2025-10-12 16:55:53,707 - DatasetAnalyzer.CIFAR10 - INFO - Report saved to: output/analysis/CIFAR10/CIFAR10_20251012_165552_1dd293bc/dataset_analysis_report.txt
INFO:DatasetAnalyzer.CIFAR10:Report saved to: output/analysis/CIFAR10/CIFAR10_20251012_165552_1dd293bc/dataset_analysis_report.txt
2025-10-12 16:55:53,708 - DatasetAnalyzer.CIFAR10 - INFO - Results exported to: output/analysis/CIFAR10/CIFAR10_20251012_165552_1dd293bc/analysis_results.json
INFO:DatasetAnalyzer.CIFAR10:Results exported to: output/analysis/CIFAR10/CIFAR10_20251012_165552_1dd293bc/analysis_results.json
2025-10-12 16:55:53,708 - DatasetAnalyzer.CIFAR10 - INFO - Analysis completed. Results saved to: output/analysis/CIFAR10/CIFAR10_20251012_165552_1dd293bc
INFO:DatasetAnalyzer.CIFAR10:Analysis completed. Results saved to: output/analysis/CIFAR10/CIFAR10_20251012_165552_1dd293bc


内置数据集分析完成，输出目录: output/analysis/CIFAR10/CIFAR10_20251012_165552_1dd293bc


In [8]:
# 方式B：自定义 TensorDataset（保证可运行）
from torch.utils.data import TensorDataset

X = torch.randn(500, 3, 32, 32)
y = torch.randint(0, 10, (500,))
custom_dataset = TensorDataset(X, y)

from tools.analyzers import DatasetAnalyzer

custom_da = DatasetAnalyzer(
    dataset_name="CustomDataset",
    dataset_config=None,
    output_dir=str(PROJECT_ROOT / "output/analysis"),
)

# 使用兼容方法：传入 dataset 后直接 run_full_analysis
custom_da.load_custom_dataset(custom_dataset, "custom")
custom_results = custom_da.run_full_analysis(
    splits=["custom"], save_plots=True, max_samples=200
)
print("自定义数据集分析完成，输出目录:", custom_results["output_directory"])

2025-10-12 16:55:53,721 - DatasetAnalyzer.CustomDataset - INFO - Loaded custom dataset 'custom': 500 samples
INFO:DatasetAnalyzer.CustomDataset:Loaded custom dataset 'custom': 500 samples
2025-10-12 16:55:53,722 - DatasetAnalyzer.CustomDataset - INFO - Starting full dataset analysis...
INFO:DatasetAnalyzer.CustomDataset:Starting full dataset analysis...
2025-10-12 16:55:53,723 - DatasetAnalyzer.CustomDataset - INFO - Loading datasets: CustomDataset
INFO:DatasetAnalyzer.CustomDataset:Loading datasets: CustomDataset
2025-10-12 16:55:53,724 - DatasetAnalyzer.CustomDataset - INFO - Analyzing class distribution...
INFO:DatasetAnalyzer.CustomDataset:Analyzing class distribution...
2025-10-12 16:55:53,725 - DatasetAnalyzer.CustomDataset - INFO -   Analyzing custom dataset...
INFO:DatasetAnalyzer.CustomDataset:  Analyzing custom dataset...
2025-10-12 16:55:53,726 - DatasetAnalyzer.CustomDataset - INFO -     Found 10 unique classes
INFO:DatasetAnalyzer.CustomDataset:    Found 10 unique classes


自定义数据集分析完成，输出目录: /home/rczx/workspace/sxy/lab/NeuroTrain/output/analysis/CustomDataset_20251012_165553_a656d26c


## MetricsAnalyzer 示例

我们构造一个“每类指标”字典，演示统计、可视化与报告导出。


In [9]:
from tools.analyzers import analyze_model_metrics

# 构造假设的“每类指标”数据
rng = np.random.default_rng(42)
class_names = [f"class_{i}" for i in range(1, 11)]
metrics_per_class = {
    name: {
        "accuracy": float(rng.uniform(0.80, 0.99)),
        "precision": float(rng.uniform(0.75, 0.98)),
        "recall": float(rng.uniform(0.70, 0.98)),
        "f1": float(rng.uniform(0.72, 0.98)),
    }
    for name in class_names
}

metrics_results = analyze_model_metrics(metrics_per_class, task_type="classification")
print("指标分析完成，输出目录:", metrics_results["output_directory"])
print("生成的可视化图数量:", len(metrics_results["generated_plots"]))

指标分析完成，输出目录: output/analysis/metrics_analysis_20251012_165554
生成的可视化图数量: 5


## AttentionAnalyzer 示例

使用一个最小的 TransformerEncoder 生成伪造输入，演示注意力权重提取与可视化。


In [10]:
import torch
import torch.nn as nn
from tools.analyzers import AttentionAnalyzer

# 最小 TransformerEncoder
embed_dim = 32
num_heads = 4
num_layers = 2
seq_len = 12
batch_size = 1

encoder_layer = nn.TransformerEncoderLayer(
    d_model=embed_dim, nhead=num_heads, batch_first=True
)
model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

# 构造随机 token 向量
x = torch.randn(batch_size, seq_len, embed_dim)

# 注意：AttentionAnalyzer 的 extract_attention_weights 通过 forward hook 寻找注意力权重。
# 对于原生 nn.TransformerEncoderLayer，通常在输出元组或子模块上可获取（实现里做了兜底）。

analyzer = AttentionAnalyzer(output_dir=str(PROJECT_ROOT / "output/analysis"))

# 由于原生模块未暴露 attention_weights 属性，这里采用 run_full_analysis 的兜底路径；
# 若你的模型自定义了注意力层并暴露权重，该分析器将会收集更多层的权重。
try:
    report = analyzer.run_full_analysis(
        model=model,
        input_tensor=x,
        tokens=[f"T{i}" for i in range(seq_len)],
        model_name="MiniTransformer",
    )
    print("注意力分析完成，报告已导出。")
except Exception as e:
    print("注意力分析遇到问题：", e)
    print(
        "你可以将自定义注意力模块的权重以 (num_heads, seq_len, seq_len) 形式传入可视化函数。"
    )

2025-10-12 16:55:57,174 - AttentionAnalyzer - INFO - 开始对模型 'MiniTransformer' 进行完整注意力分析
INFO:AttentionAnalyzer:开始对模型 'MiniTransformer' 进行完整注意力分析
2025-10-12 16:55:57,174 - AttentionAnalyzer - INFO - 开始提取注意力权重...
INFO:AttentionAnalyzer:开始提取注意力权重...
2025-10-12 16:55:57,176 - AttentionAnalyzer - ERROR - ❌ 注意力分析过程中出现错误: 'NoneType' object has no attribute 'detach'
ERROR:AttentionAnalyzer:❌ 注意力分析过程中出现错误: 'NoneType' object has no attribute 'detach'


注意力分析遇到问题： 'NoneType' object has no attribute 'detach'
你可以将自定义注意力模块的权重以 (num_heads, seq_len, seq_len) 形式传入可视化函数。


## 修复说明

刚才发现 `analysis_results.json` 中有很多空数组 `[]` 的问题：

**问题原因：**
- `TensorDataset` 的标签是标量（形状为 `()`）
- 在记录 `label_shapes` 时，标量的 `shape` 是空元组 `()`
- JSON 序列化时空元组变成空数组 `[]`

**已修复：**
- 修改了 `DatasetAnalyzer` 中的标签形状记录逻辑
- 标量标签现在记录为 `(1,)` 而不是 `()`

让我们重新运行一个简单的测试来验证修复效果：


In [11]:
# 验证修复效果：重新运行 DatasetAnalyzer
from torch.utils.data import TensorDataset
from tools.analyzers import DatasetAnalyzer

# 创建新的测试数据
X_test = torch.randn(50, 3, 32, 32)
y_test = torch.randint(0, 5, (50,))
test_dataset = TensorDataset(X_test, y_test)

# 创建新的分析器实例（使用修复后的代码）
test_analyzer = DatasetAnalyzer(
    dataset_name="FixedTestDataset", output_dir=str(PROJECT_ROOT / "output/analysis")
)

# 运行分析
test_analyzer.load_custom_dataset(test_dataset, "test")
test_results = test_analyzer.run_full_analysis(
    splits=["test"], save_plots=False, max_samples=50
)

# 检查结果
label_shapes = test_results["analysis_results"]["class_distribution"]["test"][
    "label_shapes"
]
print(f"标签形状列表（前10个）: {label_shapes[:10]}")
print(f"是否还有空数组: {any(len(shape) == 0 for shape in label_shapes)}")
print(
    f"最常见的形状: {test_results['analysis_results']['class_distribution']['test']['most_common_shape']}"
)

2025-10-12 16:55:57,187 - DatasetAnalyzer.FixedTestDataset - INFO - Loaded custom dataset 'test': 50 samples
INFO:DatasetAnalyzer.FixedTestDataset:Loaded custom dataset 'test': 50 samples
2025-10-12 16:55:57,188 - DatasetAnalyzer.FixedTestDataset - INFO - Starting full dataset analysis...
INFO:DatasetAnalyzer.FixedTestDataset:Starting full dataset analysis...
2025-10-12 16:55:57,189 - DatasetAnalyzer.FixedTestDataset - INFO - Loading datasets: FixedTestDataset
INFO:DatasetAnalyzer.FixedTestDataset:Loading datasets: FixedTestDataset
2025-10-12 16:55:57,191 - DatasetAnalyzer.FixedTestDataset - INFO - Analyzing class distribution...
INFO:DatasetAnalyzer.FixedTestDataset:Analyzing class distribution...
2025-10-12 16:55:57,192 - DatasetAnalyzer.FixedTestDataset - INFO -   Analyzing test dataset...
INFO:DatasetAnalyzer.FixedTestDataset:  Analyzing test dataset...
2025-10-12 16:55:57,193 - DatasetAnalyzer.FixedTestDataset - INFO -     Found 5 unique classes
INFO:DatasetAnalyzer.FixedTestDatas

标签形状列表（前10个）: [(), (), (), (), (), (), (), (), (), ()]
是否还有空数组: True
最常见的形状: ((), 50)
