# README — SFT Data Generation (supervised_fine_tuning)

**Purpose:** Generate high-quality SFT training examples from PDF text using a teacher + (optional) auditor pipeline. This notebook supports caching, batching, and configurable concurrency so you can balance speed vs. strict auditing. 

### Quick start ✅
- Install dependencies (local, no docker required):

```bash
pip install -r requirements.txt
```

- Copy `.env.example` to `.env` and edit values (do NOT commit your `.env`):

```bash
cp .env.example .env
# edit .env to match your environment
```

- Prompts and templates are stored in `config/prompts.json` and are loaded by the notebook at runtime; edit that file to customize system prompts or templates.

- `.env` recommendations:

```text
OLLAMA_URL=http://localhost:11434
TEACHER_MODEL=qwen2.5:72b-instruct
AUDITOR_MODEL=deepseek-r1:70b
# Optional: run Redis for shared cache
REDIS_URL=redis://localhost:6379/0
MAX_LLM_CONCURRENCY=8
USE_SINGLE_CALL=1         # recommended for speed
USE_BATCHING=0            # optional: 0/1
BATCH_SIZE=4
AUDIT_SAMPLE_RATE=0.05    # sample strict audits when using single-call
```

### Prompts & Templates ✅
- Location: `config/prompts.json`.
- What it contains: system prompts and small templates used by the pipeline (keys include `system_gen`, `gen_prompt_template`, `audit_system`, `audit_prompt_template`, `single_call_system`, `single_call_prompt_template`, `batch_system`, `batch_block_template`).
- How it works: the notebook loads `config/prompts.json` at runtime; if the file is missing the notebook falls back to safe built-in defaults so nothing breaks.
- Editing tips:
  - Edit `config/prompts.json` with your custom wording and keep values valid JSON.
  - Templates use `{chunk}` and `{generated}` placeholders for prompt composition (these are substituted when the notebook runs).
  - After editing, re-run the top configuration cells (or restart the kernel and run top cells) to pickup changes.
- Example: modify the `system_gen` value to shift the teacher's style or constraints, or adjust `audit_prompt_template` to change strictness.


# --- ORGANIZED ENTRYPOINTS (New ordering) ---
# Use the cells that follow as the canonical, ordered implementation. Older, scattered cells remain below for history but **do not** run them.


In [None]:
# Imports & Configuration (ordered)
import os
import sys
import json
import time
from pathlib import Path
from dotenv import load_dotenv
from datetime import datetime

# Load environment
load_dotenv()
print("Current working directory:", os.getcwd())
print(".env exists?", os.path.exists(".env"))

# Core endpoints & models
OLLAMA_URL = os.getenv("OLLAMA_URL")
if OLLAMA_URL is None:
    raise RuntimeError("Error: OLLAMA_URL is not set in the environment variables. Please add OLLAMA_URL to your .env file")
OLLAMA_URL = OLLAMA_URL.rstrip("/")
TEACHER_MODEL = os.getenv("TEACHER_MODEL", "qwen2.5:72b-instruct")
AUDITOR_MODEL = os.getenv("AUDITOR_MODEL", "deepseek-r1:70b")

# Cache & concurrency defaults
REDIS_URL = os.getenv("REDIS_URL") or None
CACHE_DIR = os.getenv("CACHE_DIR", str(Path.cwd() / "cache"))
CHUNK_TTL = int(os.getenv("CHUNK_TTL", 60 * 60 * 24))
SFT_TTL = int(os.getenv("SFT_TTL", 60 * 60 * 24 * 7))
MAX_LLM_CONCURRENCY = int(os.getenv("MAX_LLM_CONCURRENCY", 8))
USE_SINGLE_CALL = os.getenv("USE_SINGLE_CALL", "1") in ["1", "true", "True", True]
AUDIT_SAMPLE_RATE = float(os.getenv("AUDIT_SAMPLE_RATE", "0.05"))
USE_BATCHING = os.getenv("USE_BATCHING", "0") in ["1", "true", "True", True]
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "4"))
BATCH_CONCURRENCY = int(os.getenv("BATCH_CONCURRENCY", "2"))
MAX_BATCH_CHARS = int(os.getenv("MAX_BATCH_CHARS", "20000"))

# PATH setup
try:
    SCRIPT_DIR = Path(__file__).parent.resolve()
except NameError:
    SCRIPT_DIR = Path.cwd().resolve()
RAW_DATA_DIR = SCRIPT_DIR / "data" / "raw" / "in-progress"
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
PROCESSED_DIR = SCRIPT_DIR / "data" / "processed" / TIMESTAMP
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

print(f"Input PDFs: {RAW_DATA_DIR}")
print(f"Outputs will be saved to: {PROCESSED_DIR}")

In [None]:
# Prompts, Cache backend, HTTP session & METRICS (ordered)
import warnings

# Prompts loader
PROMPTS_PATH = Path("config") / "prompts.json"
if PROMPTS_PATH.exists():
    with PROMPTS_PATH.open("r", encoding="utf-8") as f:
        PROMPTS = json.load(f)
    print(f"Loaded prompts from {PROMPTS_PATH}")
else:
    print(f"Prompts config not found at {PROMPTS_PATH}; using built-in defaults")
    PROMPTS = {}

# Cache backend init (prefer Redis, fallback to disk)
try:
    import redis
except Exception:
    redis = None
try:
    from diskcache import Cache as DiskCache
except Exception:
    DiskCache = None

_cache_client = None

def init_cache():
    global _cache_client
    if _cache_client is not None:
        return
    if REDIS_URL and redis is not None:
        try:
            _cache_client = redis.from_url(REDIS_URL)
            _cache_client.ping()
            print(f"Using Redis cache at {REDIS_URL}")
            return
        except Exception as e:
            print(f"Could not connect to Redis ({e}), falling back to DiskCache")
    if DiskCache is None:
        raise RuntimeError("No cache backend available (install diskcache or provide REDIS_URL)")
    _cache_client = DiskCache(CACHE_DIR)
    print(f"Using DiskCache at {CACHE_DIR}")

# HTTP session & retries
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import requests

SESSION = requests.Session()
retries = Retry(total=3, backoff_factor=0.6, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(pool_connections=MAX_LLM_CONCURRENCY * 2, pool_maxsize=MAX_LLM_CONCURRENCY * 2, max_retries=retries)
SESSION.mount("http://", adapter)
SESSION.mount("https://", adapter)

# METRICS
METRICS = {"calls": []}

def record_call(model: str, duration: float, success: bool, error: str | None = None):
    METRICS["calls"].append({"model": model, "duration": duration, "success": bool(success), "error": str(error) if error else None})

def summarise_metrics():
    import statistics
    by_model = {}
    for c in METRICS["calls"]:
        m = c["model"]
        by_model.setdefault(m, []).append(c)
    lines = []
    for m, calls in by_model.items():
        durations = [c["duration"] for c in calls if c["duration"] is not None]
        successes = sum(1 for c in calls if c["success"])
        total = len(calls)
        mean = statistics.mean(durations) if durations else 0
        p95 = sorted(durations)[int(len(durations) * 0.95)] if durations else 0
        lines.append(f"{m}: calls={total} success={successes} mean={mean:.2f}s p95={p95:.2f}s")
    return "\n".join(lines)


In [None]:
# Helpers (hashing, caching helpers, chunking, parsing)
import hashlib

# Hashing
def pdf_sha256(path: Path) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return h.hexdigest()

def chunk_sha256(chunk: str) -> str:
    return hashlib.sha256(chunk.encode("utf-8")).hexdigest()

# Cache get/set helpers
import json

def cache_chunks(pdf_hash: str, chunks, ttl=CHUNK_TTL):
    init_cache()
    key = f"chunks:{pdf_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        _cache_client.set(key, json.dumps(chunks, ensure_ascii=False), ex=ttl)
    else:
        _cache_client.set(key, chunks, expire=ttl)


def get_cached_chunks(pdf_hash: str):
    init_cache()
    key = f"chunks:{pdf_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        v = _cache_client.get(key)
        return json.loads(v) if v else None
    else:
        return _cache_client.get(key)


def cache_sft_pair(chunk_hash: str, pair, ttl=SFT_TTL):
    init_cache()
    key = f"sft:{chunk_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        _cache_client.set(key, json.dumps(pair, ensure_ascii=False), ex=ttl)
    else:
        _cache_client.set(key, pair, expire=ttl)


def get_cached_sft_pair(chunk_hash: str):
    init_cache()
    key = f"sft:{chunk_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        v = _cache_client.get(key)
        return json.loads(v) if v else None
    else:
        return _cache_client.get(key)

# Chunking and parsing helpers

def chunk_text_to_chunks(md_content: str, min_len: int = 300, max_len: int = 3000):
    chunks = [c.strip() for c in md_content.split("\n\n") if min_len < len(c.strip()) < max_len]
    return chunks


def _parse_json_array_or_objects(text: str):
    text = text.strip()
    try:
        parsed = json.loads(text)
        if isinstance(parsed, dict):
            return [parsed]
        if isinstance(parsed, list):
            return parsed
    except Exception:
        objs = re.findall(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
        results = []
        for o in objs:
            try:
                results.append(json.loads(o))
            except Exception:
                continue
        if results:
            return results
    return None


In [None]:
# LLM call + pipeline functions (ordered)

def call_ollama(model, prompt, system_prompt="", session=None):
    session = session or SESSION
    url = f"{OLLAMA_URL}/api/chat"
    payload = {
        "model": model,
        "messages": [
            {"role": "system", "content": system_prompt} if system_prompt else None,
            {"role": "user", "content": prompt}
        ],
        "stream": False,
        "options": {"temperature": 0.1, "num_predict": 1024}
    }
    payload["messages"] = [m for m in payload["messages"] if m is not None]
    print(f"      > Sending to {model}...", end="")
    for attempt in range(3):
        start = time.time()
        try:
            response = session.post(url, json=payload, timeout=300)
            response.raise_for_status()
            duration = time.time() - start
            data = response.json()
            raw_content = data["message"]["content"]
            print(" [Done]")
            content = raw_content.strip()
            if content.startswith("```json"):
                content = content[7:]
            if content.endswith("```"):
                content = content[:-3]
            content = content.strip()
            try:
                parsed = json.loads(content)
            except json.JSONDecodeError:
                print("      ! Direct parse failed. Trying regex extraction...")
                match = re.search(r'\{.*\}', content, re.DOTALL)
                if match:
                    parsed = json.loads(match.group(0))
                else:
                    record_call(model, duration, False, error="invalid-json")
                    return None
            record_call(model, duration, True)
            return parsed
        except Exception as e:
            duration = time.time() - start
            record_call(model, duration, False, error=str(e))
            print(f"\n      ! Attempt {attempt+1} failed: {str(e)[:200]}")
            time.sleep(3)
    record_call(model, None, False, error="all attempts failed")
    print(f"      ! All attempts failed for {model}")
    return None


def generate_and_audit(chunk):
    start = time.time()
    system_gen = PROMPTS.get("system_gen")
    gen_prompt = PROMPTS.get("gen_prompt_template", "Using only the following extract from Australian life insurance regulatory documentation, create one high-quality training example...\n\n{chunk}").format(chunk=chunk)
    raw_pair = call_ollama(TEACHER_MODEL, gen_prompt, system_gen)
    if not raw_pair:
        return None
    audit_system = PROMPTS.get("audit_system")
    audit_prompt = PROMPTS.get("audit_prompt_template", "Source Text:\n{chunk}\n\nGenerated Pair:\n{generated}\n\nVerify factual accuracy...").format(chunk=chunk, generated=json.dumps(raw_pair, indent=2))
    final_pair = call_ollama(AUDITOR_MODEL, audit_prompt, audit_system)
    pipeline_duration = time.time() - start
    record_call("pipeline:generate_and_audit", pipeline_duration, True if final_pair else False)
    return final_pair


def generate_and_audit_single(chunk):
    system_prompt = PROMPTS.get("single_call_system")
    prompt = PROMPTS.get("single_call_prompt_template", "Source Text:\n{chunk}\n\nCreate the training example and self-audit it; return only the final JSON.").format(chunk=chunk)
    start = time.time()
    out = call_ollama(TEACHER_MODEL, prompt, system_prompt)
    record_call("pipeline:single_generate_and_audit", time.time() - start, True if out else False)
    return out


def generate_and_audit_batch(chunks: list[str]):
    total_chars = sum(len(c) for c in chunks)
    if total_chars > MAX_BATCH_CHARS:
        print(f"      ! Batch too large ({total_chars} chars) — reduce BATCH_SIZE")
        return None
    system_prompt = PROMPTS.get("batch_system")
    block_template = PROMPTS.get("batch_block_template", "--- SOURCE {i} ---\\n{chunk}\\n")
    parts = [block_template.format(i=i, chunk=c) for i,c in enumerate(chunks, start=1)]
    prompt = "\n".join(parts)
    start = time.time()
    raw = call_ollama(TEACHER_MODEL, prompt, system_prompt)
    duration = time.time() - start
    record_call("pipeline:batch_generate_and_audit", duration, True if raw else False)
    if not raw:
        return None
    if isinstance(raw, list):
        return raw
    if isinstance(raw, dict):
        return [raw]
    parsed = _parse_json_array_or_objects(str(raw))
    return parsed

In [None]:
# Processing orchestration (process_chunk, process_pdfs, process_pdfs_with_batching)
import random
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed


def process_chunk(chunk: str, idx: int, semaphore: threading.BoundedSemaphore):
    chunk_hash = chunk_sha256(chunk)
    cached_pair = get_cached_sft_pair(chunk_hash)
    if cached_pair:
        print(f"      > Cache hit for chunk {idx+1}")
        return idx, json.dumps(cached_pair, ensure_ascii=False) + "\n"
    with semaphore:
        if USE_BATCHING:
            raise RuntimeError("process_chunk shouldn't be used in BATCHING mode")
        if USE_SINGLE_CALL:
            sft_pair = generate_and_audit_single(chunk)
            if sft_pair and random.random() < AUDIT_SAMPLE_RATE:
                strict_pair = generate_and_audit(chunk)
                if strict_pair:
                    sft_pair = strict_pair
        else:
            sft_pair = generate_and_audit(chunk)
    if sft_pair and isinstance(sft_pair, dict) and "instruction" in sft_pair:
        cache_sft_pair(chunk_hash, sft_pair)
        return idx, json.dumps(sft_pair, ensure_ascii=False) + "\n"
    else:
        print(f"      ! Failed chunk {idx+1}")
        return idx, None


def process_pdfs(max_chunks_per_pdf: int = None):
    converter = DocumentConverter()
    pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
    if not pdf_files:
        print(f"Error: No PDF files found in {RAW_DATA_DIR}")
        return
    print(f"Found {len(pdf_files)} PDF(s)")
    for pdf_path in pdf_files:
        pdf_stem = pdf_path.stem
        output_file = PROCESSED_DIR / f"{pdf_stem}.train.jsonl"
        print(f"\n--- Processing: {pdf_path.name} → {output_file.name} ---")
        try:
            result = converter.convert(pdf_path)
            md_content = result.document.export_to_markdown()
        except Exception as e:
            print(f"      ! Failed to convert PDF: {e}")
            continue
        pdf_hash = pdf_sha256(pdf_path)
        cached = get_cached_chunks(pdf_hash)
        if cached:
            chunks = cached
            print(f"Using {len(chunks)} cached chunks for {pdf_path.name}")
        else:
            chunks = chunk_text_to_chunks(md_content)
            cache_chunks(pdf_hash, chunks)
            print(f"Extracted and cached {len(chunks)} chunks.")
        process_count = len(chunks) if max_chunks_per_pdf is None else min(len(chunks), max_chunks_per_pdf)
        print(f"Processing {process_count} chunks (single-call={USE_SINGLE_CALL}, batching={USE_BATCHING})...")
        uncached_indices = []
        for i in range(process_count):
            chh = chunk_sha256(chunks[i])
            if get_cached_sft_pair(chh) is None:
                uncached_indices.append(i)
        results_buffer = [None] * process_count
        successful_entries = 0
        file_lock = threading.Lock()
        if uncached_indices:
            semaphore = threading.BoundedSemaphore(MAX_LLM_CONCURRENCY)
            with ThreadPoolExecutor(max_workers=MAX_LLM_CONCURRENCY) as executor:
                future_to_idx = {executor.submit(process_chunk, chunks[i], i, semaphore): i for i in uncached_indices}
                for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx), desc=f"SFT [{pdf_stem}]"):
                    try:
                        idx, line = future.result()
                    except Exception as e:
                        print(f"      ! Chunk job failed: {e}")
                        continue
                    if line:
                        results_buffer[idx] = line
        # Fill from cache
        for i in range(process_count):
            if results_buffer[i] is None:
                cached_pair = get_cached_sft_pair(chunk_sha256(chunks[i]))
                if cached_pair:
                    results_buffer[i] = json.dumps(cached_pair, ensure_ascii=False) + "\n"
        with open(output_file, "w", encoding="utf-8") as f:
            for line in results_buffer:
                if line:
                    with file_lock:
                        f.write(line)
                        f.flush()
                        successful_entries += 1
        print(f"Saved {successful_entries}/{process_count} entries → {output_file}")
    print(f"\nAll processing complete! Outputs in: {PROCESSED_DIR}")

In [None]:
# Autotuner & Benchmarks (ordered)
def _reconfigure_session_for_concurrency(concurrency: int):
    global SESSION
    import requests
    from requests.adapters import HTTPAdapter
    from urllib3.util.retry import Retry
    SESSION = requests.Session()
    retries = Retry(total=3, backoff_factor=0.6, status_forcelist=[429, 500, 502, 503, 504])
    adapter = HTTPAdapter(pool_connections=concurrency * 2, pool_maxsize=concurrency * 2, max_retries=retries)
    SESSION.mount("http://", adapter)
    SESSION.mount("https://", adapter)


def benchmark_single_call(chunks, concurrency=1, repeat=1):
    _reconfigure_session_for_concurrency(concurrency)
    METRICS["calls"].clear()
    def _run_once():
        t0 = time.perf_counter()
        with ThreadPoolExecutor(max_workers=concurrency) as ex:
            futures = [ex.submit(generate_and_audit_single, c) for c in chunks]
            results = [f.result() for f in futures]
        t1 = time.perf_counter()
        duration = t1 - t0
        success = sum(1 for r in results if r and isinstance(r, dict) and "instruction" in r)
        return duration, success
    runs = [_run_once() for _ in range(repeat)]
    total_processed = sum(s for _, s in runs)
    total_time = sum(d for d, _ in runs)
    throughput = total_processed / total_time if total_time > 0 else 0
    return {"mode": "single_call", "concurrency": concurrency, "throughput": throughput, "total_processed": total_processed, "total_time": total_time}


def benchmark_batch(chunks, batch_size=4, batch_concurrency=1, repeat=1):
    batches = []
    current = []
    current_chars = 0
    for c in chunks:
        if len(current) >= batch_size or (current_chars + len(c)) > MAX_BATCH_CHARS:
            batches.append(current)
            current = []
            current_chars = 0
        current.append(c)
        current_chars += len(c)
    if current:
        batches.append(current)
    _reconfigure_session_for_concurrency(batch_concurrency)
    METRICS["calls"].clear()
    def _run_once():
        t0 = time.perf_counter()
        with ThreadPoolExecutor(max_workers=batch_concurrency) as ex:
            futures = [ex.submit(generate_and_audit_batch, b) for b in batches]
            results = [f.result() for f in futures]
        t1 = time.perf_counter()
        duration = t1 - t0
        processed = 0
        for res in results:
            if isinstance(res, list):
                processed += len(res)
        return duration, processed
    runs = [_run_once() for _ in range(repeat)]
    total_processed = sum(p for _, p in runs)
    total_time = sum(d for d, _ in runs)
    throughput = total_processed / total_time if total_time > 0 else 0
    return {"mode": "batch", "batch_size": batch_size, "batch_concurrency": batch_concurrency, "throughput": throughput, "total_processed": total_processed, "total_time": total_time}


def autotune(chunks, concurrency_options=[1,2,4], batch_sizes=[None,2,4], batch_concurrency_options=[1,2], repeat=1):
    results = []
    print(f"Autotune: testing {len(chunks)} chunks | single-call conc: {concurrency_options} | batch_sizes: {batch_sizes} | batch_conc: {batch_concurrency_options}")
    for conc in concurrency_options:
        r = benchmark_single_call(chunks, concurrency=conc, repeat=repeat)
        print(f"single_call conc={conc} -> throughput={r['throughput']:.3f} chunks/sec")
        results.append(r)
    for bsize in [b for b in batch_sizes if b is not None]:
        for bconc in batch_concurrency_options:
            r = benchmark_batch(chunks, batch_size=bsize, batch_concurrency=bconc, repeat=repeat)
            print(f"batch bsize={bsize} bconc={bconc} -> throughput={r['throughput']:.3f} chunks/sec")
            results.append(r)
    best = max(results, key=lambda x: x["throughput"]) if results else None
    sorted_results = sorted(results, key=lambda x: x["throughput"], reverse=True)
    print("\nTop configs:")
    for s in sorted_results[:5]:
        print(s)
    print("\nBest config:", best)
    return {"best": best, "all": sorted_results}

# Probe helper (keeps the probe cell minimal)
PROBE_CHUNKS = 6
pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
if pdf_files:
    pdf_path = pdf_files[0]
    pdf_hash = pdf_sha256(pdf_path)
    probe_chunks = (get_cached_chunks(pdf_hash) or [])[:PROBE_CHUNKS]
    print(f"Probe chunks ready: {len(probe_chunks)} from {pdf_path.name}")
else:
    probe_chunks = []


In [None]:
# Data preview & Main run (ordered)
import random, statistics
from collections import Counter

def load_jsonl(path):
    valid = []
    invalid = []
    with open(path, "r", encoding="utf-8") as f:
        for i, raw in enumerate(f):
            line = raw.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                valid.append(obj)
            except Exception as e:
                invalid.append((i, str(e), line[:300]))
    return valid, invalid

# Main run helper
def run_all(max_chunks_per_pdf: int = None):
    if USE_BATCHING:
        process_pdfs_with_batching(max_chunks_per_pdf)
    else:
        process_pdfs(max_chunks_per_pdf)

# Quick preview: latest processed file
def quick_preview():
    files = sorted(PROCESSED_DIR.glob("*.train.jsonl"), key=lambda p: p.stat().st_mtime, reverse=True)
    if not files:
        print("No processed files found in:", PROCESSED_DIR)
        return
    path = files[0]
    print("Previewing:", path)
    entries, invalid = load_jsonl(path)
    print(f"Total lines: {sum(1 for _ in open(path,'r',encoding='utf-8'))}")
    print(f"Valid: {len(entries)} | Invalid: {len(invalid)}")
    print("Basic stats:", {"count": len(entries)})
    print("LLM metrics:\n", summarise_metrics())


In [None]:
import os
from dotenv import load_dotenv
print("Current working directory:", os.getcwd())
print(".env exists?", os.path.exists(".env"))
load_dotenv()

OLLAMA_URL = os.getenv("OLLAMA_URL")
print("OLLAMA_URL:", OLLAMA_URL)
TEACHER_MODEL = os.getenv("TEACHER_MODEL")
print("TEACHER_MODEL:", TEACHER_MODEL)
AUDITOR_MODEL = os.getenv("AUDITOR_MODEL")
print("AUDITOR_MODEL:", AUDITOR_MODEL)

In [None]:
def log(msg, end="\n"):
    """Lightweight logger with immediate flush used across the notebook."""
    import sys
    sys.stdout.write(f"{msg}{end}")
    sys.stdout.flush()

In [None]:
import json
import requests
import time
import os
import sys
import re
import hashlib
from pathlib import Path
from docling.document_converter import DocumentConverter
# Safe tqdm import: prefer notebook tqdm when ipywidgets is available, otherwise fallback to console tqdm
import warnings
try:
    import ipywidgets  # type: ignore
    from tqdm.notebook import tqdm as tqdm
except Exception:
    from tqdm import tqdm as tqdm
try:
    from tqdm.std import TqdmWarning
    warnings.filterwarnings("ignore", category=TqdmWarning)
except Exception:
    pass

from dotenv import load_dotenv
from datetime import datetime

# --- CONFIGURATION (lightweight) ---
load_dotenv()

OLLAMA_URL = os.getenv("OLLAMA_URL")
if OLLAMA_URL is None:
    print("Error: OLLAMA_URL is not set in the environment variables.")
    print("Please add OLLAMA_URL=http://your-ip:11434 to your .env file")
    sys.exit(1)

# Strip trailing slash if present
OLLAMA_URL = OLLAMA_URL.rstrip("/")

TEACHER_MODEL = os.getenv("TEACHER_MODEL", "qwen2.5:72b-instruct")
AUDITOR_MODEL = os.getenv("AUDITOR_MODEL", "deepseek-r1:70b")

In [None]:
# Cache and concurrency settings (separate, lightweight)
REDIS_URL = os.getenv("REDIS_URL") or None  # e.g. redis://localhost:6379/0
CACHE_DIR = os.getenv("CACHE_DIR", str(Path.cwd() / "cache"))
CHUNK_TTL = int(os.getenv("CHUNK_TTL", 60 * 60 * 24))  # 1 day
SFT_TTL = int(os.getenv("SFT_TTL", 60 * 60 * 60 * 24 * 7))  # 7 days
MAX_LLM_CONCURRENCY = int(os.getenv("MAX_LLM_CONCURRENCY", 8))
# Strategy: 'two_step' (teacher + auditor) or 'single_call' (teacher self-audits)
USE_SINGLE_CALL = os.getenv("USE_SINGLE_CALL", "1") in ["1", "true", "True", True]
# Audit sampling: when using single-call, run the two-step auditor on a small random sample to detect regressions
AUDIT_SAMPLE_RATE = float(os.getenv("AUDIT_SAMPLE_RATE", "0.05"))
# Batching options (optional speed optimization):
USE_BATCHING = os.getenv("USE_BATCHING", "0") in ["1", "true", "True", True]
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "4"))  # number of chunks to send in a single prompt
BATCH_CONCURRENCY = int(os.getenv("BATCH_CONCURRENCY", "2"))  # concurrent batch jobs
# Safety: max batch chars to avoid overly large prompts
MAX_BATCH_CHARS = int(os.getenv("MAX_BATCH_CHARS", "20000"))

In [None]:
# Load prompts configuration (config/prompts.json). If the file is missing, fall back to built-in defaults.
PROMPTS_PATH = Path("config") / "prompts.json"
if PROMPTS_PATH.exists():
    with PROMPTS_PATH.open("r", encoding="utf-8") as f:
        PROMPTS = json.load(f)
    log(f"Loaded prompts from {PROMPTS_PATH}")
else:
    log(f"Prompts config not found at {PROMPTS_PATH}; using built-in defaults")
    PROMPTS = {
        "system_gen": (
            "You are a senior Life Insurance Operations expert with deep knowledge of Australian life insurance processes, "
            "including policy administration, underwriting, claims handling, reinsurance, customer service, and regulatory compliance (APRA standards).\n\n"
            "Your task is to create ONE high-quality supervised fine-tuning example based solely on the provided regulatory text.\n"
            "The example should reflect real-world operational tasks that a life insurance company employee or specialized LLM would perform, "
            "such as explaining requirements, summarizing obligations, answering procedural questions, or guiding compliance actions.\n\n"
            "Output ONLY valid JSON in this exact format:\n{\"instruction\": \"A clear, realistic user query or task related to life insurance operations\", \"output\": \"A precise, professional, and accurate response based strictly on the source text\"}\n\n"
            "No explanations, no markdown, no extra text."
        ),
        "gen_prompt_template": (
            "Using only the following extract from Australian life insurance regulatory documentation, "
            "create one high-quality training example for fine-tuning an LLM to excel in life insurance operations:\n\n{chunk}"
        ),
        "audit_system": (
            "You are a meticulous Life Insurance Regulatory Auditor.\n"
            "Your role is to verify that the generated instruction-output pair is factually accurate, complete, and faithfully represents the source text with no hallucinations or additions. "
            "Correct any inaccuracies, improve clarity if needed, but preserve the original intent.\n\n"
            "Output ONLY the final corrected JSON in this exact format:\n{\"instruction\": \"...\", \"output\": \"...\"}\n\n"
            "No thinking steps, no preamble, no markdown."
        ),
        "audit_prompt_template": "Source Text:\n{chunk}\n\nGenerated Pair:\n{generated}\n\nVerify factual accuracy against the source. Correct errors. Return only the final valid JSON.",
        "single_call_system": (
            "You are a senior Life Insurance Operations expert with deep knowledge of Australian life insurance processes and a meticulous auditor.\n\n"
            "Create ONE high-quality supervised fine-tuning example based ONLY on the provided regulatory text. Also verify and audit the example yourself to ensure factual accuracy and strict fidelity to the source.\n\n"
            "Output ONLY the final valid JSON in this exact format:\n{\"instruction\": \"...\", \"output\": \"...\"}\n\n"
            "No explanations, no markdown, no extra text."
        ),
        "single_call_prompt_template": "Source Text:\n{chunk}\n\nCreate the training example and self-audit it; return only the final JSON.",
        "batch_system": (
            "You are a senior Life Insurance Operations expert and a strict auditor.\n\n"
            "For each of the following SOURCE TEXT blocks, create ONE high-quality supervised fine-tuning example and ensure each example is strictly fact-checked. Return a JSON array where each item corresponds to the source blocks in the same order. "
            "Each item must be an object of the form: {\"instruction\": \"...\", \"output\": \"...\"}.\nOutput only valid JSON — either a JSON array or newline-separated JSON objects."
        ),
        "batch_block_template": "--- SOURCE {i} ---\n{chunk}\n"
    }


In [53]:
# Helper: print loaded PROMPTS for quick review
import pprint
log("PROMPTS loaded from: " + str(PROMPTS_PATH))
keys = list(PROMPTS.keys())
log("PROMPTS keys: " + ", ".join(keys))
# Print truncated preview of each prompt to avoid flooding the notebook
preview = {}
for k, v in PROMPTS.items():
    if isinstance(v, str):
        preview[k] = v if len(v) < 400 else v[:400] + "... [truncated]"
    else:
        preview[k] = str(type(v))
print("\nPROMPTS preview (truncated):")
pp = pprint.PrettyPrinter(indent=2, width=120)
pp.pprint(preview)


PROMPTS loaded from: config/prompts.json
PROMPTS keys: system_gen, gen_prompt_template, audit_system, audit_prompt_template, single_call_system, single_call_prompt_template, batch_system, batch_block_template

PROMPTS preview (truncated):
{ 'audit_prompt_template': 'Source Text:\n'
                           '{chunk}\n'
                           '\n'
                           'Generated Pair:\n'
                           '{generated}\n'
                           '\n'
                           'Verify factual accuracy against the source. Correct errors. Return only the final valid '
                           'JSON.',
  'audit_system': 'You are a meticulous Life Insurance Regulatory Auditor.\n'
                  'Your role is to verify that the generated instruction-output pair is factually accurate, complete, '
                  'and faithfully represents the source text with no hallucinations or additions. Correct any '
                  'inaccuracies, improve clarity if needed, 

In [None]:
def test_ollama_connection():
    test_url = f"{OLLAMA_URL}/api/tags"  # Lists loaded models
    log("Testing Ollama connection...", end=" ")
    try:
        response = requests.get(test_url, timeout=10)
        response.raise_for_status()
        models_data = response.json()
        model_names = [m["name"] for m in models_data.get("models", [])]
        log("[OK]")
        if model_names:
            log(f"Available models: {', '.join(model_names)}")
        else:
            log("No models currently loaded on the server.")

        missing = []
        if TEACHER_MODEL not in model_names:
            missing.append(TEACHER_MODEL)
        if AUDITOR_MODEL not in model_names:
            missing.append(AUDITOR_MODEL)
        if missing:
            log(f"Warning: Required models not loaded: {', '.join(missing)}")
            log("Please pull them with: ollama pull <model>")
        return True

    except requests.exceptions.Timeout:
        log("[FAILED] Connection timeout (10s)")
        return False
    except requests.exceptions.ConnectionError:
        log("[FAILED] Cannot connect to Ollama server")
        log(f"Check if Ollama is running on {OLLAMA_URL} and network is reachable")
        return False
    except Exception as e:
        log(f"[FAILED] Unexpected error: {str(e)}")
        return False

In [None]:
# PATH SETUP (small cell)
try:
    SCRIPT_DIR = Path(__file__).parent.resolve()
except NameError:
    SCRIPT_DIR = Path.cwd().resolve()

RAW_DATA_DIR = SCRIPT_DIR / "data" / "raw" / "in-progress"
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)

TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
PROCESSED_DIR = SCRIPT_DIR / "data" / "processed" / TIMESTAMP
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

log(f"Input PDFs: {RAW_DATA_DIR}")
log(f"Outputs will be saved to: {PROCESSED_DIR}")

In [None]:
# Cache backends imports & initialization (focused)
# Prefer Redis when REDIS_URL is set; otherwise fall back to DiskCache for local use.
try:
    import redis
except Exception:
    redis = None

try:
    from diskcache import Cache as DiskCache
except Exception:
    DiskCache = None

_cache_client = None

def init_cache():
    """Initialize cache client: prefer Redis if REDIS_URL provided, otherwise DiskCache."""
    global _cache_client
    if _cache_client is not None:
        return

    if REDIS_URL and redis is not None:
        try:
            _cache_client = redis.from_url(REDIS_URL)
            _cache_client.ping()
            log(f"Using Redis cache at {REDIS_URL}")
            return
        except Exception as e:
            log(f"Could not connect to Redis ({e}), falling back to DiskCache")

    if DiskCache is None:
        log("DiskCache not available. Install with `pip install diskcache` or set REDIS_URL to a running Redis server.")
        raise RuntimeError("No cache backend available")

    _cache_client = DiskCache(CACHE_DIR)
    log(f"Using DiskCache at {CACHE_DIR}")

In [None]:
# Hashing utilities (small cell)
import hashlib

def pdf_sha256(path: Path) -> str:
    """Return SHA256 hex of a file's bytes."""
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return h.hexdigest()


def chunk_sha256(chunk: str) -> str:
    return hashlib.sha256(chunk.encode("utf-8")).hexdigest()

In [None]:
# Cache get/set helpers (small cell)
import json

def cache_chunks(pdf_hash: str, chunks, ttl=CHUNK_TTL):
    init_cache()
    key = f"chunks:{pdf_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        _cache_client.set(key, json.dumps(chunks, ensure_ascii=False), ex=ttl)
    else:
        _cache_client.set(key, chunks, expire=ttl)


def get_cached_chunks(pdf_hash: str):
    init_cache()
    key = f"chunks:{pdf_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        v = _cache_client.get(key)
        return json.loads(v) if v else None
    else:
        return _cache_client.get(key)


def cache_sft_pair(chunk_hash: str, pair, ttl=SFT_TTL):
    init_cache()
    key = f"sft:{chunk_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        _cache_client.set(key, json.dumps(pair, ensure_ascii=False), ex=ttl)
    else:
        _cache_client.set(key, pair, expire=ttl)


def get_cached_sft_pair(chunk_hash: str):
    init_cache()
    key = f"sft:{chunk_hash}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        v = _cache_client.get(key)
        return json.loads(v) if v else None
    else:
        return _cache_client.get(key)

In [None]:
# HTTP session setup (connection pooling & retries)
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import requests

SESSION = requests.Session()
retries = Retry(total=3, backoff_factor=0.6, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(pool_connections=MAX_LLM_CONCURRENCY * 2, pool_maxsize=MAX_LLM_CONCURRENCY * 2, max_retries=retries)
SESSION.mount("http://", adapter)
SESSION.mount("https://", adapter)

In [None]:
# METRICS: simple in-memory timing recorder for LLM calls
METRICS = {"calls": []}

def record_call(model: str, duration: float, success: bool, error: str | None = None):
    METRICS["calls"].append({"model": model, "duration": duration, "success": bool(success), "error": str(error) if error else None})


def summarise_metrics():
    import statistics
    by_model = {}
    for c in METRICS["calls"]:
        m = c["model"]
        by_model.setdefault(m, []).append(c)

    lines = []
    for m, calls in by_model.items():
        durations = [c["duration"] for c in calls if c["duration"] is not None]
        successes = sum(1 for c in calls if c["success"])
        total = len(calls)
        mean = statistics.mean(durations) if durations else 0
        p95 = sorted(durations)[int(len(durations) * 0.95)] if durations else 0
        lines.append(f"{m}: calls={total} success={successes} mean={mean:.2f}s p95={p95:.2f}s")
    return "\n".join(lines)


In [None]:
def call_ollama(model, prompt, system_prompt="", session=None):
    """Send a single chat request to Ollama using a shared session by default.
    Records per-call duration and success into METRICS.
    """
    session = session or SESSION
    url = f"{OLLAMA_URL}/api/chat"

    payload = {
        "model": model,
        "messages": [
            {"role": "system", "content": system_prompt} if system_prompt else None,
            {"role": "user", "content": prompt}
        ],
        "stream": False,
        "options": {
            "temperature": 0.1,
            "num_predict": 1024
        }
    }
    # Remove None from messages list
    payload["messages"] = [m for m in payload["messages"] if m is not None]

    log(f"      > Sending to {model}...", end="")

    for attempt in range(3):
        start = time.time()
        try:
            response = session.post(url, json=payload, timeout=300)
            response.raise_for_status()
            duration = time.time() - start

            data = response.json()
            raw_content = data["message"]["content"]
            log(" [Done]")

            log(f"      Raw output from {model}:\n{raw_content[:500]}...")

            content = raw_content.strip()
            if content.startswith("```json"):
                content = content[7:]
            if content.endswith("```"):
                content = content[:-3]
            content = content.strip()

            try:
                parsed = json.loads(content)
            except json.JSONDecodeError:
                log("      ! Direct parse failed. Trying regex extraction...")
                match = re.search(r'\{.*\}', content, re.DOTALL)
                if match:
                    parsed = json.loads(match.group(0))
                else:
                    log("      ! Could not extract valid JSON")
                    record_call(model, duration, False, error="invalid-json")
                    return None

            record_call(model, duration, True)
            return parsed

        except requests.exceptions.HTTPError as e:
            duration = time.time() - start
            record_call(model, duration, False, error=str(e))
            log(f"\n      ! Attempt {attempt+1} failed: HTTP {response.status_code} {str(e)}")
            if response.text:
                log(f"      Response: {response.text[:300]}")
        except Exception as e:
            duration = time.time() - start
            record_call(model, duration, False, error=str(e))
            log(f"\n      ! Attempt {attempt+1} failed: {str(e)[:200]}")
            time.sleep(3)

    record_call(model, None, False, error="all attempts failed")
    log(f"      ! All attempts failed for {model}")
    return None

In [None]:
# Chunking helper (focused)

def chunk_text_to_chunks(md_content: str, min_len: int = 300, max_len: int = 3000):
    """Split markdown content into chunks (paragraphs) with length constraints."""
    chunks = [c.strip() for c in md_content.split("\n\n") if min_len < len(c.strip()) < max_len]
    return chunks

In [None]:
def generate_and_audit(chunk):
    """Pipeline:
    1. Teacher generates a realistic, high-quality SFT example for life insurance operations.
    2. Auditor verifies factual accuracy against the source text.
    Returns a clean {"instruction": ..., "output": ...} pair.
    Records pipeline timing as `pipeline:generate_and_audit` in METRICS.
    """
    start = time.time()

    # --- TEACHER: Generate operational-focused training pair (loaded from config) ---
    system_gen = PROMPTS.get("system_gen")
    gen_prompt = PROMPTS.get(
        "gen_prompt_template",
        "Using only the following extract from Australian life insurance regulatory documentation, create one high-quality training example for fine-tuning an LLM to excel in life insurance operations:\n\n{chunk}"
    ).format(chunk=chunk)

    raw_pair = call_ollama(TEACHER_MODEL, gen_prompt, system_gen)
    if not raw_pair:
        return None

    # --- AUDITOR: Strict fact-checking against source (loaded from config) ---
    audit_system = PROMPTS.get("audit_system")
    audit_prompt = PROMPTS.get(
        "audit_prompt_template",
        "Source Text:\n{chunk}\n\nGenerated Pair:\n{generated}\n\nVerify factual accuracy against the source. Correct errors. Return only the final valid JSON."
    ).format(chunk=chunk, generated=json.dumps(raw_pair, indent=2))

    final_pair = call_ollama(AUDITOR_MODEL, audit_prompt, audit_system)

    pipeline_duration = time.time() - start
    record_call("pipeline:generate_and_audit", pipeline_duration, True if final_pair else False)

    return final_pair

In [None]:
def generate_and_audit_single(chunk):
    """Single-call pipeline: ask one model to both generate an SFT pair and audit it in the same prompt.
    May reduce total time by eliminating a second round-trip (auditor call).
    """
    system_prompt = PROMPTS.get("single_call_system")
    prompt = PROMPTS.get("single_call_prompt_template", "Source Text:\n{chunk}\n\nCreate the training example and self-audit it; return only the final JSON.").format(chunk=chunk)

    start = time.time()
    out = call_ollama(TEACHER_MODEL, prompt, system_prompt)
    record_call("pipeline:single_generate_and_audit", time.time() - start, True if out else False)
    return out

In [None]:
def _parse_json_array_or_objects(text: str):
    """Attempt to parse a response that is either a JSON array or multiple JSON objects.
    Returns list of parsed objects or None.
    """
    text = text.strip()
    try:
        parsed = json.loads(text)
        if isinstance(parsed, dict):
            return [parsed]
        if isinstance(parsed, list):
            return parsed
    except Exception:
        # Fallback: extract all {...} objects with regex
        objs = re.findall(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
        results = []
        for o in objs:
            try:
                results.append(json.loads(o))
            except Exception:
                continue
        if results:
            return results
    return None


def generate_and_audit_batch(chunks: list[str]):
    """Batch pipeline: send multiple chunks in a single prompt and return a list of SFT pairs.
    Returns a list in same order as `chunks` or None on failure.
    """
    # Validate prompt size
    total_chars = sum(len(c) for c in chunks)
    if total_chars > MAX_BATCH_CHARS:
        log(f"      ! Batch too large ({total_chars} chars) — reduce BATCH_SIZE")
        return None

    system_prompt = PROMPTS.get("batch_system")
    # Compose prompt with numbered blocks using template from config
    parts = []
    block_template = PROMPTS.get("batch_block_template", "--- SOURCE {i} ---\\n{chunk}\\n")
    for i, c in enumerate(chunks, start=1):
        parts.append(block_template.format(i=i, chunk=c))
    prompt = "\n".join(parts)

    start = time.time()
    raw = call_ollama(TEACHER_MODEL, prompt, system_prompt)
    duration = time.time() - start
    # call_ollama already recorded low-level call metrics; record batch pipeline time explicitly
    record_call("pipeline:batch_generate_and_audit", duration, True if raw else False)

    if not raw:
        return None

    # raw should be parsed JSON (list or object). If call_ollama returned dict/list already, handle it.
    if isinstance(raw, list):
        return raw
    if isinstance(raw, dict):
        # Single object — wrap
        return [raw]

    # If raw is string (call_ollama normally returns parsed), attempt parsing from raw string
    # But in our implementation call_ollama returns parsed JSON or None. We'll still be robust:
    parsed = _parse_json_array_or_objects(str(raw))
    return parsed


In [None]:
import random

def process_chunk(chunk: str, idx: int, semaphore: threading.BoundedSemaphore):
    """Process a single chunk: check cache, run pipeline (single-call or two-step), cache result.
    When USE_SINGLE_CALL is True we optionally run the strict two-step auditor on a random sample defined by AUDIT_SAMPLE_RATE.
    """
    chunk_hash = chunk_sha256(chunk)

    cached_pair = get_cached_sft_pair(chunk_hash)
    if cached_pair:
        log(f"      > Cache hit for chunk {idx+1}")
        return idx, json.dumps(cached_pair, ensure_ascii=False) + "\n"

    with semaphore:
        if USE_BATCHING:
            # batching handled at higher level; this function should not be used for batching mode
            raise RuntimeError("process_chunk shouldn't be used in BATCHING mode")

        if USE_SINGLE_CALL:
            sft_pair = generate_and_audit_single(chunk)
            # occasional strict audit checks (detect regressions without doubling runtime)
            if sft_pair and random.random() < AUDIT_SAMPLE_RATE:
                strict_pair = generate_and_audit(chunk)
                if strict_pair:
                    # Prefer strict audited pair
                    sft_pair = strict_pair
        else:
            sft_pair = generate_and_audit(chunk)

    if sft_pair and isinstance(sft_pair, dict) and "instruction" in sft_pair:
        cache_sft_pair(chunk_hash, sft_pair)
        return idx, json.dumps(sft_pair, ensure_ascii=False) + "\n"
    else:
        log(f"      ! Failed chunk {idx+1}")
        return idx, None


def process_pdfs_with_batching(max_chunks_per_pdf: int = None):
    """Process PDFs in batching mode. Groups uncached chunks in batches and sends each batch in one LLM request."""
    converter = DocumentConverter()
    pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))

    if not pdf_files:
        log(f"Error: No PDF files found in {RAW_DATA_DIR}")
        log(f"Note: Place your PDFs in: {RAW_DATA_DIR}")
        return

    log(f"Found {len(pdf_files)} PDF(s) in 'in-progress' folder.")

    for pdf_path in pdf_files:
        pdf_stem = pdf_path.stem
        output_file = PROCESSED_DIR / f"{pdf_stem}.train.jsonl"

        log(f"\n--- Processing (BATCH MODE): {pdf_path.name} → {output_file.name} ---")

        try:
            result = converter.convert(pdf_path)
            md_content = result.document.export_to_markdown()
        except Exception as e:
            log(f"      ! Failed to convert PDF: {e}")
            continue

        # Chunk and cache per-PDF
        pdf_hash = pdf_sha256(pdf_path)
        cached = get_cached_chunks(pdf_hash)
        if cached:
            chunks = cached
            log(f"Using {len(chunks)} cached chunks for {pdf_path.name}")
        else:
            chunks = chunk_text_to_chunks(md_content)
            cache_chunks(pdf_hash, chunks)
            log(f"Extracted and cached {len(chunks)} chunks.")

        process_count = len(chunks) if max_chunks_per_pdf is None else min(len(chunks), max_chunks_per_pdf)
        log(f"Processing {process_count} chunks in batch mode (BATCH_SIZE={BATCH_SIZE})...")

        # Determine uncached indices
        uncached_indices = []
        for i in range(process_count):
            chh = chunk_sha256(chunks[i])
            if get_cached_sft_pair(chh) is None:
                uncached_indices.append(i)

        # Build batches of indices but respect MAX_BATCH_CHARS
        batches = []
        current = []
        current_chars = 0
        for idx in uncached_indices:
            c = chunks[idx]
            if len(current) >= BATCH_SIZE or (current_chars + len(c)) > MAX_BATCH_CHARS:
                batches.append(current)
                current = []
                current_chars = 0
            current.append(idx)
            current_chars += len(c)
        if current:
            batches.append(current)

        successful_entries = 0
        file_lock = threading.Lock()
        results_buffer = [None] * process_count

        # Process batches with a thread pool bounded by BATCH_CONCURRENCY
        with ThreadPoolExecutor(max_workers=BATCH_CONCURRENCY) as executor:
            future_to_batch = {}
            for batch_idxs in batches:
                batch_chunks = [chunks[i] for i in batch_idxs]
                future = executor.submit(generate_and_audit_batch, batch_chunks)
                future_to_batch[future] = batch_idxs

            for future in tqdm(as_completed(future_to_batch), total=len(future_to_batch), desc=f"BATCH SFT [{pdf_stem}]"):
                batch_idxs = future_to_batch[future]
                try:
                    out_list = future.result()
                except Exception as e:
                    log(f"      ! Batch failed: {e}")
                    out_list = None

                if out_list and isinstance(out_list, list):
                    # Expect out_list to be same length (or at least same number) as batch_chunks
                    for idx_in_batch, obj in enumerate(out_list):
                        target_idx = batch_idxs[idx_in_batch] if idx_in_batch < len(batch_idxs) else None
                        if target_idx is not None and isinstance(obj, dict) and "instruction" in obj:
                            cache_sft_pair(chunk_sha256(chunks[target_idx]), obj)
                            results_buffer[target_idx] = json.dumps(obj, ensure_ascii=False) + "\n"
                else:
                    log("      ! Batch returned invalid output")

        # For any still missing (either cache hit earlier or not produced), fill from cache
        for i in range(process_count):
            if results_buffer[i] is None:
                cached_pair = get_cached_sft_pair(chunk_sha256(chunks[i]))
                if cached_pair:
                    results_buffer[i] = json.dumps(cached_pair, ensure_ascii=False) + "\n"

        # Write results in order
        with open(output_file, "w", encoding="utf-8") as f:
            for line in results_buffer:
                if line:
                    with file_lock:
                        f.write(line)
                        f.flush()
                        successful_entries += 1

        log(f"Saved {successful_entries}/{process_count} high-quality entries → {output_file}")

    log(f"\nAll processing complete! Outputs in: {PROCESSED_DIR}")

## Optional: Redis Docker Compose (use if you want a Redis server)

If you prefer Redis, add this to a `docker-compose.redis.yml` and run `docker compose -f docker-compose.redis.yml up -d`.

```yaml
version: '3.8'
services:
  redis:
    image: redis:7
    restart: unless-stopped
    ports:
      - '6379:6379'
    volumes:
      - redis-data:/data

volumes:
  redis-data:
```

Notes:
- Point `REDIS_URL` env var to `redis://localhost:6379/0` in your `.env` file to enable Redis caching.
- If `REDIS_URL` is not set or Redis is unreachable, the notebook falls back to `diskcache` (no extra service needed).

---

## Local (no-docker) alternative: DiskCache ✅

- `diskcache` is a fast, pure-Python on-disk cache (`pip install diskcache`) and is used by default if Redis is not configured.
- It provides TTLs, eviction, and works well for local experimentation without running external services.



### Quick setup & env vars

- Install dependencies:

```bash
pip install diskcache redis
```

- .env recommendations:

```
OLLAMA_URL=http://localhost:11434
TEACHER_MODEL=qwen2.5:72b-instruct
AUDITOR_MODEL=deepseek-r1:70b
# Optional: set to use redis
REDIS_URL=redis://localhost:6379/0
# Concurrency
MAX_LLM_CONCURRENCY=8
# Batching (optional)
USE_BATCHING=0
BATCH_SIZE=4
BATCH_CONCURRENCY=2
# Single-call self-audit (recommended)
USE_SINGLE_CALL=1
AUDIT_SAMPLE_RATE=0.05
```

- Run notes:
  - By default the notebook uses DiskCache (no Redis required).
  - To enable Redis, start the Redis service (Docker compose above) and set `REDIS_URL` in `.env`.
  - `USE_SINGLE_CALL=1` uses a single model call that both generates and self-audits (big speedup).
  - `USE_BATCHING=1` groups N chunks per prompt (see `BATCH_SIZE`). This reduces number of LLM requests but increases per-call latency — test and tune `BATCH_SIZE` and `BATCH_CONCURRENCY` for your machine.
  - Use `AUDIT_SAMPLE_RATE` to run strict two-step auditing on a small random sample (recommended default 5%).

---

This completes the caching + bounded concurrency + batching implementation. Next you may run the notebook and observe batch vs single-call trade-offs using the benchmark cells.


In [None]:
# Batch throughput benchmark: USE_BATCHING True, BATCH_SIZE 4
USE_BATCHING = True
AUDIT_SAMPLE_RATE = 0.0
BATCH_SIZE = 4
BATCH_CONCURRENCY = 2
MAX_LLM_CONCURRENCY = 2

# Reconfigure session
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
SESSION = requests.Session()
retries = Retry(total=3, backoff_factor=0.6, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(pool_connections=MAX_LLM_CONCURRENCY * 2, pool_maxsize=MAX_LLM_CONCURRENCY * 2, max_retries=retries)
SESSION.mount("http://", adapter)
SESSION.mount("https://", adapter)

# Prepare 8 uncached chunks
pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
pdf_path = pdf_files[0]
pdf_hash = pdf_sha256(pdf_path)
chunks = get_cached_chunks(pdf_hash) or []
if len(chunks) < 8:
    raise RuntimeError("Not enough chunks to run batch throughput test")
for i in range(8):
    chh = chunk_sha256(chunks[i])
    key = f"sft:{chh}"
    if redis is not None and isinstance(_cache_client, redis.Redis):
        _cache_client.delete(key)
    else:
        try:
            del _cache_client[key]
        except Exception:
            pass

# Clear metrics and run batch processing
METRICS["calls"].clear()
process_pdfs_with_batching(max_chunks_per_pdf=8)
print("\n--- METRICS SUMMARY (batch throughput test) ---")
print(summarise_metrics())

In [None]:
# Autotuner helpers — split into focused cells below
# The heavy lifting functions are defined in subsequent small cells for readability and stepwise execution.

In [52]:
def _reconfigure_session_for_concurrency(concurrency: int):
    """Recreate SESSION to update pool sizes for concurrency."""
    global SESSION
    import requests
    from requests.adapters import HTTPAdapter
    from urllib3.util.retry import Retry
    SESSION = requests.Session()
    retries = Retry(total=3, backoff_factor=0.6, status_forcelist=[429, 500, 502, 503, 504])
    adapter = HTTPAdapter(pool_connections=concurrency * 2, pool_maxsize=concurrency * 2, max_retries=retries)
    SESSION.mount("http://", adapter)
    SESSION.mount("https://", adapter)

In [46]:
def benchmark_single_call(chunks, concurrency=1, repeat=1):
    """Benchmark generate_and_audit_single on given chunks with specified concurrency.
    Returns throughput (chunks/sec) and summary dict.
    """
    _reconfigure_session_for_concurrency(concurrency)
    METRICS["calls"].clear()

    def _run_once():
        t0 = time.perf_counter()
        with ThreadPoolExecutor(max_workers=concurrency) as ex:
            futures = [ex.submit(generate_and_audit_single, c) for c in chunks]
            results = [f.result() for f in futures]
        t1 = time.perf_counter()
        duration = t1 - t0
        success = sum(1 for r in results if r and isinstance(r, dict) and "instruction" in r)
        return duration, success

    runs = [_run_once() for _ in range(repeat)]
    total_processed = sum(s for _, s in runs)
    total_time = sum(d for d, _ in runs)
    throughput = total_processed / total_time if total_time > 0 else 0
    return {"mode": "single_call", "concurrency": concurrency, "throughput": throughput, "total_processed": total_processed, "total_time": total_time}

In [47]:
def benchmark_batch(chunks, batch_size=4, batch_concurrency=1, repeat=1):
    """Benchmark batching pipeline using generate_and_audit_batch and return throughput.
    """
    # Build batches
    batches = []
    current = []
    current_chars = 0
    for c in chunks:
        if len(current) >= batch_size or (current_chars + len(c)) > MAX_BATCH_CHARS:
            batches.append(current)
            current = []
            current_chars = 0
        current.append(c)
        current_chars += len(c)
    if current:
        batches.append(current)

    _reconfigure_session_for_concurrency(batch_concurrency)
    METRICS["calls"].clear()

    def _run_once():
        t0 = time.perf_counter()
        with ThreadPoolExecutor(max_workers=batch_concurrency) as ex:
            futures = [ex.submit(generate_and_audit_batch, b) for b in batches]
            results = [f.result() for f in futures]
        t1 = time.perf_counter()
        duration = t1 - t0
        processed = 0
        for res in results:
            if isinstance(res, list):
                processed += len(res)
        return duration, processed

    runs = [_run_once() for _ in range(repeat)]
    total_processed = sum(p for _, p in runs)
    total_time = sum(d for d, _ in runs)
    throughput = total_processed / total_time if total_time > 0 else 0
    return {"mode": "batch", "batch_size": batch_size, "batch_concurrency": batch_concurrency, "throughput": throughput, "total_processed": total_processed, "total_time": total_time}

In [48]:
def autotune(chunks, concurrency_options=[1,2,4], batch_sizes=[None,2,4], batch_concurrency_options=[1,2], repeat=1):
    """Run benchmarks across options and return the best config by throughput."""
    results = []
    print(f"Autotune: testing {len(chunks)} chunks | single-call conc: {concurrency_options} | batch_sizes: {batch_sizes} | batch_conc: {batch_concurrency_options}")

    for conc in concurrency_options:
        r = benchmark_single_call(chunks, concurrency=conc, repeat=repeat)
        print(f"single_call conc={conc} -> throughput={r['throughput']:.3f} chunks/sec")
        results.append(r)

    for bsize in [b for b in batch_sizes if b is not None]:
        for bconc in batch_concurrency_options:
            r = benchmark_batch(chunks, batch_size=bsize, batch_concurrency=bconc, repeat=repeat)
            print(f"batch bsize={bsize} bconc={bconc} -> throughput={r['throughput']:.3f} chunks/sec")
            results.append(r)

    best = max(results, key=lambda x: x["throughput"]) if results else None
    sorted_results = sorted(results, key=lambda x: x["throughput"], reverse=True)
    print("\nTop configs:")
    for s in sorted_results[:5]:
        print(s)

    print("\nBest config:", best)
    return {"best": best, "all": sorted_results}

In [43]:
# Autotuner probe setup (small cell)
PROBE_CHUNKS = 6  # small sample by default
pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
if not pdf_files:
    raise RuntimeError("No PDFs found in RAW_DATA_DIR for autotune probe")

pdf_path = pdf_files[0]
pdf_hash = pdf_sha256(pdf_path)
chunks = get_cached_chunks(pdf_hash)
if not chunks:
    raise RuntimeError("No cached chunks found. Run the pipeline or set CACHE first")

probe_chunks = chunks[:PROBE_CHUNKS]
print(f"Using {len(probe_chunks)} probe chunks from {pdf_path.name}")

Using 6 probe chunks from Priority_Protection_Product_Disclosure_Statement.pdf


In [49]:
# Autotuner probe run (separate small cell) — runs the quick probe and reports recommendation
concurrency_options = [1, 2]
batch_sizes = [None, 2, 4]
batch_concurrency_options = [1, 2]

print("Running autotuner (quick probe). This will use the live models and may incur runtime.")
res = autotune(probe_chunks, concurrency_options=concurrency_options, batch_sizes=batch_sizes, batch_concurrency_options=batch_concurrency_options, repeat=1)

best = res.get("best")
if best and best.get("mode") == "single_call":
    print("\nRecommended: USE_SINGLE_CALL=1, MAX_LLM_CONCURRENCY=", best["concurrency"])
elif best and best.get("mode") == "batch":
    print("\nRecommended: USE_SINGLE_CALL=0, USE_BATCHING=1, BATCH_SIZE=", best["batch_size"], ", BATCH_CONCURRENCY=", best["batch_concurrency"])
else:
    print("No clear best configuration found; inspect results in `res` variable")

AUTOTUNE_RESULTS = res

Running autotuner (quick probe). This will use the live models and may incur runtime.
Autotune: testing 6 chunks | single-call conc: [1, 2] | batch_sizes: [None, 2, 4] | batch_conc: [1, 2]
      > Sending to qwen2.5:72b-instruct... [Done]
      Raw output from qwen2.5:72b-instruct:
{"instruction": "Where can detailed information about the replacement of existing Priority Protection plans be found?", "output": "Detailed information about the replacement of existing Priority Protection Income Protection and Income Protection Accident Only plans, as well as Priority Protection plans with Term Level premiums, is incorporated by reference in this PDS. The material (AIA07702-11/25) can be accessed on the aia.com.au website or requested free of charge by contacting AIA on 1800 33...
      > Sending to qwen2.5:72b-instruct... [Done]
      Raw output from qwen2.5:72b-instruct:
{"instruction": "What is the AIA Insurance Superannuation Scheme No2, and who is the trustee?", "output": "The AIA Insu

In [56]:
# Summarize AUTOTUNE_RESULTS (top 5) for quick review
from pprint import pprint
try:
    best = AUTOTUNE_RESULTS.get("best")
    print("Best configuration recommendation:\n", best)
    print('\nTop 5 configs:')
    for s in AUTOTUNE_RESULTS.get("all", [])[:5]:
        pprint(s)
except Exception as e:
    print("No AUTOTUNE_RESULTS found:", e)


Best configuration recommendation:
 {'mode': 'single_call', 'concurrency': 2, 'throughput': 0.11300484792212409, 'total_processed': 6, 'total_time': 53.09506724998937}

Top 5 configs:
{'concurrency': 2,
 'mode': 'single_call',
 'throughput': 0.11300484792212409,
 'total_processed': 6,
 'total_time': 53.09506724998937}
{'concurrency': 1,
 'mode': 'single_call',
 'throughput': 0.10483676233695516,
 'total_processed': 6,
 'total_time': 57.23183229099959}
{'batch_concurrency': 2,
 'batch_size': 4,
 'mode': 'batch',
 'throughput': 0.0953141482381946,
 'total_processed': 6,
 'total_time': 62.949731083004735}
{'batch_concurrency': 1,
 'batch_size': 4,
 'mode': 'batch',
 'throughput': 0.09520258681452311,
 'total_processed': 6,
 'total_time': 63.02349758299533}
{'batch_concurrency': 1,
 'batch_size': 2,
 'mode': 'batch',
 'throughput': 0.09310181341505104,
 'total_processed': 6,
 'total_time': 64.44557608402101}


In [None]:
# Debug: print recent METRICS calls sequence
for i, c in enumerate(METRICS["calls"][-16:]):
    print(i+1, c)


In [None]:
# Data preview: load latest processed .train.jsonl, show samples and basic stats
import json, random, statistics
from collections import Counter


def load_jsonl(path):
    valid = []
    invalid = []
    with open(path, "r", encoding="utf-8") as f:
        for i, raw in enumerate(f):
            line = raw.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                valid.append(obj)
            except Exception as e:
                invalid.append((i, str(e), line[:300]))
    return valid, invalid


def summarize_entries(entries):
    inst_lens = [len(e.get("instruction", "")) for e in entries]
    out_lens = [len(e.get("output", "")) for e in entries]
    return {
        "count": len(entries),
        "inst_mean": round(statistics.mean(inst_lens), 1) if inst_lens else 0,
        "inst_median": int(statistics.median(inst_lens)) if inst_lens else 0,
        "out_mean": round(statistics.mean(out_lens), 1) if out_lens else 0,
        "out_median": int(statistics.median(out_lens)) if out_lens else 0,
    }

# Find latest processed file
files = sorted(PROCESSED_DIR.glob("*.train.jsonl"), key=lambda p: p.stat().st_mtime, reverse=True)
if not files:
    print("No processed .train.jsonl files found in:", PROCESSED_DIR)
else:
    path = files[0]
    print("Previewing:", path)

    entries, invalid = load_jsonl(path)
    total_lines = sum(1 for _ in open(path, "r", encoding="utf-8"))

    print(f"Total lines in file: {total_lines}")
    print(f"Valid entries: {len(entries)} | Invalid lines: {len(invalid)}")
    print("\nBasic stats:", summarize_entries(entries))

    # duplicate check (by instruction)
    insts = [e.get("instruction", "") for e in entries]
    dup_counts = [(t, c) for t, c in Counter(insts).most_common() if c > 1]
    print("\nTop duplicate instructions (first 5):")
    for t, c in dup_counts[:5]:
        print(f"- {c}x: {t[:120]!r}")

    # Show first 5 samples
    print("\nFirst 5 samples:")
    for i, e in enumerate(entries[:5], start=1):
        print(f"\n--- Sample {i} ---")
        print("Instruction:", e.get("instruction", "")[:400])
        print("Output:", e.get("output", "")[:800])

    # Random samples
    print("\nRandom samples:")
    for i, e in enumerate(random.sample(entries, min(3, len(entries))), start=1):
        print(f"\n--- Random {i} ---")
        print("Instruction:", e.get("instruction", "")[:400])
        print("Output:", e.get("output", "")[:800])

    # Show any invalid lines (first 5)
    if invalid:
        print("\nInvalid lines (first 5):")
        for idx, err, snippet in invalid[:5]:
            print(f"- line {idx}: {err} -> {snippet!r}")

    # LLM metrics summary (if any)
    print("\nLLM metrics summary:\n", summarise_metrics())

In [None]:
#Cell 5: Run the Processing
if __name__ == "__main__":
    if USE_BATCHING:
        process_pdfs_with_batching()
    else:
        process_pdfs()

In [None]:
# Dry-run: Convert the first PDF, extract the first chunk, and generate one SFT example for inspection
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
if not pdf_files:
    raise RuntimeError(f"No PDFs found in {RAW_DATA_DIR}; please add a PDF to test the dry-run")
pdf_path = pdf_files[0]
log(f"Using PDF: {pdf_path.name}")
result = converter.convert(pdf_path)
md = result.document.export_to_markdown()
chunks = chunk_text_to_chunks(md)
if not chunks:
    raise RuntimeError("No chunks extracted from the PDF content")
log(f"Extracted {len(chunks)} chunks; using chunk 0 for dry-run")
chunk0 = chunks[0]
print("--- CHUNK SNIPPET ---")
print(chunk0[:800])
print("--- END SNIPPET ---\n")

# Run a single-call self-audit generation (fast) for inspection
out = generate_and_audit_single(chunk0)
print("Generated SFT pair (single-call):\n", out)

# Show quick LLM metrics
print("\nMETRICS SUMMARY:\n", summarise_metrics())

In [55]:
def process_pdfs(max_chunks_per_pdf: int = None):
    """Process PDFs in non-batching mode. Processes uncached chunks concurrently using `process_chunk`.

    Writes ordered `*.train.jsonl` files in `PROCESSED_DIR`.
    """
    converter = DocumentConverter()
    pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))

    if not pdf_files:
        log(f"Error: No PDF files found in {RAW_DATA_DIR}")
        log(f"Note: Place your PDFs in: {RAW_DATA_DIR}")
        return

    log(f"Found {len(pdf_files)} PDF(s) in 'in-progress' folder.")

    for pdf_path in pdf_files:
        pdf_stem = pdf_path.stem
        output_file = PROCESSED_DIR / f"{pdf_stem}.train.jsonl"

        log(f"\n--- Processing: {pdf_path.name} → {output_file.name} ---")

        try:
            result = converter.convert(pdf_path)
            md_content = result.document.export_to_markdown()
        except Exception as e:
            log(f"      ! Failed to convert PDF: {e}")
            continue

        # Chunk and cache per-PDF
        pdf_hash = pdf_sha256(pdf_path)
        cached = get_cached_chunks(pdf_hash)
        if cached:
            chunks = cached
            log(f"Using {len(chunks)} cached chunks for {pdf_path.name}")
        else:
            chunks = chunk_text_to_chunks(md_content)
            cache_chunks(pdf_hash, chunks)
            log(f"Extracted and cached {len(chunks)} chunks.")

        process_count = len(chunks) if max_chunks_per_pdf is None else min(len(chunks), max_chunks_per_pdf)
        log(f"Processing {process_count} chunks (single-call={USE_SINGLE_CALL}, batching={USE_BATCHING})...")

        # Determine uncached indices
        uncached_indices = []
        for i in range(process_count):
            chh = chunk_sha256(chunks[i])
            if get_cached_sft_pair(chh) is None:
                uncached_indices.append(i)

        results_buffer = [None] * process_count
        successful_entries = 0
        file_lock = threading.Lock()

        if uncached_indices:
            semaphore = threading.BoundedSemaphore(MAX_LLM_CONCURRENCY)
            with ThreadPoolExecutor(max_workers=MAX_LLM_CONCURRENCY) as executor:
                future_to_idx = {executor.submit(process_chunk, chunks[i], i, semaphore): i for i in uncached_indices}

                for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx), desc=f"SFT [{pdf_stem}]"):
                    try:
                        idx, line = future.result()
                    except Exception as e:
                        log(f"      ! Chunk job failed: {e}")
                        continue

                    if line:
                        results_buffer[idx] = line

        # Fill results from cache for any missing
        for i in range(process_count):
            if results_buffer[i] is None:
                cached_pair = get_cached_sft_pair(chunk_sha256(chunks[i]))
                if cached_pair:
                    results_buffer[i] = json.dumps(cached_pair, ensure_ascii=False) + "\n"

        # Write results in order
        with open(output_file, "w", encoding="utf-8") as f:
            for line in results_buffer:
                if line:
                    with file_lock:
                        f.write(line)
                        f.flush()
                        successful_entries += 1

        log(f"Saved {successful_entries}/{process_count} high-quality entries → {output_file}")

    log(f"\nAll processing complete! Outputs in: {PROCESSED_DIR}")

In [None]:
# Check outputs: list processed .train.jsonl files and preview first file
files = sorted(PROCESSED_DIR.glob("*.train.jsonl"), key=lambda p: p.stat().st_mtime, reverse=True)
if not files:
    print("No processed files found in:", PROCESSED_DIR)
else:
    print(f"Found {len(files)} processed file(s). Latest: {files[0].name}")
    p = files[0]
    with open(p, 'r', encoding='utf-8') as f:
        lines = [next(f).strip() for _ in range(min(5, sum(1 for _ in open(p, 'r', encoding='utf-8'))))]
    print('\nFirst up to 5 lines (sample):')
    for i, ln in enumerate(lines, start=1):
        print(f"{i}: {ln[:400]}")
    # Print metrics summary
    print('\nLLM metrics summary:\n', summarise_metrics())