# README — SFT Data Generation (supervised_fine_tuning)

**Purpose:** Generate high-quality SFT training examples from PDF text using a teacher + (optional) auditor pipeline. This notebook supports caching, batching, and configurable concurrency so you can balance speed vs. strict auditing. 

### Quick start ✅
- Install dependencies (local, no docker required):

```bash
pip install -r requirements.txt
```

- Copy `.env.example` to `.env` and edit values (do NOT commit your `.env`):

```bash
cp .env.example .env
# edit .env to match your environment
```

- Prompts and templates are stored in `config/prompts.json` and are loaded by the notebook at runtime; edit that file to customize system prompts or templates.

- `.env` recommendations:

```text
OLLAMA_URL=http://localhost:11434
TEACHER_MODEL=qwen2.5:72b-instruct
AUDITOR_MODEL=deepseek-r1:70b
# Optional: run Redis for shared cache
REDIS_URL=redis://localhost:6379/0
MAX_LLM_CONCURRENCY=8
USE_SINGLE_CALL=1         # recommended for speed
USE_BATCHING=0            # optional: 0/1
BATCH_SIZE=4
AUDIT_SAMPLE_RATE=0.05    # sample strict audits when using single-call
```

### Prompts & Templates ✅
- Location: `config/prompts.json`.
- What it contains: system prompts and small templates used by the pipeline (keys include `system_gen`, `gen_prompt_template`, `audit_system`, `audit_prompt_template`, `single_call_system`, `single_call_prompt_template`, `batch_system`, `batch_block_template`).
- How it works: the notebook loads `config/prompts.json` at runtime; if the file is missing the notebook falls back to safe built-in defaults so nothing breaks.
- Editing tips:
  - Edit `config/prompts.json` with your custom wording and keep values valid JSON.
  - Templates use `{chunk}` and `{generated}` placeholders for prompt composition (these are substituted when the notebook runs).
  - After editing, re-run the top configuration cells (or restart the kernel and run top cells) to pickup changes.
- Example: modify the `system_gen` value to shift the teacher's style or constraints, or adjust `audit_prompt_template` to change strictness.


# --- ORGANIZED ENTRYPOINTS (New ordering) ---
# Use the cells that follow as the canonical, ordered implementation. Older, scattered cells remain below for history but **do not** run them.


In [None]:
# Install dependencies (one-off helper cell)
# Run this in the notebook when you need to install packages for this project
!pip install -r requirements.txt

In [1]:
# Imports & Configuration (canonical, single cell)
import os, sys, json, time, re, random
from pathlib import Path
from datetime import datetime
from dotenv import load_dotenv
from concurrent.futures import ThreadPoolExecutor
load_dotenv()

# Core endpoints & models
OLLAMA_URL = os.getenv("OLLAMA_URL")
if OLLAMA_URL is None:
    raise RuntimeError("Error: OLLAMA_URL is not set in the environment variables. Please add OLLAMA_URL to your .env file")
OLLAMA_URL = OLLAMA_URL.rstrip("/")
TEACHER_MODEL = os.getenv("TEACHER_MODEL", "qwen2.5:72b-instruct")
AUDITOR_MODEL = os.getenv("AUDITOR_MODEL", "deepseek-r1:70b")

# Cache & concurrency defaults
REDIS_URL = os.getenv("REDIS_URL") or None
CACHE_DIR = os.getenv("CACHE_DIR", str(Path.cwd() / "cache"))
CHUNK_TTL = int(os.getenv("CHUNK_TTL", 60 * 60 * 24))
SFT_TTL = int(os.getenv("SFT_TTL", 60 * 60 * 24 * 7))
MAX_LLM_CONCURRENCY = int(os.getenv("MAX_LLM_CONCURRENCY", 8))
USE_SINGLE_CALL = os.getenv("USE_SINGLE_CALL", "1") in ["1","true","True", True]
AUDIT_SAMPLE_RATE = float(os.getenv("AUDIT_SAMPLE_RATE", "0.05"))
USE_BATCHING = os.getenv("USE_BATCHING", "0") in ["1","true","True", True]
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "4"))
BATCH_CONCURRENCY = int(os.getenv("BATCH_CONCURRENCY", "2"))
MAX_BATCH_CHARS = int(os.getenv("MAX_BATCH_CHARS", "20000"))
# Chunking defaults (chars)
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2000"))
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "200"))
# Optional override to force cache backend: 'redis' or 'disk'
CACHE_BACKEND = os.getenv("CACHE_BACKEND", "").lower()  # set to 'disk' to force DiskCache

# Paths
try:
    SCRIPT_DIR = Path(__file__).parent.resolve()
except NameError:
    SCRIPT_DIR = Path.cwd().resolve()
RAW_DATA_DIR = SCRIPT_DIR / "data" / "raw" / "in-progress"
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)

TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

PROCESSED_DIR = SCRIPT_DIR / "data" / "processed" / TIMESTAMP
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

In [2]:
# Prompts (load or defaults)
PROMPTS_PATH = Path("config") / "prompts.json"
if PROMPTS_PATH.exists():
    with PROMPTS_PATH.open("r", encoding="utf-8") as f:
        PROMPTS = json.load(f)
else:
    PROMPTS = {
        "system_gen": "You are a senior Life Insurance Operations expert... Output ONLY valid JSON in this exact format: {\"instruction\": \"...\", \"output\": \"...\"}",
        "gen_prompt_template": "Using only the following extract...\\n\\n{chunk}",
        "audit_system": "You are a meticulous Life Insurance Regulatory Auditor... Output ONLY the final corrected JSON...",
        "audit_prompt_template": "Source Text:\\n{chunk}\\n\\nGenerated Pair:\\n{generated}\\n\\nVerify factual accuracy...",
        "single_call_system": "You are a senior Life Insurance Operations expert and a meticulous auditor... Output ONLY final JSON...",
        "single_call_prompt_template": "Source Text:\\n{chunk}\\n\\nCreate the training example and self-audit it; return only the final JSON.",
        "batch_system": "You are a senior Life Insurance Operations expert and a strict auditor. Return a JSON array or newline-separated JSON objects.",
        "batch_block_template": "--- SOURCE {i} ---\\n{chunk}\\n"
    }

In [4]:
# Cache backend init + helper to detect Redis usage
try:
    import redis
except Exception:
    redis = None
try:
    from diskcache import Cache as DiskCache
except Exception:
    DiskCache = None

_cache_client = None

def init_cache():
    """Initialize a cache client. Prefer Redis (when available and healthy),
    otherwise fall back to DiskCache. Honours CACHE_BACKEND env override.
    """
    global _cache_client
    if _cache_client is not None:
        return
    _info = globals().get("log", print)

    # Force disk backend when requested
    if CACHE_BACKEND == "disk":
        _info("CACHE_BACKEND=disk -> forcing DiskCache backend")
        if DiskCache is None:
            _info("DiskCache not available. Install with `pip install diskcache`")
            raise RuntimeError("No cache backend available")
        _cache_client = DiskCache(CACHE_DIR)
        _info(f"Using DiskCache at {CACHE_DIR} (forced by CACHE_BACKEND)")
        return

    # Try Redis when configured
    if REDIS_URL and redis is not None:
        try:
            client = redis.from_url(REDIS_URL, socket_connect_timeout=2, socket_timeout=2, decode_responses=True)
            # require a ping to verify health
            client.ping()
            _cache_client = client
            _info(f"Using Redis cache at {REDIS_URL}")
            return
        except Exception as e:
            _info(f"Could not use Redis at {REDIS_URL} ({type(e).__name__}: {e}). Falling back to DiskCache.")
            try:
                client.close()
            except Exception:
                pass
            _cache_client = None

    # Fall back to DiskCache
    if DiskCache is None:
        _info("DiskCache not available. Install with `pip install diskcache` or set REDIS_URL to a running Redis server.")
        raise RuntimeError("No cache backend available")

    _cache_client = DiskCache(CACHE_DIR)
    _info(f"Using DiskCache at {CACHE_DIR}")


def _using_redis():
    """Return True when the active cache client is Redis-backed."""
    return bool(REDIS_URL and redis is not None and _cache_client is not None and not isinstance(_cache_client, DiskCache))

In [5]:
# Hashing & cache helpers (robust to Redis bytes/str)
import hashlib

def pdf_sha256(path: Path) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return h.hexdigest()

def chunk_sha256(chunk: str) -> str:
    return hashlib.sha256(chunk.encode("utf-8")).hexdigest()

def cache_chunks(pdf_hash: str, chunks, ttl=CHUNK_TTL):
    global _cache_client
    init_cache()
    key = f"chunks:{pdf_hash}"
    if _using_redis():
        try:
            _cache_client.set(key, json.dumps(chunks, ensure_ascii=False), ex=ttl)
            return
        except Exception as e:
            # on redis failure, fall back to diskcache
            print(f"Redis cache_chunks error: {e}; falling back to DiskCache")
            try:
                _cache_client.close()
            except Exception:
                pass
            _cache_client = None
            init_cache()
    # DiskCache path
    _cache_client.set(key, chunks, expire=ttl)

def get_cached_chunks(pdf_hash: str):
    global _cache_client
    init_cache()
    key = f"chunks:{pdf_hash}"
    if _using_redis():
        try:
            v = _cache_client.get(key)
        except Exception as e:
            print(f"Redis get_cached_chunks error: {e}; falling back to DiskCache")
            try:
                _cache_client.close()
            except Exception:
                pass
            _cache_client = None
            init_cache()
            return _cache_client.get(key)
        if not v:
            return None
        try:
            return json.loads(v)
        except Exception:
            return None
    else:
        return _cache_client.get(key)

def cache_sft_pair(chunk_hash: str, pair, ttl=SFT_TTL):
    global _cache_client
    init_cache()
    key = f"sft:{chunk_hash}"
    if _using_redis():
        try:
            _cache_client.set(key, json.dumps(pair, ensure_ascii=False), ex=ttl)
            return
        except Exception as e:
            print(f"Redis cache_sft_pair error: {e}; falling back to DiskCache")
            try:
                _cache_client.close()
            except Exception:
                pass
            _cache_client = None
            init_cache()
    _cache_client.set(key, pair, expire=ttl)

def get_cached_sft_pair(chunk_hash: str):
    global _cache_client
    init_cache()
    key = f"sft:{chunk_hash}"
    if _using_redis():
        try:
            v = _cache_client.get(key)
        except Exception as e:
            print(f"Redis get_cached_sft_pair error: {e}; falling back to DiskCache")
            try:
                _cache_client.close()
            except Exception:
                pass
            _cache_client = None
            init_cache()
            return _cache_client.get(key)
        if not v:
            return None
        try:
            return json.loads(v)
        except Exception:
            return None
    else:
        return _cache_client.get(key)

# Text chunking helper

def chunk_text_to_chunks(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]:
    """Split `text` into chunks of approximately `chunk_size` characters with `overlap`.

    Strategy:
    - Split text into paragraphs on two newlines
    - Accumulate paragraphs until adding would exceed chunk_size
    - If a single paragraph is larger than chunk_size, split it into slices with overlap
    - Return list of chunk strings (stripped)
    """
    if not text:
        return []

    paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
    chunks = []
    current = []
    current_len = 0

    def flush_current():
        nonlocal current, current_len
        if current:
            chunk = "\n\n".join(current).strip()
            if chunk:
                chunks.append(chunk)
        current = []
        current_len = 0

    for p in paragraphs:
        p_len = len(p)
        if current_len + p_len + (2 if current else 0) <= chunk_size:
            current.append(p)
            current_len += p_len + (2 if current else 0)
        else:
            flush_current()
            if p_len <= chunk_size:
                current.append(p)
                current_len = p_len
            else:
                # paragraph itself is larger than chunk_size; split it
                start = 0
                while start < p_len:
                    end = min(start + chunk_size, p_len)
                    slice_ = p[start:end].strip()
                    if slice_:
                        chunks.append(slice_)
                    if end >= p_len:
                        break
                    start = max(0, end - overlap)
    # flush remaining
    flush_current()
    return chunks

In [6]:
# HTTP session (pooling & retries) and METRICS
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import requests

SESSION = requests.Session()
retries = Retry(total=3, backoff_factor=0.6, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(pool_connections=MAX_LLM_CONCURRENCY*2, pool_maxsize=MAX_LLM_CONCURRENCY*2, max_retries=retries)
SESSION.mount("http://", adapter)
SESSION.mount("https://", adapter)

METRICS = {"calls": []}
def record_call(model: str, duration: float, success: bool, error: str | None = None):
    METRICS["calls"].append({"model": model, "duration": duration, "success": bool(success), "error": str(error) if error else None})
def summarise_metrics():
    import statistics
    by_model = {}
    for c in METRICS["calls"]:
        m = c["model"]
        by_model.setdefault(m, []).append(c)
    lines = []
    for m, calls in by_model.items():
        durations = [c["duration"] for c in calls if c["duration"] is not None]
        successes = sum(1 for c in calls if c["success"])
        total = len(calls)
        mean = statistics.mean(durations) if durations else 0
        p95 = sorted(durations)[int(len(durations) * 0.95)] if durations else 0
        lines.append(f"{m}: calls={total} success={successes} mean={mean:.2f}s p95={p95:.2f}s")
    return "\n".join(lines)

In [7]:
# LLM call + parsing helpers
def call_ollama(model, prompt, system_prompt="", session=None):
    session = session or SESSION
    url = f"{OLLAMA_URL}/api/chat"
    payload = {
        "model": model,
        "messages": [
            {"role": "system", "content": system_prompt} if system_prompt else None,
            {"role": "user", "content": prompt}
        ],
        "stream": False,
        "options": {"temperature": 0.1, "num_predict": 1024}
    }
    payload["messages"] = [m for m in payload["messages"] if m is not None]
    for attempt in range(3):
        start = time.time()
        try:
            response = session.post(url, json=payload, timeout=300)
            response.raise_for_status()
            duration = time.time() - start
            data = response.json()
            raw_content = data["message"]["content"]
            content = raw_content.strip()
            if content.startswith("```json"):
                content = content[7:]
            if content.endswith("```"):
                content = content[:-3]
            content = content.strip()
            try:
                parsed = json.loads(content)
            except json.JSONDecodeError:
                match = re.search(r'\{.*\}', content, re.DOTALL)
                if match:
                    parsed = json.loads(match.group(0))
                else:
                    record_call(model, duration, False, error="invalid-json")
                    return None
            record_call(model, duration, True)
            return parsed
        except Exception as e:
            duration = time.time() - start
            record_call(model, duration, False, error=str(e))
            time.sleep(1)
    record_call(model, None, False, error="all attempts failed")
    return None

def _parse_json_array_or_objects(text: str):
    text = text.strip()
    try:
        parsed = json.loads(text)
        if isinstance(parsed, dict):
            return [parsed]
        if isinstance(parsed, list):
            return parsed
    except Exception:
        objs = re.findall(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
        results = []
        for o in objs:
            try:
                results.append(json.loads(o))
            except Exception:
                continue
        if results:
            return results
    return None

In [8]:
# Generate & audit pipelines

# PDF text extraction fallbacks (docling -> fitz -> pdfminer)

def extract_text_with_fitz(pdf_path: Path):
    try:
        import fitz
    except Exception:
        return None
    try:
        doc = fitz.open(str(pdf_path))
        texts = []
        for p in doc:
            texts.append(p.get_text("text"))
        return "\n\n".join(t.strip() for t in texts if t and t.strip())
    except Exception as e:
        print("fitz extraction failed:", e)
        return None


def extract_text_with_pdfminer(pdf_path: Path):
    try:
        from pdfminer.high_level import extract_text
    except Exception:
        return None
    try:
        return extract_text(str(pdf_path))
    except Exception as e:
        print("pdfminer extraction failed:", e)
        return None


def get_markdown_for_pdf(pdf_path: Path, converter=None):
    """Return markdown text for a PDF by trying docling conversion first, then fallbacks.

    Returns empty string if no usable text is found.
    """
    converter = converter or DocumentConverter()
    # try docling conversion
    try:
        result = converter.convert(pdf_path)
        md = result.document.export_to_markdown()
        if md and len(md.strip()) > 50:
            return md
        print("Docling conversion produced insufficient text; trying fallbacks.")
    except Exception as e:
        print("Docling conversion failed:", e)

    # try PyMuPDF (fitz)
    md = extract_text_with_fitz(pdf_path)
    if md and len(md.strip()) > 50:
        return md

    # try pdfminer
    md = extract_text_with_pdfminer(pdf_path)
    if md and len(md.strip()) > 50:
        return md

    print("Fallback extractors returned no usable text. Check OCR engines or install fitz/pdfminer.")
    return ""


def generate_and_audit(chunk):
    start = time.time()
    system_gen = PROMPTS.get("system_gen")
    gen_prompt = PROMPTS.get("gen_prompt_template").format(chunk=chunk)
    raw_pair = call_ollama(TEACHER_MODEL, gen_prompt, system_gen)
    if not raw_pair:
        return None
    audit_system = PROMPTS.get("audit_system")
    audit_prompt = PROMPTS.get("audit_prompt_template").format(chunk=chunk, generated=json.dumps(raw_pair, indent=2))
    final_pair = call_ollama(AUDITOR_MODEL, audit_prompt, audit_system)
    record_call("pipeline:generate_and_audit", time.time() - start, True if final_pair else False)
    return final_pair

def generate_and_audit_single(chunk):
    system_prompt = PROMPTS.get("single_call_system")
    prompt = PROMPTS.get("single_call_prompt_template").format(chunk=chunk)
    start = time.time()
    out = call_ollama(TEACHER_MODEL, prompt, system_prompt)
    record_call("pipeline:single_generate_and_audit", time.time() - start, True if out else False)
    return out

def generate_and_audit_batch(chunks: list[str]):
    total_chars = sum(len(c) for c in chunks)
    if total_chars > MAX_BATCH_CHARS:
        print("Batch too large")
        return None
    system_prompt = PROMPTS.get("batch_system")
    block_template = PROMPTS.get("batch_block_template")
    parts = [block_template.format(i=i, chunk=c) for i,c in enumerate(chunks, start=1)]
    prompt = "\n".join(parts)
    start = time.time()
    raw = call_ollama(TEACHER_MODEL, prompt, system_prompt)
    record_call("pipeline:batch_generate_and_audit", time.time() - start, True if raw else False)
    if not raw:
        return None
    if isinstance(raw, (list, dict)):
        return raw if isinstance(raw, list) else [raw]
    return _parse_json_array_or_objects(str(raw))

In [9]:
# Processing orchestration (single-call and batching)
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

def process_chunk(chunk: str, idx: int, semaphore: threading.BoundedSemaphore):
    chunk_hash = chunk_sha256(chunk)
    cached_pair = get_cached_sft_pair(chunk_hash)
    if cached_pair:
        return idx, json.dumps(cached_pair, ensure_ascii=False) + "\n"
    with semaphore:
        if USE_BATCHING:
            raise RuntimeError("process_chunk shouldn't be used in BATCHING mode")
        if USE_SINGLE_CALL:
            sft_pair = generate_and_audit_single(chunk)
            if sft_pair and random.random() < AUDIT_SAMPLE_RATE:
                strict_pair = generate_and_audit(chunk)
                if strict_pair:
                    sft_pair = strict_pair
        else:
            sft_pair = generate_and_audit(chunk)
    if sft_pair and isinstance(sft_pair, dict) and "instruction" in sft_pair:
        cache_sft_pair(chunk_hash, sft_pair)
        return idx, json.dumps(sft_pair, ensure_ascii=False) + "\n"
    return idx, None

def process_pdfs(max_chunks_per_pdf: int = None):
    from docling.document_converter import DocumentConverter
    converter = DocumentConverter()
    pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
    if not pdf_files:
        print("No PDFs found")
        return
    for pdf_path in pdf_files:
        pdf_stem = pdf_path.stem
        output_file = PROCESSED_DIR / f"{pdf_stem}.train.jsonl"
        md_content = get_markdown_for_pdf(pdf_path, converter)
        if not md_content:
            print("Conversion failed or produced no text for", pdf_path)
            continue
        pdf_hash = pdf_sha256(pdf_path)
        chunks = get_cached_chunks(pdf_hash) or chunk_text_to_chunks(md_content)
        cache_chunks(pdf_hash, chunks)
        process_count = len(chunks) if max_chunks_per_pdf is None else min(len(chunks), max_chunks_per_pdf)
        uncached_indices = [i for i in range(process_count) if get_cached_sft_pair(chunk_sha256(chunks[i])) is None]
        results_buffer = [None] * process_count
        if uncached_indices:
            semaphore = threading.BoundedSemaphore(MAX_LLM_CONCURRENCY)
            with ThreadPoolExecutor(max_workers=MAX_LLM_CONCURRENCY) as ex:
                futures = {ex.submit(process_chunk, chunks[i], i, semaphore): i for i in uncached_indices}
                for future in as_completed(futures):
                    try:
                        idx, line = future.result()
                    except Exception as e:
                        print("Chunk job failed", e)
                        continue
                    if line:
                        results_buffer[idx] = line
        for i in range(process_count):
            if results_buffer[i] is None:
                cached_pair = get_cached_sft_pair(chunk_sha256(chunks[i]))
                if cached_pair:
                    results_buffer[i] = json.dumps(cached_pair, ensure_ascii=False) + "\n"
        with open(output_file, "w", encoding="utf-8") as f:
            for line in results_buffer:
                if line:
                    f.write(line)
        print(f"Saved entries → {output_file}")

def process_pdfs_with_batching(max_chunks_per_pdf: int = None):
    from docling.document_converter import DocumentConverter
    converter = DocumentConverter()
    pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
    if not pdf_files:
        print("No PDFs found")
        return
    for pdf_path in pdf_files:
        pdf_stem = pdf_path.stem
        output_file = PROCESSED_DIR / f"{pdf_stem}.train.jsonl"
        md_content = get_markdown_for_pdf(pdf_path, converter)
        if not md_content:
            print("Conversion failed or produced no text for", pdf_path)
            continue
        pdf_hash = pdf_sha256(pdf_path)
        chunks = get_cached_chunks(pdf_hash) or chunk_text_to_chunks(md_content)
        cache_chunks(pdf_hash, chunks)
        process_count = len(chunks) if max_chunks_per_pdf is None else min(len(chunks), max_chunks_per_pdf)
        uncached_indices = [i for i in range(process_count) if get_cached_sft_pair(chunk_sha256(chunks[i])) is None]
        batches = []
        current = []
        current_chars = 0
        for idx in uncached_indices:
            c = chunks[idx]
            if len(current) >= BATCH_SIZE or (current_chars + len(c)) > MAX_BATCH_CHARS:
                batches.append(current)
                current = []
                current_chars = 0
            current.append(idx)
            current_chars += len(c)
        if current:
            batches.append(current)
        results_buffer = [None] * process_count
        with ThreadPoolExecutor(max_workers=BATCH_CONCURRENCY) as executor:
            future_to_batch = {}
            for batch_idxs in batches:
                batch_chunks = [chunks[i] for i in batch_idxs]
                future = executor.submit(generate_and_audit_batch, batch_chunks)
                future_to_batch[future] = batch_idxs
            for future in as_completed(future_to_batch):
                batch_idxs = future_to_batch[future]
                try:
                    out_list = future.result()
                except Exception as e:
                    print("Batch failed", e)
                    out_list = None
                if out_list and isinstance(out_list, list):
                    for idx_in_batch, obj in enumerate(out_list):
                        target_idx = batch_idxs[idx_in_batch] if idx_in_batch < len(batch_idxs) else None
                        if target_idx is not None and isinstance(obj, dict) and "instruction" in obj:
                            cache_sft_pair(chunk_sha256(chunks[target_idx]), obj)
                            results_buffer[target_idx] = json.dumps(obj, ensure_ascii=False) + "\n"
                else:
                    print("Batch returned invalid output")
        for i in range(process_count):
            if results_buffer[i] is None:
                cached_pair = get_cached_sft_pair(chunk_sha256(chunks[i]))
                if cached_pair:
                    results_buffer[i] = json.dumps(cached_pair, ensure_ascii=False) + "\n"
        with open(output_file, "w", encoding="utf-8") as f:
            for line in results_buffer:
                if line:
                    f.write(line)
        print(f"Saved entries → {output_file}")

In [10]:
# Autotuner & benchmarks (single, batch)
def _reconfigure_session_for_concurrency(concurrency: int):
    global SESSION
    import requests
    from requests.adapters import HTTPAdapter
    from urllib3.util.retry import Retry
    SESSION = requests.Session()
    retries = Retry(total=3, backoff_factor=0.6, status_forcelist=[429, 500, 502, 503, 504])
    adapter = HTTPAdapter(pool_connections=concurrency*2, pool_maxsize=concurrency*2, max_retries=retries)
    SESSION.mount("http://", adapter)
    SESSION.mount("https://", adapter)

def benchmark_single_call(chunks, concurrency=1, repeat=1):
    if not chunks:
        return {"mode":"single_call","concurrency":concurrency,"throughput":0.0,"total_processed":0,"total_time":0.0}
    _reconfigure_session_for_concurrency(concurrency)
    METRICS["calls"].clear()
    def _run_once():
        t0 = time.perf_counter()
        with ThreadPoolExecutor(max_workers=concurrency) as ex:
            futures = [ex.submit(generate_and_audit_single, c) for c in chunks]
            results = [f.result() for f in futures]
        t1 = time.perf_counter()
        duration = t1 - t0
        success = sum(1 for r in results if r and isinstance(r, dict) and "instruction" in r)
        return duration, success
    runs = [_run_once() for _ in range(repeat)]
    total_processed = sum(s for _, s in runs)
    total_time = sum(d for d, _ in runs)
    throughput = total_processed / total_time if total_time > 0 else 0
    return {"mode": "single_call", "concurrency": concurrency, "throughput": throughput, "total_processed": total_processed, "total_time": total_time}

def benchmark_batch(chunks, batch_size=4, batch_concurrency=1, repeat=1):
    """Benchmark batching pipeline using generate_and_audit_batch and return throughput."""
    if not chunks:
        return {"mode":"batch","batch_size":batch_size,"batch_concurrency":batch_concurrency,"throughput":0.0,"total_processed":0,"total_time":0.0}

    # Build batches respecting batch_size and MAX_BATCH_CHARS
    batches = []
    current = []
    current_chars = 0
    for c in chunks:
        if len(current) >= batch_size or (current_chars + len(c)) > MAX_BATCH_CHARS:
            batches.append(current)
            current = []
            current_chars = 0
        current.append(c)
        current_chars += len(c)
    if current:
        batches.append(current)

    _reconfigure_session_for_concurrency(batch_concurrency)
    METRICS["calls"].clear()

    def _run_once():
        t0 = time.perf_counter()
        with ThreadPoolExecutor(max_workers=batch_concurrency) as ex:
            futures = [ex.submit(generate_and_audit_batch, b) for b in batches]
            results = [f.result() for f in futures]
        t1 = time.perf_counter()
        duration = t1 - t0
        processed = 0
        for res in results:
            if isinstance(res, list):
                processed += len(res)
            elif isinstance(res, dict):
                processed += 1
        return duration, processed

    runs = [_run_once() for _ in range(repeat)]
    total_processed = sum(p for _, p in runs)
    total_time = sum(d for d, _ in runs)
    throughput = total_processed / total_time if total_time > 0 else 0
    return {"mode":"batch","batch_size":batch_size,"batch_concurrency":batch_concurrency,"throughput":throughput,"total_processed":total_processed,"total_time":total_time}

In [11]:
# Diagnostics cell (run immediately after imports if anything seems off)
print("Diagnostics:")
print("ThreadPoolExecutor:", ThreadPoolExecutor)
init_cache()
print("Using Redis:", _using_redis())
print("Cache client type:", type(_cache_client))

Diagnostics:
ThreadPoolExecutor: <class 'concurrent.futures.thread.ThreadPoolExecutor'>
CACHE_BACKEND=disk -> forcing DiskCache backend
Using DiskCache at /home/rahul/dev/sft-data-gen/cache (forced by CACHE_BACKEND)
Using Redis: False
Cache client type: <class 'diskcache.core.Cache'>


In [12]:
# Main run helper and dry-run example
def run_all(max_chunks_per_pdf: int = None):
    if USE_BATCHING:
        process_pdfs_with_batching(max_chunks_per_pdf)
    else:
        process_pdfs(max_chunks_per_pdf)

if __name__ == "__main__":
    # Example: run_all() or keep for CLI usage in notebooks
    print("Notebook loaded. Call run_all() to start processing.")

Notebook loaded. Call run_all() to start processing.


In [None]:
# Dry-run helper (process entire first PDF concurrently; writes .dryrun.jsonl to processed dir)
try:
    from docling.document_converter import DocumentConverter
    from concurrent.futures import ThreadPoolExecutor, as_completed
    from tqdm import tqdm
    import threading

    converter = DocumentConverter()
    pdf_files = list(RAW_DATA_DIR.glob("*.pdf"))
    if not pdf_files:
        print("No PDFs found for dry-run.")
    else:
        pdf_path = pdf_files[0]
        print("Dry-run using:", pdf_path)
        md = get_markdown_for_pdf(pdf_path, converter)
        if not md:
            print("Dry-run: no text extracted for", pdf_path)
        else:
            chunks = chunk_text_to_chunks(md)
            n_chunks = len(chunks)
            print(f"PDF produced {n_chunks} chunks")
            output_file = PROCESSED_DIR / f"{pdf_path.stem}.dryrun.jsonl"

            # concurrency bounded by MAX_LLM_CONCURRENCY and number of chunks
            concurrency = min(MAX_LLM_CONCURRENCY, max(1, n_chunks))
            semaphore = threading.BoundedSemaphore(concurrency)
            results_buffer = [None] * n_chunks
            processed = 0
            skipped = 0

            with ThreadPoolExecutor(max_workers=concurrency) as ex:
                future_to_idx = {ex.submit(process_chunk, chunks[i], i, semaphore): i for i in range(n_chunks)}
                for fut in tqdm(as_completed(future_to_idx), total=n_chunks, desc="Dry-run chunks"):
                    try:
                        idx, line = fut.result()
                    except Exception as e:
                        print("Chunk job failed:", e)
                        continue
                    if line:
                        results_buffer[idx] = line
                        processed += 1
                    else:
                        cached = get_cached_sft_pair(chunk_sha256(chunks[idx]))
                        if cached:
                            results_buffer[idx] = json.dumps(cached, ensure_ascii=False) + "\n"
                            skipped += 1

            # Write results in order
            with open(output_file, "w", encoding="utf-8") as out_f:
                for line in results_buffer:
                    if line:
                        out_f.write(line)

            print(f"Dry-run complete: processed={processed} skipped_cached={skipped} written→{output_file}")
except Exception as e:
    print("Dry-run skipped:", e)

  from .autonotebook import tqdm as notebook_tqdm
2026-01-04 01:03:26,622 - INFO - detected formats: [<InputFormat.PDF: 'pdf'>]


Dry-run using: /home/rahul/dev/sft-data-gen/data/raw/in-progress/accelerated-protection-combined-pds.pdf


2026-01-04 01:03:27,407 - INFO - Going to convert document batch...
2026-01-04 01:03:27,409 - INFO - Initializing pipeline for StandardPdfPipeline with options hash e15bc6f248154cc62f8db15ef18a8ab7
2026-01-04 01:03:27,422 - INFO - Loading plugin 'docling_defaults'
2026-01-04 01:03:27,430 - INFO - Registered picture descriptions: ['vlm', 'api']
2026-01-04 01:03:27,450 - INFO - Loading plugin 'docling_defaults'
2026-01-04 01:03:27,456 - INFO - Registered ocr engines: ['auto', 'easyocr', 'ocrmac', 'rapidocr', 'tesserocr', 'tesseract']
2026-01-04 01:03:27,475 - INFO - rapidocr cannot be used because onnxruntime is not installed.
2026-01-04 01:03:27,478 - INFO - easyocr cannot be used because it is not installed.
2026-01-04 01:03:28,203 - INFO - Accelerator device: 'cuda:0'
[32m[INFO] 2026-01-04 01:03:28,277 [RapidOCR] base.py:22: Using engine_name: torch[0m
[32m[INFO] 2026-01-04 01:03:28,288 [RapidOCR] device_config.py:57: Using GPU device with ID: 0[0m
[32m[INFO] 2026-01-04 01:03:28,