# Multi-seed Grokking Results Visualization

This notebook visualizes results from multi-seed grokking experiments.
Each plot shows:
- Train Accuracy (left y-axis)
- Test Accuracy (left y-axis)
- A specific metric (right y-axis): Complexity, LLC, L2 Norm, or Spectral Entropy

All plots show mean ± std over multiple seeds.


In [4]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path


In [6]:
def plot_with_metric(task_name, wd, results_base="/root/autodl-tmp/test/results/data", save_dir=None):
    """绘制训练/测试acc + 各个单独指标的图"""
    task_safe = task_name.replace("/", "_div_").replace("*", "_mul_").replace("+", "_plus_")
    data_dir = Path(results_base) / task_safe / f"wd_{wd}"
    seed_files = sorted(list(data_dir.glob("seed*.csv")))
    
    if not seed_files:
        print(f"⚠ No data: {task_name} WD={wd}")
        return
    
    # 加载数据
    all_data = []
    for seed_file in seed_files:
        data = {}
        with open(seed_file, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                for key, value in row.items():
                    if key not in data: data[key] = []
                    data[key].append(float(value))
        all_data.append(data)
    
    # 计算统计量
    steps = np.array(all_data[0]['steps'])
    def stats(key):
        d = np.array([x[key] for x in all_data])
        return np.mean(d, axis=0), np.std(d, axis=0)
    
    train_acc_m, train_acc_s = stats('train_acc')
    test_acc_m, test_acc_s = stats('test_acc')
    llc_m, llc_s = stats('llc')
    l2_m, l2_s = stats('l2_norm')
    se_m, se_s = stats('spectral_entropy')
    attn_se_m, attn_se_s = stats('attention_spectral_entropy')
    emb_se_m, emb_se_s = stats('embedding_spectral_entropy')
    complexity_m = llc_m * l2_m
    complexity_s = np.sqrt((llc_s * l2_m)**2 + (llc_m * l2_s)**2)
    
    # 定义要绘制的指标
    metrics = [
        ('Complexity (LLC×L2)', complexity_m, complexity_s, 'green'),
        ('LLC', llc_m, llc_s, 'purple'),
        ('L2 Norm', l2_m, l2_s, 'orange'),
        ('Spectral Entropy', se_m, se_s, 'brown'),
        ('Attention Entropy', attn_se_m, attn_se_s, 'cyan'),
        ('Embedding Entropy', emb_se_m, emb_se_s, 'magenta'),
    ]
    
    # 为每个指标生成一张图
    for metric_name, metric_m, metric_s, color in metrics:
        fig, ax1 = plt.subplots(figsize=(12, 7))
        ax2 = ax1.twinx()
        
        # 左轴：训练和测试准确率
        line1 = ax1.plot(steps, train_acc_m, 'b-', label='Train Acc', linewidth=2, alpha=0.8)
        ax1.fill_between(steps, train_acc_m - train_acc_s, train_acc_m + train_acc_s, 
                         color='blue', alpha=0.15)
        line2 = ax1.plot(steps, test_acc_m, 'r-', label='Test Acc', linewidth=2.5, alpha=0.9)
        ax1.fill_between(steps, test_acc_m - test_acc_s, test_acc_m + test_acc_s, 
                         color='red', alpha=0.2)
        
        ax1.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
        ax1.set_ylabel('Accuracy', fontsize=13, fontweight='bold', color='black')
        ax1.tick_params(axis='y', labelcolor='black')
        ax1.set_xscale('log')
        ax1.set_ylim(-0.05, 1.05)
        ax1.grid(True, alpha=0.3, linestyle='--')
        
        # 右轴：特定指标
        line3 = ax2.plot(steps, metric_m, color=color, linestyle='--', 
                        label=metric_name, linewidth=2.5, alpha=0.9)
        ax2.fill_between(steps, metric_m - metric_s, metric_m + metric_s, 
                         color=color, alpha=0.2)
        
        ax2.set_ylabel(metric_name, fontsize=13, fontweight='bold', color=color)
        ax2.tick_params(axis='y', labelcolor=color)
        
        # 合并图例
        lines = line1 + line2 + line3
        labels = [l.get_label() for l in lines]
        ax1.legend(lines, labels, loc='best', fontsize=11, framealpha=0.95)
        
        # 标题
        plt.title(f'Task: {task_name} | WD={wd} | {metric_name}\n(Mean ± Std, {len(all_data)} seeds)', 
                 fontsize=14, fontweight='bold', pad=15)
        
        plt.tight_layout()
        
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            metric_safe = metric_name.replace(' ', '_').replace('(', '').replace(')', '').replace('×', 'x')
            save_path = os.path.join(save_dir, f"{task_safe}_wd{wd}_{metric_safe}.png")
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        
        plt.close()
    
    print(f"✓ {task_name:15s} WD={wd} → 6 plots generated")

# 批量绘制所有任务（9个任务 × 2个权重衰减 × 6个指标 = 108张图）
tasks = ['x+y', 'x-y', 'x*y', 'x_div_y', 'x2+y2', 'x2+xy+y2', 'x2+xy+y2+x', 'x3+xy', 'x3+xy2+y']
weight_decays = [0.0, 1.0]
save_base = '/root/autodl-tmp/test/results/plots'

print(f"{'='*70}")
print(f"Generating plots: {len(tasks)} tasks × {len(weight_decays)} WDs × 6 metrics")
print(f"Total: {len(tasks) * len(weight_decays) * 6} figures")
print(f"{'='*70}\n")

for task in tasks:
    for wd in weight_decays:
        try:
            plot_with_metric(task, wd, save_dir=save_base)
        except Exception as e:
            print(f"✗ Error in {task} WD={wd}: {e}")

print(f"\n{'='*70}")
print(f"All plots saved to: {save_base}")
print(f"{'='*70}")


Generating plots: 9 tasks × 2 WDs × 6 metrics
Total: 108 figures

✓ x+y             WD=0.0 → 6 plots generated
✓ x+y             WD=1.0 → 6 plots generated
✓ x-y             WD=0.0 → 6 plots generated
✓ x-y             WD=1.0 → 6 plots generated
✓ x*y             WD=0.0 → 6 plots generated
✓ x*y             WD=1.0 → 6 plots generated
✓ x_div_y         WD=0.0 → 6 plots generated
✓ x_div_y         WD=1.0 → 6 plots generated
✓ x2+y2           WD=0.0 → 6 plots generated
✓ x2+y2           WD=1.0 → 6 plots generated
✓ x2+xy+y2        WD=0.0 → 6 plots generated
✓ x2+xy+y2        WD=1.0 → 6 plots generated
✓ x2+xy+y2+x      WD=0.0 → 6 plots generated
✓ x2+xy+y2+x      WD=1.0 → 6 plots generated
✓ x3+xy           WD=0.0 → 6 plots generated
✓ x3+xy           WD=1.0 → 6 plots generated
✓ x3+xy2+y        WD=0.0 → 6 plots generated
✓ x3+xy2+y        WD=1.0 → 6 plots generated

All plots saved to: /root/autodl-tmp/test/results/plots


In [9]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import torch

tasks = ['x+y', 'x-y', 'x*y', 'x_div_y', 'x2+y2', 'x2+xy+y2', 'x2+xy+y2+x', 'x3+xy', 'x3+xy2+y']
weight_decays = [0.0, 1.0]
seeds = [42, 101, 2025]
steps = [100, 1000, 10000, 100000]
seed_colors = {42: 'red', 101: 'blue', 2025: 'green'}
seed_markers = {42: 'o', 101: 's', 2025: '^'}
checkpoint_base = Path("/root/autodl-tmp/test/results/checkpoints")
save_base = Path("/root/autodl-tmp/test/results/embedding_viz")
save_base.mkdir(parents=True, exist_ok=True)

print(f"{'='*70}")
print(f"Processing: {len(tasks)} tasks × {len(weight_decays)} WDs (3 seeds combined)")
print(f"{'='*70}\n")

for task in tasks:
    task_safe = task.replace("/", "_div_").replace("*", "_mul_").replace("+", "_plus_")
    for wd in weight_decays:
        checkpoint_dir = checkpoint_base / task_safe / f"wd_{wd}"
        if not checkpoint_dir.exists():
            continue
        
        fig, axes = plt.subplots(4, 4, figsize=(24, 24))
        found_any = False
        
        for step_idx, step in enumerate(steps):
            # 收集所有种子的embedding
            all_input_embs, all_output_embs = [], []
            seed_labels = []
            vocab_size = None
            
            for seed in seeds:
                checkpoint_file = checkpoint_dir / f"seed{seed}_step{step}.pt"
                if not checkpoint_file.exists():
                    continue
                
                checkpoint = torch.load(checkpoint_file, map_location='cpu', weights_only=False)
                state_dict = checkpoint['model_state_dict']
                input_emb = state_dict['token_embeddings.weight'].numpy()
                output_emb = state_dict['head.weight'].numpy()
                
                all_input_embs.append(input_emb)
                all_output_embs.append(output_emb)
                seed_labels.extend([seed] * input_emb.shape[0])
                if vocab_size is None:
                    vocab_size = input_emb.shape[0]
            
            if not all_input_embs:
                continue
            
            found_any = True
            combined_input = np.vstack(all_input_embs)
            combined_output = np.vstack(all_output_embs)
            
            # Input PCA
            pca = PCA(n_components=2)
            input_pca = pca.fit_transform(combined_input)
            ax = axes[step_idx, 0]
            for seed in seeds:
                mask = np.array(seed_labels) == seed
                ax.scatter(input_pca[mask, 0], input_pca[mask, 1], 
                          c=seed_colors[seed], marker=seed_markers[seed], 
                          s=60, alpha=0.6, label=f'Seed {seed}')
            # 标注数字（使用第一个种子的位置）
            for i in range(min(vocab_size, 15)):
                ax.annotate(str(i), (input_pca[i, 0], input_pca[i, 1]), 
                           fontsize=7, alpha=0.7, fontweight='bold')
            ax.set_title(f'Input PCA (Step {step})', fontsize=11, fontweight='bold')
            ax.set_xlabel('PC1', fontsize=9)
            ax.set_ylabel('PC2', fontsize=9)
            ax.legend(fontsize=8, loc='upper right')
            ax.grid(True, alpha=0.3)
            
            # Input t-SNE
            perplexity = min(30, combined_input.shape[0] - 1)
            tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
            input_tsne = tsne.fit_transform(combined_input)
            ax = axes[step_idx, 1]
            for seed in seeds:
                mask = np.array(seed_labels) == seed
                ax.scatter(input_tsne[mask, 0], input_tsne[mask, 1], 
                          c=seed_colors[seed], marker=seed_markers[seed], 
                          s=60, alpha=0.6, label=f'Seed {seed}')
            for i in range(min(vocab_size, 15)):
                ax.annotate(str(i), (input_tsne[i, 0], input_tsne[i, 1]), 
                           fontsize=7, alpha=0.7, fontweight='bold')
            ax.set_title(f'Input t-SNE (Step {step})', fontsize=11, fontweight='bold')
            ax.set_xlabel('t-SNE 1', fontsize=9)
            ax.set_ylabel('t-SNE 2', fontsize=9)
            ax.legend(fontsize=8, loc='upper right')
            ax.grid(True, alpha=0.3)
            
            # Output PCA
            pca = PCA(n_components=2)
            output_pca = pca.fit_transform(combined_output)
            ax = axes[step_idx, 2]
            for seed in seeds:
                mask = np.array(seed_labels) == seed
                ax.scatter(output_pca[mask, 0], output_pca[mask, 1], 
                          c=seed_colors[seed], marker=seed_markers[seed], 
                          s=60, alpha=0.6, label=f'Seed {seed}')
            for i in range(min(vocab_size, 15)):
                ax.annotate(str(i), (output_pca[i, 0], output_pca[i, 1]), 
                           fontsize=7, alpha=0.7, fontweight='bold')
            ax.set_title(f'Output PCA (Step {step})', fontsize=11, fontweight='bold')
            ax.set_xlabel('PC1', fontsize=9)
            ax.set_ylabel('PC2', fontsize=9)
            ax.legend(fontsize=8, loc='upper right')
            ax.grid(True, alpha=0.3)
            
            # Output t-SNE
            tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
            output_tsne = tsne.fit_transform(combined_output)
            ax = axes[step_idx, 3]
            for seed in seeds:
                mask = np.array(seed_labels) == seed
                ax.scatter(output_tsne[mask, 0], output_tsne[mask, 1], 
                          c=seed_colors[seed], marker=seed_markers[seed], 
                          s=60, alpha=0.6, label=f'Seed {seed}')
            for i in range(min(vocab_size, 15)):
                ax.annotate(str(i), (output_tsne[i, 0], output_tsne[i, 1]), 
                           fontsize=7, alpha=0.7, fontweight='bold')
            ax.set_title(f'Output t-SNE (Step {step})', fontsize=11, fontweight='bold')
            ax.set_xlabel('t-SNE 1', fontsize=9)
            ax.set_ylabel('t-SNE 2', fontsize=9)
            ax.legend(fontsize=8, loc='upper right')
            ax.grid(True, alpha=0.3)
        
        if found_any:
            plt.suptitle(f'Task: {task} | WD={wd} | Multi-seed Embedding (Seeds: 42, 101, 2025)', 
                        fontsize=16, fontweight='bold', y=0.995)
            plt.tight_layout()
            save_path = save_base / f"{task_safe}_wd{wd}_multiseed_embeddings.png"
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()
            print(f"✓ {task:15s} WD={wd}")

print(f"\n{'='*70}")
print(f"Done! Saved to: {save_base}")
print(f"{'='*70}")

Processing: 9 tasks × 2 WDs (3 seeds combined)

✓ x+y             WD=0.0
✓ x+y             WD=1.0
✓ x-y             WD=0.0
✓ x-y             WD=1.0
✓ x*y             WD=0.0
✓ x*y             WD=1.0
✓ x_div_y         WD=0.0
✓ x_div_y         WD=1.0
✓ x2+y2           WD=0.0
✓ x2+y2           WD=1.0
✓ x2+xy+y2        WD=0.0
✓ x2+xy+y2        WD=1.0
✓ x2+xy+y2+x      WD=0.0
✓ x2+xy+y2+x      WD=1.0
✓ x3+xy           WD=0.0
✓ x3+xy           WD=1.0
✓ x3+xy2+y        WD=0.0
✓ x3+xy2+y        WD=1.0

Done! Saved to: /root/autodl-tmp/test/results/embedding_viz


In [12]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import torch
import matplotlib.cm as cm

tasks = ['x+y', 'x-y', 'x*y', 'x_div_y', 'x2+y2', 'x2+xy+y2', 'x2+xy+y2+x', 'x3+xy', 'x3+xy2+y']
weight_decays = [0.0, 1.0]
seed = 101
steps = [100, 1000, 10000, 100000]
checkpoint_base = Path("/root/autodl-tmp/test/results/checkpoints")
save_base = Path("/root/autodl-tmp/test/results/embedding_viz_single")
save_base.mkdir(parents=True, exist_ok=True)

print(f"{'='*70}")
print(f"Processing: {len(tasks)} tasks × {len(weight_decays)} WDs (Seed {seed} only)")
print(f"{'='*70}\n")

for task in tasks:
    task_safe = task.replace("/", "_div_").replace("*", "_mul_").replace("+", "_plus_")
    for wd in weight_decays:
        checkpoint_dir = checkpoint_base / task_safe / f"wd_{wd}"
        if not checkpoint_dir.exists():
            continue
        
        fig, axes = plt.subplots(4, 4, figsize=(22, 22))
        found_any = False
        
        for step_idx, step in enumerate(steps):
            checkpoint_file = checkpoint_dir / f"seed{seed}_step{step}.pt"
            if not checkpoint_file.exists():
                continue
            
            found_any = True
            checkpoint = torch.load(checkpoint_file, map_location='cpu', weights_only=False)
            state_dict = checkpoint['model_state_dict']
            
            input_emb = state_dict['token_embeddings.weight'].numpy()
            output_emb = state_dict['head.weight'].numpy()
            vocab_size = input_emb.shape[0]
            
            # Input PCA
            pca = PCA(n_components=2)
            input_pca = pca.fit_transform(input_emb)
            ax = axes[step_idx, 0]
            scatter = ax.scatter(input_pca[:, 0], input_pca[:, 1], c=range(vocab_size), 
                                cmap='viridis', s=100, alpha=0.8, edgecolors='white', linewidths=1)
            cbar = plt.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label('Token ID', fontsize=9)
            ax.set_title(f'Input Emb PCA (Step {step})', fontsize=12, fontweight='bold')
            ax.set_xlabel('PC1', fontsize=10)
            ax.set_ylabel('PC2', fontsize=10)
            ax.grid(True, alpha=0.3)
            
            # Input t-SNE
            tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, vocab_size-1))
            input_tsne = tsne.fit_transform(input_emb)
            ax = axes[step_idx, 1]
            scatter = ax.scatter(input_tsne[:, 0], input_tsne[:, 1], c=range(vocab_size), 
                                cmap='viridis', s=100, alpha=0.8, edgecolors='white', linewidths=1)
            cbar = plt.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label('Token ID', fontsize=9)
            ax.set_title(f'Input Emb t-SNE (Step {step})', fontsize=12, fontweight='bold')
            ax.set_xlabel('t-SNE 1', fontsize=10)
            ax.set_ylabel('t-SNE 2', fontsize=10)
            ax.grid(True, alpha=0.3)
            
            # Output PCA
            pca = PCA(n_components=2)
            output_pca = pca.fit_transform(output_emb)
            ax = axes[step_idx, 2]
            scatter = ax.scatter(output_pca[:, 0], output_pca[:, 1], c=range(vocab_size), 
                                cmap='viridis', s=100, alpha=0.8, edgecolors='white', linewidths=1)
            cbar = plt.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label('Token ID', fontsize=9)
            ax.set_title(f'Output Emb PCA (Step {step})', fontsize=12, fontweight='bold')
            ax.set_xlabel('PC1', fontsize=10)
            ax.set_ylabel('PC2', fontsize=10)
            ax.grid(True, alpha=0.3)
            
            # Output t-SNE
            tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, vocab_size-1))
            output_tsne = tsne.fit_transform(output_emb)
            ax = axes[step_idx, 3]
            scatter = ax.scatter(output_tsne[:, 0], output_tsne[:, 1], c=range(vocab_size), 
                                cmap='viridis', s=100, alpha=0.8, edgecolors='white', linewidths=1)
            cbar = plt.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label('Token ID', fontsize=9)
            ax.set_title(f'Output Emb t-SNE (Step {step})', fontsize=12, fontweight='bold')
            ax.set_xlabel('t-SNE 1', fontsize=10)
            ax.set_ylabel('t-SNE 2', fontsize=10)
            ax.grid(True, alpha=0.3)
        
        if found_any:
            plt.suptitle(f'Task: {task} | WD={wd} | Seed={seed}', 
                        fontsize=18, fontweight='bold', y=0.995)
            plt.tight_layout()
            save_path = save_base / f"{task_safe}_wd{wd}_seed{seed}_embeddings.png"
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()
            print(f"✓ {task:15s} WD={wd}")

print(f"\n{'='*70}")
print(f"Done! Saved to: {save_base}")
print(f"{'='*70}")


Processing: 9 tasks × 2 WDs (Seed 101 only)

✓ x+y             WD=0.0
✓ x+y             WD=1.0
✓ x-y             WD=0.0
✓ x-y             WD=1.0
✓ x*y             WD=0.0
✓ x*y             WD=1.0
✓ x_div_y         WD=0.0
✓ x_div_y         WD=1.0
✓ x2+y2           WD=0.0
✓ x2+y2           WD=1.0
✓ x2+xy+y2        WD=0.0
✓ x2+xy+y2        WD=1.0
✓ x2+xy+y2+x      WD=0.0
✓ x2+xy+y2+x      WD=1.0
✓ x3+xy           WD=0.0
✓ x3+xy           WD=1.0
✓ x3+xy2+y        WD=0.0
✓ x3+xy2+y        WD=1.0

Done! Saved to: /root/autodl-tmp/test/results/embedding_viz_single


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

# 定义能够输出注意力权重的模型
class Block(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.0):
        super().__init__()
        self.ln_1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=dropout)
        self.ln_2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout)
        )
    def forward(self, x, attn_mask=None, return_attn=False):
        ln_x = self.ln_1(x)
        if return_attn:
            attn_out, attn_weights = self.attn(ln_x, ln_x, ln_x, attn_mask=attn_mask, need_weights=True, average_attn_weights=False)
            x = x + attn_out
            x = x + self.mlp(self.ln_2(x))
            return x, attn_weights
        else:
            x = x + self.attn(ln_x, ln_x, ln_x, attn_mask=attn_mask, need_weights=False)[0]
            x = x + self.mlp(self.ln_2(x))
            return x

class Decoder(nn.Module):
    def __init__(self, dim=128, num_layers=2, num_heads=4, num_tokens=97, seq_len=5, dropout=0.0):
        super().__init__()
        self.token_embeddings = nn.Embedding(num_tokens, dim)
        self.position_embeddings = nn.Embedding(seq_len, dim)
        self.layers = nn.ModuleList([Block(dim, num_heads, dropout) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_tokens, bias=False)
        mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
        self.register_buffer("causal_mask", mask)
        self.register_buffer("pos_ids", torch.arange(seq_len))
    
    def forward(self, x, return_attn=False):
        h = self.token_embeddings(x) + self.position_embeddings(self.pos_ids[:x.size(1)])
        mask = self.causal_mask[:x.size(1), :x.size(1)]
        
        if return_attn:
            attn_weights_list = []
            for layer in self.layers:
                h, attn_weights = layer(h, attn_mask=mask, return_attn=True)
                attn_weights_list.append(attn_weights)
            logits = self.head(self.ln_f(h))
            return logits, attn_weights_list
        else:
            for layer in self.layers:
                h = layer(h, attn_mask=mask)
            return self.head(self.ln_f(h))

# 可视化注意力热力图
task = 'x-y'
wd = 1.0
seed = 101
steps = [100, 1000, 10000, 100000]
task_safe = task.replace("/", "_div_").replace("*", "_mul_").replace("+", "_plus_")
checkpoint_base = Path("/root/autodl-tmp/test/results/checkpoints")
checkpoint_dir = checkpoint_base / task_safe / f"wd_{wd}"
save_base = Path("/root/autodl-tmp/test/results/attention_viz")
save_base.mkdir(parents=True, exist_ok=True)

# 创建测试输入：例如 50 - 30 = 20 (mod 97)
# 格式: [x, op, y, eq]，期望输出result
p = 97
x, y = 50, 30
op_token = p  # x-y的操作符
eq_token = p  # 等号
result = (x - y) % p
test_input = torch.tensor([[x, op_token, y, eq_token]]).long()

print(f"Test input: {x} - {y} = {result} (mod {p})")
print(f"Input tokens: {test_input.tolist()}")
print(f"\n{'='*70}")
print(f"Visualizing attention for task: {task} | WD={wd} | Seed={seed}")
print(f"{'='*70}\n")

# 为每个训练步数生成可视化
fig, axes = plt.subplots(len(steps), 2, figsize=(16, 4 * len(steps)))

for step_idx, step in enumerate(steps):
    checkpoint_file = checkpoint_dir / f"seed{seed}_step{step}.pt"
    if not checkpoint_file.exists():
        print(f"⚠ Missing: {checkpoint_file}")
        continue
    
    # 加载模型
    checkpoint = torch.load(checkpoint_file, map_location='cpu', weights_only=False)
    model = Decoder(dim=128, num_layers=2, num_heads=4, num_tokens=p+2, seq_len=5)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()
    
    # 前向传播并获取注意力权重
    with torch.no_grad():
        logits, attn_weights_list = model(test_input, return_attn=True)
        prediction = logits[0, -1].argmax().item()
        confidence = torch.softmax(logits[0, -1], dim=0).max().item()
    
    # 绘制每一层的注意力热力图
    for layer_idx, attn_weights in enumerate(attn_weights_list):
        # attn_weights shape: [batch, num_heads, seq_len, seq_len]
        # 对所有head取平均
        avg_attn = attn_weights[0].mean(dim=0).numpy()  # [seq_len, seq_len]
        
        ax = axes[step_idx, layer_idx]
        sns.heatmap(avg_attn, annot=True, fmt='.3f', cmap='viridis', 
                   cbar_kws={'label': 'Attention Weight'}, 
                   xticklabels=[f'{x}', 'op', f'{y}', 'eq'],
                   yticklabels=[f'{x}', 'op', f'{y}', 'eq'],
                   vmin=0, vmax=1, ax=ax, square=True)
        
        # 标题显示预测结果
        acc_status = "✓" if prediction == result else "✗"
        ax.set_title(f'Step {step} | Layer {layer_idx+1} | Pred: {prediction} {acc_status} (Conf: {confidence:.3f})', 
                    fontsize=12, fontweight='bold')
        ax.set_xlabel('Key Position', fontsize=10)
        ax.set_ylabel('Query Position', fontsize=10)
    
    print(f"✓ Step {step:6d}: Pred={prediction:2d} (True={result:2d}) | Confidence={confidence:.3f} | {'Correct' if prediction==result else 'Wrong'}")

plt.suptitle(f'Attention Patterns | Task: {task} | WD={wd} | Seed={seed}\nInput: {x} - {y} = {result}', 
            fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()

save_path = save_base / f"{task_safe}_wd{wd}_seed{seed}_attention.png"
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"\n{'='*70}")
print(f"Saved to: {save_path}")
print(f"{'='*70}")
plt.show()
