In [13]:
# embkit_run.py
# 依赖：pip install gensim scikit-learn datasets numpy scipy scikit-learn-intelex
import os, json, numpy as np
from typing import Dict, Tuple, List, Iterable
from gensim.models import KeyedVectors, Word2Vec
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from datasets import load_dataset
from scipy.stats import pearsonr, spearmanr

# =====================
# A. 加载三种词向量
# =====================
def load_glove(win: int, dim: int, root: str = "./models_glove") -> KeyedVectors:
    """
    加载你在 GloVe.ipynb 里保存的 .kv 文件，例如 glove_win8_dim200.kv
    """
    path = os.path.join(root, f"glove_win{win}_dim{dim}.kv")
    if not os.path.exists(path):
        raise FileNotFoundError(f"GloVe .kv not found: {path}")
    kv = KeyedVectors.load(path, mmap='r')
    kv.fill_norms()
    return kv

def load_word2vec(win: int, dim: int, root: str = "./models_word2vec") -> KeyedVectors:
    """
    可选：加载 Skip-gram（gensim Word2Vec .model）
    """
    path = os.path.join(root, f"word2vec_win{win}_dim{dim}.model")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Word2Vec .model not found: {path}")
    model = Word2Vec.load(path)
    kv = model.wv
    kv.fill_norms()
    return kv

def load_sppmi(win: int, dim: int, root: str = "./models_sppmi", id2word_path: str = "id2word.json",
               mean_center: bool = False, l2_normalize: bool = True) -> KeyedVectors:
    """
    加载 SPPMI：emb_win{win}_dim{dim}.npy + id2word.json -> KeyedVectors
    """
    emb_path = os.path.join(root, f"emb_win{win}_dim{dim}.npy")
    if not os.path.exists(emb_path):
        raise FileNotFoundError(f"SPPMI .npy not found: {emb_path}")
    if not os.path.exists(id2word_path):
        raise FileNotFoundError(f"id2word not found: {id2word_path}")
    emb = np.load(emb_path).astype(np.float32)

    with open(id2word_path, "r", encoding="utf-8") as f:
        id2word = json.load(f)
    if isinstance(id2word, dict):
        try:
            words = [id2word[str(i)] for i in range(len(id2word))]
        except KeyError:
            words = [id2word[i] for i in range(len(id2word))]
    elif isinstance(id2word, list):
        words = id2word
    else:
        raise ValueError("id2word must be list or dict")

    if len(words) != emb.shape[0]:
        words = words[:emb.shape[0]]
        if len(words) != emb.shape[0]:
            raise ValueError(f"len(id2word) ({len(words)}) != emb rows ({emb.shape[0]})")

    X = emb.copy()
    if mean_center:
        X -= X.mean(0, keepdims=True)
    if l2_normalize:
        X /= (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)

    kv = KeyedVectors(vector_size=X.shape[1])
    kv.add_vectors(words, X)
    kv.fill_norms()
    return kv

# =====================
# B. 可选：ABTT 后处理（去掉前 r 个主成分）
# =====================
def abtt_inplace(kv: KeyedVectors, r: int = 2):
    """
    All-but-the-Top：去掉前 r 个主成分（常见 r=1~3），能稳一点相似度/类比。
    直接在 kv.vectors 上原地修改，改完需要重新 fill_norms()。
    """
    if r <= 0:
        return
    X = kv.vectors
    X = X - X.mean(0, keepdims=True)
    # 直接用 SVD 取前 r 个主成分
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    P = Vt[:r]  # (r, d)
    X = X - X @ P.T @ P
    kv.vectors[:] = X
    kv.fill_norms()

# =====================
# C. 文本表示（句子/文档平均）
# =====================
def tokenize(s: str) -> List[str]:
    # 简单 tokenizer：只保留纯字母的 token；可替换为更强的分词器
    return [w for w in s.lower().split() if w.isalpha()]

def sent_embedding_avg(kv: KeyedVectors, text: str, unk: str = None) -> np.ndarray:
    toks = tokenize(text)
    vecs = [kv[w] for w in toks if w in kv]
    if not vecs and (unk is not None) and (unk in kv):
        vecs = [kv[unk]]
    if not vecs:
        return np.zeros(kv.vector_size, dtype=np.float32)
    return np.mean(vecs, axis=0)

def batch_embed(kv: KeyedVectors, texts: Iterable[str]) -> np.ndarray:
    return np.stack([sent_embedding_avg(kv, t) for t in texts], axis=0)

# =====================
# D1. 下游：AG News 文本分类（已修正切片）
# =====================
def evaluate_agnews(kv: KeyedVectors, train_size: int = 20000, test_size: int = 7600,
                    C: float = 2.0, max_iter: int = 1000, random_state: int = 42) -> Dict[str, float]:
    ds = load_dataset("ag_news")

    # 用 select 选行，再取列，避免“string indices must be integers”错误
    train_ds = ds["train"].select(range(min(train_size, len(ds["train"]))))
    test_ds  = ds["test"].select(range(min(test_size,  len(ds["test"]))))

    X_train_texts = train_ds["text"]   # List[str]
    y_train = np.array(train_ds["label"])
    X_test_texts  = test_ds["text"]
    y_test = np.array(test_ds["label"])

    X_train = batch_embed(kv, X_train_texts)
    X_test  = batch_embed(kv, X_test_texts)

    # 逻辑回归做线性分类器
    clf = LogisticRegression(
        solver="saga", penalty="l2", C=C, max_iter=max_iter,
        random_state=random_state, n_jobs=-1
    )
    clf.fit(X_train, y_train)
    pred = clf.predict(X_test)
    acc = accuracy_score(y_test, pred)
    return {"acc": float(acc), "train_size": int(len(y_train)), "test_size": int(len(y_test))}

# =====================
# D2. 下游：STS-B 语义相似度
# =====================
def evaluate_stsb(kv: KeyedVectors, split: str = "validation") -> Dict[str, float]:
    """
    GLUE STS-B 标签范围 [0,5]。我们直接用余弦与人类分数做皮尔逊/斯皮尔曼相关。
    """
    ds = load_dataset("glue", "stsb")[split]
    s1 = [x["sentence1"] for x in ds]
    s2 = [x["sentence2"] for x in ds]
    y  = np.array([x["label"] for x in ds], dtype=np.float32)

    E1 = batch_embed(kv, s1)
    E2 = batch_embed(kv, s2)

    num = np.sum(E1 * E2, axis=1)
    den = (np.linalg.norm(E1, axis=1) * np.linalg.norm(E2, axis=1) + 1e-9)
    cos = num / den

    pear = pearsonr(cos, y)[0]
    spear = spearmanr(cos, y)[0]
    return {"pearson": float(pear), "spearman": float(spear), "n": int(len(y))}

def evaluate_yelp_polarity(kv, train_size=50000, test_size=5000, C=2.0, max_iter=1000, random_state=42):
    ds = load_dataset("yelp_polarity")
    train_ds = ds["train"].select(range(min(train_size, len(ds["train"]))))
    test_ds  = ds["test"].select(range(min(test_size,  len(ds["test"]))))

    Xtr = batch_embed(kv, train_ds["text"])
    ytr = np.array(train_ds["label"])
    Xte = batch_embed(kv, test_ds["text"])
    yte = np.array(test_ds["label"])

    clf = LogisticRegression(solver="saga", penalty="l2", C=C, max_iter=max_iter,
                             random_state=random_state, n_jobs=-1)
    clf.fit(Xtr, ytr)
    acc = accuracy_score(yte, clf.predict(Xte))
    return {"acc": float(acc), "train_size": int(len(ytr)), "test_size": int(len(yte))}

def _pair_features(U, V):
    # U, V: (n, d)
    return np.hstack([U, V, np.abs(U - V), U * V])

def evaluate_qqp(kv, train_size=50000, valid_size=10000, C=2.0, max_iter=1000, random_state=42):
    ds_tr = load_dataset("glue", "qqp")["train"].select(range(train_size))
    ds_va = load_dataset("glue", "qqp")["validation"].select(range(valid_size))

    Utr = batch_embed(kv, ds_tr["question1"])
    Vtr = batch_embed(kv, ds_tr["question2"])
    Xtr = _pair_features(Utr, Vtr)
    ytr = np.array(ds_tr["label"])

    Uva = batch_embed(kv, ds_va["question1"])
    Vva = batch_embed(kv, ds_va["question2"])
    Xva = _pair_features(Uva, Vva)
    yva = np.array(ds_va["label"])

    clf = LogisticRegression(solver="saga", penalty="l2", C=C, max_iter=max_iter,
                             random_state=random_state, n_jobs=-1)
    clf.fit(Xtr, ytr)
    acc = accuracy_score(yva, clf.predict(Xva))
    return {"acc": float(acc), "train_size": int(len(ytr)), "valid_size": int(len(yva))}


def evaluate_trec6(kv, train_size=5452, test_size=500, C=2.0, max_iter=1000, random_state=42):
    # HuggingFace: "trec" 数据集，coarse 6类标签在 "label-coarse"
    ds = load_dataset("trec")
    train_ds = ds["train"].select(range(min(train_size, len(ds["train"]))))
    test_ds  = ds["test"].select(range(min(test_size,  len(ds["test"]))))

    Xtr = batch_embed(kv, train_ds["text"])
    ytr = np.array(train_ds["coarse_label"]) if "coarse_label" in train_ds.column_names else np.array(train_ds["label-coarse"])
    Xte = batch_embed(kv, test_ds["text"])
    yte = np.array(test_ds["coarse_label"]) if "coarse_label" in test_ds.column_names else np.array(test_ds["label-coarse"])

    clf = LogisticRegression(solver="saga", penalty="l2", C=C, max_iter=max_iter,
                             random_state=random_state, n_jobs=-1)
    clf.fit(Xtr, ytr)
    acc = accuracy_score(yte, clf.predict(Xte))
    return {"acc": float(acc), "train_size": int(len(ytr)), "test_size": int(len(yte))}

# =====================
# F. 便捷统一入口（按需使用）
# =====================
LOADERS = {
    "glove": load_glove,
    "w2v": load_word2vec,
    "sppmi": load_sppmi,
}

def load_model(kind: str, win: int, dim: int, **kwargs) -> KeyedVectors:
    if kind not in LOADERS:
        raise ValueError(f"Unknown kind={kind}, choose from {list(LOADERS)}")
    return LOADERS[kind](win, dim, **kwargs)

# =====================
# G. 示例主程序
# =====================
if __name__ == "__main__":
    models, model_name = [], ["w2v", "sppmi", "glove"]
    WIN, DIM = 4, 200
    # word2vec
    models.append(load_word2vec(WIN, DIM, root = "./word2vec/models_word2vec"))
    # SPPMI
    models.append(load_sppmi(WIN, DIM, root="./SPPMI/models_sppmi"))
    # glove
    models.append(load_model("glove", win=8, dim=200, root="./GloVe/models_glove"))


    for i, kv in enumerate(models):
        print(model_name[i])
        # ---- 2) 跑 AG News 文本分类 ----
        res_cls = evaluate_agnews(kv, train_size=20000, test_size=7600)
        print("[AG News]", res_cls)

        # ---- 3) 跑 STS-B 句子相似度（验证集）----
        res_sts = evaluate_stsb(kv, split="validation")
        print("[STS-B]", res_sts)

        print("[Yelp Polarity]",  evaluate_yelp_polarity(kv, train_size=50000, test_size=5000))
        print("[QQP]",            evaluate_qqp(kv, train_size=50000, valid_size=10000))
        print("[TREC-6]",         evaluate_trec6(kv))



w2v
[AG News] {'acc': 0.8503947368421053, 'train_size': 20000, 'test_size': 7600}
[STS-B] {'pearson': 0.5030552262224939, 'spearman': 0.5305749224499638, 'n': 1500}


Generating train split: 100%|██████████| 560000/560000 [00:00<00:00, 1564162.30 examples/s]
Generating test split: 100%|██████████| 38000/38000 [00:00<00:00, 1690713.40 examples/s]


[Yelp Polarity] {'acc': 0.8144, 'train_size': 50000, 'test_size': 5000}


Generating train split: 100%|██████████| 363846/363846 [00:00<00:00, 5870647.67 examples/s]
Generating validation split: 100%|██████████| 40430/40430 [00:00<00:00, 5076509.12 examples/s]
Generating test split: 100%|██████████| 390965/390965 [00:00<00:00, 6609643.30 examples/s]


[QQP] {'acc': 0.7125, 'train_size': 50000, 'valid_size': 10000}


Downloading data: 100%|██████████| 336k/336k [00:01<00:00, 230kB/s] 
Downloading data: 100%|██████████| 23.4k/23.4k [00:00<00:00, 26.1MB/s]
Generating train split: 100%|██████████| 5452/5452 [00:00<00:00, 72436.54 examples/s]
Generating test split: 100%|██████████| 500/500 [00:00<00:00, 69814.31 examples/s]


[TREC-6] {'acc': 0.85, 'train_size': 5452, 'test_size': 500}
sppmi
[AG News] {'acc': 0.8477631578947369, 'train_size': 20000, 'test_size': 7600}
[STS-B] {'pearson': 0.37298143895901104, 'spearman': 0.45016584520149194, 'n': 1500}
[Yelp Polarity] {'acc': 0.7788, 'train_size': 50000, 'test_size': 5000}
[QQP] {'acc': 0.6931, 'train_size': 50000, 'valid_size': 10000}
[TREC-6] {'acc': 0.706, 'train_size': 5452, 'test_size': 500}
glove
[AG News] {'acc': 0.8318421052631579, 'train_size': 20000, 'test_size': 7600}
[STS-B] {'pearson': 0.3354405743842207, 'spearman': 0.41965032614010145, 'n': 1500}
[Yelp Polarity] {'acc': 0.7964, 'train_size': 50000, 'test_size': 5000}
[QQP] {'acc': 0.7024, 'train_size': 50000, 'valid_size': 10000}
[TREC-6] {'acc': 0.762, 'train_size': 5452, 'test_size': 500}
