
# WhaRAGBot — RAG local con tus chats de WhatsApp

Cuaderno paso a paso para: (1) leer los ZIP exportados de WhatsApp, (2) generar pares contexto→respuesta con tus mensajes, (3) indexar con embeddings locales y FAISS, y (4) chatear usando OpenAI API con contexto recuperado por RAG.

> ⚠️ Este repo está pensado para ser público. **No subas tus chats ni datos procesados**: quedan ignorados en `.gitignore`.



## 0. Entorno (solo kernel)
Crea o reutiliza el entorno virtual y registra el kernel. Las dependencias del proyecto se instalan en el paso 1.


In [None]:

# Ejecuta una vez; crea .venv con Python 3.12 y registra el kernel.
# Las dependencias del proyecto se instalan en el paso 1 (Dependencias).
import os, subprocess, pathlib, sys
from subprocess import CalledProcessError
ENV_DIR = pathlib.Path("..").resolve() / ".venv"
PY_BIN = os.getenv("PYTHON_BIN", "python3.12")
KERNEL_NAME = "wha-ragbot"
DISPLAY_NAME = "wha-ragbot (py312)"


def create_with_venv():
    subprocess.check_call([PY_BIN, "-m", "venv", str(ENV_DIR)])


def create_with_virtualenv():
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "virtualenv"])
    subprocess.check_call([sys.executable, "-m", "virtualenv", "-p", PY_BIN, str(ENV_DIR)])

if ENV_DIR.exists():
    print(f"Ya existe {ENV_DIR}")
else:
    try:
        create_with_venv()
        print(f"Creado entorno con {PY_BIN} en {ENV_DIR} (venv)")
    except CalledProcessError as e:
        print(f"No se pudo crear con venv ({e}); probando con virtualenv...")
        create_with_virtualenv()
        print(f"Creado entorno con {PY_BIN} en {ENV_DIR} (virtualenv)")

# comprobar versión
try:
    out = subprocess.check_output([str(ENV_DIR/"bin/python"), "-c", "import sys;print(sys.version.split()[0])"], text=True).strip()
    if not out.startswith("3.12"):
        print(f"⚠️ El venv usa Python {out}; borra .venv y reejecuta esta celda para crearlo con 3.12")
    else:
        print(f"Usando venv con Python {out}")
except Exception as e:
    print(f"⚠️ No se pudo comprobar la versión del venv: {e}")

# instalar solo ipykernel para registrar el kernel
subprocess.check_call([str(ENV_DIR/"bin/python"), "-m", "pip", "install", "-q", "ipykernel"])
subprocess.check_call([str(ENV_DIR/"bin/python"), "-m", "ipykernel", "install", "--user", "--name", KERNEL_NAME, "--display-name", DISPLAY_NAME])
print(f"Kernel registrado como '{DISPLAY_NAME}'. Ahora selecciónalo en VS Code/Jupyter y reinicia el kernel.")
print("Luego pasa al paso 1 para instalar requirements.")



## 1. Dependencias (ejecuta una vez)
Si ya instalaste los requisitos, puedes saltar esta celda.


In [None]:
%pip install -q -r ../requirements.txt


In [None]:

import os
os.environ['TRANSFORMERS_NO_TF'] = '1'
os.environ['TRANSFORMERS_NO_FLAX'] = '1'



## 2. Configuración básica
Ajusta rutas y parámetros principales. `CHATS_ZIP_DIR` debe apuntar a la carpeta que contiene los `.zip` exportados por WhatsApp (opción "sin archivos multimedia").


In [None]:
from pathlib import Path
import os, zipfile, re, json
import pandas as pd
import numpy as np
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from openai import OpenAI
from dotenv import load_dotenv

PROJECT_ROOT = Path("..").resolve()
load_dotenv(PROJECT_ROOT / ".env")

CHATS_ZIP_DIR = Path(os.getenv("CHATS_ZIP_DIR", str(PROJECT_ROOT / "Chats en .zip")))
DATA_DIR = PROJECT_ROOT / "data"
INDEX_DIR = PROJECT_ROOT / "index"
MY_NAME = os.getenv("MY_NAME", "Tu Nombre")
CTX_WINDOW = 4
EMBED_MODEL = os.getenv("EMBED_MODEL", "intfloat/multilingual-e5-small")
GEN_MODEL = os.getenv("GEN_MODEL", "gpt-4.1-mini")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")

DATA_DIR.mkdir(parents=True, exist_ok=True)
INDEX_DIR.mkdir(parents=True, exist_ok=True)



## 3. Funciones de parseo
Soporta exportes típicos de Android/iPhone en español. Filtra mensajes de sistema y mantiene multilinea.


In [None]:

START_PATTERNS = [
    re.compile(r"^(?P<date>\d{1,2}/\d{1,2}/\d{2,4}), (?P<time>\d{1,2}:\d{2}) - (?P<sender>[^:]+): (?P<text>.*)$"),
    re.compile(r"^\[(?P<date>\d{1,2}/\d{1,2}/\d{2,4}), (?P<time>\d{1,2}:\d{2}(?:\s?[ap]\.?m\.?))\] (?P<sender>[^:]+): (?P<text>.*)$"),
]

SYSTEM_SNIPPETS = [
    "cifrado de extremo a extremo",
    "cambió el asunto",
    "cambió la foto",
    "cambió la descripción",
    "creó este grupo",
    "te añadieron",
    "mensaje eliminado",
    "multimedia omitido",
]


def match_start(line: str):
    for pat in START_PATTERNS:
        m = pat.match(line)
        if m:
            return m
    return None


def parse_chat_text(text: str, chat_name: str):
    rows = []
    current = None
    for raw_line in text.splitlines():
        line = raw_line.strip("﻿")  # quitar BOM si existe
        m = match_start(line)
        if m:
            if current:
                rows.append(current)
            current = {
                "date": m.group("date"),
                "time": m.group("time"),
                "sender": m.group("sender").strip(),
                "text": m.group("text").strip(),
                "chat_name": chat_name,
            }
        else:
            if current:
                current["text"] += " " + line.strip()
    if current:
        rows.append(current)
    if not rows:
        return pd.DataFrame()
    df = pd.DataFrame(rows)
    df["timestamp"] = pd.to_datetime(df["date"] + " " + df["time"], dayfirst=True, errors="coerce")
    df = df.dropna(subset=["timestamp"])
    # filtrar mensajes de sistema
    mask = df["text"].str.lower().apply(lambda t: not any(snippet in t.lower() for snippet in SYSTEM_SNIPPETS))
    df = df[mask]
    df["text"] = df["text"].str.replace(' ', ' ', regex=False)
    return df[["timestamp", "sender", "text", "chat_name"]]


def parse_zip_chat(zip_path: Path):
    with zipfile.ZipFile(zip_path) as zf:
        txt_files = [n for n in zf.namelist() if n.lower().endswith(".txt")]
        if not txt_files:
            print(f"⚠️ No se encontró .txt en {zip_path}")
            return pd.DataFrame()
        raw = zf.read(txt_files[0])
        for encoding in ("utf-8", "latin-1"):
            try:
                text = raw.decode(encoding)
                break
            except UnicodeDecodeError:
                continue
        else:
            raise UnicodeDecodeError("No se pudo decodificar el chat")
        return parse_chat_text(text, chat_name=zip_path.stem)


def parse_csv_chat(csv_path: Path, my_name: str):
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        print(f"⚠️ No se pudo leer CSV {csv_path.name}: {e}")
        return pd.DataFrame(columns=["timestamp", "sender", "text", "chat_name"])

    required = {"question", "wenceslao_answer"}
    if not required.issubset(set(df.columns)):
        print(f"⚠️ CSV {csv_path.name} no tiene columnas requeridas {required}")
        return pd.DataFrame(columns=["timestamp", "sender", "text", "chat_name"])

    contact_col = "contact" if "contact" in df.columns else None
    qdt_col = "question_dt" if "question_dt" in df.columns else None
    adt_col = "answer_dt" if "answer_dt" in df.columns else None

    rows = []
    for i, r in df.iterrows():
        contact = str(r.get(contact_col, "Contacto")) if contact_col else "Contacto"
        chat_name = f"Chat CSV con {contact}" if contact else f"Chat CSV {csv_path.stem}"

        qtxt = str(r.get("question", "") or "").strip()
        atxt = str(r.get("wenceslao_answer", "") or "").strip()
        qts = pd.to_datetime(r.get(qdt_col), dayfirst=True, errors="coerce") if qdt_col else pd.NaT
        ats = pd.to_datetime(r.get(adt_col), dayfirst=True, errors="coerce") if adt_col else pd.NaT

        if qtxt:
            rows.append({
                "timestamp": qts,
                "sender": contact,
                "text": qtxt,
                "chat_name": chat_name,
            })
        if atxt:
            rows.append({
                "timestamp": ats,
                "sender": my_name,
                "text": atxt,
                "chat_name": chat_name,
            })

    if not rows:
        return pd.DataFrame(columns=["timestamp", "sender", "text", "chat_name"])

    out = pd.DataFrame(rows)
    out = out.dropna(subset=["timestamp"]) 
    out["text"] = out["text"].astype(str).str.replace(' ', ' ', regex=False)
    out = out.sort_values(["chat_name", "timestamp"]).reset_index(drop=True)
    return out[["timestamp", "sender", "text", "chat_name"]]


def parse_input_file(path: Path, my_name: str):
    low = path.suffix.lower()
    if low == ".zip":
        return parse_zip_chat(path)
    if low == ".csv":
        return parse_csv_chat(path, my_name=my_name)
    return pd.DataFrame(columns=["timestamp", "sender", "text", "chat_name"])




## 4. Cargar todos los ZIP y guardar mensajes procesados


In [None]:
import hashlib
import time

INGEST_CACHE_DIR = DATA_DIR / "ingest_cache"
INGEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
INGEST_MANIFEST_PATH = INGEST_CACHE_DIR / "manifest.json"


def _load_json(path: Path, default):
    if not path.exists():
        return default
    try:
        return json.loads(path.read_text())
    except Exception:
        return default


def _save_json(path: Path, payload):
    path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))


def _file_sig(path: Path):
    st = path.stat()
    return {"size": int(st.st_size), "mtime_ns": int(st.st_mtime_ns)}


def _sha256_file(path: Path, chunk_size: int = 1024 * 1024):
    h = hashlib.sha256()
    with path.open("rb") as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            h.update(chunk)
    return h.hexdigest()


zip_files = sorted(CHATS_ZIP_DIR.glob("*.zip"))
csv_files = sorted(CHATS_ZIP_DIR.glob("*.csv"))
all_files = zip_files + csv_files
print(f"Encontrados {len(zip_files)} ZIPs y {len(csv_files)} CSVs en {CHATS_ZIP_DIR}")
print(f"[Paso 4] Inicio: {pd.Timestamp.now().isoformat()} | archivos a revisar: {len(all_files)}")

ingest_t0 = time.perf_counter()
manifest = _load_json(INGEST_MANIFEST_PATH, {"files": {}, "version": 1})
manifest_files = manifest.get("files", {})
PARSER_CACHE_VERSION = "v2_my_name_tag"
parse_tag = f"{PARSER_CACHE_VERSION}|my_name={str(MY_NAME).strip()}"
frames = []

cache_fast_hits = 0
cache_hash_hits = 0
reparsed = 0
failed = 0
rows_total = 0

for idx, fp in enumerate(tqdm(all_files, desc="Paso 4 - ingesta", unit="archivo"), start=1):
    key = str(fp.resolve())
    sig = _file_sig(fp)
    meta = manifest_files.get(key, {})

    cache_file = meta.get("cache_file", "")
    cache_path = INGEST_CACHE_DIR / cache_file if cache_file else None

    fast_hit = (
        cache_path is not None
        and cache_path.exists()
        and meta.get("size") == sig["size"]
        and meta.get("mtime_ns") == sig["mtime_ns"]
        and meta.get("parse_tag") == parse_tag
    )

    df = None
    file_sha = None

    if fast_hit:
        try:
            df = pd.read_parquet(cache_path)
            cache_fast_hits += 1
        except Exception:
            df = None

    if df is None:
        file_sha = _sha256_file(fp)
        hash_hit = (
            cache_path is not None
            and cache_path.exists()
            and meta.get("sha256") == file_sha
            and meta.get("parse_tag") == parse_tag
        )

        if hash_hit:
            try:
                df = pd.read_parquet(cache_path)
                cache_hash_hits += 1
            except Exception:
                df = None

        if df is None:
            try:
                df = parse_input_file(fp, my_name=MY_NAME)
                reparsed += 1
            except Exception as e:
                failed += 1
                print(f"[Paso 4][WARN] fallo parseando {fp.name}: {e}")
                continue

            cache_file = f"{file_sha}.parquet"
            cache_path = INGEST_CACHE_DIR / cache_file
            df.to_parquet(cache_path, index=False)

        manifest_files[key] = {
            "sha256": file_sha,
            "cache_file": cache_file,
            "rows": int(len(df)),
            "size": sig["size"],
            "mtime_ns": sig["mtime_ns"],
            "parse_tag": parse_tag,
        }

    rows_total += int(len(df))
    frames.append(df)

    if idx % 10 == 0 or idx == len(all_files):
        print(
            f"[Paso 4] progreso {idx}/{len(all_files)} | "
            f"fast={cache_fast_hits} hash={cache_hash_hits} reparsed={reparsed} failed={failed} "
            f"rows_acum={rows_total}"
        )

# Limpia entradas de archivos borrados.
valid_keys = {str(p.resolve()) for p in all_files}
for old_key in list(manifest_files.keys()):
    if old_key not in valid_keys:
        manifest_files.pop(old_key, None)

manifest["files"] = manifest_files
manifest["updated_at"] = pd.Timestamp.now("UTC").isoformat()
_save_json(INGEST_MANIFEST_PATH, manifest)

messages = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(columns=["timestamp", "sender", "text", "chat_name"])
messages = messages.dropna(subset=["timestamp"]).sort_values(["chat_name", "timestamp"]).reset_index(drop=True)

elapsed = time.perf_counter() - ingest_t0
print(
    f"[Paso 4] Fin en {elapsed:.1f}s | total_filas={len(messages)} | "
    f"fast={cache_fast_hits} hash={cache_hash_hits} reparsed={reparsed} failed={failed}"
)
print(messages.head())

try:
    messages.to_parquet(DATA_DIR / "messages.parquet", index=False)
    print(f"Guardado messages.parquet con {len(messages)} filas en {DATA_DIR}")
except Exception as e:
    print(f"⚠️ No se pudo escribir parquet ({e}) -> guardando CSV")
    messages.to_csv(DATA_DIR / "messages.csv", index=False)
    print(f"Guardado messages.csv con {len(messages)} filas en {DATA_DIR}")




## 4.1 Limpieza y etiquetas (recomendado)
Normaliza texto, marca preguntas y bajo contenido. Esto mejora la recuperacion y reduce respuestas raras.


In [None]:
import re
import unicodedata

ACK_WORDS = {"ok", "oki", "okey", "vale", "si", "sii", "sip", "aja", "mmm", "mm", "mhm", "jaja", "jajaja", "jeje", "jaj", "xd", "xD"}
QUESTION_PREFIXES = ("que ", "como ", "por que ", "por qué ", "cuando ", "cuanto ", "donde ", "quien ", "quien ", "cual ", "cuales ")
SELF_FACT_HINTS = ["me llamo", "mi nombre", "me llaman", "soy de", "vivo en", "trabajo", "curro", "estudio", "tengo "]
MY_NAME_PLACEHOLDERS = {"", "tu nombre", "your name", "mi nombre"}


def _normalize_text(t: str) -> str:
    t = str(t or "")
    t = unicodedata.normalize("NFKC", t)
    t = t.replace("\u200b", "").replace("\ufeff", "")
    t = re.sub(r"\s+", " ", t).strip()
    return t


def _normalize_sender_name(name: str) -> str:
    return unicodedata.normalize("NFKC", str(name or "")).strip()


def _sender_key(name: str) -> str:
    return _normalize_sender_name(name).casefold()


def _has_url(t: str) -> bool:
    return bool(re.search(r"(https?://|www\.)", t or "", flags=re.I))


def _is_question(t: str) -> bool:
    t = (t or "").strip().lower()
    if "?" in t:
        return True
    return any(t.startswith(p) for p in QUESTION_PREFIXES)


def _is_low_signal(t: str) -> bool:
    t = (t or "").strip().lower()
    if not t:
        return True
    if t in ACK_WORDS:
        return True
    if len(t) <= 2:
        return True
    if re.fullmatch(r"[\W_]+", t):
        return True
    if sum(ch.isalnum() for ch in t) == 0:
        return True
    return False


def _signal_score(t: str, is_question: bool = False, has_url: bool = False) -> float:
    t = (t or "").strip()
    tokens = len(t.split())
    score = min(1.0, tokens / 12)
    if len(t) >= 80:
        score += 0.10
    if is_question:
        score += 0.05
    if has_url:
        score -= 0.10
    return max(0.0, min(1.0, score))


def _is_self_fact(t: str) -> bool:
    n = (t or "").lower()
    return any(h in n for h in SELF_FACT_HINTS)


def _resolve_my_name(df: pd.DataFrame, configured_name: str):
    cfg_raw = _normalize_sender_name(configured_name)
    cfg_key = _sender_key(cfg_raw)

    senders = df["sender"].astype(str).map(_normalize_sender_name)
    sender_keys = senders.map(_sender_key)

    if cfg_key and cfg_key not in MY_NAME_PLACEHOLDERS and (sender_keys == cfg_key).any():
        chosen = senders[sender_keys == cfg_key].value_counts().idxmax()
        return str(chosen), "config_match"

    tmp = pd.DataFrame({
        "sender": senders,
        "sender_key": sender_keys,
        "chat_name": df["chat_name"].astype(str),
    })
    tmp = tmp[tmp["sender_key"] != ""]
    if tmp.empty:
        fallback = cfg_raw or "Tu Nombre"
        return fallback, "fallback_default"

    coverage = tmp.groupby("sender_key")["chat_name"].nunique().sort_values(ascending=False)
    top_key = str(coverage.index[0])
    top_coverage = int(coverage.iloc[0])
    total_chats = max(1, int(tmp["chat_name"].nunique()))

    if total_chats >= 2 and top_coverage >= max(2, int(total_chats * 0.5)):
        chosen = tmp.loc[tmp["sender_key"] == top_key, "sender"].value_counts().idxmax()
        return str(chosen), "inferred_chat_coverage"

    if cfg_raw and cfg_key not in MY_NAME_PLACEHOLDERS:
        return cfg_raw, "config_no_match"

    chosen = tmp["sender"].value_counts().idxmax()
    return str(chosen), "inferred_top_count"


messages = messages.copy()
messages["sender"] = messages["sender"].astype(str).map(_normalize_sender_name)
messages["sender_key"] = messages["sender"].map(_sender_key)

configured_name = str(MY_NAME or "").strip()
MY_NAME, my_name_source = _resolve_my_name(messages, configured_name)
my_name_key = _sender_key(MY_NAME)
messages["is_me"] = messages["sender_key"] == my_name_key if my_name_key else False

print(f"MY_NAME configurado={repr(configured_name)} | efectivo={repr(MY_NAME)} | source={my_name_source}")
print(f"Mensajes propios detectados: {int(messages['is_me'].sum())} / {len(messages)}")

messages["text_raw"] = messages["text"].astype(str)
messages["text"] = messages["text_raw"].apply(_normalize_text)
messages = messages[messages["text"].str.len() > 0]
messages["has_url"] = messages["text"].apply(_has_url)
messages["is_question"] = messages["text"].apply(_is_question)
messages["is_low_signal"] = messages["text"].apply(_is_low_signal)
messages["signal_score"] = messages.apply(lambda r: _signal_score(r["text"], r["is_question"], r["has_url"]), axis=1)
messages["token_len"] = messages["text"].str.split().str.len()
messages["self_fact"] = messages["is_me"] & messages["text"].apply(_is_self_fact)

before = len(messages)
messages = messages.drop_duplicates(subset=["chat_name", "timestamp", "sender", "text"])
print(f"Limpieza lista. Filas: {before} -> {len(messages)}")
print(messages[["sender", "text", "is_question", "is_low_signal", "signal_score"]].head(3))

try:
    messages.to_parquet(DATA_DIR / "messages_clean.parquet", index=False)
    print(f"Guardado messages_clean.parquet con {len(messages)} filas en {DATA_DIR}")
except Exception as e:
    print(f"⚠️ No se pudo escribir parquet ({e}) -> guardando CSV")
    messages.to_csv(DATA_DIR / "messages_clean.csv", index=False)
    print(f"Guardado messages_clean.csv con {len(messages)} filas en {DATA_DIR}")



## 5. Construir memoria por eventos (no solo Q/A)
Se crean tres tipos de unidades para recuperar mejor evidencia en WhatsApp real:
- `qa_turn`: contexto previo -> respuesta tuya.
- `my_message`: mensajes tuyos individuales (con micro-contexto).
- `topic_block`: bloques de 3 a 8 mensajes dentro de un mismo tema (segmentado por pausas de tiempo).


In [None]:
from datetime import timedelta
import hashlib
import time

QA_MAX_GAP_MIN = 30
QA_REQUIRE_OTHER = True
QA_REQUIRE_QUESTION = False
BLOCK_MIN_SIGNAL = 0.15
MEMORY_BUILD_VERSION = "v10_dual_rag_units_is_me"
MEMORY_BUILD_MANIFEST = DATA_DIR / "memory_build_manifest.json"
MEMORY_UNITS_PATH = DATA_DIR / "memory_units.parquet"

# Datasets derivados para RAG dual
DUAL_UNITS_MANIFEST = DATA_DIR / "dual_units_manifest.json"
RESPONSE_UNITS_PATH = DATA_DIR / "response_units.parquet"
STYLE_UNITS_PATH = DATA_DIR / "style_units.parquet"

# Control de coste del paso 5
BLOCK_STRIDE_SMALL = 2
BLOCK_STRIDE_MEDIUM = 6
BLOCK_STRIDE_LARGE = 12
MAX_BLOCKS_PER_TOPIC = 120
MAX_BLOCKS_TOTAL = 250000
LOG_EVERY_CHATS = 10

# Filtros de calidad
MIN_RESPONSE_CHARS = 2
MIN_STYLE_CHARS = 2
MAX_STYLE_WORDS = 60


def _quick_messages_fingerprint(df: pd.DataFrame):
    cols = ["timestamp", "sender", "text", "chat_name"]
    base = df[cols].copy()
    base["timestamp"] = pd.to_datetime(base["timestamp"], errors="coerce")

    n = len(base)
    if n == 0:
        return "empty"

    ts_min = str(base["timestamp"].min())
    ts_max = str(base["timestamp"].max())

    # Muestra estable y acotada: evita ordenar todo el dataset.
    step = max(1, n // 5000)
    sample = base.iloc[::step].head(5000).fillna("")
    sample_hash = pd.util.hash_pandas_object(sample, index=False)

    h = hashlib.sha256()
    h.update(f"{n}|{ts_min}|{ts_max}|{step}".encode("utf-8"))
    h.update(sample_hash.values.tobytes())
    return h.hexdigest()


def _ingest_manifest_hash():
    p = DATA_DIR / "ingest_cache" / "manifest.json"
    if not p.exists():
        return "missing"
    try:
        return hashlib.sha256(p.read_bytes()).hexdigest()
    except Exception:
        return "unreadable"


def _build_key(messages_df: pd.DataFrame):
    payload = {
        "version": MEMORY_BUILD_VERSION,
        "ingest_manifest_hash": _ingest_manifest_hash(),
        "messages_fp": _quick_messages_fingerprint(messages_df),
        "qa_max_gap_min": QA_MAX_GAP_MIN,
        "qa_require_other": QA_REQUIRE_OTHER,
        "qa_require_question": QA_REQUIRE_QUESTION,
        "ctx_window": CTX_WINDOW,
        "block_min_signal": BLOCK_MIN_SIGNAL,
        "my_name": MY_NAME,
        "block_stride_small": BLOCK_STRIDE_SMALL,
        "block_stride_medium": BLOCK_STRIDE_MEDIUM,
        "block_stride_large": BLOCK_STRIDE_LARGE,
        "max_blocks_per_topic": MAX_BLOCKS_PER_TOPIC,
        "max_blocks_total": MAX_BLOCKS_TOTAL,
        "min_response_chars": MIN_RESPONSE_CHARS,
        "min_style_chars": MIN_STYLE_CHARS,
        "max_style_words": MAX_STYLE_WORDS,
    }
    raw = json.dumps(payload, sort_keys=True, ensure_ascii=False)
    return hashlib.sha256(raw.encode("utf-8")).hexdigest(), payload


def _load_json(path: Path, default):
    if not path.exists():
        return default
    try:
        return json.loads(path.read_text())
    except Exception:
        return default


def _save_json(path: Path, payload):
    path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))


def _progress(chat_idx: int, total_chats: int):
    return chat_idx == 1 or chat_idx % LOG_EVERY_CHATS == 0 or chat_idx == total_chats


def _style_signature(text: str):
    t = str(text or "").strip()
    if not t:
        return "empty"
    words = t.split()
    n_words = len(words)
    q = int(t.endswith("?"))
    ex = int(t.endswith("!"))
    caps_ratio = sum(1 for ch in t if ch.isupper()) / max(1, sum(1 for ch in t if ch.isalpha()))

    if n_words <= 2:
        length_bucket = "tiny"
    elif n_words <= 6:
        length_bucket = "short"
    elif n_words <= 14:
        length_bucket = "mid"
    else:
        length_bucket = "long"

    if caps_ratio > 0.35:
        tone_bucket = "caps"
    else:
        tone_bucket = "normal"

    return f"{length_bucket}|q{q}|e{ex}|{tone_bucket}"


def build_qa_examples(df: pd.DataFrame, ctx_window: int = 4,
                      max_gap_minutes: int = 30, require_other: bool = True, require_question: bool = False):
    df = df.sort_values("timestamp")
    samples = []
    total_chats = int(df["chat_name"].nunique())
    t0 = time.perf_counter()

    for c_idx, (chat, group) in enumerate(df.groupby("chat_name"), start=1):
        group = group.reset_index(drop=True)
        before = len(samples)

        for i, row in group.iterrows():
            if not bool(row.get("is_me", False)):
                continue
            if row.get("is_low_signal", False):
                continue

            start = max(0, i - ctx_window)
            ctx = group.iloc[start:i]
            if ctx.empty:
                continue

            last_other = None
            for r in reversed(list(ctx.itertuples())):
                if not bool(getattr(r, "is_me", False)):
                    last_other = r
                    break

            if require_other and last_other is None:
                continue

            if last_other is not None:
                dt = row.timestamp - last_other.timestamp
                if pd.notna(dt) and dt > timedelta(minutes=max_gap_minutes):
                    continue
                reply_gap_min = float(dt.total_seconds() / 60.0)
                partner_text = str(last_other.text)
                partner_sender = str(last_other.sender)
            else:
                reply_gap_min = None
                partner_text = ""
                partner_sender = ""

            if require_question:
                recent = ctx.tail(4)
                if not any(getattr(r, "is_question", False) and not bool(getattr(r, "is_me", False)) for r in recent.itertuples()):
                    continue

            ctx_text = "\n".join(f"{r.sender}: {r.text}" for r in ctx.itertuples())
            resp = str(row.text or "").strip()
            samples.append({
                "unit_type": "qa_turn",
                "chat_name": chat,
                "timestamp": row.timestamp,
                "context": ctx_text,
                "partner_text": partner_text,
                "partner_sender": partner_sender,
                "reply_gap_min": reply_gap_min,
                "response": resp,
                "style_text": resp,
                "style_signature": _style_signature(resp),
                "signal_score": float(row.get("signal_score", 0.5)),
                "self_fact": bool(row.get("self_fact", False)),
                "embed_text": (
                    f"Ultimo mensaje del contacto:\n{partner_text}\n\n"
                    f"Contexto reciente:\n{ctx_text}\n\n"
                    f"Mi respuesta real:\n{resp}"
                ),
                "source_id": f"qa::{chat}::{i}",
            })

        if _progress(c_idx, total_chats):
            added = len(samples) - before
            print(f"[Paso 5][qa] chat {c_idx}/{total_chats} | +{added} | total={len(samples)}")

    print(f"[Paso 5][qa] fin: {len(samples)} unidades en {time.perf_counter() - t0:.1f}s")
    return pd.DataFrame(samples)


def split_topics(group: pd.DataFrame, gap_minutes: int = 45):
    group = group.sort_values("timestamp").reset_index(drop=True)
    topics = []
    current = [0]
    for i in range(1, len(group)):
        dt = group.loc[i, "timestamp"] - group.loc[i - 1, "timestamp"]
        if pd.isna(dt) or dt > timedelta(minutes=gap_minutes):
            topics.append(group.iloc[current].copy())
            current = [i]
        else:
            current.append(i)
    if current:
        topics.append(group.iloc[current].copy())
    return topics


def build_my_message_units(df: pd.DataFrame, local_window: int = 2):
    units = []
    total_chats = int(df["chat_name"].nunique())
    t0 = time.perf_counter()

    for c_idx, (chat, group) in enumerate(df.groupby("chat_name"), start=1):
        group = group.sort_values("timestamp").reset_index(drop=True)
        before = len(units)

        # Precalcula el ultimo mensaje del otro para cada fila (evita coste O(n^2)).
        prev_other_text = [""] * len(group)
        prev_other_sender = [""] * len(group)
        last_other_text = ""
        last_other_sender = ""
        for i, r in enumerate(group.itertuples()):
            prev_other_text[i] = last_other_text
            prev_other_sender[i] = last_other_sender
            if not bool(getattr(r, "is_me", False)):
                last_other_text = str(r.text)
                last_other_sender = str(r.sender)

        for i, row in group.iterrows():
            if not bool(row.get("is_me", False)):
                continue
            if row.get("is_low_signal", False):
                continue

            left = max(0, i - local_window)
            right = min(len(group), i + local_window + 1)
            local = group.iloc[left:right]
            local_text = "\n".join(f"{r.sender}: {r.text}" for r in local.itertuples())

            partner_text = prev_other_text[i]
            partner_sender = prev_other_sender[i]
            resp = str(row.text or "").strip()
            units.append({
                "unit_type": "my_message",
                "chat_name": chat,
                "timestamp": row.timestamp,
                "context": local_text,
                "partner_text": partner_text,
                "partner_sender": partner_sender,
                "reply_gap_min": None,
                "response": resp,
                "style_text": resp,
                "style_signature": _style_signature(resp),
                "signal_score": float(row.get("signal_score", 0.5)),
                "self_fact": bool(row.get("self_fact", False)),
                "embed_text": (
                    f"Ultimo mensaje del contacto:\n{partner_text}\n\n"
                    f"Mensaje mio:\n{resp}\n\n"
                    f"Micro-contexto:\n{local_text}"
                ),
                "source_id": f"msg::{chat}::{i}",
            })

        if _progress(c_idx, total_chats):
            added = len(units) - before
            print(f"[Paso 5][msg] chat {c_idx}/{total_chats} | +{added} | total={len(units)}")

    print(f"[Paso 5][msg] fin: {len(units)} unidades en {time.perf_counter() - t0:.1f}s")
    return pd.DataFrame(units)


def build_topic_blocks(df: pd.DataFrame, min_block: int = 3, max_block: int = 8, gap_minutes: int = 45):
    rows = []
    total_blocks = 0
    total_chats = int(df["chat_name"].nunique())
    t0 = time.perf_counter()

    for c_idx, (chat, group) in enumerate(df.groupby("chat_name"), start=1):
        topics = split_topics(group, gap_minutes=gap_minutes)
        topics_used = 0

        for t_idx, topic in enumerate(topics):
            topic = topic.reset_index(drop=True)
            n = len(topic)
            if n < min_block:
                continue

            topics_used += 1
            if n < 200:
                stride = BLOCK_STRIDE_SMALL
            elif n < 2000:
                stride = BLOCK_STRIDE_MEDIUM
            else:
                stride = BLOCK_STRIDE_LARGE

            per_topic = 0
            for start in range(0, n, stride):
                for size in (min_block, 5, max_block):
                    if size < min_block or size > max_block:
                        continue
                    end = start + size
                    if end > n:
                        continue

                    block = topic.iloc[start:end]
                    block_signal = float(block.get("signal_score", pd.Series([0.5] * len(block))).mean())
                    if block_signal < BLOCK_MIN_SIGNAL:
                        continue

                    text = "\n".join(f"{r.sender}: {r.text}" for r in block.itertuples())
                    rows.append({
                        "unit_type": "topic_block",
                        "chat_name": chat,
                        "timestamp": block.iloc[-1].timestamp,
                        "context": text,
                        "partner_text": "",
                        "partner_sender": "",
                        "reply_gap_min": None,
                        "response": "",
                        "style_text": "",
                        "style_signature": "block",
                        "signal_score": block_signal,
                        "self_fact": bool(block.get("self_fact", pd.Series([False] * len(block))).any()),
                        "embed_text": f"Bloque de conversacion:\n{text}",
                        "source_id": f"blk::{chat}::{t_idx}::{start}::{end}",
                    })

                    per_topic += 1
                    total_blocks += 1
                    if per_topic >= MAX_BLOCKS_PER_TOPIC:
                        break
                    if total_blocks >= MAX_BLOCKS_TOTAL:
                        break
                if per_topic >= MAX_BLOCKS_PER_TOPIC or total_blocks >= MAX_BLOCKS_TOTAL:
                    break
            if total_blocks >= MAX_BLOCKS_TOTAL:
                break

        if _progress(c_idx, total_chats):
            print(
                f"[Paso 5][blk] chat {c_idx}/{total_chats} | "
                f"topics={topics_used} | total_blocks={total_blocks}"
            )

        if total_blocks >= MAX_BLOCKS_TOTAL:
            print(f"[Paso 5][blk] alcanzado MAX_BLOCKS_TOTAL={MAX_BLOCKS_TOTAL}")
            break

    print(f"[Paso 5][blk] fin: {len(rows)} unidades en {time.perf_counter() - t0:.1f}s")
    return pd.DataFrame(rows)


def build_response_units(memory_units_df: pd.DataFrame):
    base = memory_units_df.copy()
    base = base[base["unit_type"].eq("qa_turn")].copy()
    base["partner_text"] = base["partner_text"].fillna("").astype(str).str.strip()
    base["response"] = base["response"].fillna("").astype(str).str.strip()

    base = base[
        (base["partner_text"].str.len() >= 1) &
        (base["response"].str.len() >= MIN_RESPONSE_CHARS)
    ].copy()

    base["response_embed_text"] = (
        "Mensaje del contacto:\n" + base["partner_text"] +
        "\n\nContexto:\n" + base["context"].fillna("").astype(str) +
        "\n\nMi respuesta real:\n" + base["response"]
    )
    base["retrieval_role"] = "response"
    return base.reset_index(drop=True)


def build_style_units(memory_units_df: pd.DataFrame):
    base = memory_units_df.copy()
    base = base[base["unit_type"].isin(["qa_turn", "my_message"])].copy()
    base["style_text"] = base["response"].fillna("").astype(str).str.strip()
    base = base[base["style_text"].str.len() >= MIN_STYLE_CHARS].copy()
    base = base[base["style_text"].str.split().str.len().clip(lower=0) <= MAX_STYLE_WORDS].copy()

    base["style_signature"] = base["style_text"].map(_style_signature)
    base["style_embed_text"] = (
        "Mi mensaje:\n" + base["style_text"] +
        "\n\nContexto breve:\n" + base["context"].fillna("").astype(str)
    )
    base["retrieval_role"] = "style"
    return base.reset_index(drop=True)


step5_t0 = time.perf_counter()
build_key, build_payload = _build_key(messages)
build_manifest = _load_json(MEMORY_BUILD_MANIFEST, {})

cache_hit = (
    MEMORY_UNITS_PATH.exists() and
    build_manifest.get("build_key") == build_key
)

print(f"[Paso 5] build_key={build_key[:12]}... | version={MEMORY_BUILD_VERSION}")
print(f"[Paso 5] my_name efectivo={MY_NAME!r} | mensajes_propios={int(messages.get('is_me', pd.Series(dtype=bool)).sum())}")

if cache_hit:
    memory_units = pd.read_parquet(MEMORY_UNITS_PATH)
    print(
        f"[Paso 5] cache hit: reutilizando {len(memory_units)} unidades | "
        f"updated_at={build_manifest.get('updated_at', 'n/a')}"
    )
else:
    print("[Paso 5] reconstruyendo memoria (puede tardar la primera vez)...")

    t_phase = time.perf_counter()
    qa_units = build_qa_examples(
        messages,
        CTX_WINDOW,
        max_gap_minutes=QA_MAX_GAP_MIN,
        require_other=QA_REQUIRE_OTHER,
        require_question=QA_REQUIRE_QUESTION,
    )
    print(f"[Paso 5] qa_units={len(qa_units)} | {time.perf_counter() - t_phase:.1f}s")

    t_phase = time.perf_counter()
    my_units = build_my_message_units(messages, local_window=2)
    print(f"[Paso 5] my_units={len(my_units)} | {time.perf_counter() - t_phase:.1f}s")
    if len(qa_units) == 0 and len(my_units) == 0:
        print("[Paso 5][WARN] qa_units y my_units quedaron en 0. Revisa MY_NAME en .env y la deteccion de mensajes propios.")

    t_phase = time.perf_counter()
    block_units = build_topic_blocks(messages, min_block=3, max_block=8, gap_minutes=45)
    print(f"[Paso 5] block_units={len(block_units)} | {time.perf_counter() - t_phase:.1f}s")

    t_phase = time.perf_counter()
    memory_units = pd.concat([qa_units, my_units, block_units], ignore_index=True)
    memory_units = memory_units.sort_values("timestamp").reset_index(drop=True)
    print(f"[Paso 5] concat+sort: {len(memory_units)} unidades | {time.perf_counter() - t_phase:.1f}s")

    memory_units.to_parquet(MEMORY_UNITS_PATH, index=False)

    build_manifest = {
        "build_key": build_key,
        "payload": build_payload,
        "rows": int(len(memory_units)),
        "updated_at": pd.Timestamp.now("UTC").isoformat(),
    }
    _save_json(MEMORY_BUILD_MANIFEST, build_manifest)
    print(f"[Paso 5] reconstruido: {len(memory_units)} unidades")

# Construccion/cache de unidades duales (response/style)
dual_manifest = _load_json(DUAL_UNITS_MANIFEST, {})
dual_hit = (
    RESPONSE_UNITS_PATH.exists() and
    STYLE_UNITS_PATH.exists() and
    dual_manifest.get("build_key") == build_key
)

if dual_hit:
    response_units = pd.read_parquet(RESPONSE_UNITS_PATH)
    style_units = pd.read_parquet(STYLE_UNITS_PATH)
    print(
        f"[Paso 5] dual cache hit: response={len(response_units)} | style={len(style_units)}"
    )
else:
    print("[Paso 5] construyendo datasets duales (response/style)...")
    t_phase = time.perf_counter()
    response_units = build_response_units(memory_units)
    style_units = build_style_units(memory_units)

    response_units.to_parquet(RESPONSE_UNITS_PATH, index=False)
    style_units.to_parquet(STYLE_UNITS_PATH, index=False)

    dual_manifest = {
        "build_key": build_key,
        "rows_response": int(len(response_units)),
        "rows_style": int(len(style_units)),
        "updated_at": pd.Timestamp.now("UTC").isoformat(),
    }
    _save_json(DUAL_UNITS_MANIFEST, dual_manifest)
    print(
        f"[Paso 5] dual reconstruido en {time.perf_counter() - t_phase:.1f}s | "
        f"response={len(response_units)} | style={len(style_units)}"
    )

print(memory_units["unit_type"].value_counts())
if not style_units.empty:
    print("[Paso 5] style_signature top:")
    print(style_units["style_signature"].value_counts().head(10))
print(memory_units.head(3))
print(f"[Paso 5] Guardadas {len(memory_units)} unidades en {DATA_DIR} | tiempo total {time.perf_counter() - step5_t0:.1f}s")





## 6. Embeddings locales + FAISS (indice dual por eventos)
Se indexan unidades heterogeneas (`qa_turn`, `my_message`, `topic_block`) con cache incremental por hash.


In [None]:
import hashlib
import math

if memory_units.empty:
    raise SystemExit("No hay unidades: revisa parseo y construccion de memoria")

response_units_path = DATA_DIR / "response_units.parquet"
style_units_path = DATA_DIR / "style_units.parquet"

if response_units_path.exists() and style_units_path.exists():
    response_units = pd.read_parquet(response_units_path)
    style_units = pd.read_parquet(style_units_path)
else:
    raise SystemExit("Faltan response_units/style_units del paso 5. Ejecuta paso 5 completo.")

if response_units.empty or style_units.empty:
    raise SystemExit("response_units o style_units esta vacio. Revisa limpieza y parseo.")

response_index_path = INDEX_DIR / "response.index"
style_index_path = INDEX_DIR / "style.index"
response_meta_path = INDEX_DIR / "response_units.parquet"
style_meta_path = INDEX_DIR / "style_units.parquet"
response_emb_path = INDEX_DIR / "response_embeddings.npy"
style_emb_path = INDEX_DIR / "style_embeddings.npy"
style_cluster_manifest_path = INDEX_DIR / "style_cluster_manifest.json"

EMBED_INPUT_VERSION_RESPONSE = "v1_response_dual"
EMBED_INPUT_VERSION_STYLE = "v1_style_dual"
STYLE_CLUSTER_VERSION = "v1_faiss_kmeans"


def _row_hash(row, text_col: str, version: str):
    payload = (
        f"{version}\n<SEP>\n"
        f"{row.get('source_id', '')}\n<SEP>\n"
        f"{row.get('chat_name', '')}\n<SEP>\n"
        f"{row.get('timestamp', '')}\n<SEP>\n"
        f"{row.get(text_col, '')}"
    )
    return hashlib.sha256(payload.encode("utf-8")).hexdigest()


def _ensure_embedder():
    global embedder
    if "embedder" not in globals() or embedder is None:
        embedder = SentenceTransformer(EMBED_MODEL)


def _build_embeddings(units_df: pd.DataFrame, text_col: str, hash_col: str,
                      meta_path: Path, emb_path: Path, version: str, label: str):
    work = units_df.copy()
    work[hash_col] = work.apply(lambda r: _row_hash(r, text_col=text_col, version=version), axis=1)

    cached_map = {}
    if meta_path.exists() and emb_path.exists():
        cached_units = pd.read_parquet(meta_path)
        cached_emb = np.load(emb_path)
        if hash_col not in cached_units.columns:
            cached_units[hash_col] = cached_units.apply(lambda r: _row_hash(r, text_col=text_col, version=version), axis=1)
        for h, emb in zip(cached_units[hash_col].tolist(), cached_emb):
            cached_map[h] = emb

    hashes = work[hash_col].tolist()
    missing = [h for h in hashes if h not in cached_map]

    if missing:
        print(f"[Paso 6][{label}] embeddings nuevos/cambiados: {len(missing)}")
        _ensure_embedder()
        mask_new = work[hash_col].isin(missing)
        texts = work.loc[mask_new, text_col].fillna("").astype(str).tolist()
        new_emb = embedder.encode(
            texts,
            batch_size=32,
            convert_to_numpy=True,
            show_progress_bar=True,
            normalize_embeddings=True,
        ).astype("float32")

        for h, emb in zip(work.loc[mask_new, hash_col].tolist(), new_emb):
            cached_map[h] = emb
    else:
        print(f"[Paso 6][{label}] sin cambios: reutilizando embeddings")
        _ensure_embedder()

    emb_matrix = np.stack([cached_map[h] for h in hashes]).astype("float32")
    return work, emb_matrix


def _build_faiss_index(embeddings: np.ndarray):
    idx = faiss.IndexFlatIP(embeddings.shape[1])
    idx.add(embeddings)
    return idx


# 1) Response index
response_units_work, response_embeddings = _build_embeddings(
    response_units,
    text_col="response_embed_text",
    hash_col="response_unit_hash",
    meta_path=response_meta_path,
    emb_path=response_emb_path,
    version=EMBED_INPUT_VERSION_RESPONSE,
    label="response",
)

response_index = _build_faiss_index(response_embeddings)
faiss.write_index(response_index, str(response_index_path))
response_units_work.to_parquet(response_meta_path, index=False)
np.save(response_emb_path, response_embeddings)
print(f"[Paso 6] response index listo: {len(response_units_work)} unidades")


# 2) Style index
style_units_work, style_embeddings = _build_embeddings(
    style_units,
    text_col="style_embed_text",
    hash_col="style_unit_hash",
    meta_path=style_meta_path,
    emb_path=style_emb_path,
    version=EMBED_INPUT_VERSION_STYLE,
    label="style",
)

style_index = _build_faiss_index(style_embeddings)
faiss.write_index(style_index, str(style_index_path))


# 3) Categorizacion automatica de estilo (clusters semanticos)
def _style_cluster_key(units_df: pd.DataFrame):
    raw = "\n".join(units_df["style_unit_hash"].astype(str).tolist())
    payload = f"{STYLE_CLUSTER_VERSION}\n{raw}"
    return hashlib.sha256(payload.encode("utf-8")).hexdigest()


def _load_json(path: Path, default):
    if not path.exists():
        return default
    try:
        return json.loads(path.read_text())
    except Exception:
        return default


def _save_json(path: Path, payload):
    path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))


cluster_key = _style_cluster_key(style_units_work)
cluster_manifest = _load_json(style_cluster_manifest_path, {})
cluster_hit = (
    "style_cluster" in style_units_work.columns and
    cluster_manifest.get("cluster_key") == cluster_key
)

if cluster_hit:
    print("[Paso 6][style] clusters reutilizados")
else:
    n = len(style_embeddings)
    if n < 30:
        style_units_work["style_cluster"] = 0
        print("[Paso 6][style] pocos datos: cluster unico")
    else:
        k = min(24, max(6, int(math.sqrt(n / 180.0))))
        train_size = min(n, 50000)
        if train_size < n:
            rng = np.random.default_rng(42)
            sel = rng.choice(n, size=train_size, replace=False)
            train_x = style_embeddings[sel]
        else:
            train_x = style_embeddings

        kmeans = faiss.Kmeans(
            d=style_embeddings.shape[1],
            k=k,
            niter=20,
            verbose=False,
            seed=42,
            gpu=False,
        )
        kmeans.train(train_x.astype("float32"))
        _, labels = kmeans.index.search(style_embeddings.astype("float32"), 1)
        style_units_work["style_cluster"] = labels.reshape(-1).astype(int)
        print(f"[Paso 6][style] clusters generados: k={k}")

    cluster_manifest = {
        "cluster_key": cluster_key,
        "rows": int(len(style_units_work)),
        "updated_at": pd.Timestamp.now("UTC").isoformat(),
    }
    _save_json(style_cluster_manifest_path, cluster_manifest)

style_units_work.to_parquet(style_meta_path, index=False)
np.save(style_emb_path, style_embeddings)

print(f"[Paso 6] style index listo: {len(style_units_work)} unidades")
print("[Paso 6] top clusters de estilo:")
print(style_units_work["style_cluster"].value_counts().head(10))
print(f"Indices duales guardados en {INDEX_DIR}")




## 7. Recuperación y chat (OpenAI API + memoria híbrida)
La recuperación combina evidencia de mensajes individuales y bloques de conversación.
Si la evidencia está fragmentada, se une; si no hay soporte, responde con cautela.


### Configurar OpenAI desde `.env`
El notebook carga `../.env` automáticamente en la sección de configuración.
Crea ese archivo a partir de `../.env.example` y define `OPENAI_API_KEY` (y opcionalmente `GEN_MODEL`).


In [None]:
from dotenv import load_dotenv
from pathlib import Path
import os

PROJECT_ROOT = Path("..").resolve()
load_dotenv(PROJECT_ROOT / ".env", override=True)

if os.getenv("OPENAI_API_KEY", ""):
    print("OPENAI_API_KEY detectado en ../.env")
else:
    print("Falta OPENAI_API_KEY en ../.env")
print("GEN_MODEL:", os.getenv("GEN_MODEL", "gpt-4.1-mini"))



### Comprobar credenciales de OpenAI
La celda siguiente valida si `OPENAI_API_KEY` está definido en el entorno del kernel.


In [None]:

# Verificar credenciales OpenAI
import os
key = os.getenv("OPENAI_API_KEY", "")
if key:
    print("OPENAI_API_KEY detectado")
else:
    print("Falta OPENAI_API_KEY en el entorno del kernel")
print("GEN_MODEL:", os.getenv("GEN_MODEL", "gpt-4.1-mini"))


In [None]:
import re
import unicodedata

api_key = os.getenv("OPENAI_API_KEY", "")
if not api_key:
    raise SystemExit("Falta OPENAI_API_KEY: define la variable en ../.env")

GEN_MODEL = os.getenv("GEN_MODEL", "gpt-4.1-mini")
client = OpenAI(api_key=api_key)
SYSTEM_PROMPT = (
    "Eres la persona que habla en los mensajes de WhatsApp. Respondes en primera persona y en espanol. "
    "Debes generar una respuesta nueva, no copiar literalmente ejemplos. "
    "Usa la evidencia de respuestas para decidir contenido y la evidencia de estilo para el tono. "
    "Si no hay soporte suficiente para un dato, dilo con honestidad."
)

response_units_cached = pd.read_parquet(INDEX_DIR / "response_units.parquet")
style_units_cached = pd.read_parquet(INDEX_DIR / "style_units.parquet")
response_index_cached = faiss.read_index(str(INDEX_DIR / "response.index"))
style_index_cached = faiss.read_index(str(INDEX_DIR / "style.index"))
response_embeddings_cached = np.load(INDEX_DIR / "response_embeddings.npy")
style_embeddings_cached = np.load(INDEX_DIR / "style_embeddings.npy")

if "embedder" not in globals() or embedder is None:
    embedder = SentenceTransformer(EMBED_MODEL)

max_ts_response = pd.to_datetime(response_units_cached["timestamp"], errors="coerce").max()
max_ts_style = pd.to_datetime(style_units_cached["timestamp"], errors="coerce").max()


def _chat_create(messages, temperature=None):
    kwargs = {
        "model": GEN_MODEL,
        "messages": messages,
    }
    if temperature is not None:
        kwargs["temperature"] = temperature
    try:
        return client.chat.completions.create(**kwargs)
    except Exception as e:
        msg = str(e).lower()
        if "temperature" in msg and ("unsupported" in msg or "only the default" in msg):
            return client.chat.completions.create(model=GEN_MODEL, messages=messages)
        raise


def _contact_hint(query: str):
    q = query.lower()
    all_chats = pd.concat(
        [response_units_cached["chat_name"], style_units_cached["chat_name"]],
        ignore_index=True,
    ).dropna().unique().tolist()
    for chat in all_chats:
        low = str(chat).lower()
        if low in q:
            return chat
    return None


def _recency_boost(ts, max_ts, alpha=0.05):
    ts = pd.to_datetime(ts, errors="coerce")
    if pd.isna(ts) or pd.isna(max_ts):
        return 0.0
    days = max((max_ts - ts).days, 0)
    return max(0.0, alpha - days * 0.001)


def _norm_text(s: str):
    s = (s or "").lower().strip()
    s = unicodedata.normalize("NFKD", s)
    s = "".join(ch for ch in s if not unicodedata.combining(ch))
    s = re.sub(r"[¿?¡!.,;:\-_\"'()\[\]]", "", s)
    s = re.sub(r"\s+", " ", s)
    return s


def _retrieve_response_hits(query: str, k: int = 6, min_score: float = 0.22, contact_hint: str = None):
    q_vec = embedder.encode([query], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
    scores, idxs = response_index_cached.search(q_vec, max(60, k * 8))

    out = []
    for score, idx in zip(scores[0], idxs[0]):
        if idx < 0:
            continue
        s = float(score)
        if s < min_score:
            continue

        row = response_units_cached.iloc[idx].to_dict()
        signal = float(row.get("signal_score", 0.4) or 0.4)
        recency = _recency_boost(row.get("timestamp"), max_ts_response, alpha=0.05)
        contact_boost = 0.04 if (contact_hint and row.get("chat_name") == contact_hint) else 0.0

        gap = row.get("reply_gap_min")
        if gap is None or pd.isna(gap):
            gap_bonus = 0.0
        else:
            gap = float(gap)
            gap_bonus = max(0.0, 0.03 - min(gap, 120.0) * 0.00025)

        row["score"] = s
        row["rank_score"] = s + 0.08 * signal + recency + contact_boost + gap_bonus
        row["retrieval_role"] = "response"
        out.append(row)

    out = sorted(out, key=lambda x: x["rank_score"], reverse=True)
    dedup = []
    seen = set()
    for r in out:
        key = r.get("source_id")
        if key in seen:
            continue
        seen.add(key)
        dedup.append(r)
        if len(dedup) >= k:
            break
    return dedup


def _retrieve_style_hits(query: str, k: int = 10, min_score: float = 0.20, contact_hint: str = None):
    q_vec = embedder.encode([query], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
    scores, idxs = style_index_cached.search(q_vec, max(80, k * 8))

    ranked = []
    for score, idx in zip(scores[0], idxs[0]):
        if idx < 0:
            continue
        s = float(score)
        if s < min_score:
            continue

        row = style_units_cached.iloc[idx].to_dict()
        signal = float(row.get("signal_score", 0.4) or 0.4)
        recency = _recency_boost(row.get("timestamp"), max_ts_style, alpha=0.04)
        contact_boost = 0.03 if (contact_hint and row.get("chat_name") == contact_hint) else 0.0

        text = str(row.get("style_text", "") or "").strip()
        n_words = len(text.split())
        len_bonus = min(0.04, n_words / 400.0)

        row["score"] = s
        row["rank_score"] = s + 0.07 * signal + recency + contact_boost + len_bonus
        row["retrieval_role"] = "style"
        ranked.append(row)

    ranked = sorted(ranked, key=lambda x: x["rank_score"], reverse=True)

    # Diversificacion por cluster: evita repetir siempre el mismo patron de respuesta.
    buckets = {}
    for r in ranked:
        cluster = int(r.get("style_cluster", -1))
        buckets.setdefault(cluster, []).append(r)

    diversified = []
    seen_text = set()
    clusters = list(buckets.keys())
    cursor = {c: 0 for c in clusters}

    while len(diversified) < k and clusters:
        next_clusters = []
        for c in clusters:
            i = cursor[c]
            arr = buckets[c]
            while i < len(arr):
                cand = arr[i]
                i += 1
                t_key = _norm_text(str(cand.get("style_text", "")))
                if not t_key or t_key in seen_text:
                    continue
                seen_text.add(t_key)
                diversified.append(cand)
                break
            cursor[c] = i
            if i < len(arr):
                next_clusters.append(c)
            if len(diversified) >= k:
                break
        clusters = next_clusters

    return diversified


def retrieve_bundle(query: str, k_response: int = 6, k_style: int = 8, min_score: float = 0.22):
    contact = _contact_hint(query)
    response_hits = _retrieve_response_hits(query, k=k_response, min_score=min_score, contact_hint=contact)
    style_hits = _retrieve_style_hits(query, k=k_style, min_score=min_score * 0.9, contact_hint=contact)
    return {
        "response_hits": response_hits,
        "style_hits": style_hits,
    }


def retrieve(query: str, k_total: int = 10, min_score: float = 0.22):
    bundle = retrieve_bundle(
        query,
        k_response=max(4, k_total // 2),
        k_style=max(4, k_total),
        min_score=min_score,
    )
    merged = []
    merged.extend(bundle["response_hits"])
    merged.extend(bundle["style_hits"])
    merged = sorted(merged, key=lambda x: x.get("rank_score", x.get("score", 0.0)), reverse=True)
    return merged[:k_total]


def answer(prompt: str, k_total: int = 10, temperature: float = None, min_score: float = 0.22):
    k_response = max(4, k_total // 2)
    k_style = max(6, k_total)
    bundle = retrieve_bundle(prompt, k_response=k_response, k_style=k_style, min_score=min_score)

    response_hits = bundle["response_hits"]
    style_hits = bundle["style_hits"]

    if not response_hits and not style_hits:
        return "No tengo evidencia suficiente en mis chats para responder eso con seguridad."

    response_block = "\n\n".join(
        (
            f"[response | score {h['score']:.3f} | chat: {h['chat_name']} | fecha: {h['timestamp']}]\n"
            f"Mensaje del contacto: {h.get('partner_text', '')}\n"
            f"Mi respuesta real: {h.get('response', '')}\n"
            f"Contexto: {h.get('context', '')}"
        )
        for h in response_hits
    )

    style_block = "\n".join(
        f"- ({h.get('style_cluster', -1)}) {h.get('style_text', '')}"
        for h in style_hits[:8]
    )

    user_msg = (
        "Consulta nueva:\n" + prompt +
        "\n\nEvidencia de como respondo a mensajes parecidos:\n" + (response_block or "(sin evidencia de respuesta)") +
        "\n\nMuestras de mi estilo de escritura:\n" + (style_block or "(sin muestras de estilo)") +
        "\n\nInstrucciones de salida:\n"
        "1) Genera una sola respuesta como la escribiria yo.\n"
        "2) Usa la evidencia de respuesta para el contenido.\n"
        "3) Usa las muestras de estilo para tono/forma, sin copiar literal.\n"
        "4) Si falta evidencia para afirmar algo, dilo con claridad."
    )

    completion = _chat_create(
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_msg},
        ],
        temperature=temperature,
    )
    text = (completion.choices[0].message.content or "").strip()

    if not text:
        return "No tengo evidencia suficiente en mis chats para responder eso con seguridad."

    # Segundo intento ligero solo si devuelve eco exacto.
    if _norm_text(text) == _norm_text(prompt):
        retry = _chat_create(
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_msg + "\n\nReformula sin repetir la consulta."},
            ],
            temperature=temperature,
        )
        text2 = (retry.choices[0].message.content or "").strip()
        if text2:
            text = text2

    return text




### Probar recuperación


In [None]:

retrieve("¿Que tal?")



### Probar chat (OpenAI API)
Si `OPENAI_API_KEY` y `GEN_MODEL` están definidos en `../.env`, ejecuta la celda de abajo con `answer(...)`.



### Nota de costes y privacidad
En modo API, cada consulta envía contexto recuperado al proveedor y consume tokens facturables.
Ajusta `k` en `answer(prompt, k=...)` para controlar coste/contexto.


In [None]:
answer("Te gusta la carbonara?")
