# OKR (Opportunistic Keyed Routing) 水印实验

本notebook实现了OKR水印算法，可以在Google Colab上运行实验。

## 核心思想

OKR是一种"机会主义"水印方法：
- 只在"安全区域"（indifference zone）内修改路由决策
- 安全区域定义：`max_logit - current_logit < epsilon`
- 使用LSH投影计算水印偏好
- 如果不在安全区，保持原始路由（fail-open）

## 实验流程

1. 安装依赖
2. 加载模型和分词器
3. 注入OKR水印
4. 生成带水印的文本
5. 检测水印


## 1. 安装依赖


In [None]:
%pip install -q torch transformers accelerate sentencepiece tqdm numpy matplotlib


## 2. 导入库和核心代码


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Union, Optional, Tuple
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm import tqdm
import json
from pathlib import Path
import hashlib


### 2.1 OKR核心路由逻辑 (okr_kernel.py)


In [None]:
class OKRRouter(nn.Module):
    """
    Opportunistic Keyed Routing 路由器

    核心思想：只在安全区域内修改路由，保持质量
    """

    def __init__(self,
                 input_dim: int,
                 num_experts: int,
                 top_k: int = 2,
                 epsilon: float = 1.5):
        super().__init__()
        self.top_k = top_k
        self.epsilon = epsilon

        # 标准的 Gating 层
        self.gate_network = nn.Linear(input_dim, num_experts, bias=False)

        # 水印私钥投影矩阵
        self.register_buffer(
            "secret_projection",
            torch.randn(input_dim, num_experts)
        )

    def forward(self, hidden_states: torch.Tensor):
        """
        Args:
            hidden_states: [batch_size, seq_len, input_dim]
        Returns:
            routing_weights, selected_experts
        """
        # 1. 计算原始路由分数
        raw_logits = self.gate_network(hidden_states)

        # 2. 计算水印偏好 (LSH投影)
        if self.secret_projection.device != hidden_states.device:
            self.secret_projection = self.secret_projection.to(hidden_states.device)
        watermark_bias = torch.matmul(hidden_states, self.secret_projection)

        # 3. 计算安全掩码 (Indifference Zone)
        max_logits, _ = raw_logits.max(dim=-1, keepdim=True)
        safe_mask = raw_logits >= (max_logits - self.epsilon)

        # 4. 机会主义注入
        if watermark_bias.device != raw_logits.device:
            watermark_bias = watermark_bias.to(raw_logits.device)

        modified_scores = torch.where(
            safe_mask,
            watermark_bias,
            torch.tensor(-1e9, device=raw_logits.device, dtype=raw_logits.dtype)
        )

        # 5. 选取 Top-K
        _, selected_experts = torch.topk(modified_scores, self.top_k, dim=-1)

        # 6. 计算权重 (使用原始logits)
        router_logits = torch.gather(raw_logits, -1, selected_experts)
        routing_weights = F.softmax(router_logits, dim=-1)

        return routing_weights, selected_experts


### 2.2 OKR注入代码 (okr_patch.py)


In [None]:
def _okr_forward_core(router: OKRRouter, hidden_states: torch.Tensor):
    """OKR核心前向传播逻辑"""
    if hidden_states.dtype != router.gate_network.weight.dtype:
        hidden_states = hidden_states.to(dtype=router.gate_network.weight.dtype)

    raw_logits = router.gate_network(hidden_states)

    if router.secret_projection.device != hidden_states.device:
        router.secret_projection = router.secret_projection.to(hidden_states.device)

    watermark_bias = torch.matmul(hidden_states, router.secret_projection)

    max_logits, _ = raw_logits.max(dim=-1, keepdim=True)
    safe_mask = raw_logits >= (max_logits - router.epsilon)

    if watermark_bias.device != raw_logits.device:
        watermark_bias = watermark_bias.to(raw_logits.device)

    modified_scores = torch.where(
        safe_mask,
        watermark_bias,
        torch.tensor(-1e9, device=watermark_bias.device, dtype=watermark_bias.dtype)
    )

    _, selected_experts = torch.topk(modified_scores, router.top_k, dim=-1)

    router_logits = torch.gather(raw_logits, -1, selected_experts)
    routing_weights = F.softmax(router_logits, dim=-1)

    return routing_weights, selected_experts


def inject_okr(model: AutoModelForSeq2SeqLM,
               epsilon: float = 1.5,
               secret_key: Optional[str] = None) -> AutoModelForSeq2SeqLM:
    """
    将OKR注入到MoE模型中

    Args:
        model: 预训练的MoE模型
        epsilon: 质量容忍阈值
        secret_key: 可选的密钥（用于初始化secret_projection）

    Returns:
        patched_model: 已注入OKR的模型
    """
    count = 0

    # 获取模型配置
    if hasattr(model.config, 'num_local_experts'):
        num_experts = model.config.num_local_experts
        top_k = getattr(model.config, 'num_experts_per_tok', 1)
    elif hasattr(model.config, 'num_experts'):
        num_experts = model.config.num_experts
        top_k = getattr(model.config, 'num_experts_per_tok', 1)
    else:
        raise ValueError("无法检测模型配置：找不到num_experts或num_local_experts")

    print(f"DEBUG: Initial model config: num_experts={num_experts}, top_k={top_k}")

    # 查找decoder
    decoder = None
    if hasattr(model, 'decoder'):
        decoder = model.decoder
        print(f"DEBUG: Found decoder directly in model: {type(decoder)}")
    elif hasattr(model, 'model') and hasattr(model.model, 'decoder'):
        decoder = model.model.decoder
        print(f"DEBUG: Found decoder in model.model: {type(decoder)}")

    if decoder is None:
        print("DEBUG: Could not find decoder.")
        raise ValueError("无法找到decoder")

    # 获取decoder blocks
    decoder_blocks = None
    if hasattr(decoder, 'block'):
        decoder_blocks = decoder.block
        print(f"DEBUG: Found decoder blocks as 'block': {type(decoder_blocks)}")
    elif hasattr(decoder, 'layers'):
        decoder_blocks = decoder.layers
        print(f"DEBUG: Found decoder blocks as 'layers': {type(decoder_blocks)}")

    if decoder_blocks is None:
        print("DEBUG: Could not find decoder blocks.")
        raise ValueError("无法找到decoder blocks")

    # 遍历所有层，查找MoE router
    for layer_idx, layer in enumerate(decoder_blocks):
        router = None

        print(f"DEBUG: Processing decoder layer {layer_idx}")
        print(f"DEBUG: Layer type: {type(layer)}")

        # 查找router（Switch Transformer标准结构，T5 decoder block通常是 self-attn(0), cross-attn(1), FFN(2)）
        if hasattr(layer, 'layer'):
            layer_list = layer.layer
            print(f"DEBUG: Layer.layer type: {type(layer_list)}, length: {len(layer_list)}")
            if len(layer_list) > 2: # 确保有足够的层，索引2有效
                ffn_layer = layer_list[2]
                print(f"DEBUG: FFN Layer type (index 2): {type(ffn_layer)}")
                # 修正: 路由器位于 ffn_layer.mlp.router
                if hasattr(ffn_layer, 'mlp') and hasattr(ffn_layer.mlp, 'router'):
                    router = ffn_layer.mlp.router
                    print(f"DEBUG: Found router in layer {layer_idx} at ffn_layer.mlp.router, type: {type(router)}")
                else:
                    print(f"DEBUG: FFN layer (index 2) in layer {layer_idx} does NOT have 'router' attribute directly or in 'mlp'.")
            else:
                print(f"DEBUG: layer.layer in layer {layer_idx} has {len(layer_list)} items, which is not enough for FFN at index 2.")
        else:
            print(f"DEBUG: Layer {layer_idx} does NOT have 'layer' attribute.")


        if router is None:
            print(f"DEBUG: No router found for layer {layer_idx}. Continuing...")
            continue

        # 获取gate层和维度
        gate_layer = None
        input_dim = None
        num_experts_found = None
        router_dtype = torch.float32

        if hasattr(router, 'classifier') and isinstance(router.classifier, torch.nn.Linear):
            gate_layer = router.classifier
            input_dim = gate_layer.in_features
            num_experts_found = gate_layer.out_features
            router_dtype = gate_layer.weight.dtype
        elif hasattr(router, 'gate') and isinstance(router.gate, torch.nn.Linear):
            gate_layer = router.gate
            input_dim = gate_layer.in_features
            num_experts_found = gate_layer.out_features
            router_dtype = gate_layer.weight.dtype

        if input_dim is None:
            if hasattr(model.config, 'd_model'):
                input_dim = model.config.d_model
            else:
                continue

        if num_experts_found is None:
            num_experts_found = num_experts

        # 获取设备
        device = next(router.parameters()).device

        # 创建OKR router
        okr_router = OKRRouter(input_dim, num_experts_found, top_k=top_k, epsilon=epsilon)
        okr_router = okr_router.to(device=device, dtype=router_dtype)

        # 复制原始权重
        if gate_layer is not None:
            with torch.no_grad():
                gate_weight = gate_layer.weight
                okr_weight = okr_router.gate_network.weight

                if gate_weight.shape == okr_weight.shape:
                    if gate_weight.dtype != okr_weight.dtype:
                        okr_weight.data = gate_weight.data.to(dtype=okr_weight.dtype)
                    else:
                        okr_weight.copy_(gate_weight)
                elif gate_weight.shape == okr_weight.shape[::-1]:
                    weight_t = gate_weight.T
                    if weight_t.dtype != okr_weight.dtype:
                        okr_weight.data = weight_t.data.to(dtype=okr_weight.dtype)
                    else:
                        okr_weight.copy_(weight_t)

        # 初始化secret_projection
        if secret_key is not None:
            seed = int(hashlib.sha256(secret_key.encode()).hexdigest()[:16], 16)
            generator = torch.Generator(device=device)
            generator.manual_seed(seed)
            with torch.no_grad():
                okr_router.secret_projection.data = torch.randn(
                    input_dim, num_experts_found,
                    generator=generator,
                    device=device,
                    dtype=router_dtype
                )

        # 保存OKR router
        router._okr_router = okr_router

        # 创建新的forward方法
        original_forward = router.forward
        is_first_router = (count == 0)

        def make_okr_forward(orig_fwd, okr_rt, mdl, layer_id, is_first=False):
            def okr_forward(hidden_states: torch.Tensor):
                # 调用OKR核心逻辑
                routing_weights, selected_experts = _okr_forward_core(okr_rt, hidden_states)

                # 保存路由数据（仅第一个router）
                if is_first:
                    if not hasattr(router, '_okr_all_selected_experts'):
                        router._okr_all_selected_experts = []

                    batch_size, seq_len, _ = hidden_states.shape
                    if seq_len == 1:
                        router._okr_all_selected_experts.append(selected_experts)
                    else:
                        for i in range(seq_len):
                            router._okr_all_selected_experts.append(selected_experts[:, i:i+1, :])

                router._selected_experts = selected_experts

                # 构造Switch Transformers格式的返回值
                batch_size, seq_len, _ = hidden_states.shape
                num_experts = okr_rt.gate_network.out_features

                router_mask = torch.zeros(
                    (batch_size, seq_len, num_experts),
                    device=hidden_states.device,
                    dtype=hidden_states.dtype
                )
                router_mask.scatter_(dim=-1, index=selected_experts, src=routing_weights)

                router_probs = routing_weights.sum(dim=-1, keepdim=True)

                router_logits = torch.full(
                    (batch_size, seq_len, num_experts),
                    float('-inf'),
                    device=hidden_states.device,
                    dtype=hidden_states.dtype
                )
                logits_values = torch.log(routing_weights + 1e-9)
                router_logits.scatter_(dim=-1, index=selected_experts, src=logits_values)

                return (router_mask, router_probs, router_logits)
            return okr_forward

        router.forward = make_okr_forward(original_forward, okr_router, model, layer_idx, is_first=is_first_router)
        count += 1

    if count == 0:
        raise ValueError("未找到任何router层")

    # 注入clear_okr_stats方法
    def clear_okr_stats(mdl):
        if hasattr(mdl, '_okr_routing_data'):
            mdl._okr_routing_data = {}
        for name, module in mdl.named_modules():
            if hasattr(module, '_okr_all_selected_experts'):
                module._okr_all_selected_experts = []
            if hasattr(module, '_selected_experts'):
                module._selected_experts = None

    import types
    model.clear_okr_stats = types.MethodType(clear_okr_stats, model)

    print(f"✓ 已注入OKR到 {count} 个decoder层")
    return model

In [None]:
class OKRDetector:
    """
    OKR水印检测器

    核心验证逻辑：
    1. 重算原本的Logits (Ground Truth)
    2. 重算水印信号 (Expected Signal)
    3. 重算机会窗口 (Opportunities)
    4. 验证命中 (Check Hits)
    """

    def __init__(self, model, epsilon: float = 1.5):
        self.model = model
        self.epsilon = epsilon

    def detect(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
               decoder_input_ids: Optional[torch.Tensor] = None):
        """
        对一段文本进行检测

        Args:
            input_ids: 输入token IDs [batch, seq_len] (encoder输入)
            attention_mask: 注意力掩码 [batch, seq_len]
            decoder_input_ids: decoder输入token IDs [batch, seq_len]

        Returns:
            score: 水印命中率 (0-1)
            verdict: "Watermarked" 或 "Clean"
        """
        # 获取路由数据
        actual_selected_experts = self._extract_selected_experts()

        if actual_selected_experts is None:
            return 0.0, "No routing data available"

        # 重新运行模型获取hidden_states
        if decoder_input_ids is None:
            decoder_input_ids = input_ids

        # 保存并清空路由数据
        router_for_detection = self._get_router()
        saved_routing_data = {}
        if router_for_detection:
            if hasattr(router_for_detection, '_okr_all_selected_experts'):
                saved_routing_data['_okr_all_selected_experts'] = router_for_detection._okr_all_selected_experts.copy() if router_for_detection._okr_all_selected_experts else []

        if router_for_detection:
            if hasattr(router_for_detection, '_okr_all_selected_experts'):
                router_for_detection._okr_all_selected_experts = []

        # 对齐长度
        routing_seq_len = actual_selected_experts.shape[1]
        if decoder_input_ids.shape[1] < routing_seq_len:
            actual_selected_experts = actual_selected_experts[:, -decoder_input_ids.shape[1]:, :]
        elif decoder_input_ids.shape[1] > routing_seq_len:
            diff = decoder_input_ids.shape[1] - routing_seq_len
            if diff == 1:
                decoder_input_ids = decoder_input_ids[:, 1:]
            else:
                decoder_input_ids = decoder_input_ids[:, :routing_seq_len]

        # 重新运行模型获取hidden_states
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                output_hidden_states=True
            )

        # 恢复路由数据
        if router_for_detection:
            if hasattr(router_for_detection, '_okr_all_selected_experts'):
                router_for_detection._okr_all_selected_experts = saved_routing_data.get('_okr_all_selected_experts', [])

        # 获取decoder的最后一层hidden_states
        if hasattr(outputs, 'decoder_hidden_states') and outputs.decoder_hidden_states:
            hidden_states = outputs.decoder_hidden_states[-1]
        elif hasattr(outputs, 'hidden_states') and outputs.hidden_states:
            hidden_states = outputs.hidden_states[-1]
        else:
            return 0.0, "No hidden states available"

        # 对齐长度
        min_seq_len = min(hidden_states.shape[1], actual_selected_experts.shape[1])
        hidden_states = hidden_states[:, :min_seq_len, :]
        actual_selected_experts = actual_selected_experts[:, :min_seq_len, :]

        return self.verify_batch(hidden_states, actual_selected_experts)

    def verify_batch(self, hidden_states: torch.Tensor, actual_selected_experts: torch.Tensor) -> Tuple[float, str]:
        """
        核心验证逻辑

        Args:
            hidden_states: [batch, seq, dim]
            actual_selected_experts: [batch, seq, top_k]

        Returns:
            score: 水印命中率 (0-1)
            verdict: "Watermarked" 或 "Clean"
        """
        router = self._get_router()
        if router is None:
            return 0.0, "No router found"

        okr_router = None
        if hasattr(router, '_okr_router'):
            okr_router = router._okr_router
        elif hasattr(router, 'gate_network') and hasattr(router, 'secret_projection'):
            okr_router = router

        if okr_router is None:
            return 0.0, "No OKR router found"

        # 1. 重算原本的Logits
        raw_logits = okr_router.gate_network(hidden_states)
        max_logits, _ = raw_logits.max(dim=-1, keepdim=True)

        # 2. 重算水印信号
        if okr_router.secret_projection.device != hidden_states.device:
            secret_proj = okr_router.secret_projection.to(hidden_states.device)
        else:
            secret_proj = okr_router.secret_projection

        watermark_bias = torch.matmul(hidden_states, secret_proj)

        # 3. 重算机会窗口
        safe_mask = raw_logits >= (max_logits - self.epsilon)
        num_safe_experts = safe_mask.sum(dim=-1)
        valid_opportunity_mask = (num_safe_experts >= 2)

        if valid_opportunity_mask.sum() == 0:
            return 0.0, "No Opportunities"

        # 4. 验证命中
        masked_watermark_scores = torch.where(
            safe_mask,
            watermark_bias,
            torch.tensor(-1e9, device=watermark_bias.device, dtype=watermark_bias.dtype)
        )
        target_expert = torch.argmax(masked_watermark_scores, dim=-1)

        target_expert_expanded = target_expert.unsqueeze(-1)
        hits = (actual_selected_experts == target_expert_expanded).any(dim=-1)

        # 5. 统计分数
        valid_hits = hits[valid_opportunity_mask]
        score = valid_hits.float().mean().item()

        # 6. 判断阈值
        num_experts = raw_logits.shape[-1]
        random_baseline = 1.0 / num_experts
        threshold = random_baseline * 2.0

        verdict = "Watermarked" if score > threshold else "Clean"
        return score, verdict

    def _get_router(self):
        """从模型中提取路由器"""
        for name, module in self.model.named_modules():
            if hasattr(module, '_okr_router'):
                return module
            if hasattr(module, 'secret_projection') and hasattr(module, 'gate_network'):
                return module
        return None

    def _extract_selected_experts(self) -> Optional[torch.Tensor]:
        """从模型中提取实际选择的专家"""
        router_for_detection = self._get_router()
        if router_for_detection and hasattr(router_for_detection, '_okr_all_selected_experts'):
            all_experts = router_for_detection._okr_all_selected_experts
            if all_experts and len(all_experts) > 0:
                concatenated = torch.cat(all_experts, dim=1)
                return concatenated

        router_found = False
        for name, module in self.model.named_modules():
            if hasattr(module, '_okr_all_selected_experts'):
                all_experts = module._okr_all_selected_experts
                if all_experts and len(all_experts) > 0:
                    if not router_found:
                        concatenated = torch.cat(all_experts, dim=1)
                        router_found = True
                        return concatenated

        return None


In [None]:
# 实验配置
CONFIG = {
    "model_name": "google/switch-base-8",  # Switch Transformers模型
    "epsilon": 1.5,  # 质量容忍阈值
    "secret_key": "OKR_COLAB_EXPERIMENT_KEY",  # 水印密钥
    "num_samples": 10,  # 实验样本数（Colab上建议用小样本）
    "max_length": 128,  # 最大生成长度
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

print(f"设备: {CONFIG['device']}")
print(f"模型: {CONFIG['model_name']}")
print(f"Epsilon: {CONFIG['epsilon']}")
print(f"样本数: {CONFIG['num_samples']}")


## 4. 加载模型和分词器


In [None]:
print("加载分词器...")
tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("加载模型...")
model = AutoModelForSeq2SeqLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.float32,
    device_map="auto" if CONFIG["device"] == "cuda" else None,
    low_cpu_mem_usage=True
)
model.eval()

if CONFIG["device"] == "cpu":
    model = model.to(CONFIG["device"])

print("✓ 模型加载完成")


## 5. 注入OKR水印


In [None]:
print("注入OKR水印...")
watermarked_model = inject_okr(
    model,
    epsilon=CONFIG["epsilon"],
    secret_key=CONFIG["secret_key"]
)
print("✓ 水印注入完成")

## 6. 准备测试数据


In [None]:
# 测试提示词
test_prompts = [
    "The quick brown fox jumps over the lazy dog.",
    "In a world where technology advances rapidly, artificial intelligence plays a crucial role.",
    "Climate change is one of the most pressing issues facing humanity today.",
    "The study of mathematics has been fundamental to scientific progress throughout history.",
    "Machine learning algorithms can identify patterns in large datasets.",
    "The human brain is one of the most complex structures in the known universe.",
    "Renewable energy sources are becoming increasingly important for sustainable development.",
    "The internet has revolutionized the way we communicate and access information.",
    "Quantum computing represents a paradigm shift in computational capabilities.",
    "Biodiversity conservation is essential for maintaining ecosystem stability."
]

# 只使用前num_samples个
test_prompts = test_prompts[:CONFIG["num_samples"]]
print(f"测试提示词数量: {len(test_prompts)}")

## 7. 生成带水印的文本


In [None]:
watermarked_texts = []
watermarked_token_ids = []
sample_routing_data = []

decoder_start_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

for i, text in enumerate(tqdm(test_prompts, desc="生成文本")):
    # 清空路由数据
    if hasattr(watermarked_model, 'clear_okr_stats'):
        watermarked_model.clear_okr_stats()

    # 编码输入
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True,
                      max_length=CONFIG["max_length"]).to(CONFIG["device"])

    # 生成文本
    with torch.no_grad():
        outputs = watermarked_model.generate(
            **inputs,
            max_length=CONFIG["max_length"],
            num_beams=1,
            do_sample=False,
            decoder_start_token_id=decoder_start_token_id,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    watermarked_texts.append(generated_text)
    watermarked_token_ids.append(outputs[0])

    # 提取路由数据
    current_sample_routing_data = None
    router_found = False
    for name, module in watermarked_model.named_modules():
        if hasattr(module, '_okr_all_selected_experts') and module._okr_all_selected_experts:
            if not router_found:
                all_experts = module._okr_all_selected_experts
                if all_experts and len(all_experts) > 0:
                    current_sample_routing_data = torch.cat(all_experts, dim=1)
                    router_found = True
                    break

    if current_sample_routing_data is not None:
        sample_routing_data.append(current_sample_routing_data.clone())
    else:
        sample_routing_data.append(None)

    print(f"样本 {i+1}: 原始='{text[:50]}...', 生成='{generated_text[:50]}...'"
)

## 8. 检测水印


In [None]:
detector = OKRDetector(watermarked_model, epsilon=CONFIG["epsilon"])

detection_results = []

for i, (original_text, watermarked_text, generated_token_ids, sample_routing) in enumerate(
    zip(test_prompts, watermarked_texts, watermarked_token_ids, sample_routing_data)
):
    # 准备encoder输入
    encoder_inputs = tokenizer(original_text, return_tensors="pt", padding=True, truncation=True,
                              max_length=CONFIG["max_length"]).to(CONFIG["device"])

    # 准备decoder输入
    if len(generated_token_ids.shape) == 1:
        generated_seq = generated_token_ids
    else:
        generated_seq = generated_token_ids[0]

    if generated_seq.shape[0] > 0:
        decoder_input_ids = torch.cat([
            torch.tensor([[decoder_start_token_id]], device=CONFIG["device"], dtype=torch.long),
            generated_seq[:-1].unsqueeze(0) if generated_seq.shape[0] > 1 else torch.tensor([[decoder_start_token_id]], device=CONFIG["device"], dtype=torch.long)
        ], dim=1)
    else:
        decoder_input_ids = torch.tensor([[decoder_start_token_id]], device=CONFIG["device"], dtype=torch.long)

    # 设置路由数据
    if hasattr(watermarked_model, 'clear_okr_stats'):
        watermarked_model.clear_okr_stats()

    if sample_routing is not None:
        router_for_detection = None
        router_found = False
        for name, module in watermarked_model.named_modules():
            if hasattr(module, '_okr_all_selected_experts'):
                if not router_found:
                    router_for_detection = module
                    router_found = True
                    break

        if router_for_detection is not None:
            routing_list = []
            for j in range(sample_routing.shape[1]):
                routing_list.append(sample_routing[:, j:j+1, :])
            router_for_detection._okr_all_selected_experts = routing_list

    # 检测
    score, verdict = detector.detect(
        input_ids=encoder_inputs["input_ids"],
        attention_mask=encoder_inputs.get("attention_mask"),
        decoder_input_ids=decoder_input_ids
    )

    detection_results.append({
        "sample_id": i,
        "original_text": original_text[:100],
        "watermarked_text": watermarked_text[:100],
        "hit_rate": float(score),
        "verdict": verdict
    })

    print(f"样本 {i+1}: 命中率={score:.4f}, 判定={verdict}")

In [None]:
# 统计结果
hit_rates = [r["hit_rate"] for r in detection_results]
avg_hit_rate = np.mean(hit_rates) if hit_rates else 0.0
watermarked_count = sum(1 for r in detection_results if r["verdict"] == "Watermarked")

print("=" * 60)
print("实验结果汇总")
print("=" * 60)
print(f"总样本数: {len(test_prompts)}")
print(f"检测为水印: {watermarked_count}")
print(f"平均命中率: {avg_hit_rate:.4f}")
print(f"命中率范围: [{min(hit_rates):.4f}, {max(hit_rates):.4f}]")
print("=" * 60)

# 保存结果
results = {
    "config": CONFIG,
    "summary": {
        "total_samples": len(test_prompts),
        "watermarked_samples": watermarked_count,
        "average_hit_rate": float(avg_hit_rate),
        "hit_rate_range": [float(min(hit_rates)), float(max(hit_rates))]
    },
    "detailed_results": detection_results
}

with open("okr_results.json", "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print("\n✓ 实验结果已保存到 okr_results.json")

## 10. 可视化结果（可选）


In [None]:
import matplotlib.pyplot as plt

# 绘制命中率分布
plt.figure(figsize=(10, 6))
plt.hist(hit_rates, bins=20, edgecolor='black', alpha=0.7)
plt.axvline(avg_hit_rate, color='r', linestyle='--', label=f'Average Hit Rate: {avg_hit_rate:.4f}')
plt.xlabel('Hit Rate')
plt.ylabel('Number of Samples')
plt.title('Distribution of OKR Watermark Detection Hit Rates')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# 打印详细结果
print("\nDetailed Results:")
for r in detection_results:
    print(f"Sample {r['sample_id']+1}: {r['verdict']} (Hit Rate: {r['hit_rate']:.4f})")