In [11]:
"""
Enhanced AI Text Detection Pipeline
-----------------------------------
核心改进:
1. 更好的LLM选择（T5/GPT-2用于重写）
2. 多维度特征提取
3. 更智能的扰动策略
4. 集成学习方法


"""


# 尝试导入，如果失败则提示安装

import os
import json
import hashlib
import numpy as np
import pandas as pd
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
from tqdm import tqdm
import random
import pickle

# OpenAI library imports
import openai  # Good for catching error types like openai.RateLimitError
from openai import OpenAI  # ✅ **This is the missing line you need to add.**
import tiktoken

# 基础库
from sklearn.metrics import roc_auc_score, accuracy_score, classification_report, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
import torch


from __future__ import annotations
import os, time, hashlib, json, random
from typing import List, Tuple

import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score, classification_report
# Transformers
from transformers import (
        T5ForConditionalGeneration, T5Tokenizer,
        GPT2LMHeadModel, GPT2Tokenizer,
        AutoTokenizer, AutoModel
    )
from sentence_transformers import SentenceTransformer
# 文本处理
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords

# 文本扰动库
try:
    import nlpaug.augmenter.word as naw
    import nlpaug.augmenter.sentence as nas
except ImportError:
    print("nlpaug not found, installing...")
    os.system("pip install nlpaug")
    import nlpaug.augmenter.word as naw
    import nlpaug.augmenter.sentence as nas

# Transformers 和 sentence-transformers

from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, GPT2LMHeadModel, GPT2Tokenizer
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass, field
from typing import List, Optional
# NLTK数据下载
try:
    import nltk
    print("Downloading required NLTK data...")
    nltk.download('punkt', quiet=True)
    nltk.download('averaged_perceptron_tagger', quiet=True)
    nltk.download('averaged_perceptron_tagger_eng', quiet=True)
    nltk.download('wordnet', quiet=True)
    nltk.download('omw-1.4', quiet=True)
    nltk.download('stopwords', quiet=True)
    print("NLTK data download completed.")
except Exception as e:
    print(f"NLTK download warning: {e}")
    print("Some augmentation features may not work properly.")
@dataclass
class DetectionConfig:
    """检测配置类"""
    # 模型选择
    revision_model: str = "t5-small"
    embedding_model: str = "all-MiniLM-L6-v2"
    
    # 【修正】为OpenAI API添加配置字段
    api_key: Optional[str] = None
    base_url: Optional[str] = None
    
    # 扰动参数
    perturbation_rate: float = 0.15
    # 【修正】使用field(default_factory=...)来避免可变默认参数问题
    perturbation_methods: List[str] = field(default_factory=lambda: ["synonym", "contextual"])
    
    # 检测参数
    similarity_threshold: float = 0.95 # 建议使用0.85作为更稳健的默认值
    use_ml_classifier: bool = True
    
    # 批处理
    batch_size: int = 16
    max_length: int = 512


    def __post_init__(self):
        if self.perturbation_methods is None:
            self.perturbation_methods = ["synonym", "contextual", "backtranslation"]

class EnhancedTextPerturber:
    """增强的文本扰动器"""

    def __init__(self, config: DetectionConfig):
        self.config = config
        self._init_augmenters()

    def _init_augmenters(self):
        """初始化各种扰动器"""
        self.augmenters = {
            "synonym": naw.SynonymAug(aug_src='wordnet', aug_p=self.config.perturbation_rate),
            "contextual": naw.ContextualWordEmbsAug(
                model_path='distilbert-base-uncased',
                action="substitute",
                aug_p=self.config.perturbation_rate
            ),
            "random_swap": naw.RandomWordAug(action="swap", aug_p=self.config.perturbation_rate),
            "spelling": naw.SpellingAug(aug_p=self.config.perturbation_rate),
        }

    def perturb(self, text: str, method: Optional[str] = None) -> str:
        """
        智能扰动策略

        Args:
            text: 原始文本
            method: 指定方法，None则随机选择
        Returns:
            扰动后的文本
        """
        if method is None:
            method = np.random.choice(self.config.perturbation_methods)

        try:
            if method == "backtranslation":
                return self._backtranslate(text)
            elif method in self.augmenters:
                augmented = self.augmenters[method].augment(text)
                return augmented[0] if isinstance(augmented, list) else augmented
            else:
                # 混合扰动
                return self._mixed_perturbation(text)
        except Exception as e:
            print(f"Perturbation failed: {e}")
            return self._simple_perturb(text)

    def _mixed_perturbation(self, text: str) -> str:
        """混合多种扰动方法"""
        sentences = sent_tokenize(text)
        perturbed_sentences = []

        for sent in sentences:
            # 随机选择扰动方法
            method = np.random.choice(list(self.augmenters.keys()))
            try:
                perturbed = self.augmenters[method].augment(sent)
                perturbed_sent = perturbed[0] if isinstance(perturbed, list) else perturbed
                perturbed_sentences.append(perturbed_sent)
            except:
                perturbed_sentences.append(sent)

        return ' '.join(perturbed_sentences)

    def _simple_perturb(self, text: str) -> str:
        """简单扰动作为后备"""
        words = text.split()

        # 随机替换15%的词
        num_changes = max(1, int(len(words) * self.config.perturbation_rate))
        indices = np.random.choice(range(len(words)), size=min(num_changes, len(words)), replace=False)

        for idx in indices:
            # 简单的字符级修改
            word = words[idx]
            if len(word) > 3:
                words[idx] = word[:-1] + np.random.choice(list('aeiou'))

        return ' '.join(words)

    def _backtranslate(self, text: str) -> str:
        """反向翻译（需要翻译API）"""
        # 这里只是示例，实际需要调用翻译API
        return text

class EnhancedLLMReviser:
    """增强的LLM重写器"""

    def __init__(self, config: DetectionConfig):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.client = None 
        self.model = None 
        self.tokenizer = None 
        self._init_reviser()

    def _init_model(self):
        """初始化重写模型"""
        model_name = self.config.revision_model

        if self.config.revision_model.startswith("t5"):
            # T5模型更适合文本重写任务
            self.tokenizer = T5Tokenizer.from_pretrained(self.config.revision_model)
            self.model = T5ForConditionalGeneration.from_pretrained(self.config.revision_model)
        elif self.config.revision_model == "gpt2":
            # GPT-2作为备选
            self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            self.model = GPT2LMHeadModel.from_pretrained("gpt2")
            self.tokenizer.pad_token = self.tokenizer.eos_token
        elif model_name.startswith("gpt-"):
            # 【整合】初始化OpenAI API客户端
            api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("请在DetectionConfig中提供api_key或设置OPENAI_API_KEY环境变量")
            
            self.client = OpenAI(api_key=api_key, base_url=self.config.base_url)
            self.max_retries = 3
            self.retry_delay = 1.0
            print(f"OpenAI client initialized for model: {model_name}")
            if self.config.base_url:
                print(f"Using custom base URL: {self.config.base_url}")
        else:
            print(f"Warning: Unknown revision model '{model_name}'. Will use rule-based fallback.")


        self.model.to(self.device)
        self.model.eval()

    def _init_reviser(self):
        """根据配置初始化重写模型或API客户端"""
        model_name = self.config.revision_model
        
        if model_name.startswith("t5"):
            self.tokenizer = T5Tokenizer.from_pretrained(model_name)
            self.model = T5ForConditionalGeneration.from_pretrained(model_name)
            self.model.to(self.device).eval()
            print(f"Local T5 model loaded: {model_name}")
            
        elif model_name == "gpt2":
            self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
            self.model = GPT2LMHeadModel.from_pretrained(model_name)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.model.to(self.device).eval()
            print(f"Local GPT-2 model loaded: {model_name}")

        elif model_name.startswith("gpt-"):
            # 【整合】初始化OpenAI API客户端
            api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("请在DetectionConfig中提供api_key或设置OPENAI_API_KEY环境变量")
            
            self.client = OpenAI(api_key=api_key, base_url=self.config.base_url)
            self.max_retries = 3
            self.retry_delay = 1.0
            print(f"OpenAI client initialized for model: {model_name}")
            if self.config.base_url:
                print(f"Using custom base URL: {self.config.base_url}")
        else:
            print(f"Warning: Unknown revision model '{model_name}'. Will use rule-based fallback.")

    
    def revise(self, text: str, cache: Dict[str, str]) -> str:
        """
        使用LLM重写文本
        
        Args:
            text: 待重写文本
            cache: 缓存字典
        Returns:
            重写后的文本
        """
        # 生成缓存键
        cache_key = hashlib.md5(f"{text}_{self.config.revision_model}".encode()).hexdigest()
        # print(f"\norigin text: {text}")
        if cache_key in cache:
            return cache[cache_key]
        
        try:
            model_name = self.config.revision_model
            revised = ""
            if model_name.startswith("t5"):
                revised = self._revise_with_t5(text)
                # print(f"t5 rewrite: {revised}")
            elif model_name == "gpt2":
                revised = self._revise_with_gpt2(text)
                # print(f"gpt2 rewrite: {revised}")
            elif model_name.startswith("gpt-"):
                #api rewrite
                revised = self._revise_with_api(text)
                # print(f"gpt3.5 rewrite: {revised}")
            else:
                revised = self._rule_based_revision(text)
            
            cache[cache_key] = revised
            return revised
            
        except Exception as e:
            print(f"Revision failed with model {self.config.revision_model}: {e}")
            revised = self._rule_based_revision(text)
            cache[cache_key] = revised
            return revised

    def _revise_with_t5(self, text: str) -> str:
        """使用T5模型重写 - 优化版"""
        # T5需要特定的提示格式
        prompt = f"paraphrase: {text}"
    
        inputs = self.tokenizer.encode(
            prompt,
            return_tensors="pt",
            max_length=512,  # 对输入进行截断的最大长度
            truncation=True
        ).to(self.device)
    
        with torch.no_grad():
            # 使用束搜索 (Beam Search) 以获得更高质量和更稳定的输出
            outputs = self.model.generate(
                inputs,
                # --- 解码策略修改 ---
                num_beams=4,              # 修改：设置束的宽度，4或5是常用值
                early_stopping=True,      # 新增：当所有束都完成时提前停止
    
                # --- 长度控制优化 ---
                max_length=256,           # 修改：为输出设置一个固定的最大长度
                min_length=40,            # 新增：设置一个合理的最小长度，防止输出过短
    
                # --- 质量优化 ---
                repetition_penalty=1.2    # 新增：轻微惩罚重复，提升多样性
            )
            
            # 注意：使用 beam search 时，不需要 do_sample, temperature, top_p
    
        revised = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return revised

    def _revise_with_gpt2(self, text: str) -> str:
        """使用GPT-2模型重写"""
        try:
            prompt_templates = [
                f"Paraphrase the following text while keeping the same meaning: {text}\n\nParaphrased version:",
                f"Rewrite this sentence in a different way: {text}\n\nRewritten:",
                f"Express this differently: {text}\n\nAlternative expression:"
            ]
            prompt = np.random.choice(prompt_templates)
            inputs = self.tokenizer.encode(
                prompt, return_tensors="pt", max_length=512, truncation=True
            ).to(self.device)
    
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_length=min(len(inputs[0]) + 100, 1024),
                    num_return_sequences=1,
                    temperature=0.8,
                    do_sample=True,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    no_repeat_ngram_size=3,
                    early_stopping=True
                )
    
            
            # 【已修正】使用 self.tokenizer
            full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
            # ... (后面的提取和清理逻辑不变) ...
            rewritten = ""
            for marker in ["Paraphrased version:", "Rewritten:", "Alternative expression:"]:
                if marker in full_text:
                    rewritten = full_text.split(marker)[-1].strip()
                    break
            
            if not rewritten:
                rewritten = full_text[len(prompt):].strip()
    
            if '\n' in rewritten:
                rewritten = rewritten.split('\n')[0]
            if '.' in rewritten:
                rewritten = rewritten.split('.')[0] + '.'
    
            return rewritten if rewritten else text
    
        except Exception as e:
            print(f"  修复版重写失败: {e}")
            return text

    def _revise_with_api(self, text: str, temperature: float = 0.7, revision_style: str = "standard") -> str:
        """使用OpenAI API进行文本重写，包含重试和错误处理"""
        if not self.client:
            raise ValueError("OpenAI client not initialized.")
            
        system_prompt = "Please rewrite the following text while maintaining its meaning:"
        
        for attempt in range(self.max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=self.config.revision_model,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": text}
                    ],
                    temperature=temperature,
                    top_p=0.9
                )
                return response.choices[0].message.content.strip()
                
            except openai.RateLimitError:
                print(f"Rate limit exceeded, waiting {self.retry_delay}s...")
                time.sleep(self.retry_delay)
                self.retry_delay *= 2
            except openai.APIError as e:
                print(f"API error (attempt {attempt + 1}): {e}")
                if attempt < self.max_retries - 1:
                    time.sleep(self.retry_delay)
                else:
                    return text # 返回原文作为后备
            except Exception as e:
                print(f"Unexpected API error: {e}")
                return text # 返回原文作为后备
        return text # 所有重试失败后返回原文

    def _rule_based_revision(self, text: str) -> str:
        """基于规则的重写"""
        import re

        # 句子重排
        sentences = sent_tokenize(text)
        if len(sentences) > 2:
            # 随机调整句子顺序（保持逻辑性）
            sentences = self._reorder_sentences(sentences)

        # 同义词替换
        revised_sentences = []
        for sent in sentences:
            # 简单的同义词映射
            synonyms = {
                "important": "significant",
                "however": "nevertheless",
                "therefore": "thus",
                "because": "since",
                "many": "numerous",
                "show": "demonstrate"
            }

            for word, syn in synonyms.items():
                sent = re.sub(r'\b' + word + r'\b', syn, sent, flags=re.IGNORECASE)

            revised_sentences.append(sent)

        return ' '.join(revised_sentences)

    def _reorder_sentences(self, sentences: List[str]) -> List[str]:
        """智能句子重排"""
        # 保持第一句和最后一句
        if len(sentences) <= 2:
            return sentences

        first = sentences[0]
        last = sentences[-1]
        middle = sentences[1:-1]

        # 打乱中间句子
        np.random.shuffle(middle)

        return [first] + middle + [last]

class MultiFeatureExtractor:
    """多维度特征提取器"""

    def __init__(self, config: DetectionConfig):
        self.config = config
        self.sentence_model = SentenceTransformer(config.embedding_model)
        self._init_linguistic_features()

    def _init_linguistic_features(self):
        """初始化语言学特征提取"""
        try:
            nltk.download('punkt', quiet=True)
            nltk.download('averaged_perceptron_tagger', quiet=True)
            nltk.download('stopwords', quiet=True)
            self.stop_words = set(stopwords.words('english'))
        except:
            self.stop_words = set()

    def extract_features(self, original: str, revised: str) -> Dict[str, float]:
        """
        提取多维度特征

        Returns:
            特征字典
        """
        features = {}

        # 1. 语义相似度
        features['semantic_similarity'] = self._compute_semantic_similarity(original, revised)

        # 2. 词汇级特征
        features.update(self._extract_lexical_features(original, revised))

        # 3. 句法特征
        features.update(self._extract_syntactic_features(original, revised))

        # 4. 风格特征
        features.update(self._extract_stylistic_features(original, revised))

        # 5. 编辑距离特征
        features.update(self._extract_edit_features(original, revised))

        return features

    def _compute_semantic_similarity(self, text1: str, text2: str) -> float:
        """计算语义相似度"""
        emb1 = self.sentence_model.encode(text1)
        emb2 = self.sentence_model.encode(text2)

        from sklearn.metrics.pairwise import cosine_similarity
        return float(cosine_similarity([emb1], [emb2])[0, 0])

    def _extract_lexical_features(self, original: str, revised: str) -> Dict[str, float]:
        """提取词汇级特征"""
        orig_words = set(word_tokenize(original.lower()))
        rev_words = set(word_tokenize(revised.lower()))

        # 词汇重叠率
        overlap = len(orig_words & rev_words)
        total = len(orig_words | rev_words)

        features = {
            'word_overlap_ratio': overlap / total if total > 0 else 0,
            'vocabulary_change_ratio': len(orig_words ^ rev_words) / total if total > 0 else 0,
            'unique_words_ratio': len(rev_words - orig_words) / len(rev_words) if rev_words else 0
        }

        # 停用词比例变化
        orig_stop = len([w for w in orig_words if w in self.stop_words])
        rev_stop = len([w for w in rev_words if w in self.stop_words])

        features['stopword_ratio_change'] = abs(
            (orig_stop / len(orig_words) if orig_words else 0) -
            (rev_stop / len(rev_words) if rev_words else 0)
        )

        return features

    def _extract_syntactic_features(self, original: str, revised: str) -> Dict[str, float]:
        """提取句法特征"""
        orig_sents = sent_tokenize(original)
        rev_sents = sent_tokenize(revised)

        features = {
            'sentence_count_ratio': len(rev_sents) / len(orig_sents) if orig_sents else 1,
            'avg_sentence_length_change': abs(
                np.mean([len(word_tokenize(s)) for s in orig_sents]) -
                np.mean([len(word_tokenize(s)) for s in rev_sents])
            ) if orig_sents and rev_sents else 0
        }

        # POS标签分布变化
        try:
            orig_pos = nltk.pos_tag(word_tokenize(original))
            rev_pos = nltk.pos_tag(word_tokenize(revised))

            # 计算主要词性比例变化
            for pos_type in ['NN', 'VB', 'JJ', 'RB']:  # 名词、动词、形容词、副词
                orig_count = sum(1 for _, pos in orig_pos if pos.startswith(pos_type))
                rev_count = sum(1 for _, pos in rev_pos if pos.startswith(pos_type))

                features[f'pos_{pos_type}_ratio_change'] = abs(
                    (orig_count / len(orig_pos) if orig_pos else 0) -
                    (rev_count / len(rev_pos) if rev_pos else 0)
                )
        except:
            pass

        return features

    def _extract_stylistic_features(self, original: str, revised: str) -> Dict[str, float]:
        """提取风格特征"""
        features = {}

        # 标点符号使用变化
        orig_punct = sum(1 for c in original if c in '.,!?;:')
        rev_punct = sum(1 for c in revised if c in '.,!?;:')

        features['punctuation_ratio_change'] = abs(
            (orig_punct / len(original) if original else 0) -
            (rev_punct / len(revised) if revised else 0)
        )

        # 大写字母比例变化
        orig_caps = sum(1 for c in original if c.isupper())
        rev_caps = sum(1 for c in revised if c.isupper())

        features['capitalization_ratio_change'] = abs(
            (orig_caps / len(original) if original else 0) -
            (rev_caps / len(revised) if revised else 0)
        )

        # 平均词长变化
        orig_words = word_tokenize(original)
        rev_words = word_tokenize(revised)

        features['avg_word_length_change'] = abs(
            np.mean([len(w) for w in orig_words]) -
            np.mean([len(w) for w in rev_words])
        ) if orig_words and rev_words else 0

        return features

    def _extract_edit_features(self, original: str, revised: str) -> Dict[str, float]:
        """提取编辑距离相关特征"""
        from difflib import SequenceMatcher

        # 字符级相似度
        char_similarity = SequenceMatcher(None, original, revised).ratio()

        # 词级相似度
        orig_words = word_tokenize(original)
        rev_words = word_tokenize(revised)
        word_similarity = SequenceMatcher(None, orig_words, rev_words).ratio()

        features = {
            'char_level_similarity': char_similarity,
            'word_level_similarity': word_similarity,
            'length_ratio': len(revised) / len(original) if original else 1
        }

        return features

class AITextDetector:
    """主检测器类"""

    def __init__(self, config: DetectionConfig = None):
        self.config = config or DetectionConfig()
        self.perturber = EnhancedTextPerturber(self.config)
        self.reviser = EnhancedLLMReviser(self.config)
        self.feature_extractor = MultiFeatureExtractor(self.config)
        self.classifier = None
        self.scaler = StandardScaler()
        self.cache = {}

    def detect_batch(self, texts: List[str],
                     labels: Optional[List[int]] = None) -> Dict[str, np.ndarray]:
        """
        批量检测文本

        Args:
            texts: 文本列表
            labels: 真实标签（可选，用于训练）
        Returns:
            包含预测结果和特征的字典
        """
        all_features = []
        all_scores = []

        print("Processing texts...")
        for i, text in enumerate(tqdm(texts)):
            try:
                # 1. 扰动
                perturbed = self.perturber.perturb(text)

                # 2. LLM重写
                revised = self.reviser.revise(perturbed, self.cache)

                # 3. 特征提取
                features = self.feature_extractor.extract_features(text, revised)
                all_features.append(features)

                # 4. 主要相似度分数
                all_scores.append(features['semantic_similarity'])

            except Exception as e:
                print(f"Error processing text {i}: {e}")
                # 使用默认值
                default_features = {key: 0.5 for key in ['semantic_similarity', 'word_overlap_ratio',
                                                         'vocabulary_change_ratio', 'unique_words_ratio',
                                                         'stopword_ratio_change', 'sentence_count_ratio',
                                                         'avg_sentence_length_change', 'punctuation_ratio_change',
                                                         'capitalization_ratio_change', 'avg_word_length_change',
                                                         'char_level_similarity', 'word_level_similarity',
                                                         'length_ratio']}
                all_features.append(default_features)
                all_scores.append(0.5)

        # 转换为特征矩阵
        feature_matrix = pd.DataFrame(all_features).fillna(0).values

        # 如果使用ML分类器
        if self.config.use_ml_classifier:
            if labels is not None and self.classifier is None:
                # 训练分类器
                self._train_classifier(feature_matrix, labels)

            if self.classifier is not None:
                # 标准化特征
                feature_matrix_scaled = self.scaler.transform(feature_matrix)
                # 预测
                predictions = self.classifier.predict(feature_matrix_scaled)
                probabilities = self.classifier.predict_proba(feature_matrix_scaled)[:, 1]
            else:
                # 使用阈值方法
                predictions = (np.array(all_scores) > self.config.similarity_threshold).astype(int)
                probabilities = np.array(all_scores)
        else:
            # 仅使用阈值
            predictions = (np.array(all_scores) > self.config.similarity_threshold).astype(int)
            probabilities = np.array(all_scores)

        return {
            'predictions': predictions,
            'probabilities': probabilities,
            'features': feature_matrix,
            'similarity_scores': np.array(all_scores)
        }
    
    def _train_classifier(self, features: np.ndarray, labels: List[int]):
        """训练机器学习分类器"""
        print("Training classifier...")

        # 标准化特征
        features_scaled = self.scaler.fit_transform(features)

        # 使用随机森林
        self.classifier = RandomForestClassifier(
            n_estimators=100,
            max_depth=10,
            random_state=42,
            n_jobs=-1
        )

        self.classifier.fit(features_scaled, labels)

        # 特征重要性
        feature_names = list(pd.DataFrame(features).columns)
        importances = self.classifier.feature_importances_

        print("\nTop 10 Most Important Features:")
        for feat, imp in sorted(zip(feature_names, importances),
                               key=lambda x: x[1], reverse=True)[:10]:
            print(f"  {feat}: {imp:.4f}")

    def save_model(self, path: str):
        """保存模型"""
        import pickle

        model_data = {
            'classifier': self.classifier,
            'scaler': self.scaler,
            'config': self.config,
            'cache': self.cache
        }

        with open(path, 'wb') as f:
            pickle.dump(model_data, f)

        print(f"Model saved to {path}")

    def load_model(self, path: str):
        """加载模型"""
        import pickle

        with open(path, 'rb') as f:
            model_data = pickle.load(f)

        self.classifier = model_data['classifier']
        self.scaler = model_data['scaler']
        self.config = model_data['config']
        self.cache = model_data['cache']

        print(f"Model loaded from {path}")
        
        
# ------------------------------------------------------------
# 数据集加载函数（从骨架代码整合）
# ------------------------------------------------------------
def load_data(path: str) -> pd.DataFrame:
    """
    读取带有 'text' 与 'label' 列的数据集，并返回 DataFrame。
    支持金融数据集格式和通用格式。

    Args:
        path: 数据文件路径、'finance' 表示使用金融数据集、'sample' 表示创建示例数据
    Returns:
        df: columns = ['id', 'text', 'label']
    """
    if path == 'finance':
        return load_finance_dataset()
    elif path == 'sample':
        return create_sample_dataset()

    try:
        if path.endswith('.csv'):
            df = pd.read_csv(path)
        elif path.endswith('.jsonl'):
            df = pd.read_json(path, lines=True)
        elif path.endswith('.json'):
            df = pd.read_json(path)
        else:
            raise ValueError(f"Unsupported file format: {path}")

        # 确保有必要的列
        if 'text' not in df.columns:
            raise ValueError("Missing 'text' column in dataset")
        if 'label' not in df.columns:
            raise ValueError("Missing 'label' column in dataset")

        # 如果没有id列，自动生成
        if 'id' not in df.columns:
            df['id'] = range(len(df))

        # 数据清洗
        df = df.dropna(subset=['text', 'label'])
        df = df[df['text'].str.len() > 10]  # 过滤太短的文本

        print(f"Loaded {len(df)} samples from {path}")
        print(f"Label distribution: {df['label'].value_counts().to_dict()}")

        return df[['id', 'text', 'label']]

    except FileNotFoundError:
        print(f"File {path} not found. Creating sample dataset...")
        return create_sample_dataset()


def find_finance_data_files():
    """查找金融数据集文件"""
    print("Searching for finance data files...")

    # Kaggle环境路径
    possible_paths = [
        "/kaggle/input/finance-dataset",
        "/kaggle/input/finance-data",
        "/kaggle/input/human-chatgpt-finance",
        "/kaggle/input",
        "./data",  # 本地路径
        "."
    ]

    files_found = {}

    for path in possible_paths:
        if os.path.exists(path):
            print(f"Checking path: {path}")
            for root, dirs, files in os.walk(path):
                for file in files:
                    if 'revised_human_finance' in file or 'revised_human' in file:
                        files_found['revised_human'] = os.path.join(root, file)
                    elif 'revised_chatgpt_finance' in file or 'revised_chatgpt' in file:
                        files_found['revised_chatgpt'] = os.path.join(root, file)
                    elif 'finance.jsonl' in file:
                        files_found['original'] = os.path.join(root, file)

    return files_found
def load_finance_dataset() -> pd.DataFrame:
    """
    加载金融数据集 - 使用修订后的文本作为数据
    修复了多行JSON解析问题
    
    Returns:
        df: columns = ['id', 'text', 'label']
        其中text是修订后的文本，label: 0=human, 1=chatgpt
    """
    files = find_finance_data_files()
    
    print("Loading finance datasets...")
    print(f"Found files: {list(files.keys())}")
    
    all_texts = []
    all_labels = []
    
    # 1. 加载修订后的人类文本（每行一个JSON对象）
    if 'revised_human' in files:
        try:
            print(f"Loading revised human texts...")
            with open(files['revised_human'], 'r', encoding='utf-8') as f:
                for line_no, line in enumerate(f, 1):
                    line = line.strip()
                    if line:
                        try:
                            # 每行是一个独立的JSON对象
                            item = json.loads(line)
                            for idx, text in item.items():
                                # 清理文本
                                cleaned_text = text.strip().replace('\\n\\n', ' ').replace('\\n', ' ')
                                if len(cleaned_text) > 20:
                                    all_texts.append(cleaned_text)
                                    all_labels.append(0)  # 0 = human
                        except json.JSONDecodeError as e:
                            print(f"  Line {line_no} parse error: {e}")
            print(f"  Loaded {len([l for l in all_labels if l == 0])} human texts")
        except Exception as e:
            print(f"Error loading revised human texts: {e}")
    
    # 2. 加载修订后的ChatGPT文本（每行一个JSON对象）
    if 'revised_chatgpt' in files:
        try:
            print(f"Loading revised ChatGPT texts...")
            with open(files['revised_chatgpt'], 'r', encoding='utf-8') as f:
                for line_no, line in enumerate(f, 1):
                    line = line.strip()
                    if line:
                        try:
                            # 每行是一个独立的JSON对象
                            item = json.loads(line)
                            for idx, text in item.items():
                                # 清理文本
                                cleaned_text = text.strip().replace('\\n\\n', ' ').replace('\\n', ' ')
                                if len(cleaned_text) > 20:
                                    all_texts.append(cleaned_text)
                                    all_labels.append(1)  # 1 = chatgpt
                        except json.JSONDecodeError as e:
                            print(f"  Line {line_no} parse error: {e}")
            print(f"  Loaded {len([l for l in all_labels if l == 1])} ChatGPT texts")
        except Exception as e:
            print(f"Error loading revised ChatGPT texts: {e}")
    
    # 3. 如果没有数据，尝试原始文件或创建示例
    if len(all_texts) == 0:
        if 'original' in files:
            print("\nNo revised texts found. Loading from original finance.jsonl...")
            try:
                with open(files['original'], encoding="utf-8") as f:
                    for row in f:
                        if row.strip():
                            data = json.loads(row)
                            # 人类答案
                            human_answers = data.get("human_answers", [])
                            if human_answers and human_answers[0]:
                                all_texts.append(human_answers[0])
                                all_labels.append(0)
                            # ChatGPT答案
                            chatgpt_answers = data.get("chatgpt_answers", [])
                            if chatgpt_answers and chatgpt_answers[0]:
                                all_texts.append(chatgpt_answers[0])
                                all_labels.append(1)
                print(f"  Loaded {len(all_texts)} texts from original file")
            except Exception as e:
                print(f"Error loading original file: {e}")
        
        if len(all_texts) == 0:
            print("Creating sample dataset...")
            return create_sample_dataset()
    
    # 4. 创建DataFrame
    df = pd.DataFrame({
        'id': range(len(all_texts)),
        'text': all_texts,
        'label': all_labels
    })
    
    # 打乱数据
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    df['id'] = range(len(df))
    
    print(f"\n=== Dataset Summary ===")
    print(f"Total samples: {len(df)}")
    print(f"Human texts: {len(df[df['label'] == 0])}")
    print(f"ChatGPT texts: {len(df[df['label'] == 1])}")
    
    # 显示示例
    if len(df) > 0:
        print("\nSample texts:")
        for label in [0, 1]:
            label_name = "Human" if label == 0 else "ChatGPT"
            samples = df[df['label'] == label]
            if len(samples) > 0:
                sample_text = samples.iloc[0]['text']
                print(f"\n[{label_name}]: {sample_text[:100]}...")
    
    return df[['id', 'text', 'label']]

def create_sample_dataset(n_samples: int = 100) -> pd.DataFrame:
    """创建示例数据集用于测试"""
    import random

    human_texts = [
        "I think artificial intelligence will revolutionize many industries in the coming decades.",
        "The weather today is quite pleasant, perfect for a walk in the park.",
        "Machine learning algorithms have shown remarkable progress in recent years.",
        "I love reading books, especially mystery novels and science fiction.",
        "Climate change is one of the most pressing issues facing humanity today.",
        "Just finished my morning coffee, feeling ready to tackle the day!",
        "My cat knocked over my plant again. Why do they always do that?",
        "Been trying to learn guitar but my fingers hurt so much. Is this normal?",
        "The new restaurant downtown is pretty good but a bit pricey for what you get.",
        "I can't believe how fast this year has gone by. Feels like January was yesterday."
    ] * (n_samples // 20)

    ai_texts = [
        "Artificial intelligence represents a paradigm shift in technological advancement, offering unprecedented opportunities for innovation across multiple sectors.",
        "The meteorological conditions today present optimal parameters for outdoor recreational activities.",
        "Contemporary machine learning methodologies demonstrate significant improvements in performance metrics.",
        "Literary consumption, particularly within the mystery and science fiction genres, provides considerable intellectual stimulation.",
        "Climate change constitutes a critical global challenge requiring immediate and comprehensive action.",
        "The implementation of blockchain technology has demonstrated substantial potential for revolutionizing financial systems.",
        "Quantum computing represents a fundamental breakthrough in computational capabilities with far-reaching implications.",
        "The integration of renewable energy sources is essential for achieving sustainable development goals.",
        "Advancements in biotechnology continue to push the boundaries of medical treatment possibilities.",
        "The digital transformation of industries necessitates comprehensive adaptation strategies for organizational success."
    ] * (n_samples // 20)

    texts = human_texts[:n_samples//2] + ai_texts[:n_samples//2]
    labels = [0] * (n_samples//2) + [1] * (n_samples//2)

    # 随机打乱
    combined = list(zip(texts, labels))
    random.shuffle(combined)
    texts, labels = zip(*combined)

    return pd.DataFrame({
        'id': range(len(texts)),
        'text': texts,
        'label': labels
    })





def evaluate_detector(detector: AITextDetector,
                     texts: List[str],
                     labels: List[int]) -> Dict[str, float]:
    """
    评估检测器性能
    """
    results = detector.detect_batch(texts, labels)

    predictions = results['predictions']
    probabilities = results['probabilities']

    # 计算指标
    metrics = {
        'accuracy': accuracy_score(labels, predictions),
        'auc': roc_auc_score(labels, probabilities),
    }

    print("\n=== Evaluation Results ===")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"AUC: {metrics['auc']:.4f}")

    print("\nClassification Report:")
    print(classification_report(labels, predictions,
                              target_names=['Human', 'AI']))

    # 分析相似度分布
    human_scores = results['similarity_scores'][np.array(labels) == 0]
    ai_scores = results['similarity_scores'][np.array(labels) == 1]

    print(f"\nSimilarity Score Distribution:")
    print(f"Human texts - Mean: {np.mean(human_scores):.3f}, Std: {np.std(human_scores):.3f}")
    print(f"AI texts - Mean: {np.mean(ai_scores):.3f}, Std: {np.std(ai_scores):.3f}")

    return metrics

# 主执行函数
def main(data_path: str = "finance",
         max_samples: Optional[int] = None,
         use_cache: bool = True,
         config: Optional[DetectionConfig] = None):
    """
    主执行函数 - 使用增强版检测器

    Args:
        data_path: 数据文件路径或 'finance'/'sample'
        max_samples: 限制处理的样本数量（用于快速测试）
        use_cache: 是否使用缓存
        config: 检测器配置
    """
    print("=== Enhanced AI Text Detection ===")
    print(f"Method: Multi-feature extraction with {config.revision_model if config else 'default model'}")
    
    # 1. 加载数据
    df = load_data(data_path)
    if max_samples:
        df = df.head(max_samples)
        print(f"Using first {max_samples} samples for testing")

    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    texts = df['text'].tolist()
    labels = df['label'].tolist()

    # 2. 数据分割（70% 训练，30% 测试）
    split_idx = int(len(texts) * 0.4)
    train_texts, test_texts = texts[:split_idx], texts[split_idx:]
    train_labels, test_labels = labels[:split_idx], labels[split_idx:]

    print(f"\nDataset split: {len(train_texts)} training, {len(test_texts)} testing")

    # 3. 创建检测器
    if config is None:
        config = DetectionConfig(
            revision_model="t5-small",  # 或 "gpt2" t5-small
            embedding_model="all-MiniLM-L6-v2",
            perturbation_rate=0.15,
            use_ml_classifier=True
        )

    detector = AITextDetector(config)

    # 4. 加载缓存
    cache_path = f"enhanced_cache_{config.revision_model}.json"
    if use_cache and os.path.exists(cache_path):
        with open(cache_path, 'r') as f:
            detector.cache = json.load(f)
        print(f"Loaded cache with {len(detector.cache)} entries")

    # 5. 训练阶段
    print("\n=== Training Phase ===")
    print("Train labels distribution:",
      pd.Series(train_labels).value_counts().to_dict())
    train_results = detector.detect_batch(train_texts, train_labels)

    # 6. 测试阶段
    print("\n=== Testing Phase ===")
    test_results = detector.detect_batch(test_texts)

    # 7. 评估结果
    test_predictions = test_results['predictions']
    test_probabilities = test_results['probabilities']

    from sklearn.metrics import confusion_matrix

    # 计算指标
    accuracy = accuracy_score(test_labels, test_predictions)
    auc = roc_auc_score(test_labels, test_probabilities)

    print(f"\n=== Test Results ===")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"AUC: {auc:.4f}")

    # 详细报告
    print("\nClassification Report:")
    print(classification_report(test_labels, test_predictions,
                              target_names=['Human', 'AI']))

    # 混淆矩阵
    cm = confusion_matrix(test_labels, test_predictions)
    print("\nConfusion Matrix:")
    print(f"              Predicted")
    print(f"            Human    AI")
    print(f"Actual Human  {cm[0,0]:4d}  {cm[0,1]:4d}")
    print(f"       AI     {cm[1,0]:4d}  {cm[1,1]:4d}")

    # 特征分析
    features_df = test_results['features']
    similarity_scores = test_results['similarity_scores']

    # 分组统计
    human_scores = similarity_scores[np.array(test_labels) == 0]
    ai_scores = similarity_scores[np.array(test_labels) == 1]

    print(f"\n=== Similarity Score Analysis ===")
    print(f"Human texts:")
    print(f"  Mean: {np.mean(human_scores):.3f}, Std: {np.std(human_scores):.3f}")
    print(f"  Min: {np.min(human_scores):.3f}, Max: {np.max(human_scores):.3f}")

    print(f"AI texts:")
    print(f"  Mean: {np.mean(ai_scores):.3f}, Std: {np.std(ai_scores):.3f}")
    print(f"  Min: {np.min(ai_scores):.3f}, Max: {np.max(ai_scores):.3f}")

    # 显示错误分类的例子
    print("\n=== Misclassified Examples ===")
    misclassified_idx = np.where(test_predictions != test_labels)[0]

    for i, idx in enumerate(misclassified_idx[:50]):  # 显示前5个
        true_label = "AI" if test_labels[idx] == 1 else "Human"
        pred_label = "AI" if test_predictions[idx] == 1 else "Human"
        text = test_texts[idx][:100] + "..." if len(test_texts[idx]) > 100 else test_texts[idx]

        print(f"\nExample {i+1}:")
        print(f"Text: {text}")
        print(f"True: {true_label}, Predicted: {pred_label}")
        print(f"Similarity: {similarity_scores[idx]:.3f}, Probability: {test_probabilities[idx]:.3f}")

    # 8. 保存缓存和模型
    if use_cache:
        with open(cache_path, 'w') as f:
            json.dump(detector.cache, f)
        print(f"\nCache saved to {cache_path}")

    # 保存模型
    model_path = f"enhanced_detector_{config.revision_model}.pkl"
    detector.save_model(model_path)

    return {
        'detector': detector,
        'test_results': test_results,
        'accuracy': accuracy,
        'auc': auc
    }

# 实验对比函数
def run_comparison_experiment(data_path: str = "finance", max_samples: Optional[int] = None):
    """
    运行对比实验，比较不同配置的效果
    """
    print("=== Comparison Experiment ===")

    results = {}

    # 配置1: T5-small
    print("\n1. Testing with T5-small...")
    config1 = DetectionConfig(
        revision_model="t5-small",
        embedding_model="all-MiniLM-L6-v2",
        perturbation_rate=0.15,
        use_ml_classifier=True
    )
    results['t5'] = main(data_path, max_samples, use_cache=True, config=config1)

    # 配置2: GPT-2
    print("\n2. Testing with GPT-2...")
    config2 = DetectionConfig(
        revision_model="gpt2",
        embedding_model="all-MiniLM-L6-v2",
        perturbation_rate=0.15,
        use_ml_classifier=True
    )
    results['gpt2'] = main(data_path, max_samples, use_cache=True, config=config2)
    
    # 配置3: GPT-3.5-Turbo (API模型)
    print("\n3. Testing with GPT-3.5-Turbo API...")
    api_key_for_proxy = "sk-RwjJjBq7VVTFvhv9352929B353Bb41D68d2f0959EcEe3b6f"
    base_url_for_proxy = "https://api.bltcy.ai/v1"
    config3 = DetectionConfig(
        revision_model="gpt-3.5-turbo",
        api_key=api_key_for_proxy,
        base_url=base_url_for_proxy
    )
    results['gpt3.5'] = main(data_path, max_samples, use_cache=True, config=config3)

    # 配置4: 不同扰动率
    print("\n3. Testing with higher perturbation rate...")
    config4 = DetectionConfig(
        revision_model="t5-small",
        embedding_model="all-MiniLM-L6-v2",
        perturbation_rate=0.25,
        perturbation_methods=["synonym", "contextual", "random_swap"],
        use_ml_classifier=True
    )
    results['high_perturb'] = main(data_path, max_samples, use_cache=True, config=config4)

    # 比较结果
    print("\n=== Comparison Results ===")
    print(f"{'Configuration':<20} {'Accuracy':<10} {'AUC':<10}")
    print("-" * 40)
    print(f"{'T5-small':<20} {results['t5']['accuracy']:<10.4f} {results['t5']['auc']:<10.4f}")
    print(f"{'GPT-2':<20} {results['gpt2']['accuracy']:<10.4f} {results['gpt2']['auc']:<10.4f}")
    print(f"{'High Perturbation':<20} {results['high_perturb']['accuracy']:<10.4f} {results['high_perturb']['auc']:<10.4f}")

    return results


Downloading required NLTK data...
NLTK data download completed.


In [12]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import json
import os
import time
import hashlib
from typing import Dict, List, Optional
from dataclasses import dataclass, asdict
import pickle
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI

@dataclass
class RevisedTextEntry:
    """Data structure for storing original and revised text pairs"""
    text_id: int
    original_text: str
    revised_text: str
    text_type: str  # 'human' or 'ai'
    revision_model: str
    timestamp: str
    
class TextRevisionPreprocessor:
    """Preprocess and cache all text revisions to avoid repeated API calls"""
    
    def __init__(self, config, output_dir: str = "./preprocessed_data"):
        self.config = config
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Initialize OpenAI client
        self.client = OpenAI(
            api_key=config.api_key,
            base_url=config.base_url
        )
        
        # Paths for different files
        self.human_original_path = os.path.join(output_dir, "human_original_texts.json")
        self.human_revised_path = os.path.join(output_dir, "human_revised_texts.json")
        self.ai_original_path = os.path.join(output_dir, "ai_original_texts.json")
        self.ai_revised_path = os.path.join(output_dir, "ai_revised_texts.json")
        self.progress_path = os.path.join(output_dir, "processing_progress.json")
        self.combined_path = os.path.join(output_dir, "combined_dataset.pkl")
        
    def load_progress(self) -> Dict:
        """Load processing progress for resumption"""
        if os.path.exists(self.progress_path):
            with open(self.progress_path, 'r') as f:
                return json.load(f)
        return {
            'human_processed': [],
            'ai_processed': [],
            'last_update': None
        }
    
    def save_progress(self, progress: Dict):
        """Save processing progress"""
        progress['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
        with open(self.progress_path, 'w') as f:
            json.dump(progress, f, indent=2)
    
    def revise_text_with_retry(self, text: str, max_retries: int = 3) -> str:
        """Revise text with retry logic"""
        system_prompt = """You are a professional editor. Your task is to rewrite the given text to make it more formal, standardized, and academic in style. 

Key guidelines:
- Replace casual expressions with formal language
- Use precise technical terms where appropriate
- Maintain professional tone throughout
- Standardize sentence structure
- Remove colloquialisms and personal expressions
- Keep the core meaning but express it in a more refined way

The more casual or informal the original text, the more changes you should make."""
        
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=self.config.revision_model,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": f"Please rewrite this text in a formal, academic style:\n\n{text}"}
                    ],
                    temperature=0.3,
                    top_p=0.9
                )
                return response.choices[0].message.content.strip()
                
            except Exception as e:
                print(f"API error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)  # Exponential backoff
                else:
                    print(f"Failed to revise text after {max_retries} attempts")
                    return text  # Return original text as fallback
    
    def process_dataset(self, df: pd.DataFrame, batch_size: int = 10, use_multithread: bool = False):
        """Process entire dataset and save revised texts"""
        print(f"=== Processing Dataset ===")
        print(f"Total samples: {len(df)}")
        
        # Load existing progress
        progress = self.load_progress()
        
        # Separate human and AI texts
        human_df = df[df['label'] == 0].copy()
        ai_df = df[df['label'] == 1].copy()
        
        print(f"Human texts: {len(human_df)}")
        print(f"AI texts: {len(ai_df)}")
        
        # Process human texts
        print("\n--- Processing Human Texts ---")
        human_data = self._process_text_group(
            human_df, 
            'human',
            progress['human_processed'],
            batch_size,
            use_multithread
        )
        
        # Process AI texts
        print("\n--- Processing AI Texts ---")
        ai_data = self._process_text_group(
            ai_df,
            'ai', 
            progress['ai_processed'],
            batch_size,
            use_multithread
        )
        
        # Save all data
        self._save_all_data(human_data, ai_data)
        
        print("\n=== Processing Complete ===")
        print(f"Data saved to: {self.output_dir}")
        
        return human_data, ai_data
    
    def _process_text_group(self, df: pd.DataFrame, text_type: str, 
                           processed_ids: List[int], batch_size: int,
                           use_multithread: bool) -> Dict:
        """Process a group of texts (human or AI)"""
        data = {
            'ids': [],
            'original_texts': [],
            'revised_texts': []
        }
        
        # Filter out already processed texts
        remaining_df = df[~df['id'].isin(processed_ids)]
        print(f"Already processed: {len(processed_ids)}")
        print(f"Remaining to process: {len(remaining_df)}")
        
        if len(remaining_df) == 0:
            # Load existing data
            if text_type == 'human':
                data = self._load_existing_data(self.human_original_path, self.human_revised_path)
            else:
                data = self._load_existing_data(self.ai_original_path, self.ai_revised_path)
            return data
        
        # Process in batches
        for i in tqdm(range(0, len(remaining_df), batch_size), desc=f"Processing {text_type} texts"):
            batch_df = remaining_df.iloc[i:i+batch_size]
            
            if use_multithread and batch_size > 1:
                # Parallel processing
                revised_texts = self._process_batch_parallel(batch_df['text'].tolist())
            else:
                # Sequential processing
                revised_texts = []
                for text in batch_df['text']:
                    revised = self.revise_text_with_retry(text)
                    revised_texts.append(revised)
                    time.sleep(0.5)  # Rate limiting
            
            # Add to data
            data['ids'].extend(batch_df['id'].tolist())
            data['original_texts'].extend(batch_df['text'].tolist())
            data['revised_texts'].extend(revised_texts)
            
            # Update progress
            progress = self.load_progress()
            progress[f'{text_type}_processed'].extend(batch_df['id'].tolist())
            self.save_progress(progress)
            
            # Periodic save
            if (i + batch_size) % 100 == 0:
                self._save_intermediate_data(data, text_type)
        
        # Final save
        self._save_intermediate_data(data, text_type)
        
        return data
    
    def _process_batch_parallel(self, texts: List[str], max_workers: int = 5) -> List[str]:
        """Process texts in parallel"""
        revised_texts = [None] * len(texts)
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_idx = {
                executor.submit(self.revise_text_with_retry, text): idx 
                for idx, text in enumerate(texts)
            }
            
            for future in as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    revised_texts[idx] = future.result()
                except Exception as e:
                    print(f"Error processing text {idx}: {e}")
                    revised_texts[idx] = texts[idx]  # Use original as fallback
        
        return revised_texts
    
    def _save_intermediate_data(self, data: Dict, text_type: str):
        """Save intermediate results"""
        if text_type == 'human':
            original_path = self.human_original_path
            revised_path = self.human_revised_path
        else:
            original_path = self.ai_original_path
            revised_path = self.ai_revised_path
        
        # Save as JSON for readability
        with open(original_path, 'w', encoding='utf-8') as f:
            json.dump({
                'ids': data['ids'],
                'texts': data['original_texts']
            }, f, ensure_ascii=False, indent=2)
        
        with open(revised_path, 'w', encoding='utf-8') as f:
            json.dump({
                'ids': data['ids'],
                'texts': data['revised_texts']
            }, f, ensure_ascii=False, indent=2)
    
    def _load_existing_data(self, original_path: str, revised_path: str) -> Dict:
        """Load existing preprocessed data"""
        data = {'ids': [], 'original_texts': [], 'revised_texts': []}
        
        if os.path.exists(original_path) and os.path.exists(revised_path):
            with open(original_path, 'r', encoding='utf-8') as f:
                original_data = json.load(f)
            with open(revised_path, 'r', encoding='utf-8') as f:
                revised_data = json.load(f)
            
            data['ids'] = original_data['ids']
            data['original_texts'] = original_data['texts']
            data['revised_texts'] = revised_data['texts']
        
        return data
    
    def _save_all_data(self, human_data: Dict, ai_data: Dict):
        """Save all data in various formats for easy access"""
        # Create combined dataset
        combined_data = {
            'human': {
                'ids': human_data['ids'],
                'original': human_data['original_texts'],
                'revised': human_data['revised_texts']
            },
            'ai': {
                'ids': ai_data['ids'],
                'original': ai_data['original_texts'],
                'revised': ai_data['revised_texts']
            },
            'metadata': {
                'revision_model': self.config.revision_model,
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
                'total_human': len(human_data['ids']),
                'total_ai': len(ai_data['ids'])
            }
        }
        
        # Save as pickle for fast loading
        with open(self.combined_path, 'wb') as f:
            pickle.dump(combined_data, f)
        
        # Save summary
        summary_path = os.path.join(self.output_dir, "preprocessing_summary.txt")
        with open(summary_path, 'w') as f:
            f.write("=== Preprocessing Summary ===\n\n")
            f.write(f"Revision Model: {self.config.revision_model}\n")
            f.write(f"Timestamp: {combined_data['metadata']['timestamp']}\n")
            f.write(f"Total Human Texts: {combined_data['metadata']['total_human']}\n")
            f.write(f"Total AI Texts: {combined_data['metadata']['total_ai']}\n")
            f.write(f"\nFiles created:\n")
            f.write(f"- {self.human_original_path}\n")
            f.write(f"- {self.human_revised_path}\n")
            f.write(f"- {self.ai_original_path}\n")
            f.write(f"- {self.ai_revised_path}\n")
            f.write(f"- {self.combined_path}\n")
    
    def load_preprocessed_data(self) -> Dict:
        """Load preprocessed data for training"""
        if not os.path.exists(self.combined_path):
            raise FileNotFoundError(f"Preprocessed data not found at {self.combined_path}")
        
        with open(self.combined_path, 'rb') as f:
            data = pickle.load(f)
        
        print(f"Loaded preprocessed data:")
        print(f"- Human texts: {len(data['human']['ids'])}")
        print(f"- AI texts: {len(data['ai']['ids'])}")
        print(f"- Revision model: {data['metadata']['revision_model']}")
        
        return data

def create_training_data_from_preprocessed(preprocessed_data: Dict, num_samples: int = None) -> List:
    """Create training samples from preprocessed data"""
    from sentence_transformers import InputExample
    
    samples = []
    
    # Process human texts
    human_data = preprocessed_data['human']
    n_human = len(human_data['ids']) if num_samples is None else min(num_samples // 2, len(human_data['ids']))
    
    for i in range(n_human):
        samples.append(InputExample(
            texts=[human_data['original'][i], human_data['revised'][i]],
            label=0.2  # Low similarity expected
        ))
    
    # Process AI texts
    ai_data = preprocessed_data['ai']
    n_ai = len(ai_data['ids']) if num_samples is None else min(num_samples // 2, len(ai_data['ids']))
    
    for i in range(n_ai):
        samples.append(InputExample(
            texts=[ai_data['original'][i], ai_data['revised'][i]],
            label=0.85  # High similarity expected
        ))
    
    print(f"Created {len(samples)} training samples")
    return samples



In [13]:
import sys

def main():
    """
    Main execution script that can be run in three modes:
    - 'human': Process only human-written texts.
    - 'ai': Process only AI-generated texts.
    - 'combine': Combine the results after 'human' and 'ai' modes are complete.
    """
    
    # --- Configuration ---
    # NOTE: It's recommended to load API keys from environment variables for security.
    config = DetectionConfig(
        revision_model="gpt-3.5-turbo",
        api_key="sk-rnuw0ejSvret9oLc6GHTgvHjWHyGgnyrLUvrta3o8cN1nZpw",
        base_url="https://api.chatanywhere.com.cn/v1"
    )
    preprocessor = TextRevisionPreprocessor(config)
    
    # --- Mode Selection from Command-Line Argument---
    if len(sys.argv) < 2 or sys.argv[1] not in ['human', 'ai', 'combine']:
        print("Usage: python <script_name>.py [human|ai|combine]")
        print("\n  'human'  : Preprocess only human texts.")
        print("  'ai'     : Preprocess only AI texts.")
        print("  'combine': Combine the outputs after both 'human' and 'ai' runs are finished.")
        sys.exit(1)
        
    mode = sys.argv[1]

    # --- Logic for 'human' or 'ai' processing modes ---
    if mode in ['human', 'ai']:
        print(f"--- Running in '{mode}' mode ---")
        
        # Load the full dataset
        print("Loading full dataset...")
        df = load_data("finance")
        
        # Determine which part of the dataset to process
        process_df = df[df['label'] == (0 if mode == 'human' else 1)].copy()
        text_type = mode

        # Load progress to allow for resumption
        progress = preprocessor.load_progress()
        processed_ids = progress.get(f'{text_type}_processed', [])

        # Process the selected text group
        preprocessor._process_text_group(
            df=process_df,
            text_type=text_type,
            processed_ids=processed_ids,
            batch_size=10,
            use_multithread=True
        )
        print(f"\n--- '{mode}' mode finished successfully. ---")
        print(f"Output files saved to directory: '{preprocessor.output_dir}'")

    # --- Logic for 'combine' mode ---
    elif mode == 'combine':
        print("--- Running in 'combine' mode ---")
        
        try:
            print("Loading preprocessed human data...")
            human_data = preprocessor._load_existing_data(
                preprocessor.human_original_path, 
                preprocessor.human_revised_path
            )
            
            print("Loading preprocessed AI data...")
            ai_data = preprocessor._load_existing_data(
                preprocessor.ai_original_path, 
                preprocessor.ai_revised_path
            )
            
            if not human_data['ids'] or not ai_data['ids']:
                raise FileNotFoundError("One or both preprocessed data files are missing or empty.")
                
            print("\nCombining and saving final dataset...")
            preprocessor._save_all_data(human_data, ai_data)
            
            print("\n--- 'combine' mode finished successfully. ---")
            print(f"Final combined dataset saved to: {preprocessor.combined_path}")

        except FileNotFoundError as e:
            print(f"\nError: {e}")
            print("Please ensure you have run both 'human' and 'ai' modes successfully before running 'combine'.")
            sys.exit(1)



In [16]:
import sys

sys.argv = ['ai_generated.ipynb','combine']

# To process AI texts
# sys.argv = ['your_script_name.py', 'ai']

# To COMBINE the results after the other two are done
#sys.argv = ['your_script_name.py', 'combine']

In [17]:
if __name__ == "__main__":
    main()

--- Running in 'combine' mode ---
Loading preprocessed human data...
Loading preprocessed AI data...

Combining and saving final dataset...

--- 'combine' mode finished successfully. ---
Final combined dataset saved to: ./preprocessed_data\combined_dataset.pkl
