In [None]:
# # 1. 安装所有必需的库
# !pip install torch transformers scikit-learn umap-learn matplotlib pandas seaborn

In [1]:
# 2. 挂载 Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 3. 更改 Colab 的当前工作目录 (CWD)
# !!! 请确保这个路径与您截图中的 Google Drive 路径一致 !!!
%cd /content/drive/MyDrive/NLP/

Mounted at /content/drive
/content/drive/MyDrive/NLP


In [2]:
#!/usr/bin/env python3
"""
analyze_open_set.py

为成员 D 实现的完整脚本：
- 加载 checkpoints/authorship_model.pt
- 提取 logits & embeddings
- 计算三类开集分数（max-prob, energy, prototype distance）
- 绘制 ROC / PR 曲线并寻找最佳阈值 tau（若有验证集）
- UMAP / t-SNE 可视化作者嵌入
- 导出错误样本分析 CSV

依赖:
pip install torch transformers scikit-learn umap-learn matplotlib pandas seaborn
"""
import os
import argparse
import json
from typing import Tuple, Dict, List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import roc_curve, precision_recall_curve, auc, f1_score, classification_report
from sklearn.manifold import TSNE
import umap
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# --------------------------
# Config / Defaults
# --------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"设备检测: {DEVICE}")

设备检测: cuda


In [3]:
# --------------------------
# 1. 最小数据集和模型定义 (与 train.py 对齐)
# --------------------------

class AuthorDataset(Dataset):
    """
    自定义数据集类，用于加载作者文本。
    它负责将文本分词 (tokenize)，并将作者名 (string) 转换为标签 (label index)。
    """
    def __init__(self, df: pd.DataFrame, tokenizer, author2id: Dict[str,int], max_len: int = 512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.author2id = author2id  # 从 "user_a" -> 0
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.tokenizer(
            str(row["text"]),
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_tensors="pt"
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}

        # --- 开集 (Open-set) 处理 ---
        # 检查作者是否为 "已知" 作者
        author = row.get("author", None)
        if pd.isna(author) or author not in self.author2id:
            # 如果作者是 "unknown" 或不在已知列表中，设置标签为 -1
            item["labels"] = torch.tensor(-1, dtype=torch.long)
        else:
            # 否则，设置标签为对应的 ID (0 到 8)
            item["labels"] = torch.tensor(self.author2id[author], dtype=torch.long)

        item["text"] = str(row["text"]) # 保留原始文本，用于错误分析
        return item

class AuthorModel(nn.Module):
    """
    自定义作者归属模型 (与 train.py 一致)
    包含一个 BERT 编码器、一个投影层 (用于获取 embedding) 和一个分类器 (用于获取 logits)
    """
    def __init__(self, model_name: str, num_classes: int, emb_dim: int = 256, use_mean_pool: bool = True):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden = self.encoder.config.hidden_size
        self.use_mean_pool = use_mean_pool
        # 投影层 (Projection Head)，用于将 BERT 输出转换为固定维度的嵌入 (z)
        self.proj = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, emb_dim)
        )
        # 分类层 (Classifier Head)，用于从嵌入 (z) 预测作者
        self.classifier = nn.Linear(emb_dim, num_classes)

    def forward(self, **batch):
        labels = batch.pop("labels", None)
        # 1. 通过 BERT 编码器
        out = self.encoder(**{k:v for k,v in batch.items() if k in ("input_ids","attention_mask","token_type_ids")})

        # 2. 池化 (Pooling)
        if self.use_mean_pool:
            last_hidden = out.last_hidden_state
            attn_mask = batch["attention_mask"].unsqueeze(-1)
            pooled = (last_hidden * attn_mask).sum(1) / attn_mask.sum(1).clamp_min(1e-6)
        else:
            pooled = out.last_hidden_state[:, 0] # [CLS] token

        # 3. 得到嵌入 (z) 和 Logits
        z = self.proj(pooled)   # (z) 嵌入向量 (embedding)
        logits = self.classifier(z) # (logits) 分类分数

        return logits, z, labels

# --------------------------
# 2. 辅助函数：加载和提取
# --------------------------

def load_checkpoint(path: str, map_location="cpu"):
    ck = torch.load(path, map_location=map_location)
    return ck

def build_model_from_ck(ckpt: dict, device: str):
    """从 checkpoint (.pt) 文件中重建模型、分词器和作者列表"""
    authors = ckpt["authors"]
    model_name = ckpt.get("model_name", ckpt.get("model", "bert-base-uncased"))
    use_mean_pool = ckpt.get("use_mean_pool", True)
    model = AuthorModel(model_name, num_classes=len(authors), use_mean_pool=use_mean_pool)
    model.load_state_dict(ckpt["state_dict"])
    model.to(device)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    return model, tokenizer, authors

@torch.no_grad()
def extract_logits_and_embeddings(model: nn.Module, dataloader: DataLoader, device: str):
    """
    核心函数：遍历数据集，从模型中提取所有的 logits 和 embeddings (z)
    """
    all_logits = []
    all_embs = []
    all_labels = []
    all_texts = []
    for batch in tqdm(dataloader, desc="Extracting embeddings"):
        texts = batch.pop("text")
        labels = batch["labels"].cpu().numpy().tolist()
        batch = {k: v.to(device) for k,v in batch.items()}

        # 运行模型
        logits, z, _ = model(**batch)

        all_logits.append(logits.detach().cpu().numpy())
        all_embs.append(z.detach().cpu().numpy())
        all_labels += labels
        all_texts += texts

    all_logits = np.concatenate(all_logits, axis=0)
    all_embs = np.concatenate(all_embs, axis=0)
    return all_logits, all_embs, np.array(all_labels), all_texts

In [4]:
# -------------------------------------------------------------
# 任务 1: 设计开集判别策略 (最大概率阈值、能量分数、原型距离)
# -------------------------------------------------------------
# 这些函数接收 Logits 或 Embeddings，并为每个样本计算一个“已知度”分数。
# “已知度”分数越高 = 越像已知作者 (In-Distribution)
# “已知度”分数越低 = 越像未知作者 (Out-of-Distribution)
# -------------------------------------------------------------

def softmax(x):
    e = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)

def max_probability_score(logits: np.ndarray):
    """
    策略 1: 最大概率 (Max Probability / MSP)
    直觉：如果一个样本是“已知”的，模型在某个类别上的置信度会很高。
    分数：取 Softmax 概率中的最大值。
    """
    probs = softmax(logits)
    return probs.max(axis=1), probs  # 返回 (N,) 的分数 和 (N, C) 的完整概率

def energy_score(logits: np.ndarray, T: float = 1.0):
    """
    策略 2: 能量分数 (Energy Score)
    直觉：“已知”样本的 logits 能量（通过 logsumexp 计算）通常较低（数值上更负）。
    分数：我们返回 -Energy，以便分数越高 = 越“已知”。
    """
    from scipy.special import logsumexp
    s = -T * logsumexp(logits / T, axis=1)
    return -s # 分数越高越好

def prototype_centroids(embeddings: np.ndarray, labels: np.ndarray, authors: List[str]):
    """
    (原型距离的辅助函数)
    计算训练集中每个“已知”作者的平均嵌入向量（即“原型”或“质心”）。
    """
    centroids = {}
    for i, a in enumerate(authors):
        mask = labels == i
        if mask.sum() == 0:
            centroids[i] = np.zeros(embeddings.shape[1], dtype=float)
        else:
            centroids[i] = embeddings[mask].mean(axis=0)
    return centroids

def prototype_distance_score(embeddings: np.ndarray, centroids: Dict[int, np.ndarray], metric: str = "euclidean"):
    """
    策略 3: 原型距离 (Prototype Distance)
    直觉：“已知”样本的嵌入 (z) 应该靠近其所属类别的“原型”（质心）。
    分数：我们计算每个样本到*最近*的那个原型的距离，然后取负值 (-distance)。
          因此，距离越近 -> 分数越高 -> 越“已知”。
    """
    from sklearn.metrics.pairwise import euclidean_distances, cosine_distances
    centroid_matrix = np.stack([centroids[i] for i in sorted(centroids.keys())], axis=0)  # C x D
    if metric == "euclidean":
        d = euclidean_distances(embeddings, centroid_matrix)  # N x C
    else:
        d = cosine_distances(embeddings, centroid_matrix)
    min_d = d.min(axis=1) # 找到到最近原型的距离
    return -min_d  # 分数越高越好

In [5]:
# ----------------------------------------------------
# 任务 2: 画 ROC / Precision-Recall 曲线, 调 τ (tau)
# ----------------------------------------------------

def plot_roc_pr(y_true_binary: np.ndarray, score: np.ndarray, out_prefix: str, name: str):
    """
    绘制 ROC 和 PR 曲线，用于评估“开集判别”的好坏。
    y_true_binary: 真实标签 (1 = 已知, 0 = 未知)
    score:         来自策略 1, 2 或 3 的“已知度”分数
    """
    # y_true_binary: 1 if known (in closed set), 0 if unknown
    fpr, tpr, _ = roc_curve(y_true_binary, score)
    precision, recall, _ = precision_recall_curve(y_true_binary, score)
    roc_auc = auc(fpr, tpr)
    pr_auc = auc(recall, precision)

    plt.figure()
    plt.plot(fpr, tpr, label=f"ROC AUC={roc_auc:.4f}")
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title(f"ROC - {name}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{out_prefix}_roc_{name}.png", bbox_inches="tight")
    plt.close()

    plt.figure()
    plt.plot(recall, precision, label=f"PR AUC={pr_auc:.4f}")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"Precision-Recall - {name}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{out_prefix}_pr_{name}.png", bbox_inches="tight")
    plt.close()
    return roc_auc, pr_auc

def tune_threshold_by_f1(y_true_binary: np.ndarray, score: np.ndarray, num_steps: int = 1000):
    """
    调优阈值 (tau):
    遍历所有可能的分数阈值，找到那个能最大化 F1 分数 (用于区分 已知/未知) 的阈值。
    """
    best_f1 = -1.0
    best_tau = None
    thresholds = np.linspace(score.min(), score.max(), num_steps)
    for t in thresholds:
        # 预测：分数 >= 阈值 -> 预测为"已知" (1), 否则为"未知" (0)
        pred_known = (score >= t).astype(int)
        f1 = f1_score(y_true_binary, pred_known)
        if f1 > best_f1:
            best_f1 = f1
            best_tau = t
    return best_tau, best_f1

In [6]:
# ----------------------------------------------------
# 任务 3: 做 UMAP/T-SNE 作者嵌入聚类可视化
# ----------------------------------------------------
# 我们使用从模型中提取的嵌入 (z) 来进行降维和可视化，
# 观察“已知”作者是否聚类，“未知”作者是否远离它们。
# ----------------------------------------------------

def plot_umap(embeddings: np.ndarray, labels: np.ndarray, authors: List[str], out_path: str, n_neighbors: int = 15, min_dist: float = 0.1):
    """使用 UMAP 进行降维可视化"""
    reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, random_state=42)
    emb2 = reducer.fit_transform(embeddings)
    plt.figure(figsize=(8,6))
    unique_labels = np.unique(labels)
    palette = sns.color_palette("hls", len(unique_labels))
    for i, lab in enumerate(unique_labels):
        mask = labels == lab
        # 如果 lab 是 -1, 标签设为 "unknown"
        plt.scatter(emb2[mask,0], emb2[mask,1], s=10, label=(authors[lab] if lab>=0 else "unknown"), alpha=0.7)
    plt.legend(markerscale=2, bbox_to_anchor=(1.05,1), loc='upper left')
    plt.title("UMAP of embeddings")
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()
    return emb2

def plot_tsne(embeddings: np.ndarray, labels: np.ndarray, authors: List[str], out_path: str, perplexity: int = 30):
    """使用 t-SNE 进行降维可视化"""
    ts = TSNE(n_components=2, perplexity=perplexity, random_state=42, init='pca')
    emb2 = ts.fit_transform(embeddings)
    plt.figure(figsize=(8,6))
    unique_labels = np.unique(labels)
    for i, lab in enumerate(unique_labels):
        mask = labels == lab
        plt.scatter(emb2[mask,0], emb2[mask,1], s=10, label=(authors[lab] if lab>=0 else "unknown"), alpha=0.7)
    plt.legend(markerscale=2, bbox_to_anchor=(1.05,1), loc='upper left')
    plt.title("t-SNE of embeddings")
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()
    return emb2

In [7]:
# ----------------------------------------------------
# 任务 4: 提取错误案例分析
# ----------------------------------------------------
# 将模型预测错误 (真值为 A 却预测为 B，或 真值为 A 却预测为 Unknown)
# 的样本保存到 CSV 中，以便人工分析。
# ----------------------------------------------------

def extract_error_cases(texts: List[str], y_true: np.ndarray, y_pred: np.ndarray, probs: np.ndarray,
                        emb: np.ndarray, centroid_scores: np.ndarray, out_csv: str, authors: List[str]):
    rows = []
    for i, (txt, t, p, pr, e, cscore) in enumerate(zip(texts, y_true, y_pred, probs, emb, centroid_scores)):
        # y_true (t) 和 y_pred (p) 都是索引 (-1, 0, ... 8)
        if t != p: # 只要真实标签和预测标签不一致，就视为错误
            pred_label = p if p >= 0 and p < len(authors) else -1
            rows.append({
                "idx": i,
                "text": txt,
                "true_label": (authors[t] if t>=0 else "unknown"),
                "pred_label": (authors[pred_label] if pred_label>=0 else "unknown"),
                "max_prob": float(pr.max()),
                "pred_prob": float(pr[pred_label]) if pred_label>=0 and pred_label < pr.shape[0] else None,
                "centroid_score": float(cscore),
            })
    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    return df

In [8]:
# ----------------------------------------------------
# 主函数 (Main Pipeline)
# ----------------------------------------------------
# (已修复 NameError: 'emb' -> 'embs')
# (已修复 ValueError: 在 classification_report 中添加 'labels' 参数)
# ----------------------------------------------------

def main(args):
    os.makedirs(args.out_dir, exist_ok=True)
    print("[Info] Loading checkpoint:", args.checkpoint)

    # --- 检查文件是否存在 ---
    if not os.path.exists(args.checkpoint):
        print(f"错误：找不到 Checkpoint 文件: {args.checkpoint}")
        print("请上传 'authorship_model.pt' 文件，或修改 '--checkpoint' 路径。")
        return
    if not os.path.exists(args.data):
        print(f"错误：找不到 Data 文件: {args.data}")
        print("请上传数据文件 (例如 'author_style_dataset_OPENSET.csv')，或修改 '--data' 路径。")
        return
    if args.train_data and not os.path.exists(args.train_data):
        print(f"错误：找不到 Train Data 文件: {args.train_data}")
        return
    # -------------------------

    ck = load_checkpoint(args.checkpoint, map_location="cpu")
    model, tokenizer, authors = build_model_from_ck(ck, device=DEVICE)
    print(f"[Info] Loaded model for {len(authors)} authors.")

    # 加载数据集
    df = pd.read_csv(args.data)
    if "split" not in df.columns:
        df["split"] = "test"
    test_df = df[df["split"]==args.split].copy()
    print(f"[Info] Using split {args.split} with {len(test_df)} samples")

    # 构建 Dataloader
    author2id = {a:i for i,a in enumerate(authors)}
    dataset = AuthorDataset(test_df, tokenizer, author2id, max_len=args.max_len)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

    # 1. 提取 Logits 和 Embeddings
    logits, embs, labels, texts = extract_logits_and_embeddings(model, loader, DEVICE)

    # 2. 【任务 1】计算所有开集策略的分数
    scores_maxprob, full_probs = max_probability_score(logits)
    scores_energy = energy_score(logits, T=args.energy_T)

    # 计算原型（centroids）
    if args.train_data is not None:
        train_df = pd.read_csv(args.train_data)
        train_dataset = AuthorDataset(train_df[train_df["split"]=="train"].copy(), tokenizer, author2id, max_len=args.max_len)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
        print("[Info] Computing centroids from provided train_data...")
        _, emb_train, labels_train, _ = extract_logits_and_embeddings(model, train_loader, DEVICE)
        centroids = prototype_centroids(emb_train, labels_train, authors)
    else:
        print("[Warning] No --train_data provided. Computing centroids from *this* split's known labels. (Not recommended for final evaluation)")

        # --- (修复 NameError: 'emb' -> 'embs') ---
        centroids = prototype_centroids(embs, labels, authors)
        # -----------------------------------------

    scores_proto = prototype_distance_score(embs, centroids, metric=args.proto_metric)

    # 准备开集判别的真实标签 (1 = 已知, 0 = 未知)
    y_true_binary = (labels != -1).astype(int)

    # 检查是否有 OOD 样本 (标签为 0 的)
    if np.sum(y_true_binary == 0) == 0:
        print("\n[Warning] 您的数据集中没有“未知”(unknown)样本 (label = -1)。")
        print("           ROC/PR 曲线对于评估“开集”检测将没有意义。\n")

    # 3. 【任务 2】绘制 ROC/PR 曲线
    out_prefix = os.path.join(args.out_dir, args.prefix)
    roc_auc_mp, pr_auc_mp = plot_roc_pr(y_true_binary, scores_maxprob, out_prefix, "maxprob")
    roc_auc_en, pr_auc_en = plot_roc_pr(y_true_binary, scores_energy, out_prefix, "energy")
    roc_auc_pr, pr_auc_pr = plot_roc_pr(y_true_binary, scores_proto, out_prefix, "prototype")

    print("[Info] AUCs:")
    print(f"MaxProb ROC={roc_auc_mp:.4f} PR={pr_auc_mp:.4f}")
    print(f"Energy  ROC={roc_auc_en:.4f} PR={pr_auc_en:.4f}")
    print(f"Proto   ROC={roc_auc_pr:.4f} PR={pr_auc_pr:.4f}")

    # 调优阈值 (Tau)
    if args.tune_threshold:
        tau_mp, f1_mp = tune_threshold_by_f1(y_true_binary, scores_maxprob)
        tau_en, f1_en = tune_threshold_by_f1(y_true_binary, scores_energy)
        tau_pr, f1_pr = tune_threshold_by_f1(y_true_binary, scores_proto)
        print("[Tuning] best tau (maxprob)=", tau_mp, "F1=", f1_mp)
        print("[Tuning] best tau (energy)=", tau_en, "F1=", f1_en)
        print("[Tuning] best tau (proto)=", tau_pr, "F1=", f1_pr)
    else:
        print(f"[Info] Using fixed threshold tau = {args.tau}")
        tau_mp = args.tau
        tau_en = args.tau
        tau_pr = args.tau

    # 4. 生成最终的开集预测 (预测为 -1 代表 "未知")
    pred_labels_closed = np.argmax(logits, axis=1) # 闭集预测
    pred_labels_mp = np.where(scores_maxprob >= tau_mp, pred_labels_closed, -1)
    pred_labels_en = np.where(scores_energy >= tau_en, pred_labels_closed, -1)
    pred_labels_pr = np.where(scores_proto >= tau_pr, pred_labels_closed, -1)

    # 5. 【任务 4】保存错误分析 (使用 MaxProb 的结果作为代表)

    # 仅在“已知”作者上计算分类报告
    mask_known = labels != -1
    if mask_known.sum() > 0:

        # --- (修复 ValueError: 添加 'labels' 参数) ---
        report_mp = classification_report(
            labels[mask_known],
            pred_labels_mp[mask_known],
            labels=list(range(len(authors))),  # 告诉 sklearn 只报告 0-8 类
            target_names=authors,
            zero_division=0
        )
        # ---------------------------------------------

        print("\n--- Classification Report (MaxProb OpenSet) on Known Authors ---")
        print(report_mp)
        with open(os.path.join(args.out_dir, f"class_report_maxprob.txt"), "w") as f:
            f.write(report_mp)
    else:
        report_mp = "No known labels present to compute classification report."

    # 保存错误案例
    # df_errors = extract_error_cases(texts, labels, pred_labels_mp, full_probs, embs, scores_proto, os.path.join(args.out_dir, "error_cases_maxprob.csv"), authors)
    # print(f"[Info] Saved {len(df_errors)} error cases to {args.out_dir}/error_cases_maxprob.csv")

    # [修改后代码]

    # 5. 【任务 4】为所有三种策略保存错误分析
    ...
    # (计算 classification_report 的代码)
    ...
    # 保存错误案例
    print("[Info] Saving error cases for all 3 strategies...")

    # 策略 1: MaxProb
    out_csv_mp = os.path.join(args.out_dir, "error_cases_maxprob.csv")
    df_errors_mp = extract_error_cases(texts, labels, pred_labels_mp, full_probs, embs, scores_proto, out_csv_mp, authors)
    print(f"[Info] Saved {len(df_errors_mp)} error cases (MaxProb) to {out_csv_mp}")

    # 策略 2: Energy
    out_csv_en = os.path.join(args.out_dir, "error_cases_energy.csv")
    df_errors_en = extract_error_cases(texts, labels, pred_labels_en, full_probs, embs, scores_proto, out_csv_en, authors)
    print(f"[Info] Saved {len(df_errors_en)} error cases (Energy) to {out_csv_en}")

    # 策略 3: Prototype
    out_csv_pr = os.path.join(args.out_dir, "error_cases_prototype.csv")
    df_errors_pr = extract_error_cases(texts, labels, pred_labels_pr, full_probs, embs, scores_proto, out_csv_pr, authors)
    print(f"[Info] Saved {len(df_errors_pr)} error cases (Prototype) to {out_csv_pr}")

    # 6. 【任务 3】可视化
    umap_path = os.path.join(args.out_dir, "umap_embeddings.png")
    tsne_path = os.path.join(args.out_dir, "tsne_embeddings.png")
    print("[Info] Generating visualizations...")
    plot_umap(embs, labels, authors, umap_path)
    plot_tsne(embs, labels, authors, tsne_path)
    print(f"[Info] Saved UMAP -> {umap_path} and t-SNE -> {tsne_path}")

    # 7. 保存所有分数
    out_df = pd.DataFrame({
        "text": texts,
        "true_label_idx": labels,
        "pred_closed_idx": pred_labels_closed,
        "pred_maxprob_idx": pred_labels_mp,
        "score_maxprob": scores_maxprob,
        "score_energy": scores_energy,
        "score_proto": scores_proto
    })
    out_df.to_csv(os.path.join(args.out_dir, args.prefix + "_scores.csv"), index=False)
    print(f"[Info] Saved scores CSV to {args.out_dir}/{args.prefix + '_scores.csv'}")

    # 8. 保存总结
    summary = {
        "aucs": {
            "maxprob": {"roc": float(roc_auc_mp), "pr": float(pr_auc_mp)},
            "energy": {"roc": float(roc_auc_en), "pr": float(pr_auc_en)},
            "prototype": {"roc": float(roc_auc_pr), "pr": float(pr_auc_pr)}
        },
        "best_taus": {
            "maxprob": float(tau_mp),
            "energy": float(tau_en),
            "prototype": float(tau_pr)
        } if args.tune_threshold or args.tau is not None else {}
    }
    with open(os.path.join(args.out_dir, "summary.json"), "w") as f:
        json.dump(summary, f, indent=2)
    print("[Done] Analysis complete. Results saved to", args.out_dir)

In [None]:
# # ----------------------------------------------------
# # (新) 任务 5: 微调模型头部 (Head Fine-tuning)
# # ----------------------------------------------------
# #
# # 目标：冻结 BERT encoder，只使用“分类损失 + 原型损失”
# #      来微调 proj 和 classifier 头部。
# #
# # ----------------------------------------------------

# # [修改后代码] (第 10 行)
# from torch.optim import AdamW
# import torch.nn.functional as F

# # 1. 定义原型/中心损失函数
# def compute_prototype_loss(z, labels, device="cuda"):
#     """
#     计算简化的原型/中心损失 (只在批次内计算)
#     z: 嵌入 (B, D)
#     labels: 标签 (B,)
#     """
#     # 初始化一个标量张量来累积损失
#     loss_proto = torch.tensor(0.0, device=device)
#     num_classes_in_batch = 0

#     # 遍历这个批次中出现的所有唯一类别
#     for c in torch.unique(labels):
#         # 只为“已知”类别 (c >= 0) 计算簇内距离
#         if c.item() == -1:
#             continue

#         # 找出所有属于类别 'c' 的嵌入
#         mask = (labels == c)
#         z_class = z[mask]

#         # 至少要有2个样本才能计算中心
#         if z_class.shape[0] > 1:
#             num_classes_in_batch += 1
#             # (a) 计算该类的“原型/中心”
#             center = z_class.mean(dim=0)

#             # (b) 计算该类所有样本到其中心的 L2 距离的平方
#             # (z_class - center)^2 -> sum over dim 1 -> mean over batch
#             distances_sq = torch.sum((z_class - center.unsqueeze(0))**2, dim=1)

#             # 累加该类的平均距离损失
#             loss_proto += torch.mean(distances_sq)

#     # 返回所有类别的平均损失
#     if num_classes_in_batch > 0:
#         return loss_proto / num_classes_in_batch
#     else:
#         return loss_proto # 返回 0.0

# # 2. 定义训练一个 Epoch 的函数
# def train_epoch(model, dataloader, optimizer, lambda_proto, device):
#     model.train() # 切换到训练模式
#     total_loss_epoch = 0
#     total_loss_ce_epoch = 0
#     total_loss_proto_epoch = 0

#     # 注意：我们在这里使用 tqdm 来显示进度条
#     for batch in tqdm(dataloader, desc="Finetune Epoch"):
#         batch.pop("text", None) # 训练时不需要 text
#         batch = {k: v.to(device) for k, v in batch.items()}
#         labels = batch["labels"]

#         # --- 关键：只在“已知”样本上训练 ---
#         # 过滤掉标签为 -1 的 "unknown" 样本
#         known_mask = (labels != -1)
#         if known_mask.sum() == 0:
#             continue # 如果这个批次全是 unknown，跳过

#         labels = labels[known_mask]
#         batch = {k: v[known_mask] for k, v in batch.items()}
#         # -----------------------------------

#         # 1. 前向传播
#         logits, z, _ = model(**batch)

#         # 2. 计算 Loss 1: 分类损失 (Cross-Entropy)
#         loss_ce = nn.CrossEntropyLoss()(logits, labels)

#         # 3. 计算 Loss 2: 原型损失
#         loss_proto = compute_prototype_loss(z, labels, device)

#         # 4. 合并损失 (按 lambda_proto 比例)
#         total_loss = loss_ce + (lambda_proto * loss_proto)

#         # 5. 反向传播 (只会更新 proj 和 classifier 的权重)
#         optimizer.zero_grad()
#         total_loss.backward()
#         optimizer.step()

#         total_loss_epoch += total_loss.item()
#         total_loss_ce_epoch += loss_ce.item()
#         total_loss_proto_epoch += loss_proto.item()

#     N = len(dataloader)
#     return total_loss_epoch / N, total_loss_ce_epoch / N, total_loss_proto_epoch / N

# # 3. 定义主微调函数
# def run_finetune(
#     original_ckpt_path="./checkpoints/authorship_model.pt",
#     finetuned_ckpt_path="./checkpoints/authorship_model_FINETUNED.pt",
#     train_data_path="author_style_dataset_OPENSET.csv",
#     num_epochs=5,       # 只需要几个 Epoch
#     batch_size=32,
#     max_len=512,
#     lambda_proto=0.1,   # 超参数：原型损失的权重，您可以调这个
#     learning_rate=1e-4  # 头部微调，学习率可以高一点
# ):
#     print("[Info] Starting fine-tuning...")

#     # 1. 加载您原始的、训练好的模型
#     ck = load_checkpoint(original_ckpt_path, map_location="cpu")
#     model, tokenizer, authors = build_model_from_ck(ck, device=DEVICE)
#     print(f"[Info] Loaded original model from {original_ckpt_path}")

#     # 2. 冻结 ENCODER (这是关键)
#     print("[Info] Freezing encoder parameters...")
#     for param in model.encoder.parameters():
#         param.requires_grad = False

#     # 3. 确保头部 (proj, classifier) 是可训练的
#     for param in model.proj.parameters():
#         param.requires_grad = True
#     for param in model.classifier.parameters():
#         param.requires_grad = True

#     num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
#     num_total = sum(p.numel() for p in model.parameters())
#     print(f"[Info] Trainable params: {num_trainable} / {num_total} ({num_trainable/num_total*100:.2f}%)")

#     # 4. 准备训练数据 (只使用 'train' split)
#     df_train = pd.read_csv(train_data_path)
#     df_train = df_train[df_train["split"] == "train"].copy()

#     author2id = {a:i for i,a in enumerate(authors)}
#     train_dataset = AuthorDataset(df_train, tokenizer, author2id, max_len=max_len)
#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
#     print(f"[Info] Loaded {len(df_train)} training samples.")

#     # 5. 设置优化器 (!!! 只优化可训练的参数 !!!)
#     optimizer = AdamW(
#         filter(lambda p: p.requires_grad, model.parameters()),
#         lr=learning_rate
#     )

#     # 6. 训练循环
#     print("[Info] Starting training loop...")
#     for epoch in range(num_epochs):
#         avg_loss, avg_ce, avg_proto = train_epoch(model, train_loader, optimizer, lambda_proto, DEVICE)
#         print(f"Epoch {epoch+1}/{num_epochs} -> "
#               f"Total Loss: {avg_loss:.4f} (CE: {avg_ce:.4f} + Proto: {avg_proto:.4f})")

#     # 7. 保存微调后的模型
#     # 我们更新原始 checkpoint 中的 state_dict
#     ck["state_dict"] = model.state_dict()
#     ck["finetuned_from"] = original_ckpt_path
#     ck["finetune_loss"] = f"CE + {lambda_proto}*Proto"

#     os.makedirs(os.path.dirname(finetuned_ckpt_path), exist_ok=True)
#     torch.save(ck, finetuned_ckpt_path)
#     print(f"[Done] Fine-tuning complete. Model saved to {finetuned_ckpt_path}")

#     # 返回新模型的路径，以便下一步使用
#     return finetuned_ckpt_path

# # --- 立即执行微调 ---
# # (这里我们使用 5 个 Epoch 和 0.1 的 lambda，您可以根据需要调整)
# NEW_MODEL_PATH = run_finetune(
#     num_epochs=5,
#     lambda_proto=0.1,
#     learning_rate=1e-4
# )

In [9]:
# ----------------------------------------------------
# 运行
# ----------------------------------------------------

# 1. 定义参数
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", type=str, default="./checkpoints/authorship_model.pt", help="path to checkpoint")
# (修复：指向您上传的 CSV 文件名)
p.add_argument("--data", type=str, default="author_style_dataset_OPENSET.csv", help="csv dataset with columns [text,author,split]")
# (修复：我们将使用 data 文件中的 'train' split 来计算原型)
p.add_argument("--train_data", type=str, default="author_style_dataset_OPENSET.csv", help="optionally provide training csv to compute prototypes from train set")
p.add_argument("--split", type=str, default="test", help="which split in csv to analyze")
# p.add_argument("--out_dir", type=str, default="./open_set_analysis", help="output directory")
p.add_argument("--out_dir", type=str, default="./open_set_finetune", help="output directory")
p.add_argument("--prefix", type=str, default="run1", help="prefix for output files")

# (修复：移除了重复的 --batch_size)
p.add_argument("--batch_size", type=int, default=32)

p.add_argument("--max_len", type=int, default=512)
p.add_argument("--tau", type=float, default=0.5, help="fallback threshold if not tuning")
p.add_argument("--tune_threshold", action="store_true", help="search tau by maximizing F1 on known/unknown (requires ground-truth unknown labels)")
p.add_argument("--energy_T", type=float, default=1.0)
p.add_argument("--proto_metric", type=str, choices=["euclidean","cosine"], default="euclidean")

# 2. (修复：使用 args=[] 来防止 Colab 解析 -f 参数)
args = p.parse_args(args=[])

# 3. --- 您可以在这里手动覆盖默认值 ---
# 例如，如果您想在 'val' split 上调优阈值:
# args.split = "val"
# args.tune_threshold = True

# 我们将 --checkpoint 指向刚刚微调过的新模型
# ----------------------------------------------------
# args.checkpoint = NEW_MODEL_PATH # <--- 这是关键的修改！
# [修改后代码]
args.checkpoint = "./checkpoints/authorship_model_FINETUNED.pt"

# 或者，如果您想在 'test' split 上运行 (假设您没有 'unknown' 标签):
args.split = "test"
args.tune_threshold = True # 无法在没有 'unknown' 标签的 test 集上调优
# args.tau = 0.5 # 使用一个固定的阈值

# (!!重要!!): 您的数据 'author_style_dataset_OPENSET.csv' 的 'test' split
# 必须包含 "unknown" 作者才能让 tune_threshold=True 工作。
# 如果 'test' split 只有已知作者, 请设置 tune_threshold=False。
# ------------------------------------

# 4. 运行
try:
    main(args)
except FileNotFoundError as e:
    print(f"\n[Execution Stopped] 关键文件未找到: {e}")
    print("请确保在 Colab 中上传了所需的数据集和模型文件，并正确设置了 --checkpoint 和 --data 参数。")
except Exception as e:
    print(f"\n[Execution Stopped] 发生意外错误: {e}")
    # 打印完整的错误追溯
    import traceback
    traceback.print_exc()

[Info] Loading checkpoint: ./checkpoints/authorship_model_FINETUNED.pt


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/874M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

[Info] Loaded model for 9 authors.
[Info] Using split test with 2781 samples


Extracting embeddings: 100%|██████████| 87/87 [06:59<00:00,  4.82s/it]


[Info] Computing centroids from provided train_data...


Extracting embeddings: 100%|██████████| 175/175 [14:12<00:00,  4.87s/it]


[Info] AUCs:
MaxProb ROC=0.6299 PR=0.6715
Energy  ROC=0.6253 PR=0.6675
Proto   ROC=0.7322 PR=0.7867
[Tuning] best tau (maxprob)= 0.4466705 F1= 0.688034188034188
[Tuning] best tau (energy)= 3.5247774 F1= 0.6892411143131604
[Tuning] best tau (proto)= -1.5387969 F1= 0.697560975609756

--- Classification Report (MaxProb OpenSet) on Known Authors ---
                  precision    recall  f1-score   support

ArthurConanDoyle       0.54      0.75      0.63       111
  CharlesDickens       0.71      0.94      0.81        80
      Chesterton       0.99      0.41      0.58       192
   EdgarAllanPoe       0.65      0.87      0.74       172
  HermanMelville       0.98      0.62      0.76       301
      JaneAusten       1.00      0.99      0.99       168
       MarkTwain       0.62      1.00      0.77        38
      OscarWilde       0.78      1.00      0.87       101
   VirginiaWoolf       0.86      0.98      0.92       293

       micro avg       0.80      0.80      0.80      1456
       macro

  warn(


[Info] Saved UMAP -> ./open_set_finetune/umap_embeddings.png and t-SNE -> ./open_set_finetune/tsne_embeddings.png
[Info] Saved scores CSV to ./open_set_finetune/run1_scores.csv
[Done] Analysis complete. Results saved to ./open_set_finetune
