Cell 1：环境与可视化工具（矩阵热力图 + 标注 + 交互）

In [1]:
# ====== Cell 1: imports & visualization helpers ======
import os, math, random, html
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# 中文显示（Windows/常见字体）
plt.rcParams["font.family"] = ["Microsoft YaHei", "SimHei", "SimSun", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_all(42)

# ---------- 矩阵可视化：热力图+数值标注 ----------
def show_matrix(M, row_labels=None, col_labels=None, title="", figsize=(8,4), cmap=None, annotate=True, fmt="{:.2f}"):
    """
    M: torch.Tensor or np.ndarray (2D)
    row_labels/col_labels: list[str]
    """
    if isinstance(M, torch.Tensor):
        M = M.detach().cpu().float().numpy()
    assert M.ndim == 2, f"show_matrix expects 2D, got {M.ndim}"

    plt.figure(figsize=figsize)
    plt.imshow(M, aspect='auto')  # 不指定颜色，让 matplotlib 默认配色
    plt.colorbar()
    plt.title(title)
    if row_labels is not None:
        plt.yticks(range(len(row_labels)), row_labels)
    if col_labels is not None:
        plt.xticks(range(len(col_labels)), col_labels, rotation=45, ha="right")

    if annotate and M.shape[0] <= 30 and M.shape[1] <= 40:
        for i in range(M.shape[0]):
            for j in range(M.shape[1]):
                plt.text(j, i, fmt.format(M[i, j]), ha="center", va="center", fontsize=8)

    plt.tight_layout()
    plt.show()

def show_vector(v, labels=None, title="", figsize=(8,2)):
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().float().numpy()
    v = v.reshape(-1)
    plt.figure(figsize=figsize)
    plt.plot(v)
    plt.title(title)
    if labels is not None and len(labels)==len(v):
        plt.xticks(range(len(labels)), labels, rotation=45, ha="right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def show_attention(attn, q_tokens, k_tokens, title="", figsize=(8,4), annotate=True):
    """
    attn: (tq, tk) softmax attention weights
    """
    show_matrix(attn, row_labels=q_tokens, col_labels=k_tokens, title=title, figsize=figsize, annotate=annotate)


DEVICE: cpu


Cell 2：读取数据（用你提供的函数，只取前 60000）

In [3]:
# ====== Cell 2: load WMT zh-en pairs (first 60000) ======
def load_pairs_from_csv_first_comma(path: str | Path, max_rows: int | None = None):
    path = Path(path)
    pairs = []

    with path.open("r", encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            line = line.strip()
            if not line:
                continue

            # 跳过表头：0,1 或 "0","1"
            if line_idx == 0:
                head = line.replace('"', "").replace(" ", "")
                if head == "0,1":
                    continue

            # 找第一个逗号
            k = line.find(",")
            if k == -1:
                continue

            src = line[:k].strip()
            tgt = line[k+1:].strip()

            # 去两侧引号（若有）
            if len(src) >= 2 and src[0] == '"' and src[-1] == '"':
                src = src[1:-1].strip()
            if len(tgt) >= 2 and tgt[0] == '"' and tgt[-1] == '"':
                tgt = tgt[1:-1].strip()

            # 反转义：&apos; &quot; 等
            src = html.unescape(src)
            tgt = html.unescape(tgt)

            if src and tgt:
                pairs.append((src, tgt))

            if max_rows is not None and len(pairs) >= max_rows:
                break

    return pairs

csv_path = Path(r"wmt_data\wmt_zh_en_training_corpus.csv")
pairs_all = load_pairs_from_csv_first_comma(csv_path, max_rows=60000)
print("Loaded pairs:", len(pairs_all))
for i in range(10):
    print("Example:", pairs_all[i]) 


Loaded pairs: 60000
Example: ('表演 的 明星 是 X 女孩 团队 — — 由 一对 具有 天才 技艺 的 艳舞 女孩 们 组成 ， 其中 有些 人 受过 专业 的 训练 。', 'the show stars the X Girls - a troupe of talented topless dancers , some of whom are classically trained .')
Example: ('表演 的 压轴戏 是 闹剧 版 《 天鹅湖 》 ， 男女 小 人们 身着 粉红色 的 芭蕾舞 裙 扮演 小天鹅 。', 'the centerpiece of the show is a farcical rendition of Swan Lake in which male and female performers dance in pink tutus and imitate swans .')
Example: ('表演 和 后期制作 之间 的 屏障 被 清除 了 ， 这 对 演员 来说 一样 大有裨益 。', 'the removal of the barrier between performance and post @-@ production was just as helpful for the actors .')
Example: ('（ 表演 或 背诵 时 ） 通过 暗示 下面 忘记 或 记地 不准 的 东西 来 帮助 某人 。', 'assist ( somebody acting or reciting ) by suggesting the next words of something forgotten or imperfectly learned .')
Example: ('表演 基本上 很 精彩 - - 我 只 对 她 的 技巧 稍 有 意见 。', 'basically it was a fine performance I have only minor quibbles to make about her technique .')
Example: ('表演 结束 后 ， 我们 看到 一对对 车灯 沿主路 一路 排回 镇上 ， 然后 散开 来 各回 各家 。', "aft

替换 Cell 3：按空格切词 + 统计词表（已分好词版本）

In [4]:
# ====== Cell 3 (REPLACE): whitespace tokenizer for pre-segmented corpus ======
from collections import Counter

SPECIALS = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"]

def tokenize_whitespace(s: str):
    # 你的数据已分词：直接按空格切
    # 顺便去掉空 token
    return [w for w in s.strip().split() if w]

def build_vocab_from_texts(texts, vocab_size=50000, min_freq=1):
    counter = Counter()
    for s in texts:
        counter.update(tokenize_whitespace(s))

    # 过滤低频
    items = [(w,c) for w,c in counter.items() if c >= min_freq]
    items.sort(key=lambda x: x[1], reverse=True)

    # vocab 截断
    itos = SPECIALS + [w for w,_ in items[:max(0, vocab_size - len(SPECIALS))]]
    stoi = {w:i for i,w in enumerate(itos)}
    return stoi, itos, counter

# 用你已经加载的 60000 pairs_all
zh_texts = [s for s,t in pairs_all]
en_texts = [t for s,t in pairs_all]

# 词表大小你可调：越大越接近原语料，越小越快
ZH_VOCAB_SIZE = 50000
EN_VOCAB_SIZE = 40000
MIN_FREQ = 1

zh_stoi, zh_itos, zh_counter = build_vocab_from_texts(zh_texts, vocab_size=ZH_VOCAB_SIZE, min_freq=MIN_FREQ)
en_stoi, en_itos, en_counter = build_vocab_from_texts(en_texts, vocab_size=EN_VOCAB_SIZE, min_freq=MIN_FREQ)

def zh_vocab_id(tok): return zh_stoi.get(tok, zh_stoi["[UNK]"])
def en_vocab_id(tok): return en_stoi.get(tok, en_stoi["[UNK]"])

print("ZH vocab size:", len(zh_stoi), "top10:", zh_counter.most_common(10))
print("EN vocab size:", len(en_stoi), "top10:", en_counter.most_common(10))


ZH vocab size: 50000 top10: [('的', 85405), ('，', 78216), ('。', 59252), ('是', 14200), ('在', 13879), ('了', 12592), ('和', 11764), ('、', 7595), ('从', 7549), ('他', 6788)]
EN vocab size: 40000 top10: [('the', 73883), (',', 64554), ('.', 63219), ('of', 39217), ('to', 33312), ('and', 31547), ('a', 30286), ('in', 20508), ('is', 15376), ('that', 10916)]
