# RAG-Enhanced Translation System

This notebook builds on the baseline results from `Baseline.ipynb`.

**Features:**
- **Adaptive Retrieval**: Context-aware glossary and TM retrieval
- **Comprehensive Evaluation**: RAG vs baseline comparison

---

## 1. Setup & Data Loading

In [81]:
import os
import re
import json
import time
import random
import warnings
import hashlib
import threading
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional

from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pandas as pd
from dotenv import load_dotenv
from tqdm import tqdm

# RAG-specific
from sentence_transformers import SentenceTransformer
import chromadb

# OpenAI client
from openai import OpenAI

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)

load_dotenv()
print("RAG-Enhanced Translation System (Clean Rewrite)")
print("=" * 70)

def flatten_json_strings(obj: Any, prefix: str = "") -> List[Tuple[str, str]]:
    out: List[Tuple[str, str]] = []
    if isinstance(obj, dict):
        for k, v in obj.items():
            new_prefix = f"{prefix}.{k}" if prefix else k
            out.extend(flatten_json_strings(v, new_prefix))
    elif isinstance(obj, list):
        for i, v in enumerate(obj):
            new_prefix = f"{prefix}[{i}]"
            out.extend(flatten_json_strings(v, new_prefix))
    elif isinstance(obj, str) and obj.strip():
        out.append((prefix, obj))
    return out

SRC_FILE = Path("data/en.json")
if not SRC_FILE.exists():
    raise FileNotFoundError("Missing data/en.json")

with open(SRC_FILE, "r", encoding="utf-8") as f:
    en_json = json.load(f)

en_segments: List[Tuple[str, str]] = flatten_json_strings(en_json)
print(f"Loaded {len(en_segments)} source segments")

try:
    def _sha1(p): 
        h = hashlib.sha1()
        with open(p, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''): 
                h.update(chunk)
        return h.hexdigest()
    SOURCE_SHA = _sha1("data/en.json")
except Exception:
    SOURCE_SHA = None

RAG-Enhanced Translation System (Clean Rewrite)
Loaded 76 source segments


## 2. Load Translation Memory (TM) & Glossary

In [82]:
def load_translation_memory() -> Dict[str, Dict[str, str]]:
    tm_dict = {"fr": {}, "ja": {}, "it": {}}
    tm_file = Path("data/translation_memory.csv")
    if tm_file.exists():
        try:
            df = pd.read_csv(tm_file)
            for _, row in df.iterrows():
                lang = row.get("tgt_lang")
                src = row.get("src_text")
                tgt = row.get("tgt_text")
                if lang in tm_dict and isinstance(src, str) and isinstance(tgt, str) and src and tgt:
                    tm_dict[lang][src] = tgt
            for lang in ["fr", "ja", "it"]:
                print(f"TM {lang.upper()}: {len(tm_dict[lang])} entries")
        except Exception as e:
            print(f"⚠️ Could not load TM: {e}")
    else:
        print("⚠️ No translation_memory.csv found")
    return tm_dict


def load_glossary() -> Tuple[List[str], Dict[str, Dict[str, str]], List[str]]:
    glossary_terms: List[str] = []
    glossary_map: Dict[str, Dict[str, str]] = {"fr": {}, "ja": {}, "it": {}}
    dnt_terms: List[str] = ["NaiLit"]  # default DNT brand

    gl_file = Path("data/glossary.csv")
    if gl_file.exists():
        try:
            df = pd.read_csv(gl_file)
            if "source_term" not in df.columns:
                raise ValueError("glossary.csv must have a 'source_term' column")
            glossary_terms = [t for t in df["source_term"].dropna().astype(str).tolist() if t]
            for lang in ["fr", "ja", "it"]:
                if lang in df.columns:
                    col = df[lang].astype(str)
                    glossary_map[lang] = {
                        st: tt for st, tt in zip(df["source_term"], col)
                        if pd.notna(st) and pd.notna(tt)
                    }
                    print(f"Glossary {lang.upper()}: {len(glossary_map[lang])} mappings")
            if "dnt" in df.columns:
                mask = df["dnt"].astype(str).str.upper() == "TRUE"
                dnt_terms.extend(df.loc[mask, "source_term"].dropna().astype(str).tolist())
        except Exception as e:
            print(f"⚠️ Could not load glossary: {e}")
    else:
        print("⚠️ No glossary.csv found")

    dnt_terms = sorted(list({t for t in dnt_terms if t}))
    return glossary_terms, glossary_map, dnt_terms

TM_DICT = load_translation_memory()
GLOSSARY_TERMS, GLOSSARY_MAP, DNT_TERMS = load_glossary()
print(f"DNT terms: {DNT_TERMS}")

TM FR: 2 entries
TM JA: 2 entries
TM IT: 2 entries
Glossary FR: 18 mappings
Glossary JA: 18 mappings
Glossary IT: 18 mappings
DNT terms: ['Gel-X', 'NaiLit']


## 3. Embeddings & Vector DB

In [83]:
print("Loading embedding model…")
EMB_MODEL = SentenceTransformer("intfloat/multilingual-e5-base")

CHROMA_PATH = ".chroma"
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
GLOSSARY_COL = chroma_client.get_or_create_collection(
    name="glossary", 
    metadata={"hnsw:space": "cosine"}
)

if GLOSSARY_TERMS:
    print("Indexing glossary…")
    ids_all = [str(i) for i in range(len(GLOSSARY_TERMS))]
    try:
        existing = set(GLOSSARY_COL.get(ids=ids_all)["ids"])  # may raise if none
    except Exception:
        existing = set()
    to_add = [(i, t) for i, t in enumerate(GLOSSARY_TERMS) if str(i) not in existing]
    if to_add:
        batch_terms = [t for _, t in to_add]
        embs = EMB_MODEL.encode(batch_terms, batch_size=64, normalize_embeddings=True, show_progress_bar=True)
        GLOSSARY_COL.add(
            ids=[str(i) for i, _ in to_add], 
            documents=batch_terms, 
            embeddings=[e.tolist() for e in embs]
        )
    print("✅ Glossary ready in Chroma")

Loading embedding model…
Indexing glossary…
✅ Glossary ready in Chroma


## 4. Retrieval & Utility Helpers

In [84]:
HTML_TAG = re.compile(r"</?\w+(?:\s+[^>]*?)?>", re.IGNORECASE)
_WORD = re.compile(r"\w+", re.UNICODE)

def _normalize_for_retrieval(s: str) -> str:
    s = HTML_TAG.sub("", s or "")
    s = re.sub(r"\s+", " ", s).strip().lower()
    return s

def _tokenize(s: str) -> set:
    return set(w.lower() for w in _WORD.findall(s or ""))

def tm_lookup(src_text: str, lang: str) -> Optional[str]:
    return TM_DICT.get(lang, {}).get(src_text)

def tags_preserved(src: str, tgt: str) -> bool:
    return HTML_TAG.findall(src or "") == HTML_TAG.findall(tgt or "")

def retrieve_glossary_terms(
    segment_text: str, 
    top_k: int = 5,
    min_score: float = 0.45, 
    overfetch: int = 24,
    require_lex_for_long: bool = True
) -> List[str]:
    if not (GLOSSARY_COL and GLOSSARY_TERMS):
        return []
    norm = _normalize_for_retrieval(segment_text)
    # E5 query style
    q_vec = EMB_MODEL.encode([f"query: {norm}"], normalize_embeddings=True)[0].tolist()
    try:
        res = GLOSSARY_COL.query(query_embeddings=[q_vec], n_results=max(top_k * 3, overfetch))
        docs: List[str] = res.get("documents", [[]])[0] if res else []
        dists = res.get("distances", [[]])[0] if res else []
        sims = [(1.0 - d) if (0.0 <= d <= 2.0) else d for d in dists] if dists else [0.0] * len(docs)
    except Exception as e:
        print(f"⚠️ Retrieval error: {e}")
        return []

    seg_tokens = _tokenize(norm)

    def _lex_boost(term: str) -> float:
        t = term.lower()
        b = 0.0
        if t in norm:
            b += 0.12
        overlap = len(seg_tokens & _tokenize(t))
        if overlap:
            b += min(0.05 * overlap, 0.15)
        return b

    long_text = len(norm) >= 80 if require_lex_for_long else False
    scored: Dict[str, float] = {}
    for term, base in zip(docs, sims):
        lb = _lex_boost(term)
        if long_text and lb == 0.0:
            continue
        score = base + lb
        if score >= min_score:
            if term not in scored or score > scored[term]:
                scored[term] = score

    ranked = sorted(scored.items(), key=lambda x: x[1], reverse=True)
    return [t for t, _ in ranked[:top_k]]

def build_constraints(src_text: str, lang: str, top_k: int = 3) -> List[str]:
    if lang not in GLOSSARY_MAP:
        return []
    terms = retrieve_glossary_terms(src_text, top_k=top_k, min_score=0.45)
    pairs: List[str] = []
    for en_term in terms:
        tgt = GLOSSARY_MAP[lang].get(en_term)
        if isinstance(tgt, str) and tgt:
            pairs.append(f"{en_term} → {tgt}")
    return pairs

## 5. Open AI client

In [85]:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise RuntimeError("Missing OPENAI_API_KEY")

OPENAI_CLIENT = OpenAI(api_key=OPENAI_API_KEY)

try:
    OPENAI_MODELS
except NameError:
    OPENAI_MODELS = {}

try:
    BASELINE_MODELS
except NameError:
    BASELINE_MODELS = {}

OPENAI_MODELS["gpt-4o-mini"] = {
    "model": os.getenv("OPENAI_BASELINE_MODEL", "gpt-4o-mini")
}
BASELINE_MODELS["gpt-4o-mini"] = ("openai", OPENAI_MODELS["gpt-4o-mini"])

print("OPENAI models:", OPENAI_MODELS)

OPENAI models: {'gpt-4o-mini': {'model': 'gpt-4o-mini'}}


## 6. Few-shots & Prompt helpers

In [86]:
FEWSHOTS = [
    {
        "src": "Order <strong>custom designs</strong> or choose from our curated nail art collections",
        "constraints": ["press-on nails → press-on nails"],
        "tgt_fr": "Commandez des <strong>designs personnalisés</strong> ou choisissez parmi nos collections de nail art sélectionnées",
        "tgt_ja": "<strong>カスタムデザイン</strong>を注文するか、厳選されたネイルアートコレクションから選びましょう",
        "tgt_it": "Ordina <strong>design personalizzati</strong> o scegli dalle nostre collezioni curate di nail art",
    },
    {
        "src": "With <strong>NaiLit</strong>, you get <strong>fully custom</strong>, <strong>hand-painted</strong> Gel-X press-on nails",
        "constraints": ["NaiLit → NaiLit", "press-on nails → ネイルチップ"],
        "tgt_fr": "Avec <strong>NaiLit</strong>, vous obtenez des ongles Gel-X <strong>entièrement personnalisés</strong> et <strong>peints à la main</strong>",
        "tgt_ja": "<strong>NaiLit</strong> なら、<strong>フルカスタム</strong>かつ<strong>手描き</strong>のGel-X ネイルチップが手に入ります",
        "tgt_it": "Con <strong>NaiLit</strong> ottieni unghie Gel-X <strong>completamente personalizzate</strong> e <strong>dipinte a mano</strong>",
    },
    {
        "src": "<strong>How do I place an order?</strong>",
        "constraints": [],
        "tgt_fr": "<strong>Comment passer une commande ?</strong>",
        "tgt_ja": "<strong>注文方法を教えてください。</strong>",
        "tgt_it": "<strong>Come posso effettuare un ordine?</strong>",
    },
]

def render_fewshots(target_lang: str) -> str:
    lang_key = {"fr": "tgt_fr", "ja": "tgt_ja", "it": "tgt_it"}[target_lang]
    blocks = []
    for ex in FEWSHOTS:
        cons = ex.get("constraints") or []
        cons_txt = ""
        if cons:
            cons_txt = "Glossary constraints:\n- " + "\n- ".join(cons) + "\n"
        blocks.append(
            f"""### Example
Source:
{ex['src']}
{cons_txt}Target ({target_lang}):
{ex[lang_key]}
"""
        )
    return "\n".join(blocks)

## 7. Usage Meter

In [93]:
def _backoff_sleep(attempt: int, base: float = 0.5, jitter: float = 0.2):
    import time, random
    time.sleep(base * (2 ** attempt) + random.random() * jitter)

# Cache: (lang, src_text, sorted_constraints)
CACHE_TRANSLATIONS: Dict[Tuple[str, str, Tuple[str, ...]], str] = {}

def safe_usage_tokens(resp):
    """Return (input_tokens, output_tokens) from OpenAI responses; fallback to (0,0)."""
    try:
        u = getattr(resp, "usage", None)
        if u:
            if hasattr(u, "prompt_tokens") and hasattr(u, "completion_tokens"):
                return int(u.prompt_tokens or 0), int(u.completion_tokens or 0)
            if hasattr(u, "input_tokens") and hasattr(u, "output_tokens"):
                return int(u.input_tokens or 0), int(u.output_tokens or 0)
    except Exception:
        pass
    return 0, 0

import threading, time

class UsageMeter:
    def __init__(self):
        self.lock = threading.Lock()
        self.input_tokens = 0
        self.output_tokens = 0
        self.wall_seconds = 0.0
    def add(self, in_toks: int, out_toks: int, dt: float):
        with self.lock:
            self.input_tokens += int(in_toks or 0)
            self.output_tokens += int(out_toks or 0)
            self.wall_seconds += float(dt or 0.0)
    def snapshot(self):
        with self.lock:
            return {
                "input_tokens": self.input_tokens,
                "output_tokens": self.output_tokens,
                "wall_seconds": self.wall_seconds,
            }

CURRENT_METER: Optional[UsageMeter] = None

def openai_chat_with_usage(model: str, system_prompt: str, user_prompt: str) -> str:
    global CURRENT_METER
    t0 = time.time()
    resp = OPENAI_CLIENT.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": user_prompt},
        ],
        temperature=0.2,
        max_tokens=2048,
    )
    dt = time.time() - t0
    in_toks, out_toks = safe_usage_tokens(resp)
    if CURRENT_METER is not None:
        CURRENT_METER.add(in_toks, out_toks, dt)
    return (resp.choices[0].message.content or "").strip()

## 8. Translation Function

In [94]:
def translate_segment_with_rag(src_text: str, lang: str, precomputed_constraints: Optional[List[str]] = None) -> str:
    # 0) TM exact match
    tm_hit = tm_lookup(src_text, lang)
    if tm_hit:
        return tm_hit

    # 1) Build constraints + cache key
    constraints = precomputed_constraints if precomputed_constraints is not None else build_constraints(src_text, lang, top_k=3)
    constraints_sorted = tuple(sorted(constraints)) if constraints else tuple()
    cache_key = (lang, src_text, constraints_sorted)
    if cache_key in CACHE_TRANSLATIONS:
        return CACHE_TRANSLATIONS[cache_key]

    # 2) Few-shots + constraint text
    fewshots = render_fewshots(lang)
    constraint_text = ""
    if constraints_sorted:
        constraint_text = "Use these glossary mappings exactly when relevant:\n- " + "\n- ".join(constraints_sorted) + "\n"

    # 3) Prompts
    system_prompt = (
        "You are a precise, format-strict translator. "
        "Reply with only the target text enclosed in <translation>...</translation>."
    )
    user_prompt = f"""
You are a professional localization translator. Translate the source into {lang}.
Requirements:
- Preserve all HTML tags exactly (do not add/remove/reorder tags).
- Keep brand names and DNT terms as-is (case-sensitive), e.g., NaiLit.
- Be natural, fluent, and consistent with terminology.
- Follow the glossary mappings when relevant (do not hallucinate).
- Return ONLY the translation between <translation> and </translation>. Do not add notes.

{fewshots}

{constraint_text}Source:
{src_text}

<translation>
""".strip()

    model_name = (
        (OPENAI_MODELS.get("gpt-4o-mini", {}).get("model") if "OPENAI_MODELS" in globals() else None)
        or (globals().get("RAG_MODEL", {}).get("config", {}).get("model"))
        or os.getenv("OPENAI_RAG_MODEL", "gpt-4o-mini")
    )

    # 4) Call OpenAI with retries
    def _call_openai(prompt: str) -> str:
        return openai_chat_with_usage(model_name, system_prompt, prompt)

    raw = None
    for attempt in range(3):
        try:
            raw = _call_openai(user_prompt)
            break
        except Exception as e:
            if attempt == 2:
                print(f"⚠️ RAG call failed (final): {e}")
                CACHE_TRANSLATIONS[cache_key] = "[RAG_TRANSLATION_ERROR]"
                return CACHE_TRANSLATIONS[cache_key]
            _backoff_sleep(attempt)

    # 5) Extract inside <translation>…</translation>
    m = re.search(r"<translation>([\s\S]*?)</translation>", raw or "")
    translation = (m.group(1).strip() if m else (raw or "").strip())

    # 6) Validate tags; one corrective retry if needed
    if not tags_preserved(src_text, translation):
        retry_prompt = user_prompt.replace(
            "<translation>",
            "IMPORTANT: Copy every HTML tag exactly as in the source. Only the translation.\n\n<translation>"
        )
        try:
            raw2 = _call_openai(retry_prompt)
            m2 = re.search(r"<translation>([\s\S]*?)</translation>", raw2 or "")
            translation2 = (m2.group(1).strip() if m2 else (raw2 or "").strip())
            if tags_preserved(src_text, translation2):
                translation = translation2
        except Exception:
            pass

    CACHE_TRANSLATIONS[cache_key] = translation
    return translation

## 8. End-to-End RAG Batch T9N Pipeline Config

In [95]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import os, json, time
from pathlib import Path
from typing import Dict, List, Tuple, Any
from tqdm import tqdm

# Config
TARGET_LANGUAGES = ["fr", "ja", "it"]
LANGUAGE_NAMES = {"fr": "French", "ja": "Japanese", "it": "Italian"}
MAX_WORKERS = max(1, int(os.getenv("RAG_MAX_WORKERS", "4")))  # 1 = sequential (safer for rate limits)
MODEL_NAME = (globals().get("RAG_MODEL", {}).get("name") or "gpt-4o-mini")

print("Pipeline configuration:")
print(f" Languages: {', '.join(TARGET_LANGUAGES)}")
print(f" Segments: {len(en_segments)}")
print(f" Max workers: {MAX_WORKERS}")
print("=" * 70)

def translate_all_segments_rag(target_lang: str) -> List[Dict[str, Any]]:
    start = time.time()

    # --- Deduplicate by source text
    src_to_paths: Dict[str, List[str]] = {}
    for path, src in en_segments:
        src_to_paths.setdefault(src, []).append(path)
    unique_srcs = list(src_to_paths.keys())
    uniq_count = len(unique_srcs)

    # Precompute constraints and TM flags once
    src_to_constraints: Dict[str, List[str]] = {}
    src_tm_hit: Dict[str, bool] = {}
    total_constraints_used = 0
    for src in unique_srcs:
        cons = build_constraints(src, target_lang, top_k=3)
        src_to_constraints[src] = cons
        total_constraints_used += len(cons)
        src_tm_hit[src] = bool(tm_lookup(src, target_lang))

    results_map: Dict[str, str] = {}
    tm_hits = sum(1 for v in src_tm_hit.values() if v)

    if MAX_WORKERS == 1:
        for src in tqdm(unique_srcs, total=uniq_count, desc=f"RAG → {target_lang} (seq)"):
            try:
                results_map[src] = translate_segment_with_rag(src, target_lang, src_to_constraints[src])
            except Exception as e:
                snippet = (src[:60] + "…") if len(src) > 60 else src
                print(f"⚠️ Translate failed: {e} | src≈ {snippet!r}")
                results_map[src] = "[RAG_TRANSLATION_ERROR]"
    else:
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
            futures = {
                ex.submit(translate_segment_with_rag, src, target_lang, src_to_constraints[src]): src
                for src in unique_srcs
            }
            for fut in tqdm(as_completed(futures), total=len(futures), desc=f"RAG → {target_lang}"):
                src = futures[fut]
                try:
                    results_map[src] = fut.result()
                except Exception as e:
                    snippet = (src[:60] + "…") if len(src) > 60 else src
                    print(f"⚠️ Translate failed: {e} | src≈ {snippet!r}")
                    results_map[src] = "[RAG_TRANSLATION_ERROR]"

    # --- Expand back to all segments
    out: List[Dict[str, Any]] = []
    now = time.strftime("%Y-%m-%d %H:%M:%S")
    for path, src in en_segments:
        constraints = src_to_constraints.get(src, [])
        out.append({
            "path": path,
            "source": src,
            "translation": results_map.get(src, "[RAG_TRANSLATION_ERROR]"),
            "model": MODEL_NAME,
            "approach": "RAG",
            "target_lang": target_lang,
            "tm_hit": src_tm_hit.get(src, False),
            "constraints_found": len(constraints),
            "constraints_list": constraints,
            "timestamp": now,
            **({"source_sha": SOURCE_SHA} if SOURCE_SHA else {}),
        })

    dur = time.time() - start

    # --- Save
    out_dir = Path("translations/rag")
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"{target_lang}.json"
    with open(out_file, "w", encoding="utf-8") as f:
        json.dump(out, f, ensure_ascii=False, indent=2)

    seg_per_sec = len(out) / max(dur, 1e-6)
    print(f"✅ RAG translation completed: {target_lang}")
    print(f"   • Unique sources: {uniq_count}  | Segments: {len(out)}")
    print(f"   • Duration: {dur:.1f}s  | Speed: {seg_per_sec:.2f} seg/s")
    print(f"   • TM hits (unique-src): {tm_hits}/{uniq_count} ({tm_hits/max(uniq_count,1):.1%})")
    print(f"   • Constraints found (unique-src): {total_constraints_used}")
    print(f"   • Saved: {out_file}")

    return out

Pipeline configuration:
 Languages: fr, ja, it
 Segments: 76
 Max workers: 4


## 9. Execute RAG-enhanced T9N for All Languages

In [96]:
RAG_RESULTS: Dict[str, List[Dict[str, Any]]] = {}
RAG_SUMMARY: Dict[str, Any] = {}
RAG_METERS: Dict[str, Dict[str, float]] = {}
summary_rows: List[Dict[str, Any]] = []

for lang in TARGET_LANGUAGES:
    # Optional: keep this if you want a small progress cue
    print(f"\nStarting {LANGUAGE_NAMES[lang]} ({lang})…")
    CURRENT_METER = UsageMeter()  # fresh meter for this language

    t0 = time.time()
    try:
        res = translate_all_segments_rag(lang)   # writes translations/rag/{lang}.json
        dt = time.time() - t0

        usage = CURRENT_METER.snapshot()
        RAG_METERS[lang] = {**usage}  # tokens + wall time only

        RAG_RESULTS[lang] = res

        total_segments = len(res)
        unique_sources = len({r["source"] for r in res})
        tm_hits = sum(1 for r in res if r.get("tm_hit"))
        errors = sum(1 for r in res if "[RAG_TRANSLATION_ERROR]" in (r.get("translation") or ""))
        total_constraints = sum(r.get("constraints_found", 0) for r in res)
        out_file = Path("translations/rag") / f"{lang}.json"

        RAG_SUMMARY[lang] = {
            "language": lang,
            "total_segments": total_segments,
            "unique_sources": unique_sources,
            "tm_hits_segments": tm_hits,
            "tm_hit_rate_segments": (tm_hits / total_segments) if total_segments else 0.0,
            "errors": errors,
            "error_rate": (errors / total_segments) if total_segments else 0.0,
            "total_constraints": total_constraints,
            "avg_constraints_per_segment": (total_constraints / total_segments) if total_segments else 0.0,
            "duration_sec": round(dt, 1),
            "segments_per_sec": (total_segments / dt) if dt > 0 else 0.0,
            "output_file": str(out_file),
            "input_tokens": usage["input_tokens"],
            "output_tokens": usage["output_tokens"],
            "wall_seconds_metered": usage["wall_seconds"],
        }
        summary_rows.append(RAG_SUMMARY[lang])

    except Exception as e:
        print(f"❌ Failed {LANGUAGE_NAMES[lang]}: {e}")
        RAG_RESULTS[lang] = []
        RAG_SUMMARY[lang] = {"error": str(e), "language": lang}
        summary_rows.append(RAG_SUMMARY[lang])

# Optional concise recap (no per-metric details)
print("\n" + "=" * 70)
print("RAG TRANSLATIONS COMPLETE")
print("=" * 70)
for lang in TARGET_LANGUAGES:
    s = RAG_SUMMARY.get(lang, {})
    if "error" in s:
        print(f"❌ {LANGUAGE_NAMES[lang]} ({lang}): {s['error']}")
    else:
        print(f"✅ {LANGUAGE_NAMES[lang]} ({lang}) — saved: {s['output_file']}")

# Save a run summary for later analysis
Path("eval").mkdir(parents=True, exist_ok=True)
summary_path = Path("eval/rag_run_summary.csv")
pd.DataFrame(summary_rows).to_csv(summary_path, index=False)
print(f"\nRun summary saved to: {summary_path}")


Starting French (fr)…


RAG → fr: 100%|████████████████████████████████████████████████████████████████████████| 73/73 [00:15<00:00,  4.63it/s]


✅ RAG translation completed: fr
   • Unique sources: 73  | Segments: 76
   • Duration: 24.9s  | Speed: 3.05 seg/s
   • TM hits (unique-src): 2/73 (2.7%)
   • Constraints found (unique-src): 203
   • Saved: translations\rag\fr.json

Starting Japanese (ja)…


RAG → ja: 100%|████████████████████████████████████████████████████████████████████████| 73/73 [00:16<00:00,  4.44it/s]


✅ RAG translation completed: ja
   • Unique sources: 73  | Segments: 76
   • Duration: 38.3s  | Speed: 1.98 seg/s
   • TM hits (unique-src): 2/73 (2.7%)
   • Constraints found (unique-src): 203
   • Saved: translations\rag\ja.json

Starting Italian (it)…


RAG → it: 100%|████████████████████████████████████████████████████████████████████████| 73/73 [00:16<00:00,  4.38it/s]

✅ RAG translation completed: it
   • Unique sources: 73  | Segments: 76
   • Duration: 38.2s  | Speed: 1.99 seg/s
   • TM hits (unique-src): 2/73 (2.7%)
   • Constraints found (unique-src): 203
   • Saved: translations\rag\it.json

RAG TRANSLATIONS COMPLETE
✅ French (fr) — saved: translations\rag\fr.json
✅ Japanese (ja) — saved: translations\rag\ja.json
✅ Italian (it) — saved: translations\rag\it.json

Run summary saved to: eval\rag_run_summary.csv





## 10. Baseline Loader

In [102]:
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import os, json

def _choose_baseline_dir(base_dir: Path, langs: List[str], prefer: Tuple[str, ...] = ()) -> Optional[Path]:
    if not base_dir.exists():
        return None
    candidates = []
    for d in base_dir.iterdir():
        if not d.is_dir():
            continue
        if all((d / f"{lang}.json").exists() for lang in langs):
            pref_score = 0
            dn = d.name.lower()
            for i, tok in enumerate(prefer):
                if tok and tok.lower() in dn:
                    pref_score = max(pref_score, len(prefer) - i)
            candidates.append((pref_score, d.stat().st_mtime, d))
    if not candidates:
        return None
    candidates.sort(key=lambda t: (t[0], t[1]), reverse=True)
    return candidates[0][2]

def load_baseline_results(target_languages: Optional[List[str]] = None) -> Dict[str, List[Dict]]:
    langs = target_languages or TARGET_LANGUAGES  # falls back to global
    base_dir = Path("translations/baseline")
    prefer_tokens = tuple(os.getenv("BASELINE_MODEL_HINT", "gpt openai claude gemini").split())

    chosen = _choose_baseline_dir(base_dir, langs, prefer=prefer_tokens)
    if not chosen:
        print("⚠️ No baseline results directory with all target languages found in translations/baseline")
        return {}

    print(f"Loading baseline from: {chosen}")
    out: Dict[str, List[Dict]] = {}
    for lang in langs:
        f = chosen / f"{lang}.json"
        try:
            with open(f, "r", encoding="utf-8") as fh:
                out[lang] = json.load(fh)
            print(f"  • {lang}: {len(out[lang])} segments")
        except Exception as e:
            print(f"  ❌ Failed to load {lang}: {e}")
            out[lang] = []
    return out

## 11. Quality Evaluation

In [103]:
def comprehensive_rag_evaluation():
    print("\nCOMPREHENSIVE RAG EVALUATION (with baseline latency/speed compare)")
    print("=" * 72)

    # Handle both return styles: dict OR (dict, chosen_path)
    _loaded = load_baseline_results()
    if isinstance(_loaded, tuple):
        baseline_results, _baseline_dir = _loaded
    else:
        baseline_results, _baseline_dir = _loaded, None

    baseline_eval    = _load_baseline_eval_metrics()  # for baseline latency/speed
    results_rows = []
    compare_rows  = []

    rag_model_name = (
        (OPENAI_MODELS.get("gpt-4o-mini", {}).get("model") if "OPENAI_MODELS" in globals() else None)
        or (globals().get("MODEL_NAME"))
        or os.getenv("OPENAI_RAG_MODEL", "gpt-4o-mini")
    )

    for lang in TARGET_LANGUAGES:
        if lang not in RAG_RESULTS:
            print(f"⚠️ No RAG results for {lang}")
            continue

        rag  = RAG_RESULTS[lang] or []
        base = (baseline_results or {}).get(lang, [])
        base_lookup = {x.get("path"): x.get("translation", "") for x in base if "path" in x}

        seg_rows = []
        rag_hyps, base_refs = [], []

        for item in rag:
            src = item.get("source", "")
            tgt = item.get("translation", "")
            path = item.get("path", "")
            retrieved = item.get("constraints_list", [])

            seg_rows.append({
                "path": path,
                "dnt_preserved": dnt_preserved(src, tgt),
                "glossary_adherence": glossary_adherence(src, tgt, lang),
                "tag_preserved": tags_preserved(src, tgt),
                "retrieval_precision": retrieval_precision([t.split(" → ")[0] for t in retrieved], src),
                "tm_hit": item.get("tm_hit", False),
                "constraints_count": len(retrieved),
                "has_error": "[RAG_TRANSLATION_ERROR]" in (tgt or ""),
            })

            if path in base_lookup:
                rag_hyps.append(tgt or "")
                base_refs.append(base_lookup[path] or "")

        tot = len(seg_rows)
        if tot == 0:
            print(f"⚠️ No evaluable segments for {lang}")
            continue

        dnt_rate   = sum(r["dnt_preserved"] for r in seg_rows) / tot
        gloss_avg  = sum(r["glossary_adherence"] for r in seg_rows) / tot
        tag_rate   = sum(r["tag_preserved"] for r in seg_rows) / tot
        ret_prec   = sum(r["retrieval_precision"] for r in seg_rows) / tot
        tm_hits    = sum(r["tm_hit"] for r in seg_rows)
        avg_constr = sum(r["constraints_count"] for r in seg_rows) / tot
        err_rate   = sum(r["has_error"] for r in seg_rows) / tot

        # Semantic similarity to baseline (cosine via embedding)
        sem_sim = None
        if rag_hyps and base_refs:
            try:
                h_emb = EMB_MODEL.encode(rag_hyps, batch_size=64, normalize_embeddings=True)
                r_emb = EMB_MODEL.encode(base_refs, batch_size=64, normalize_embeddings=True)
                sims = (h_emb * r_emb).sum(axis=1).astype(float)
                sims = np.clip(sims, -1.0, 1.0)
                sem_sim = float(np.mean(sims))
            except Exception as e:
                print(f"  ⚠️ Semantic similarity failed for {lang}: {e}")
                sem_sim = None

        # RAG latency/speed from RAG_SUMMARY (measured)
        rag_lat   = (RAG_SUMMARY.get(lang) or {}).get("duration_sec")
        rag_speed = (RAG_SUMMARY.get(lang) or {}).get("segments_per_sec")

        # Baseline latency/speed from saved metrics (if present)
        base_lat   = None
        base_speed = None
        if lang in baseline_eval:
            base_lat   = baseline_eval[lang].get("total_duration_sec")
            base_speed = baseline_eval[lang].get("segments_per_minute")
            if isinstance(base_speed, (int, float)) and base_speed:
                base_speed = round(base_speed / 60.0, 4)  # seg/sec

        row = {
            "language": lang,
            "total_segments": tot,
            "dnt_preservation_rate": round(dnt_rate, 3),
            "glossary_adherence_avg": round(gloss_avg, 3),
            "tag_preservation_rate": round(tag_rate, 3),
            "retrieval_precision_avg": round(ret_prec, 3),
            "semantic_similarity_avg": round(sem_sim, 3) if sem_sim is not None else None,
            "tm_hits_found": int(tm_hits),
            "tm_entries_available": len(TM_DICT.get(lang, {})),
            "avg_constraints_per_segment": round(avg_constr, 2),
            "error_rate": round(err_rate, 3),
            "evaluation_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        }
        results_rows.append(row)

        compare_rows.append({
            "language": lang,
            "rag_model": rag_model_name,
            "rag_duration_sec": rag_lat,
            "rag_segments_per_sec": rag_speed,
            "baseline_duration_sec": base_lat,
            "baseline_segments_per_sec": base_speed,
            # headline quality signal
            "rag_tag_preservation": row["tag_preservation_rate"],
        })

        def _fmt(x, unit=""):
            return ("n/a" if x is None else f"{x:.2f}{unit}")

        print(f"\n{LANGUAGE_NAMES[lang]} ({lang.upper()})")
        print("-" * 40)
        print(f"  • DNT: {row['dnt_preservation_rate']:.1%} | Glossary: {row['glossary_adherence_avg']:.1%} | Tags: {row['tag_preservation_rate']:.1%}")
        print(f"  • Retrieval precision: {row['retrieval_precision_avg']:.1%}")
        if row["semantic_similarity_avg"] is not None:
            print(f"  • Semantic similarity: {row['semantic_similarity_avg']:.3f}")
        print("  • Latency & Speed")
        print(f"     - RAG:       {_fmt(rag_lat,'s')}  | speed {_fmt(rag_speed,' seg/s')}")
        print(f"     - Baseline:  {_fmt(base_lat,'s')} | speed {_fmt(base_speed,' seg/s')}")

    # Save aggregated quality table
    if results_rows:
        rag_df = pd.DataFrame(results_rows)
        out1 = Path("eval/rag_comprehensive_evaluation.csv")
        out1.parent.mkdir(parents=True, exist_ok=True)
        rag_df.to_csv(out1, index=False)
        print(f"\nSaved RAG quality table: {out1}")
    else:
        rag_df = None

    # Save latency/speed compare
    if compare_rows:
        compare_df = pd.DataFrame(compare_rows)
        out2 = Path("eval/rag_quality_vs_baseline.csv")
        out2.parent.mkdir(parents=True, exist_ok=True)
        compare_df.to_csv(out2, index=False)
        print(f"Saved RAG vs Baseline latency/speed table: {out2}")
    else:
        compare_df = None

    # Overall rollup
    if rag_df is not None and not rag_df.empty:
        print("\nOVERALL RAG PERFORMANCE SUMMARY")
        print("=" * 50)
        print(f"Languages evaluated: {len(rag_df)}")
        print(f"Avg DNT: {rag_df['dnt_preservation_rate'].mean():.1%}")
        print(f"Avg Glossary: {rag_df['glossary_adherence_avg'].mean():.1%}")
        print(f"Avg Tags: {rag_df['tag_preservation_rate'].mean():.1%}")
        print(f"Avg Retrieval precision: {rag_df['retrieval_precision_avg'].mean():.1%}")

    return rag_df, compare_df

# Run the integrated evaluation
RAG_DF, RAG_VS_BASELINE_DF = comprehensive_rag_evaluation()



COMPREHENSIVE RAG EVALUATION (with baseline latency/speed compare)
Loading baseline from: translations\baseline\gpt-4o-mini
  • fr: 76 segments
  • ja: 76 segments
  • it: 76 segments
Using baseline eval metrics from: eval\baseline\claude-3-5-sonnet

French (FR)
----------------------------------------
  • DNT: 100.0% | Glossary: 97.4% | Tags: 98.7%
  • Retrieval precision: 18.4%
  • Semantic similarity: 0.972
  • Latency & Speed
     - RAG:       24.90s  | speed 3.05 seg/s
     - Baseline:  263.57s | speed 0.29 seg/s

Japanese (JA)
----------------------------------------
  • DNT: 100.0% | Glossary: 99.3% | Tags: 96.1%
  • Retrieval precision: 18.4%
  • Semantic similarity: 0.967
  • Latency & Speed
     - RAG:       38.30s  | speed 1.98 seg/s
     - Baseline:  181.96s | speed 0.42 seg/s

Italian (IT)
----------------------------------------
  • DNT: 100.0% | Glossary: 94.3% | Tags: 97.4%
  • Retrieval precision: 18.4%
  • Semantic similarity: 0.976
  • Latency & Speed
     - RAG:   