In [3]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter


def read_fasta(filepath):
    """读取fasta文件"""
    seqs = []
    with open(filepath, 'r') as f:
        seq = ''
        for line in f:
            if line.startswith('>'):
                if seq:
                    seqs.append(seq)
                    seq = ''
            else:
                seq += line.strip()
        if seq:
            seqs.append(seq)
    return seqs


# 读取测试集序列
pos_test_seqs = read_fasta('/exp_data/sjx/star/first_data/shisuandanbai/positive_test.fasta')
neg_test_seqs = read_fasta('/exp_data/sjx/star/first_data/shisuandanbai/negative_test.fasta')

print(f"正样本测试集序列数: {len(pos_test_seqs)}")
print(f"负样本测试集序列数: {len(neg_test_seqs)}")

# 加载模型权重数据
attn_weights = np.load("/exp_data/sjx/star/main_transformer_moe_weight/experiment_data/attn_weights.npy")
gate_scores = np.load("/exp_data/sjx/star/main_transformer_moe_weight/experiment_data/gate_scores.npy")
topk_idx = np.load("/exp_data/sjx/star/main_transformer_moe_weight/experiment_data/topk_idx.npy")
labels = np.load("/exp_data/sjx/star/main_transformer_moe_weight/experiment_data/labels.npy")

print(f"数据形状:")
print(f"  attn_weights: {attn_weights.shape}")
print(f"  gate_scores: {gate_scores.shape}")
print(f"  topk_idx: {topk_idx.shape}")
print(f"  labels: {labels.shape}")

# 选择第一条正样本序列进行分析
seq_idx = 1  # 第一条正样本
layer_idx = 0  # 第一层
target_seq = pos_test_seqs[seq_idx]
seq_len = len(target_seq)

print("=" * 80)
print("序列详细信息")
print("=" * 80)
print(f"序列索引: {seq_idx}")
print(f"序列长度: {seq_len}")
print(f"序列标签: {labels[seq_idx]} (1=正样本)")
print(f"完整序列: {target_seq}")
print(f"序列前50个氨基酸: {target_seq[:50]}")
print(f"序列后50个氨基酸: {target_seq[-50:] if seq_len > 50 else target_seq}")

# 氨基酸组成统计
aa_counts = Counter(target_seq)
print(f"\n氨基酸组成:")
for aa in sorted(aa_counts.keys()):
    print(f"  {aa}: {aa_counts[aa]} ({aa_counts[aa] / seq_len * 100:.1f}%)")

# 检查注意力权重的实际形状
print(f"\n注意力权重形状检查:")
print(f"  attn_weights[layer_idx, seq_idx].shape: {attn_weights[layer_idx, seq_idx].shape}")
print(f"  attn_weights[layer_idx, seq_idx].ndim: {attn_weights[layer_idx, seq_idx].ndim}")

# 根据实际形状获取注意力权重
if attn_weights[layer_idx, seq_idx].ndim == 3:
    # 如果是 [heads, seq_len, seq_len]
    seq_attn = attn_weights[layer_idx, seq_idx].mean(axis=0)  # [seq_len, seq_len]
elif attn_weights[layer_idx, seq_idx].ndim == 2:
    # 如果已经是 [seq_len, seq_len]
    seq_attn = attn_weights[layer_idx, seq_idx]
else:
    print(f"意外的注意力权重形状: {attn_weights[layer_idx, seq_idx].shape}")
    # 尝试重塑
    seq_attn = attn_weights[layer_idx, seq_idx].reshape(300, 300)  # 假设最大长度是300

seq_gate_scores = gate_scores[layer_idx, seq_idx]  # [seq_len, num_experts]
seq_expert_assign = topk_idx[layer_idx, seq_idx]  # [seq_len, topk]

print(f"\n处理后的数据形状:")
print(f"  seq_attn: {seq_attn.shape}")
print(f"  seq_gate_scores: {seq_gate_scores.shape}")
print(f"  seq_expert_assign: {seq_expert_assign.shape}")

# 只取实际序列长度的部分（去除padding）
seq_attn = seq_attn[:seq_len, :seq_len]
seq_gate_scores = seq_gate_scores[:seq_len]
seq_expert_assign = seq_expert_assign[:seq_len]

print(f"\n裁剪后的数据形状:")
print(f"  seq_attn: {seq_attn.shape}")
print(f"  seq_gate_scores: {seq_gate_scores.shape}")
print(f"  seq_expert_assign: {seq_expert_assign.shape}")

# 计算每个位置的注意力权重总和（被其他位置关注的程度）
attention_in = seq_attn.sum(axis=0)

print("\n" + "=" * 80)
print("高注意力区域分析")
print("=" * 80)

# 找出高注意力区域
attention_threshold = np.percentile(attention_in, 80)  # 前20%的高注意力位置
high_attention_positions = np.where(attention_in > attention_threshold)[0]

print(f"注意力阈值 (前20%): {attention_threshold:.4f}")
print(f"高注意力位置数量: {len(high_attention_positions)}")
print(f"高注意力位置索引: {high_attention_positions.tolist()}")

print(f"\n高注意力位置详细信息:")
for i, pos in enumerate(high_attention_positions):
    print(f"  位置 {pos + 1}: {target_seq[pos]} (注意力分数: {attention_in[pos]:.4f})")

print(f"\n高注意力区域序列片段:")
high_attention_seq = ''.join([target_seq[i] for i in high_attention_positions])
print(f"  完整片段: {high_attention_seq}")

# 分析每个专家专注的位置
print("\n" + "=" * 80)
print("专家专注区域分析")
print("=" * 80)

num_experts = gate_scores.shape[-1]
for expert_id in range(num_experts):
    # 找出该专家被分配的位置
    expert_positions = []
    for pos in range(seq_len):
        if expert_id in seq_expert_assign[pos]:
            expert_positions.append(pos)

    if expert_positions:
        # 计算该专家专注位置与高注意力位置的重合度
        overlap = len(set(expert_positions) & set(high_attention_positions))
        overlap_ratio = overlap / len(expert_positions) if expert_positions else 0

        print(f"\nExpert {expert_id}:")
        print(f"  专注位置数量: {len(expert_positions)}")
        print(f"  专注位置索引: {expert_positions}")
        print(f"  专注位置氨基酸: {[target_seq[i] for i in expert_positions]}")
        print(f"  专注区域序列: {''.join([target_seq[i] for i in expert_positions])}")
        print(f"  与高注意力区域重合度: {overlap_ratio:.2f} ({overlap}/{len(expert_positions)})")

        # 分析该专家专注的氨基酸类型
        aa_counts = Counter([target_seq[i] for i in expert_positions])
        print(f"  氨基酸分布: {dict(aa_counts)}")

        # 计算该专家专注位置的平均门控分数
        avg_gate_score = np.mean([seq_gate_scores[pos, expert_id] for pos in expert_positions])
        print(f"  平均门控分数: {avg_gate_score:.4f}")

# 保存序列信息到文件
with open('/exp_data/sjx/star/main_transformer_moe_weight/moe_analysis/sequence_analysis_info.txt', 'w') as f:
    f.write("序列分析详细信息\n")
    f.write("=" * 50 + "\n")
    f.write(f"序列索引: {seq_idx}\n")
    f.write(f"序列长度: {seq_len}\n")
    f.write(f"完整序列: {target_seq}\n")
    f.write(f"高注意力位置: {high_attention_positions.tolist()}\n")
    f.write(f"高注意力区域序列: {high_attention_seq}\n")

print(f"\n序列信息已保存到: /exp_data/sjx/star/main_transformer_moe_weight/moe_analysis/sequence_analysis_info.txt")


In [None]:
##### UniProt ID: A0A3M8QPJ6
import numpy as np
import matplotlib.pyplot as plt

# 假设你已经有 attn_weights, pos_test_seqs, seq_idx, target_seq, seq_len
import numpy as np

# 假设你已经有 attn_weights, target_seq, seq_idx, seq_len
import numpy as np


# 调整高注意力token的输出，减少数量，提高阈值
def output_high_attention_tokens(attn_weights, target_seq, seq_idx, seq_len, output_txt, percentile=95):
    with open(output_txt, 'w') as f:
        f.write(f"Sequence index: {seq_idx}\n")
        f.write(f"Sequence length: {seq_len}\n")
        f.write(f"Full sequence:\n{target_seq}\n\n")
        for layer_idx in range(4):
            attn = attn_weights[layer_idx, seq_idx]
            if attn.ndim == 3:
                attn = attn.mean(axis=0)
            attn = attn[:seq_len, :seq_len]
            attention_in = attn.sum(axis=0)
            attention_threshold = np.percentile(attention_in, percentile)
            high_idx = np.where(attention_in >= attention_threshold)[0]
            high_aas = [target_seq[i] for i in high_idx]
            f.write(
                f"Layer {layer_idx} high-attention tokens (threshold={attention_threshold:.4f}, top {len(high_idx)} tokens):\n")
            f.write(f"Indices: {high_idx.tolist()}\n")
            f.write(f"Amino acids: {''.join(high_aas)}\n\n")
            print(f"Layer {layer_idx}: high-attention token indices: {high_idx.tolist()}")
            print(f"Layer {layer_idx}: high-attention amino acids: {''.join(high_aas)}")
            print(f"Layer {layer_idx}: number of high-attention tokens: {len(high_idx)}")


# 用法示例（提高阈值到90%，减少高注意力token数量）：
# output_high_attention_tokens(attn_weights, target_seq, seq_idx, seq_len, '/exp_data/sjx/star/main_transformer_moe_weight/moe_analysis/high_attention_tokens_info.txt', percentile=90)
for layer_idx in range(4):  # 4层
    # 1. 取该层的注意力
    attn = attn_weights[layer_idx, seq_idx]
    if attn.ndim == 3:
        attn = attn.mean(axis=0)  # [seq_len, seq_len]
    attn = attn[:seq_len, :seq_len]

    # 2. 注意力热图
    plt.figure(figsize=(8, 6))
    plt.imshow(attn, cmap='Blues', aspect='auto')
    plt.title(f'Attention Weights (Layer {layer_idx})')
    plt.xlabel('Token Position')
    plt.ylabel('Token Position')
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(f'/exp_data/sjx/star/main_transformer_moe_weight/moe_analysis/attn_heatmap_layer{layer_idx}.pdf')
    plt.show()
    plt.close()

    # 3. 注意力权重分布
    attention_in = attn.sum(axis=0)
    attention_threshold = np.percentile(attention_in, 80)
    high_attention_positions = np.where(attention_in > attention_threshold)[0]

    plt.figure(figsize=(10, 5))
    plt.plot(range(seq_len), attention_in, 'b-', linewidth=2)
    plt.axhline(y=attention_threshold, color='r', linestyle='--', label=f'Threshold ({attention_threshold:.4f})')
    plt.scatter(high_attention_positions, attention_in[high_attention_positions],
                color='red', s=50, zorder=5, label='High Attention Positions')
    plt.title(f'Attention Incoming Weights (Layer {layer_idx})')
    plt.xlabel('Token Position')
    plt.ylabel('Sum of Incoming Attention')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'/exp_data/sjx/star/main_transformer_moe_weight/moe_analysis/attn_incoming_layer{layer_idx}.pdf')
    plt.show()
    plt.close()
    # 3.1 输出注意力权重最高的前5个token
    top_n = 5
    top_idx = np.argsort(attention_in)[-top_n:][::-1]  # 从大到小
    top_aas = [target_seq[i] for i in top_idx]
    top_values = attention_in[top_idx]
    print(f"Layer {layer_idx}: Top {top_n} high-attention token indices: {top_idx.tolist()}")
    print(f"Layer {layer_idx}: Top {top_n} high-attention amino acids: {''.join(top_aas)}")
    print(f"Layer {layer_idx}: Top {top_n} attention values: {top_values.tolist()}")
output_high_attention_tokens(attn_weights, target_seq, seq_idx, seq_len,
                             '/exp_data/sjx/star/main_transformer_moe_weight/moe_analysis/high_attention_tokens_info.txt')
