# MoE 可证鲁棒水印方案 (Google Colab 实验)

本项目基于《Signal-Attack Decoupling in MoE Watermarks》及其工程方案 (`align_to_proofs.md`)，提供了一个基于 MoE Gating 机制的可证鲁棒水印方案的 Colab 实验环境。

**请按照顺序执行以下单元格：**

1.  **Setup Environment**：安装所有必要的 Python 依赖包。
2.  **Write Python Files**：将项目的所有 `.py` 脚本写入 Colab 的文件系统。
3.  **Run Experiments**：运行三种模式（`calibrate`, `embed`, `detect`）的示例命令。

## 1. Setup Environment

执行此单元格以安装所有依赖项。**安装完成后，您可能需要重启运行时（Runtime -> Restart session）。**

In [None]:
!pip install torch transformers numpy scipy scikit-learn tqdm datasets sentencepiece accelerate

## 2. Write Python Files

依次执行以下单元格，将项目的 Python 源代码写入 Colab 环境。

In [None]:
%%writefile moe_watermark.py
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from typing import Callable, Optional, Tuple

class MoEWatermark:
    """
    实现了 MoE Gating 水印的核心逻辑。
    此类的方法将被注入 (patch) 到预训练模型的 MoE Gating 模块中。
    """
    
    def __init__(self, K_sec: str, epsilon: float, num_experts: int, k_top: int, device: torch.device):
        self.secret_key = K_sec
        self.epsilon = epsilon
        self.num_experts = num_experts
        self.k_top = k_top
        self.device = device
        
        # 理论: epsilon = Var[delta_l]
        # 实现: 我们在 "绿色专家" 上施加正偏置 b, 
        # 在 k_top 个 "红色专家" 上施加负偏置 -b', 以保持均值接近0
        # 这是一个简化的实现，更精确的实现需要匹配 Var
        self.bias_strength = torch.sqrt(torch.tensor(self.epsilon, device=self.device)) * 5.0 # 简化的启发式调整

    def get_context_hash(self, hidden_states: torch.Tensor) -> int:
        """
        根据上下文 (hidden_states) 生成一个用于 PRNG 的种子。
        这是一个简化的实现。
        """
        # [batch_size, seq_len, dim] -> [batch_size, seq_len]
        hashed = torch.sum(hidden_states, dim=-1).long() 
        # 使用最后一个 token 的哈希值
        # [batch_size]
        last_token_hash = hashed[:, -1] 
        # 合并 batch 中的哈希值
        seed = torch.sum(last_token_hash).item()
        return hash((self.secret_key, seed))

    def get_bias_vector(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        生成偏置向量 delta_l。
        严格遵循方案 2. 节。
        """
        batch_size, seq_len, _ = hidden_states.shape
        context_hash = self.get_context_hash(hidden_states)
        
        # 使用确定性的种子
        generator = torch.Generator(device=self.device)
        generator.manual_seed(context_hash)
        
        delta_l = torch.zeros((batch_size, seq_len, self.num_experts), device=self.device)
        
        # 1. 选择一个 "绿色专家"
        green_expert = torch.randint(0, self.num_experts, (batch_size, seq_len), generator=generator, device=self.device)
        
        # 2. 施加正偏置
        delta_l.scatter_(-1, green_expert.unsqueeze(-1), self.bias_strength)
        
        # 3. (可选) 施加负偏置以平衡
        # ...
        
        return delta_l

    def watermarked_router_forward(
        self, 
        original_forward: Callable, 
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        被修补 (patched) 的前向传播函数。
        """
        # 1. 原始 logits
        # [batch_size, seq_len, num_experts]
        l_0 = original_forward(hidden_states)
        
        # 2. 生成并注入偏置
        delta_l = self.get_bias_vector(hidden_states)
        l_1 = l_0 + delta_l
        
        # 3. 计算 p0 (用于检测器) 和 p1 (用于路由)
        # 注意: p0 和 p1 都是完整的 softmax 分布
        p_0_dist = torch.softmax(l_0, dim=-1)
        p_1_dist = torch.softmax(l_1, dim=-1)
        
        # 4. Gating 路由（实际执行）
        # 使用 l_1 (水马 logits) 进行 Top-k 选择
        # [batch_size, seq_len, k_top]
        top_k_scores, S_indices = torch.topk(p_1_dist, self.k_top, dim=-1)
        
        # 归一化 top-k 得分
        top_k_scores = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-9)
        
        # 5. 返回检测器所需信息
        # p0, p1 (完整分布), S_indices (k个激活索引)
        return top_k_scores, S_indices, p_0_dist, p_1_dist, S_indices

def patch_moe_model_with_watermark(
    model: AutoModelForCausalLM, 
    K_sec: str, 
    epsilon: float
) -> AutoModelForCausalLM:
    """
    "修补" 一个预训练的 MoE 模型，注入水印逻辑。
    注意：这高度依赖于模型的具体实现 (如此处的 Mixtral)。
    """
    
    # 假设模型是 Mixtral
    if "Mixtral" not in model.config.model_type:
        raise NotImplementedError("此修补脚本目前仅支持 Mixtral 架构。")

    num_experts = model.config.num_local_experts
    k_top = model.config.num_experts_per_tok
    device = model.device

    watermark_injector = MoEWatermark(K_sec, epsilon, num_experts, k_top, device)
    
    print(f"Patching {model.config.model_type} with MoE watermark...")
    
    for layer in model.model.layers:
        # 找到 MoE 层的 Gating network (router)
        router = layer.block_sparse_moe.gate
        
        # 保存原始的 forward 方法
        original_forward = router.forward.__get__(router)
        
        # 创建新的 forward 方法
        def new_forward(hidden_states: torch.Tensor):
            # 调用水印逻辑
            top_k_scores, S_indices, p_0, p_1, S_obs = \
                watermark_injector.watermarked_router_forward(original_forward, hidden_states)
            
            # 将检测器信息附加到 router 对象上 (以便后续访问)
            router._watermark_detection_data = (p_0, p_1, S_obs)
            
            # MoE 路由需要稀疏的 gate_logits
            # [batch_size * seq_len, num_experts]
            batch_size, seq_len, _ = hidden_states.shape
            router_logits = torch.zeros(
                (batch_size * seq_len, num_experts), 
                dtype=top_k_scores.dtype, 
                device=device
            )
            router_logits.scatter_(-1, S_indices.view(-1, k_top), top_k_scores.view(-1, k_top))
            
            return router_logits

        # 应用 patch
        router.forward = new_forward
        
    print("Patching complete.")
    return model

def get_watermark_data_from_model(model: AutoModelForCausalLM) -> list:
    """
    从模型中提取检测器所需的数据。
    """
    data = []
    for layer in model.model.layers:
        if hasattr(layer.block_sparse_moe.gate, '_watermark_detection_data'):
            data.append(layer.block_sparse_moe.gate._watermark_detection_data)
            # (可选) 用后即焚，清除数据
            del layer.block_sparse_moe.gate._watermark_detection_data
    return data


In [None]:
%%writefile calibration.py
import torch
import numpy as np
from tqdm import tqdm
from scipy.optimize import minimize
from sklearn.linear_model import RANSACRegressor
from transformers import AutoModelForCausalLM
from torch.utils.data import DataLoader
from typing import Tuple, Dict

from moe_watermark import patch_moe_model_with_watermark, get_watermark_data_from_model
from attacks import estimate_gamma_from_text, paraphrase_text_batch

# 占位符：需要一个函数来计算 Chernoff 信息
def compute_chernoff_information(p0: torch.Tensor, p1: torch.Tensor) -> float:
    """
    计算 D*(p0, p1)
    D* = -min_{lambda in [0,1]} log( sum( p0^(1-lambda) * p1^lambda ) )
    """
    p0 = p0.cpu().numpy()
    p1 = p1.cpu().numpy()
    
    def objective(lambda_):
        if lambda_ < 0 or lambda_ > 1:
            return np.inf
        log_sum = np.log(np.sum(np.power(p0, 1 - lambda_) * np.power(p1, lambda_)) + 1e-9) # 增加稳定性
        return -log_sum

    result = minimize(objective, 0.5, bounds=[(0, 1)])
    if result.success:
        return result.fun
    else:
        # 边界情况
        return max(objective(0), objective(1))

def calibrate_Lg(model: AutoModelForCausalLM, dataloader: DataLoader, device: torch.device) -> float:
    """
    标定 Lipschitz 常数 Lg (对标 Algorithm 1)
    """
    print("Starting Lg calibration (Algorithm 1)...")
    model.eval()
    ratios = []
    
    # 假设 dataloader 产生 embedding
    # 在实际中，我们需要 tokenizer 和 model.get_input_embeddings()
    # 为简化，我们假设 dataloader 直接产生 inputs_embeds
    # NOTE: 实际dataloader 产生 'input_ids'。我们需要 'get_input_embeddings'
    
    embedding_layer = model.get_input_embeddings()
    
    for batch in tqdm(dataloader, desc="Calibrating Lg"):
        input_ids = batch['input_ids'].to(device)
        inputs_embeds = embedding_layer(input_ids)
        
        # 1. 获取原始 logits l(e)
        with torch.no_grad():
            # Mixtral 需要 attention_mask
            attention_mask = batch.get('attention_mask', torch.ones_like(input_ids)).to(device)
            outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True)
            # 我们需要 MoE 层的 *输入* hidden_states
            # 假设我们修补第一层 MoE
            # Mixtral router 的 forward 只有一个参数
            hs_e = model.model.layers[0].block_sparse_moe.gate(outputs.hidden_states[0]) # MoE 层的输入
            
        # 2. 生成扰动 e'
        epsilon = 0.01
        noise = torch.randn_like(inputs_embeds) * epsilon
        e_prime = inputs_embeds + noise
        
        # 3. 获取扰动 logits l(e')
        with torch.no_grad():
            outputs_prime = model(inputs_embeds=e_prime, attention_mask=attention_mask, output_hidden_states=True)
            hs_e_prime = model.model.layers[0].block_sparse_moe.gate(outputs_prime.hidden_states[0])

        # 4. 计算 L2 范数
        delta_l = torch.norm(hs_e - hs_e_prime, p=2, dim=-1).view(-1)
        delta_x = torch.norm(noise, p=2, dim=-1).view(-1)
        
        # 5. 计算比率
        valid_mask = delta_x > 1e-6
        r_i = delta_l[valid_mask] / delta_x[valid_mask]
        ratios.extend(r_i.cpu().numpy())

    if not ratios:
        print("Warning: Lg calibration failed to produce ratios.")
        return 2.0 # 返回默认值

    Lg_95 = np.percentile(ratios, 95)
    print(f"Lg (95th percentile) calibrated: {Lg_95:.4f}")
    return float(Lg_95)


def calibrate_C(model: AutoModelForCausalLM, dataloader: DataLoader, tokenizer, device: torch.device, Lg: float) -> Tuple[float, float, float]:
    """
    标定 C_prop, C_stability, C (对标 Algorithm 2)
    """
    print("Starting C calibration (Algorithm 2)...")
    model.eval()
    
    gammas = []
    deltas_tv = []
    
    # 1. 标定 C_prop
    print("Calibrating C_prop...")
    for batch in tqdm(dataloader, desc="Calibrating C_prop"):
        inputs = batch['input_ids'].to(device)
        text_batch = tokenizer.batch_decode(inputs, skip_special_tokens=True)
        
        # 2. 生成释义攻击 x'
        text_prime_batch = paraphrase_text_batch(text_batch)
        
        for text, text_prime in zip(text_batch, text_prime_batch):
            # 3. 估算 gamma
            gamma_i = estimate_gamma_from_text(text, text_prime, tokenizer.vocab_size)
            if gamma_i < 1e-6:
                continue
            
            # 4. 计算激活分布 p(e|x) 和 p(e|x')
            # 为简化，我们只使用第一个 MoE 层
            def get_activation_dist(txt: str) -> torch.Tensor:
                inputs = tokenizer(txt, return_tensors="pt").to(device)
                with torch.no_grad():
                    model(**inputs) # 运行 forward 以填充 _watermark_detection_data
                data = get_watermark_data_from_model(model)
                if not data:
                    return torch.empty(0)
                # (p_0, p_1, S_obs)
                # 我们需要 p_0 (原始分布)
                # [batch, seq, K_experts]
                p_0_dist = data[0][0] 
                # [seq, K_experts]
                return p_0_dist.mean(dim=[0, 1]) # 取 batch 和 seq 的平均分布

            p_dist = get_activation_dist(text)
            p_prime_dist = get_activation_dist(text_prime)

            if p_dist.numel() == 0 or p_prime_dist.numel() == 0:
                continue
                
            # 5. 计算总变差距离 delta
            delta_i_tv = 0.5 * torch.sum(torch.abs(p_dist - p_prime_dist))
            
            gammas.append(gamma_i)
            deltas_tv.append(delta_i_tv.item())

    if not gammas or not deltas_tv:
        print("Warning: C_prop calibration failed (no data). Using defaults.")
        return 1.0, 1.0, 1.0 # 返回默认值

    # 6. 稳健回归: delta ≈ C_prop * sqrt(gamma)
    X = np.sqrt(np.array(gammas)).reshape(-1, 1)
    y = np.array(deltas_tv)
    
    try:
        ransac = RANSACRegressor().fit(X, y)
        C_prop = ransac.estimator_.coef_[0]
    except Exception as e:
        print(f"RANSAC fit failed: {e}. Using numpy.linalg.lstsq.")
        try:
            C_prop = np.linalg.lstsq(X, y, rcond=None)[0][0]
        except Exception:
            print("Fallback linear fit failed. Using default C_prop=1.0")
            C_prop = 1.0
            
    print(f"C_prop (Propagation Constant) calibrated: {C_prop:.4f}")
    
    # 7. 标定 C_stability (简化)
    # 这一步在实践中非常复杂，因为它需要计算 D*(p', q') 和 D*(p, q)
    # 我们暂时使用一个基于 Lg 的启发式或一个默认值
    C_stability = max(1.0, Lg / 2.0) # 启发式
    print(f"C_stability (Stability Constant) estimated: {C_stability:.4f}")

    # 8. 综合常数 C
    C = C_stability * C_prop
    print(f"Overall System Constant C calibrated: {C:.4f}")
    
    return float(C_prop), float(C_stability), float(C)

def calibrate_C_star(
    model: AutoModelForCausalLM, 
    dataloader: DataLoader, 
    tokenizer: AutoTokenizer,
    C: float, 
    gamma_design: float, 
    device: str,
    lambda_weight: float = 1.0,
    delta_error: float = 0.001
) -> float:
    """
    标定最优安全系数 c* (对标 Algorithm 3)
    """
    print("Starting c* calibration (Algorithm 3)...")
    
    # 1. 标定性能成本函数 ΔA(c)
    # 我们用 PPL (Perplexity) 作为性能指标
    c_scan = np.linspace(C + 0.1, C * 2.5, 10)
    delta_A_values = []
    
    print("Scanning c values for performance cost (PPL)...")
    
    # 测量基线 PPL
    base_ppl = measure_ppl(model, dataloader, device)
    print(f"Base PPL (no watermark): {base_ppl:.4f}")
    
    # 保存原始 router 
    original_routers = {}
    for i, layer in enumerate(model.model.layers):
        original_routers[i] = layer.block_sparse_moe.gate.forward
        
    for c_val in tqdm(c_scan, desc="Calibrating ΔA(c)"):
        epsilon = c_val**2 * gamma_design
        # K_sec 在这里是临时的，只为测量 PPL
        temp_model = patch_moe_model_with_watermark(model, "temp_calib_key", epsilon)
        
        ppl = measure_ppl(temp_model, dataloader, device)
        delta_A = ppl - base_ppl # 性能 *下降*，所以 ppl 越高, ΔA 越大
        delta_A_values.append(delta_A)
        
        # 卸载 patch，恢复模型
        for i, layer in enumerate(model.model.layers):
            layer.block_sparse_moe.gate.forward = original_routers[i]
        
    # 拟合 ΔA(c) = a * c^p
    # 为简化，我们使用 2 阶多项式
    try:
        poly_coeffs = np.polyfit(c_scan, delta_A_values, 2)
        delta_A_func = np.poly1d(poly_coeffs)
        print(f"Performance cost function ΔA(c) fitted: {delta_A_func}")
    except np.linalg.LinAlgError:
        print("Warning: PPL fitting failed. Using linear approximation.")
        slope = (delta_A_values[-1] - delta_A_values[0]) / (c_scan[-1] - c_scan[0] + 1e-9)
        delta_A_func = lambda c: max(0, slope * (c - C))

    # 2. 网格搜索 c*
    # 目标函数: n*(c) + λ * ΔA(c)
    def objective_func(c):
        if c <= C:
            return np.inf
        # 样本复杂度 n*
        n_star = np.log(1.0 / delta_error) / (gamma_design * c * (c - C) + 1e-9)
        # 性能成本 ΔA
        delta_A = delta_A_func(c)
        
        return n_star + lambda_weight * delta_A

    # 3. 求解 c*
    # 我们在 c_scan 范围内寻找最优值
    best_c_star = c_scan[0]
    min_obj = np.inf
    
    for c_val in np.linspace(C + 0.01, C * 2.5, 50): # 细网格
        obj = objective_func(c_val)
        if obj < min_obj:
            min_obj = obj
            best_c_star = c_val
            
    print(f"Optimal Security Factor c* calibrated: {best_c_star:.4f}")
    return float(best_c_star)

def measure_ppl(model: AutoModelForCausalLM, dataloader: DataLoader, device: torch.device) -> float:
    """
    辅助函数：测量模型的 PPL
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Measuring PPL", leave=False):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch.get('attention_mask', torch.ones_like(input_ids)).to(device)
        labels = input_ids.clone()
        
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
        
        total_loss += loss.item()
        num_batches += 1
        
    if num_batches == 0:
        return 0.0
        
    avg_loss = total_loss / num_batches
    ppl = np.exp(avg_loss)
    return float(ppl)


In [None]:
%%writefile detector.py
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple

from moe_watermark import get_watermark_data_from_model

class LLRDetector:
    """
    实现了基于 LLR 的最优检测器 (对标方案 3. 节)
    """
    def __init__(
        self,
        model: AutoModelForCausalLM, 
        tokenizer: AutoTokenizer, 
        tau_alpha: float = 20.0 # 判决阈值，应通过 H0 实验标定
    ):
        self.model = model.eval()
        self.tokenizer = tokenizer
        self.tau_alpha = tau_alpha
        self.device = model.device

    def compute_llr_from_data(
        self, 
        watermark_data_list: list
    ) -> float:
        """
        从模型输出中计算 LLR 统计量。
        """
        total_llr = 0.0
        
        # 遍历所有 MoE 层
        for layer_data in watermark_data_list:
            # (p_0, p_1, S_obs)
            p_0_dist_batch, p_1_dist_batch, S_indices_batch = layer_data
            
            # [batch, seq, K_experts]
            # [batch, seq, k_top]
            
            # 理论 LLR: Λ = Σ_i log( p1(S_i) / p0(S_i) )
            # p(S_i) 是 Top-k 激活模式的概率，计算复杂。
            
            # 简化 LLR (如工程方案伪代码): 
            # 我们近似为 Σ_i Σ_{e in S_i} log(p1(e)/p0(e))
            # 即只看被激活的专家的 LLR 总和
            
            batch_size, seq_len, k_top = S_indices_batch.shape
            
            # 收集 S_i 对应的 p0 和 p1 概率
            # [batch, seq, k_top]
            p0_S = torch.gather(p_0_dist_batch, -1, S_indices_batch)
            p1_S = torch.gather(p_1_dist_batch, -1, S_indices_batch)
            
            # 防止 log(0)
            p0_S = torch.clamp(p0_S, min=1e-9)
            p1_S = torch.clamp(p1_S, min=1e-9)
            
            # 计算 LLR
            llr_per_expert = torch.log(p1_S) - torch.log(p0_S)
            
            # 对 k_top 个专家求和, 然后对 batch 和 seq 求和
            total_llr += torch.sum(llr_per_expert).item()
            
        return total_llr

    def detect(self, text: str) -> Tuple[bool, float]:
        """
        检测给定文本是否包含水印。
        返回 (是否检测到, LLR分数)
        """
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        
        # 运行模型 (已 patch)
        with torch.no_grad():
            self.model(**inputs)
        
        # 提取数据
        watermark_data = get_watermark_data_from_model(self.model)
        
        if not watermark_data:
            print("错误：未能从模型中提取到水印数据。模型是否已正确修补 (patch)？")
            return False, 0.0
            
        # 计算 LLR
        llr_score = self.compute_llr_from_data(watermark_data)
        
        # 判决
        is_detected = llr_score > self.tau_alpha
        
        print(f"LLR Score: {llr_score:.4f} (Threshold: {self.tau_alpha})")
        
        return is_detected, llr_score


In [None]:
%%writefile attacks.py
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import List

# 全局缓存释义模型
_paraphrase_model = None
_paraphrase_tokenizer = None
_device = "cuda" if torch.cuda.is_available() else "cpu"

def _load_paraphrase_model():
    """辅助函数：加载 T5 释义模型"""
    global _paraphrase_model, _paraphrase_tokenizer
    if _paraphrase_model is None:
        print("Loading paraphrase model (t5-base)...")
        model_name = "Vamsi/T5_Paraphrase" # 使用一个标准的 T5 释义模型
        _paraphrase_tokenizer = AutoTokenizer.from_pretrained(model_name)
        _paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(_device)
        _paraphrase_model.eval()
        print("Paraphrase model loaded.")

def paraphrase_text_batch(text_list: List[str]) -> List[str]:
    """
    对一批文本进行释义攻击。
    """
    _load_paraphrase_model()
    
    # T5 需要一个前缀
    inputs = _paraphrase_tokenizer(
        [f"paraphrase: {text}" for text in text_list],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(_device)
    
    with torch.no_grad():
        outputs = _paraphrase_model.generate(
            **inputs,
            max_length=512,
            num_beams=5,
            num_return_sequences=1,
            early_stopping=True
        )
        
    paraphrased_texts = _paraphrase_tokenizer.batch_decode(
        outputs, 
        skip_special_tokens=True
    )
    
    # 确保输出 batch 与输入 batch 大小一致
    num_inputs = len(text_list)
    num_outputs_per_input = 1 # 因为 num_return_sequences=1
    final_texts = [paraphrased_texts[i] for i in range(num_inputs)]
        
    return final_texts

def estimate_gamma_from_text(
    text_original: str, 
    text_attacked: str, 
    vocab_size: int
) -> float:
    """
    估算攻击强度 γ (对标文稿 7.4.1 节)
    使用 KL 散度的上界：γ ≈ (L/N) * H(V)
    """
    # 这是一个粗略的 token-level 编辑距离
    tokens_orig = text_original.split()
    tokens_atk = text_attacked.split()
    
    L = abs(len(tokens_orig) - len(tokens_atk)) # 插入/删除
    
    # 替换
    for t1, t2 in zip(tokens_orig, tokens_atk):
        if t1 != t2:
            L += 1
            
    N = max(len(tokens_orig), 1)
    
    # H(V) = log|V| (以 nats 为单位)
    H_V = np.log(vocab_size)
    
    gamma = (L / N) * H_V
    
    return float(gamma)


In [None]:
%%writefile main.py
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, Subset
import json
import os

# 确保本地文件可以被导入
import moe_watermark
import calibration
import detector
import attacks

def load_model_and_tokenizer(model_name: str, device: str):
    print(f"Loading model and tokenizer: {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16, # 使用 bfloat16 节省显存
        device_map="auto", # 自动分配到 GPU
    )
    print("Model and tokenizer loaded.")
    return model, tokenizer

def get_dataloader(dataset_name: str, split: str, tokenizer: AutoTokenizer, batch_size: int, num_samples: int):
    print(f"Loading dataset: {dataset_name} (split: {split})...")
    dataset = load_dataset(dataset_name, name="wikitext-103-v1", split=split) # 示例
    
    # Tokenize
    def tokenize_function(examples):
        tokenized = tokenizer(examples["text"], truncation=True, max_length=512, padding="max_length")
        return tokenized

    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
    
    # 截取子集
    subset_dataset = Subset(tokenized_dataset, range(min(num_samples, len(tokenized_dataset))))
    
    dataloader = DataLoader(subset_dataset, batch_size=batch_size)
    print(f"Dataset loaded with {len(subset_dataset)} samples.")
    return dataloader

def main():
    parser = argparse.ArgumentParser(description="MoE Provably Robust Watermark Project")
    parser.add_argument("--mode", type=str, required=True, choices=["calibrate", "embed", "detect"], help="操作模式")
    parser.add_argument("--model_name", type=str, default="mistralai/Mixtral-8x7B-v0.1", help="要使用的 MoE 模型")
    parser.add_argument("--dataset_name", type=str, default="wikitext", help="用于标定的数据集")
    parser.add_argument("--dataset_split", type=str, default="train", help="数据集分片")
    parser.add_argument("--num_calib_samples", type=int, default=10, help="用于标定的样本数量 (Colab 默认值较小)")
    parser.add_argument("--batch_size", type=int, default=1, help="标定时的 batch size (Colab 默认值较小)")
    
    # Embed & Detect
    parser.add_argument("--prompt", type=str, default="Once upon a time", help="用于生成的提示")
    parser.add_argument("--text_to_check", type=str, help="用于检测的文本")
    parser.add_argument("--secret_key", type=str, default="DEFAULT_SECRET_KEY", help="水印密钥")
    parser.add_argument("--attack", type=str, choices=["none", "paraphrase"], default="none", help="在检测前施加的攻击")
    
    # Watermark Params (可被 calibrate 覆盖)
    parser.add_argument("--gamma_design", type=float, default=0.03, help="设计的攻击强度 γ")
    parser.add_argument("--C_system", type=float, default=1.5, help="系统常数 C (来自标定)")
    parser.add_argument("--c_star", type=float, default=2.0, help="安全系数 c* (来自标定)")
    parser.add_argument("--tau_alpha", type=float, default=20.0, help="LLR 检测阈值 τ")
    
    # Colab 运行时，我们从 sys.argv 解析 (如果不是在 notebook 中)
    # 在 Colab 中，我们使用一个默认的 args 列表
    # args = parser.parse_args() # 在命令行中运行时使用这个
    # 在 Colab 单元格中，我们手动创建 args (见后续单元格)
    # 这里我们还是用 parse_args，但允许它为空，以便后续单元格可以覆盖
    args = parser.parse_args(args=[] if 'google.colab' in str(get_ipython()) else None)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # ----- Colab Overrides -----
    # 为了在 Colab 中顺利运行，我们使用小模型和少量样本
    # **警告**: Mixtral (47B) 在 Colab 免费版上无法运行 (OOM)
    # 请使用 T4 GPU 上的小型 MoE (如稀疏的 T5 或 DistilBERT)
    # **为演示起见，我们将继续使用 Mixtral，但这需要 Colab Pro (A100/V100)**
    # args.model_name = "google/switch-base-8" # 示例：一个更小的 MoE
    # ---------------------------
    
    if args.mode == "calibrate":
        # --- 标定模式 ---
        print("--- Mode: Calibrate ---")
        # 警告：标定模式非常消耗计算资源
        print("警告：标定模式计算量巨大。在 Colab (非 Pro) 上可能失败。")
        print(f"使用 {args.num_calib_samples} 样本进行标定。")
        
        model, tokenizer = load_model_and_tokenizer(args.model_name, device)
        dataloader = get_dataloader(
            args.dataset_name, 
            args.dataset_split, 
            tokenizer, 
            args.batch_size, 
            args.num_calib_samples
        )
        
        # 运行标定（使用占位符/默认值）
        try:
            Lg = calibration.calibrate_Lg(model, dataloader, device)
        except Exception as e:
            print(f"Lg 标定失败: {e}. 使用默认值 2.0")
            Lg = 2.0

        # 为 C 标定 patch 一个临时的 key
        patched_model_for_C = moe_watermark.patch_moe_model_with_watermark(model, "calib_C_key", 0.01)
        try:
            C_prop, C_stability, C = calibration.calibrate_C(patched_model_for_C, dataloader, tokenizer, device, Lg)
        except Exception as e:
            print(f"C 标定失败: {e}. 使用默认值 1.0, 1.5, 1.5")
            C_prop, C_stability, C = 1.0, 1.5, 1.5

        try:
            c_star = calibration.calibrate_C_star(model, dataloader, tokenizer, C, args.gamma_design, device)
        except Exception as e:
            print(f"c* 标定失败: {e}. 使用默认值 2.0")
            c_star = 2.0
        
        print("\n--- Calibration Results --- (可能使用了默认值)")
        print(f"Lg (95th percentile): {Lg:.4f}")
        print(f"System Constant C:    {C:.4f}")
        print(f"Optimal Factor c*:    {c_star:.4f}")
        print("--------------------------------------")
        print("请将这些值用于 embed 和 detect 模式")

    elif args.mode == "embed":
        # --- 嵌入模式 ---
        print("--- Mode: Embed ---")
        model, tokenizer = load_model_and_tokenizer(args.model_name, device)
        
        # 计算水印强度 ε
        epsilon = args.c_star**2 * args.gamma_design
        print(f"Using c*={args.c_star}, γ={args.gamma_design} -> ε={epsilon:.4f}")
        
        # Patch 模型
        patched_model = moe_watermark.patch_moe_model_with_watermark(model, args.secret_key, epsilon)
        
        print(f"\nGenerating watermarked text from prompt: '{args.prompt}'...")
        inputs = tokenizer(args.prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = patched_model.generate(
                **inputs, 
                max_new_tokens=100, 
                do_sample=True, # 激活采样
                top_k=50
            )
            
        watermarked_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print("\n--- Watermarked Output --- (保存此文本用于检测)")
        print(watermarked_text)
        print("--------------------------")
        # 将生成的文本写入文件，以便后续单元格使用
        with open("generated_text.txt", "w") as f:
            f.write(watermarked_text)

    elif args.mode == "detect":
        # --- 检测模式 ---
        print("--- Mode: Detect ---")
        if not args.text_to_check:
            # 尝试从 embed 模式生成的文件中读取
            if os.path.exists("generated_text.txt"):
                print("未提供 --text_to_check, 从 generated_text.txt 中读取...")
                with open("generated_text.txt", "r") as f:
                    args.text_to_check = f.read()
            else:
                raise ValueError("--text_to_check is required for detect mode (or run embed mode first)")
            
        model, tokenizer = load_model_and_tokenizer(args.model_name, device)
        
        # 计算水印强度 ε
        epsilon = args.c_star**2 * args.gamma_design
        print(f"Loading detector with c*={args.c_star}, γ={args.gamma_design} -> ε={epsilon:.4f}")

        # Patch 模型，以便 LLR 检测器可以访问 p0 和 p1
        patched_model = moe_watermark.patch_moe_model_with_watermark(model, args.secret_key, epsilon)
        
        detector = detector.LLRDetector(patched_model, tokenizer, tau_alpha=args.tau_alpha)
        
        text_to_check = args.text_to_check
        
        # 施加攻击
        if args.attack == "paraphrase":
            print("Applying paraphrase attack before detection...")
            original_text = text_to_check
            text_to_check = attacks.paraphrase_text_batch([original_text])[0]
            
            gamma_est = attacks.estimate_gamma_from_text(original_text, text_to_check, tokenizer.vocab_size)
            print(f"Paraphrased text: '{text_to_check}'")
            print(f"Estimated attack strength γ: {gamma_est:.4f}")

        print(f"\nDetecting watermark in text (length {len(text_to_check)})...")
        is_detected, llr_score = detector.detect(text_to_check)
        
        print("\n--- Detection Result ---")
        if is_detected:
            print(f"Result: Watermark DETECTED (Score: {llr_score:.2f}) / Threshold: {args.tau_alpha}")
        else:
            print(f"Result: Watermark NOT DETECTED (Score: {llr_score:.2f}) / Threshold: {args.tau_alpha}")
        print("------------------------")

def run_main_with_args(args_list):
    """在 Colab 中模拟命令行参数运行 main()"""
    import sys
    sys.argv = ['main.py'] + args_list
    main()


## 3. Run Experiments

现在您可以运行实验了。

**警告：** Mixtral-8x7B (约 47B 参数) 是一个非常大的模型。在 Google Colab 的免费版 T4 GPU (16GB VRAM) 上**无法运行**，会导致显存不足 (OOM) 错误。您需要 Colab Pro 并使用 A100 (40GB) 或 V100 (16GB，也许勉强) GPU。

为了演示，下面的命令被注释掉了。如果您有足够的 GPU 资源，请取消注释并运行它们。

### 3.1 模式一: `calibrate` (参数标定)

（**警告：** 极其消耗计算资源，在 Colab 上可能需要数小时或因超时而失败。`main.py` 已默认使用非常少的样本量 `num_calib_samples=10` 和 `batch_size=1`）

In [None]:
# !python main.py --mode calibrate --num_calib_samples 50

### 3.2 模式二: `embed` (水印嵌入)

此模式将加载模型（需要 Colab Pro），应用水印 patch，并生成一段带水印的文本。生成的文本将保存到 `generated_text.txt` 中，供下一步检测使用。

In [None]:
# !python main.py --mode embed --secret_key "colab_test_key_42" --prompt "The theoretical foundation of this watermark is signal-attack decoupling"

### 3.3 模式三: `detect` (水印检测)

此模式将加载模型，并检测 `generated_text.txt`（上一步生成）中的文本是否包含水印。

**(a) 无攻击检测**

In [None]:
# !python main.py --mode detect --secret_key "colab_test_key_42" --attack "none"

**(b) 带释义攻击 (Paraphrase) 检测**

（**注意：** 这会额外加载 T5 释义模型，进一步增加显存占用）

In [None]:
# !python main.py --mode detect --secret_key "colab_test_key_42" --attack "paraphrase"