In [None]:
# Cell 1: 导入依赖
import numpy as np
import matplotlib.pyplot as plt
import os
import seaborn as sns
# Cell 4: 生成30个专家的相关性分析（基于门控分数）
# 虽然我们没有所有专家的输出，但可以用门控分数来分析专家间的相似性

gate_scores = np.load("/exp_data/sjx/star/main_transformer_moe_weight/experiment_data/gate_scores.npy")
print("gate_scores形状:", gate_scores.shape)  # (4, 1149, 300, 30)

# 计算专家门控分数的相关性
gate_scores_reshaped = gate_scores.reshape(-1, gate_scores.shape[-1])  # (所有样本×层数×token, 30)
expert_gate_corr = np.corrcoef(gate_scores_reshaped.T)  # (30, 30)

print("专家门控分数相关性矩阵形状:", expert_gate_corr.shape)

gate_scores形状: (4, 1149, 300, 30)


In [None]:
# Cell 5: 可视化30个专家的门控分数相关性
plt.figure(figsize=(12, 10))
sns.heatmap(expert_gate_corr, cmap='coolwarm', center=0, annot=False, square=True,
            xticklabels=[f'Expert_{i}' for i in range(30)],
            yticklabels=[f'Expert_{i}' for i in range(30)])
plt.title('Expert Gate Score Correlation (All 30 Experts)', fontsize=14, fontweight='bold')
plt.xlabel('Expert ID', fontsize=12)
plt.ylabel('Expert ID', fontsize=12)
plt.tight_layout()
plt.savefig('/exp_data/sjx/star/main_transformer_moe_weight/moe_analysis/expert_gate_correlation_30experts.svg', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# Cell 6: 输出相关性统计信息
print("专家门控分数相关性统计:")
print(f"最大相关性: {expert_gate_corr.max():.4f}")
print(f"最小相关性: {expert_gate_corr.min():.4f}")
print(f"平均相关性: {expert_gate_corr.mean():.4f}")
print(f"相关性标准差: {expert_gate_corr.std():.4f}")

# 找出最相关和最不相关的专家对
mask = ~np.eye(30, dtype=bool)  # 排除对角线
off_diagonal_corr = expert_gate_corr[mask]
max_corr_idx = np.unravel_index(np.argmax(off_diagonal_corr), expert_gate_corr.shape)
min_corr_idx = np.unravel_index(np.argmin(off_diagonal_corr), expert_gate_corr.shape)

print(f"最相关的专家对: Expert_{max_corr_idx[0]} vs Expert_{max_corr_idx[1]} (相关性: {expert_gate_corr[max_corr_idx]:.4f})")
print(f"最不相关的专家对: Expert_{min_corr_idx[0]} vs Expert_{min_corr_idx[1]} (相关性: {expert_gate_corr[min_corr_idx]:.4f})")