In [1]:
from __future__ import annotations

import os
import re
import json
import base64
import logging
import sys
import time, random
from pathlib import Path
from typing import Dict, List, Optional, Literal


from pypdf import PdfReader

from pydantic import BaseModel, Field

from mistralai import Mistral
from mistralai.extra import response_format_from_pydantic_model

try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = lambda x, **_: x

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
)

PROJECT_ROOT = Path.cwd()
if PROJECT_ROOT.name == "notebooks":
    PROJECT_ROOT = PROJECT_ROOT.parent

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Data locations
SRC_ROOT = PROJECT_ROOT / "data" / "raw"
DST_PAGES = PROJECT_ROOT / "data" / "interim_pages"   # <-- where page JSONs will be written

# Create directories if they don't exist
for p in (SRC_ROOT, DST_PAGES):
    p.mkdir(parents=True, exist_ok=True)

# Parameters
ZERO_PAD: int = 3
OVERWRITE: bool = False
MAX_PAGES_PER_OCR_REQUEST: int = 8  # SDK limit

# Logger for this notebook
logger = logging.getLogger("notebook.ocr")
logger.setLevel(logging.INFO)
logger.info("Project root: %s", PROJECT_ROOT)
logger.info("SRC_ROOT=%s | DST_PAGES=%s", SRC_ROOT, DST_PAGES)

# Secrets & client init
def read_api_key(
    env_var: str = "MISTRAL_API_KEY",
    fallback_file: Path = PROJECT_ROOT / "api_key",
) -> str:
    """
    Read API key from environment or from a local 'api_key' file at project root.
    """
    key = os.environ.get(env_var)
    if key:
        return key.strip()
    if fallback_file.exists():
        return fallback_file.read_text(encoding="utf-8").strip()
    raise RuntimeError(
        f"{env_var} not set and fallback file '{fallback_file}' not found."
    )

def get_mistral_client() -> Mistral:
    """
    Lazily construct the Mistral OCR client (won't run until you call it).
    """
    return Mistral(api_key=read_api_key())


2025-10-13 17:04:49,468 | INFO | Project root: /home/fabian-ramirez/Documents/These/Code/magazine_graphs
2025-10-13 17:04:49,470 | INFO | SRC_ROOT=/home/fabian-ramirez/Documents/These/Code/magazine_graphs/data/raw | DST_PAGES=/home/fabian-ramirez/Documents/These/Code/magazine_graphs/data/interim_pages


In [4]:
# Import Stage 1 schema from schemas directory

# Add schemas directory to Python path
SCHEMAS_DIR = PROJECT_ROOT / "schemas"
if str(SCHEMAS_DIR) not in sys.path:
    sys.path.insert(0, str(SCHEMAS_DIR))

# Import the schema (choose which version to use)
# Option 1: WITH continuation fields
from stage1_page import Stage1PageModel

# Option 2: WITHOUT continuation fields (comment out Option 1, uncomment this)
# from stage1_page_no_continuation import Stage1PageModelNoContinuation as Stage1PageModel

logger.info("Loaded schema: %s", Stage1PageModel.__name__)

2025-10-13 17:04:52,185 | INFO | Loaded schema: Stage1PageModel


In [None]:
DOC_ANNOT_FMT = response_format_from_pydantic_model(Stage1PageModel)

def count_pages(pdf_path: Path) -> int:
    try:
        with pdf_path.open("rb") as fh:
            try:
                reader = PdfReader(fh, strict=False)
            except TypeError:
                reader = PdfReader(fh)  # fallback if 'strict' arg unsupported
            if getattr(reader, "is_encrypted", False) and reader.decrypt("") == 0:
                logger.warning("Encrypted PDF (cannot decrypt): %s", pdf_path)
                return 0
            return len(reader.pages)
    except Exception as e:
        logger.warning("Could not read %s: %s", pdf_path, e)
        return 0

def encode_file_to_data_url(path: Path, mime: str = "application/pdf") -> str:
    b64 = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{b64}"

def chunks(seq, size):
    for i in range(0, len(seq), size):
        yield seq[i:i+size]

In [None]:
def parse_single_annotation(resp) -> dict:
    """
    Extract one page's annotation from a single-page request response.
    Prefer resp.document_annotation; fall back to resp.pages[0].document_annotation.
    """
    ann = getattr(resp, "document_annotation", None)
    if isinstance(ann, str):
        try:
            return json.loads(ann)
        except Exception:
            pass
    elif isinstance(ann, dict):
        return ann or {}

    pages = getattr(resp, "pages", None) or []
    if pages:
        raw = getattr(pages[0], "document_annotation", None)
        if isinstance(raw, str):
            try:
                return json.loads(raw)
            except Exception:
                return {}
        elif isinstance(raw, dict):
            return raw or {}
    return {}

def call_with_retry(fn, *, retries: int = 3, base_delay: float = 1.0, max_delay: float = 8.0):
    """Simple exponential backoff with jitter for transient API errors."""
    for attempt in range(retries):
        try:
            return fn()
        except Exception as e:
            if attempt == retries - 1:
                raise
            delay = min(max_delay, base_delay * (2 ** attempt)) * (1 + 0.25 * random.random())
            logger.warning("Call failed (%s). Retrying in %.1fs...", e, delay)
            time.sleep(delay)

def _prune_empty_fields(d: dict) -> dict:
    """Remove keys with None/empty values, but keep 'items' even if empty list."""
    if not isinstance(d, dict):
        return {}
    out = {}
    for k, v in d.items():
        if v is None:
            continue
        if isinstance(v, str) and v.strip() == "":
            continue
        if isinstance(v, list):
            out[k] = v  # keep list, even empty
        elif isinstance(v, dict):
            pruned = _prune_empty_fields(v)
            if pruned:
                out[k] = pruned
        else:
            out[k] = v
    return out

def annotate_pdf_per_page(
    pdf_path: Path,
    out_root: Path = DST_PAGES,
    model_name: str = "mistral-ocr-latest",
    overwrite: bool = OVERWRITE,
) -> int:
    """Call Document Annotation once per page and write one JSON per page.
    Only keep fields/text that the model explicitly extracts from the page.
    """
    n_pages = count_pages(pdf_path)
    if n_pages == 0:
        return 0

    rel_no_ext = pdf_path.relative_to(SRC_ROOT).with_suffix("")
    out_dir = out_root / rel_no_ext
    out_dir.mkdir(parents=True, exist_ok=True)

    data_url = encode_file_to_data_url(pdf_path)
    client = get_mistral_client()

    written = 0
    for page_idx in tqdm(range(n_pages), desc=f"Annotating (per-page) {pdf_path.name}", leave=False):
        out_json = out_dir / f"{pdf_path.stem}__page-{page_idx+1:0{ZERO_PAD}d}.json"
        if out_json.exists() and not overwrite:
            continue

        def _call():
            return client.ocr.process(
                model=model_name,
                document={"type": "document_url", "document_url": data_url},
                pages=[page_idx],
                document_annotation_format=DOC_ANNOT_FMT,
                include_image_base64=False,
            )

        try:
            resp = call_with_retry(_call)
        except Exception as e:
            logger.warning("Page %d of %s failed after retries: %s", page_idx + 1, pdf_path.name, e)
            continue

        annot = parse_single_annotation(resp) or {}
        # Ensure items key exists, but don't fabricate other fields
        if "items" not in annot:
            annot["items"] = []

        # Drop synthetic/default fields; keep only what's present
        annot = _prune_empty_fields(annot)

        try:
            out_json.write_text(json.dumps(annot, ensure_ascii=False, indent=2), encoding="utf-8")
            written += 1
        except Exception as e:
            logger.warning("Failed to write %s: %s", out_json, e)

    logger.info("Per-page annotated %s → %d/%d JSONs", pdf_path.name, written, n_pages)
    return written

def annotate_all_pdfs_per_page(src_root: Path = SRC_ROOT) -> int:
    total = 0
    for pdf in tqdm([p for p in src_root.rglob("*.pdf") if p.is_file()], desc="Annotating PDFs (per-page)"):
        total += annotate_pdf_per_page(pdf)
    logger.info("Total per-page annotated JSONs: %d", total)
    return total

In [None]:
# Run (set OVERWRITE=True first if you need to replace existing files)
# OVERWRITE = True
total_per_page = annotate_all_pdfs_per_page()
total_per_page