In [None]:
# =============================================================================
#  fix_translations.py  ·  v1.0
#  ––– Corrects Spanish translations given English source lines
# =============================================================================
import os, re, time, json
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Set

import pandas as pd
import spacy
from spacy.matcher import PhraseMatcher
from dotenv import load_dotenv

# ── LLM provider selection ---------------------------------------------------
MODEL_PROVIDER = "openai"      # "openai"  or  "gemini"

load_dotenv()                  # loads OPENAI_API_KEY or GOOGLE_API_KEY

if MODEL_PROVIDER == "openai":
    import openai
    client = openai.OpenAI()
    model_name = "gpt-4o-mini"       # pick your preferred model
elif MODEL_PROVIDER == "gemini":
    import google.generativeai as genai
    genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
    model_name = "gemini-2.5-pro"
else:
    raise ValueError(f"Unknown provider: {MODEL_PROVIDER}")

# ── Optional EN–ES dictionary support ---------------------------------------
ENABLE_DICTIONARY = True
DICTIONARY_PATH   = "en_es_dictionary.txt"

def load_en_es_dictionary(file_path:str) -> Dict[str, List[str]]:
    """Load a tab‑separated EN–ES dictionary into {head → [variants…]}."""
    d: Dict[str, List[str]] = defaultdict(list)
    try:
        with Path(file_path).open(encoding="utf8") as f:
            for raw in f:
                if raw.startswith("#") or not raw.strip():         # skip comments/blank
                    continue
                en, es, *rest = raw.rstrip("\n").split("\t")
                d[en].append(" ".join([es, *rest]).strip())
    except FileNotFoundError:
        print(f"Dictionary not found at {file_path}. Continuing without it.")
    return d

# --- spaCy matcher (keeps duplicates, boosts relevant dictionary terms) -----
class DictionaryMatcher:
    def __init__(
        self,
        dictionary: Dict[str, List[str]],
        model: str = "en_core_web_sm",
        pos_focus: Set[str] = {"NOUN", "ADJ"},
    ):
        self.dictionary = dictionary
        self.nlp  = spacy.load(model, disable=["ner", "parser"])
        self.posF = pos_focus

        single, multi = set(), set()
        for head in dictionary:
            (multi if " " in head else single).add(head)

        # lemma → head  (single‑word entries)
        self.lemma2head = {h.lower(): h for h in single}

        # multi‑word entries
        self.phraser = PhraseMatcher(self.nlp.vocab, attr="LOWER")
        for phrase in multi:
            self.phraser.add(phrase, [self.nlp.make_doc(phrase)])

    def find_terms(self, text: str) -> Dict[str, List[str]]:
        doc  = self.nlp(text)
        hits = {}

        # lemma‑based (single word)
        for t in doc:
            if t.pos_ in self.posF:
                head = self.lemma2head.get(t.lemma_.lower())
                if head:
                    hits[head] = self.dictionary[head]

        # multi‑word
        for mid, s, e in self.phraser(doc):
            head = self.nlp.vocab.strings[mid]
            span = doc[s:e].text.lower()
            if span == head.lower():
                hits[head] = self.dictionary[head]
        return hits

# ── Prompt helpers -----------------------------------------------------------
def read_file(path:str) -> str:
    try:
        return Path(path).read_text(encoding="utf8").strip()
    except FileNotFoundError:
        print(f"⚠️  File not found: {path}")
        return ""

def build_fix_prompt(
    batch_pairs: dict,
    system_prompt: str,
    dictionary_block: str,
    previous_context: list[str] | None = None,
    provider: str = "openai",
) -> str:
    """
    Build the user prompt requesting corrected Spanish lines.
    Returns a markdown prompt string; system_prompt goes separately.
    """
    context_block = ""
    if previous_context:
        context_block = (
            "## Previous corrections (context)\n"
            + "\n".join(f"- {line}" for line in previous_context)
            + "\n\n"
        )

    pairs_json = json.dumps(batch_pairs, indent=2, ensure_ascii=False)

    if provider == "openai":
        user_prompt = (
            "# Correction Task\n\n"
            "You will receive a JSON object where each key maps to an object:\n"
            '`{"en": "...", "es": "..."}\'.\n'
            "Return **one JSON object** with the **same keys** and the **corrected Spanish\n"
            "string** as each value. Keep strings unchanged when already correct.\n\n"
            "Follow all style rules in the system prompt. Use the dictionary below\n"
            "only if it helps. Do not output anything except the JSON object.\n\n"
            f"{context_block}"
            f"{dictionary_block}"
            "## Input JSON\n```json\n" + pairs_json + "\n```\n"
        )
    else:      # gemini
        user_prompt = (
            "Correct the Spanish lines in the JSON (same keys). "
            "Return JSON only.\n\n"
            f"{context_block}{dictionary_block}"
            "Input:\n" + pairs_json
        )
    return user_prompt

# ── LLM call -----------------------------------------------------------------
def correct_batch_json(
    batch_pairs: dict,
    system_prompt: str,
    dictionary_block: str,
    previous_context: list[str] | None = None,
    temperature: float = 0.0,
) -> dict:
    user_prompt = build_fix_prompt(
        batch_pairs,
        system_prompt,
        dictionary_block,
        previous_context,
        provider=MODEL_PROVIDER,
    )

    try:
        if MODEL_PROVIDER == "openai":
            response = client.chat.completions.create(
                model=model_name,
                response_format={"type": "json_object"},
                temperature=temperature,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user",   "content": user_prompt},
                ],
            )
            return json.loads(response.choices[0].message.content)

        else:  # gemini
            model = genai.GenerativeModel(
                model_name=model_name,
                system_instruction=system_prompt,
                generation_config={
                    "temperature": temperature,
                    "response_mime_type": "application/json",
                },
            )
            resp = model.generate_content(user_prompt)
            return json.loads(resp.text)

    except Exception as e:
        print("⚠️  LLM call failed:", e)
        return {}

# ── Pipeline -----------------------------------------------------------------
def process_file(
    input_file: str,
    output_file: str,
    system_prompt: str,
    dictionary: Dict[str, List[str]] | None,
    matcher: DictionaryMatcher | None,
    batch_size: int = 40,
):
    df = pd.read_csv(input_file)
    if not {"Segment Text", "Translated Text"}.issubset(df.columns):
        raise ValueError("CSV must have 'Segment Text' and 'Translated Text' columns.")

    corrected_col = "Translated Text (corrected)"
    df[corrected_col] = pd.NA

    pairs_series = df[["Segment Text", "Translated Text"]].dropna().to_dict("index")
    ids = list(pairs_series.keys())

    previous_context: list[str] = []

    for b in range(0, len(ids), batch_size):
        batch_ids   = ids[b:b+batch_size]
        batch_pairs = {
            str(i): {"en": pairs_series[i]["Segment Text"],
                     "es": pairs_series[i]["Translated Text"]}
            for i in batch_ids
        }

        # build dictionary context
        if matcher:
            hits = matcher.find_terms(" ".join(pairs_series[i]["Segment Text"]
                                               for i in batch_ids))
        elif dictionary:
            hits = {k: v for k, v in dictionary.items()
                    if any(re.search(rf"\b{k}\b", pairs_series[i]["Segment Text"], re.I)
                           for i in batch_ids)}
        else:
            hits = {}

        dict_block = ""
        if hits:
            dict_block = (
                "## Dictionary Context\n"
                + "\n".join(f"• {k} → {' | '.join(v)}" for k, v in hits.items())
                + "\n\n"
            )

        print(f"• Processing batch {b//batch_size + 1}")
        fixed = correct_batch_json(
            batch_pairs,
            system_prompt,
            dict_block,
            previous_context=previous_context,
        )

        # commit to dataframe
        for i in batch_ids:
            if str(i) in fixed:
                df.loc[i, corrected_col] = fixed[str(i)]

        # save some of this batch as rolling context
        previous_context = [fixed[str(i)] for i in batch_ids if str(i) in fixed][:max(1, len(batch_ids)//2)]

    df.to_csv(output_file, index=False, encoding="utf-8-sig")
    print(f"✅  Finished → {output_file}")

# ── Main entry ---------------------------------------------------------------
if __name__ == "__main__":
    # ‣‣‣ Adjust these paths --------------------------------------------------
    input_csv_path   = r"D:\SOKM\11 Identity 2\identity_unit_translations.csv"
    system_prompt_file = "system_prompt_fix_v1.0.txt"      # your 'Span‑Eng Fixer' prompt
    output_csv_path  = input_csv_path.replace(".csv", "_corrected.csv")
    # ------------------------------------------------------------------------

    sys_prompt = read_file(system_prompt_file)
    en_es_dict = load_en_es_dictionary(DICTIONARY_PATH) if ENABLE_DICTIONARY else {}
    matcher    = DictionaryMatcher(en_es_dict) if ENABLE_DICTIONARY else None

    process_file(
        input_file=input_csv_path,
        output_file=output_csv_path,
        system_prompt=sys_prompt,
        dictionary=en_es_dict,
        matcher=matcher,
        batch_size=30,
    )
