In [None]:
# %% [markdown]
# # AI Detection Research: Stress Testing Suite
# 
# Este notebook implementa el roadmap de validación científica para elevar el nivel del detector de "Kaggle solution" a "Research Paper".
# 
# **Hardware:** 4x NVIDIA RTX 3090  
# **Objetivos:**
# 1. **Robustez Adversarial:** ¿Aguanta el modelo ataques de paráfrasis?
# 2. **Generalización (OOD):** ¿Detecta modelos nuevos (Phi-3/Llama-3) no vistos en el training?
# 3. **Explicabilidad (XAI):** Análisis SHAP para entender qué features lingüísticas pesan más.

# %% [code]
# 1. Imports y Configuración
import os
import re
import torch
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pathlib import Path
import joblib
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline, DataCollatorWithPadding, Trainer
from transformers.cache_utils import DynamicCache

# Compat: algunas versiones de transformers no exponen `seen_tokens` en DynamicCache, requerido por Phi-3
if not hasattr(DynamicCache, "seen_tokens"):
    DynamicCache.seen_tokens = property(lambda self: self.get_seq_length())
if not hasattr(DynamicCache, "get_max_length"):
    DynamicCache.get_max_length = lambda self: getattr(self, "max_cache_len", None)

# Configuración de Hardware
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {DEVICE}")
if torch.cuda.device_count() > 1:
    print(f"GPUs available: {torch.cuda.device_count()} (Will use DataParallel or device_map for heavy models)")

# %% [markdown]
# ## 0. Carga de tu Modelo Entrenado
# **IMPORTANTE:** Aquí debes instanciar tu pipeline de predicción tal como lo definiste en el notebook anterior.
# Necesitamos una función `predict_pipeline(texts: list) -> np.array` que devuelva la probabilidad de ser IA.

# %% [code]
# --- ZONA DE INTEGRACIÓN ---
# Cargamos el ensemble real entrenado (DeBERTa + LightGBM + meta-learner)
print("✅ Cargando ensemble DeBERTa + LightGBM + meta-learner (stack)")

import lightgbm as lgb
import textstat
from nltk.sentiment import SentimentIntensityAnalyzer

# Paths y constantes (coinciden con el notebook de entrenamiento)
WORK_DIR = Path("src/ai_vs_human")
DEBERTA_DIR = WORK_DIR / "models" / "deberta_v3_base"
LGB_DIR = WORK_DIR / "models" / "lightgbm_numeric"
META_PATH = WORK_DIR / "models" / "stack_meta" / "meta_learner.joblib"
MAX_LENGTH = 256
N_SPLITS = 5

# Columnas exactas usadas por LightGBM en el entrenamiento
NUM_COLS = [
    'word_count', 'character_count', 'sentence_count', 'lexical_diversity', 'avg_sentence_length', 'avg_word_length',
    'punctuation_ratio', 'flesch_reading_ease', 'gunning_fog_index', 'grammar_errors', 'passive_voice_ratio',
    'predictability_score', 'burstiness', 'sentiment_score', 'length', 'potential_noise',
    'lang_ar', 'lang_ca', 'lang_de', 'lang_en', 'lang_es', 'lang_hr', 'lang_nl', 'lang_other', 'lang_pl', 'lang_pt',
    'lang_ro', 'lang_ru', 'lang_unknown',
    'model_Mistral-7B-Instruct-v0.2', 'model_aya-101', 'model_gemini', 'model_gpt-3.5-turbo-0125', 'model_human',
    'model_opt-iml-max-30b', 'model_unknown_model', 'model_v5-Eagle-7B-HF', 'model_vicuna-13b',
    'ds_ai_human', 'ds_multisocial'
]
LANG_COLS = [c for c in NUM_COLS if c.startswith("lang_")]
MODEL_COLS = [c for c in NUM_COLS if c.startswith("model_")]
DS_COLS = [c for c in NUM_COLS if c.startswith("ds_")]

# Utilidades de features (idénticas al dataset builder)
try:
    sia = SentimentIntensityAnalyzer()
except Exception:
    import nltk
    nltk.download("vader_lexicon")
    sia = SentimentIntensityAnalyzer()
punct_re = re.compile(r'[^\w\s]')
passive_re = re.compile(r'\b(?:is|are|was|were|be|been|being)\s+\w+ed\b', re.IGNORECASE)

def safe_readability(fn, text: str):
    try:
        return float(fn(text))
    except Exception:
        return np.nan

def extract_features(text: str) -> pd.Series:
    if not isinstance(text, str):
        text = ''
    clean = text.replace('\n', ' ').strip()

    sentences = [s.strip() for s in re.split(r'[.!?]+', clean) if s.strip()]
    sentence_count = max(len(sentences), 1)

    words = re.findall(r'\b\w+\b', clean.lower())
    word_count = len(words)
    unique_words = len(set(words))
    lexical_diversity = unique_words / word_count if word_count else 0.0
    avg_sentence_length = word_count / sentence_count if sentence_count else 0.0
    avg_word_length = np.mean([len(w) for w in words]) if words else 0.0
    character_count = len(clean)
    punctuation_ratio = len(punct_re.findall(clean)) / character_count if character_count else 0.0

    fre = safe_readability(textstat.flesch_reading_ease, clean) if word_count else 0.0
    fog = safe_readability(textstat.gunning_fog, clean) if word_count else 0.0

    sentence_lengths = [len(s.split()) for s in sentences] or [0]
    burstiness = np.std(sentence_lengths) / (np.mean(sentence_lengths) + 1e-6)

    sentiment_score = sia.polarity_scores(clean)['compound'] if word_count else 0.0
    passive_voice_ratio = len(passive_re.findall(clean)) / sentence_count

    predictability_score = 1 - lexical_diversity if word_count else 0.0
    grammar_errors = len(re.findall(r"\b(?:ain't|could of|should of|would of)\b", clean.lower()))

    return pd.Series({
        'word_count': word_count,
        'character_count': character_count,
        'sentence_count': sentence_count,
        'lexical_diversity': lexical_diversity,
        'avg_sentence_length': avg_sentence_length,
        'avg_word_length': avg_word_length,
        'punctuation_ratio': punctuation_ratio,
        'flesch_reading_ease': fre,
        'gunning_fog_index': fog,
        'grammar_errors': grammar_errors,
        'passive_voice_ratio': passive_voice_ratio,
        'predictability_score': predictability_score,
        'burstiness': burstiness,
        'sentiment_score': sentiment_score,
    })

def build_feature_frame(texts):
    rows = []
    for t in texts:
        feats = extract_features(t)
        feats['length'] = len(t) if isinstance(t, str) else 0
        feats['potential_noise'] = 0
        for col in LANG_COLS:
            feats[col] = 1.0 if col == 'lang_unknown' else 0.0
        for col in MODEL_COLS:
            feats[col] = 1.0 if col == 'model_unknown_model' else 0.0
        for col in DS_COLS:
            feats[col] = 0.0
        feats['ds_ai_human'] = 1.0  # marcamos dataset desconocido como closest a ai_human
        rows.append(feats)
    df = pd.DataFrame(rows)
    # Asegurar orden/ausencia de columnas
    for col in NUM_COLS:
        if col not in df:
            df[col] = 0.0
    return df[NUM_COLS]

# Dataset helper para DeBERTa
class HFTextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer, max_length=256):
        self.texts = list(texts)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        return self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding=False,
        )

def load_deberta_paths():
    paths = []
    for f in range(N_SPLITS):
        cand = DEBERTA_DIR / f"fold_{f}" / "best"
        if cand.exists():
            paths.append(cand)
    if not paths:
        raise RuntimeError("No se encontraron checkpoints de DeBERTa en models/deberta_v3_base")
    return paths

def predict_deberta(texts):
    fold_paths = load_deberta_paths()
    tokenizer = AutoTokenizer.from_pretrained(fold_paths[0], use_fast=False)
    collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)
    dataset = HFTextDataset(texts, tokenizer, max_length=MAX_LENGTH)
    fold_preds = []
    for path in fold_paths:
        model = AutoModelForSequenceClassification.from_pretrained(path).to(DEVICE)
        trainer = Trainer(model=model, tokenizer=tokenizer, data_collator=collator)
        logits = trainer.predict(dataset).predictions
        probs = torch.softmax(torch.tensor(logits), dim=1).cpu().numpy()[:, 1]
        fold_preds.append(probs)
        torch.cuda.empty_cache()
    return np.mean(fold_preds, axis=0)

def load_lgb_models():
    models = []
    for f in range(N_SPLITS):
        cand = LGB_DIR / f"fold_{f}" / "best.txt"
        if cand.exists():
            models.append(lgb.Booster(model_file=str(cand)))
    if not models:
        raise RuntimeError("No se encontraron modelos LightGBM en models/lightgbm_numeric")
    return models

LGB_MODELS = load_lgb_models()
META_MODEL = joblib.load(META_PATH)

class MyEnsembleModel:
    def __init__(self):
        self.base_order = getattr(META_MODEL, "base_model_order", ["deberta", "lgb", "sgd"])
        self.lgb_models = LGB_MODELS
        self.meta = META_MODEL

    def predict_proba(self, texts):
        texts = list(texts)
        deberta_probs = predict_deberta(texts)
        lgb_probs = self.predict_lgb(texts)
        # Los SGD originales guardados no incluyen vectorizador -> usamos placeholder neutral
        sgd_probs = np.full(len(texts), 0.5)
        base_preds = {"deberta": deberta_probs, "lgb": lgb_probs, "sgd": sgd_probs}
        stack = np.column_stack([base_preds[name] for name in self.base_order])
        return self.meta.predict_proba(stack)[:, 1]

    def predict_lgb(self, texts):
        feats = build_feature_frame(texts)
        fold_preds = [m.predict(feats) for m in self.lgb_models]
        return np.mean(fold_preds, axis=0)

    def get_numeric_features(self, texts):
        return build_feature_frame(texts)

# Instancia tu modelo real aquí
my_model = MyEnsembleModel()
# ---------------------------

# %% [markdown]
# ## 1. Robustez Adversarial (Paraphrasing Attack)
# Generamos versiones parafraseadas de textos que sabemos que son IA para ver si logran evadir la detección.

# %% [code]
print("--- Iniciando Ataque Adversarial ---")

# 1. Cargar Modelo de Paráfrasis (T5)
# Es ligero y efectivo para cambiar estructura sin cambiar significado
para_model_name = "Vamsi/T5_Paraphrase_Paws"
para_tokenizer = AutoTokenizer.from_pretrained(para_model_name)
para_model = AutoModelForSeq2SeqLM.from_pretrained(para_model_name).to(DEVICE)

def paraphrase_text(text, beams=5, grams=1):
    text = "paraphrase: " + text + " </s>"
    encoding = para_tokenizer(text, padding="longest", return_tensors="pt").to(DEVICE)
    outputs = para_model.generate(
        input_ids=encoding['input_ids'], 
        attention_mask=encoding['attention_mask'],
        max_length=256, 
        do_sample=True, 
        top_k=120, 
        top_p=0.95, 
        early_stopping=True, 
        num_return_sequences=1
    )
    return para_tokenizer.decode(outputs[0], skip_special_tokens=True)

# 2. Dataset de Prueba (Usar un subset de tus datos etiquetados como AI)
# Simulación: Creamos textos dummy
ai_texts_original = [
    "Artificial Intelligence has revolutionized the way we process data.",
    "Deep learning models are becoming increasingly complex and accurate.",
    "The rapid expansion of neural networks poses ethical questions."
] * 10 # 30 ejemplos

# 3. Generar Ataques
print(f"Generando paráfrasis para {len(ai_texts_original)} textos...")
ai_texts_attacked = []
for t in tqdm(ai_texts_original):
    try:
        p = paraphrase_text(t)
        ai_texts_attacked.append(p)
    except:
        ai_texts_attacked.append(t) # Fallback

# 4. Evaluar Robustez
preds_original = my_model.predict_proba(ai_texts_original)
preds_attacked = my_model.predict_proba(ai_texts_attacked)

# Métricas
drop_in_confidence = np.mean(preds_original) - np.mean(preds_attacked)
success_attack_rate = np.mean(np.array(preds_attacked) < 0.5) * 100 # % que bajó de 0.5

print(f"\nResultados Adversariales:")
print(f"Confianza Media Original: {np.mean(preds_original):.4f}")
print(f"Confianza Media Atacada:  {np.mean(preds_attacked):.4f}")
print(f"Caída de Confianza:       {drop_in_confidence:.4f}")
print(f"Attack Success Rate (Flip to Human): {success_attack_rate:.2f}%")

# %% [markdown]
# ## 2. Generalización Out-of-Distribution (New GenAI Models)
# Generamos texto con un modelo moderno (Phi-3 o Llama-3) que no estaba en el training set (normalmente GPT-3/4).
# Aprovechamos tus 4x 3090 para esto.

# %% [code]
print("\n--- Iniciando Test OOD (Phi-3/Llama-3) ---")

# Usamos Phi-3 Mini por eficiencia, o cambia a "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id_gen = "meta-llama/Meta-Llama-3-8B-Instruct" 
model_id_gen = "microsoft/Phi-3-mini-4k-instruct"

generator = pipeline(
    "text-generation",
    model=model_id_gen,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto", # Usará tus múltiples GPUs automáticamente
    trust_remote_code=True
)
# Desactiva la caché de generación para evitar incompatibilidades con versiones antiguas de transformers
generator.model.config.use_cache = False
generator.model.generation_config.use_cache = False

prompts = [
    "Explain the theory of relativity to a 5 year old.",
    "Write a polite email declining a job offer.",
    "Describe a sunset on Mars.",
    "Write a Python function to reverse a linked list.",
    "Argue for and against universal basic income."
] * 5 # 25 textos nuevos

print(f"Generando textos con {model_id_gen}...")
ood_texts = []
for p in tqdm(prompts):
    messages = [{"role": "user", "content": p}]
    output = generator(messages, max_new_tokens=200, return_full_text=False, use_cache=False)
    ood_texts.append(output[0]['generated_text'])

# Evaluar Detección
ood_preds = my_model.predict_proba(ood_texts)
detection_rate_ood = np.mean(np.array(ood_preds) > 0.5) * 100

print(f"\nResultados Generalización ({model_id_gen}):")
print(f"Detection Rate (Recall): {detection_rate_ood:.2f}%")
if detection_rate_ood > 80:
    print("✅ CONCLUSION: El modelo generaliza bien a arquitecturas nuevas.")
else:
    print("⚠️ CONCLUSION: El modelo sufre con arquitecturas nuevas (Research Opportunity).")

# %% [markdown]
# ## 3. Explicabilidad (SHAP Analysis)
# Usamos SHAP sobre la parte tabular (LightGBM) de tu ensemble para entender qué características lingüísticas delatan a la IA.

# %% [code]
print("\n--- Iniciando Análisis de Explicabilidad (SHAP) ---")

# 1. Preparamos datos para explicar
# Usamos los textos OOD generados arriba como muestra de análisis
X_explain = my_model.get_numeric_features(ood_texts)

# 2. Explainer sobre el primer LightGBM del ensemble
explainer = shap.TreeExplainer(my_model.lgb_models[0])
shap_values = explainer.shap_values(X_explain)

# 3. Visualización
# Esta gráfica es ORO para el paper.
plt.figure(figsize=(10, 6))
# shap_values[1] suele ser la clase positiva (AI)
if isinstance(shap_values, list):
    vals = shap_values[1]
else:
    vals = shap_values

shap.summary_plot(vals, X_explain, show=False)
plt.title("Feature Importance (SHAP): ¿Qué delata a la IA?")
plt.tight_layout()
plt.savefig("shap_summary.png", dpi=300)
print("Gráfica guardada como 'shap_summary.png'")
plt.show()

# %% [markdown]
# ## 4. Informe Final para el Paper
# Generamos las filas LaTeX para tu tabla de resultados.

# %% [code]
print("\n=== LaTeX Table Snippet ===")
print(f"Method & Adversarial Acc & OOD Detection (Phi-3) \\\\")
print(f"\\hline")
print(f"Ours (Ensemble) & {100 - success_attack_rate:.2f}\\% & {detection_rate_ood:.2f}\\% \\\\")
