In [1]:
"""
Stage 1 OCR Extraction - Mistral Document AI

Extracts structured page-level data from photographs of historical French literary magazines.

Input:  PDF files in data/raw/
Output: JSON files per page in data/interim_pages/
Schema: schemas/stage1_page.py
"""

from __future__ import annotations

import os
import json
import base64
import logging
import sys
import time
import random
from pathlib import Path
from typing import Dict, List, Optional


from pypdf import PdfReader
from pydantic import BaseModel, ValidationError

from mistralai import Mistral
from mistralai.extra import response_format_from_pydantic_model

try:
    from tqdm.auto import tqdm
except ImportError:
    tqdm = lambda x, **kwargs: x


# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
)
logger = logging.getLogger("extraction")

# Project root detection
PROJECT_ROOT = Path.cwd()
if PROJECT_ROOT.name == "notebooks":
    PROJECT_ROOT = PROJECT_ROOT.parent

print("Stage 1 OCR Extraction")
print("=" * 60)
print(f"Project root: {PROJECT_ROOT}")

Stage 1 OCR Extraction
Project root: /home/fabian-ramirez/Documents/These/Code/magazine_graphs


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
"""
Configuration and Path Setup
"""

# Input/Output directories
SRC_ROOT = PROJECT_ROOT / "data" / "raw"
DST_PAGES = PROJECT_ROOT / "data" / "interim_pages"   # <-- where page JSONs will be written

# Create directories if they don't exist
for directory in (SRC_ROOT, DST_PAGES):
    directory.mkdir(parents=True, exist_ok=True)

# Extraction parameters
CONFIG = {
    "model_name": "mistral-ocr-latest",
    "overwrite": False,  # Skip already-extracted pages
    "zero_pad": 3,  # Page number padding (001, 002, ...)
    "max_retries": 3,  # API retry attempts
    "base_delay": 1.0,  # Initial retry delay (seconds)
    "max_delay": 8.0,  # Maximum retry delay (seconds)
}

print("\nConfiguration:")
print(f"  Source directory: {SRC_ROOT}")
print(f"  Output directory: {DST_PAGES}")
print(f"  Model: {CONFIG['model_name']}")
print(f"  Overwrite existing: {CONFIG['overwrite']}")

# API key setup
def read_api_key(
    env_var: str = "MISTRAL_API_KEY",
    fallback_file: Path = PROJECT_ROOT / "api_key",
) -> str:
    """
    Read Mistral API key from environment variable or fallback file.
    
    Args:
        env_var: Environment variable name
        fallback_file: Path to file containing API key
        
    Returns:
        API key string
        
    Raises:
        RuntimeError: If API key not found
    """
    key = os.environ.get(env_var)
    if key:
        return key.strip()
    
    if fallback_file.exists():
        return fallback_file.read_text(encoding="utf-8").strip()
    
    raise RuntimeError(
        f"{env_var} not set and fallback file '{fallback_file}' not found. "
        "Please set MISTRAL_API_KEY environment variable or create api_key file."
    )

def get_mistral_client() -> Mistral:
    """Initialize Mistral client with API key."""
    return Mistral(api_key=read_api_key())

print("  API key: Configured")


Configuration:
  Source directory: /home/fabian-ramirez/Documents/These/Code/magazine_graphs/data/raw
  Output directory: /home/fabian-ramirez/Documents/These/Code/magazine_graphs/data/interim_pages
  Model: mistral-ocr-latest
  Overwrite existing: False
  API key: ✓ Configured


In [6]:
"""
Load Stage 1 Schema
"""

# Add schemas directory to Python path
SCHEMAS_DIR = PROJECT_ROOT / "schemas"
if str(SCHEMAS_DIR) not in sys.path:
    sys.path.insert(0, str(SCHEMAS_DIR))

# Import schema
from stage1_page import Stage1PageModel, Stage1Item, ITEM_CLASS

# Generate response format for Mistral API
DOC_ANNOT_FMT = response_format_from_pydantic_model(Stage1PageModel)

print("\nSchema:")
print(f"  Loaded: {Stage1PageModel.__name__}")
print(f"  Item classes: {ITEM_CLASS}")


Schema:
  Loaded: Stage1PageModel
  Item classes: typing.Literal['prose', 'verse', 'ad', 'paratext', 'unknown']


In [None]:
"""
PDF Processing Utilities
"""

def count_pages(pdf_path: Path) -> int:
    """
    Count number of pages in a PDF file.
    
    Args:
        pdf_path: Path to PDF file
        
    Returns:
        Number of pages (0 if file cannot be read)
    """
    try:
        with pdf_path.open("rb") as fh:
            try:
                reader = PdfReader(fh, strict=False)
            except TypeError:
                reader = PdfReader(fh)  # fallback if 'strict' arg unsupported, because I'm unsure
            if getattr(reader, "is_encrypted", False) and reader.decrypt("") == 0:
                logger.warning(f"Encrypted PDF (cannot decrypt): {pdf_path.name}")
                return 0
            return len(reader.pages)
    except Exception as e:
        logger.warning(f"Could not read {pdf_path.name}: {e}")
        return 0

def encode_file_to_data_url(path: Path, mime: str = "application/pdf") -> str:
    """
    Encode file as base64 data URL for Mistral API.
    
    Args:
        path: Path to file
        mime: MIME type
        
    Returns:
        Data URL string (data:<mime>;base64,<encoded_content>)
    """
    b64 = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{b64}"

def chunks(seq, size):
    for i in range(0, len(seq), size):
        yield seq[i:i+size]

def parse_annotation_response(resp) -> dict:
    """
    Extract annotation dict from Mistral OCR response.
    
    Handles different response formats:
    - resp.document_annotation (string or dict)
    - resp.pages[0].document_annotation (fallback)
    
    Args:
        resp: Mistral OCR API response object
        
    Returns:
        Annotation dict (empty dict if parsing fails)
    """
    # Try top-level document_annotation first
    ann = getattr(resp, "document_annotation", None)
    
    if isinstance(ann, str):
        try:
            return json.loads(ann)
        except json.JSONDecodeError:
            pass
    elif isinstance(ann, dict):
        return ann or {}
    
    # Fall back to pages array
    pages = getattr(resp, "pages", None) or []
    if pages:
        page_ann = getattr(pages[0], "document_annotation", None)
        
        if isinstance(page_ann, str):
            try:
                return json.loads(page_ann)
            except json.JSONDecodeError:
                return {}
        elif isinstance(page_ann, dict):
            return page_ann or {}
    
    return {}

def call_with_retry(fn, *, retries: int = 3, base_delay: float = 1.0, max_delay: float = 8.0):
    """
    Call function with exponential backoff retry logic.
    
    Args:
        fn: Function to call (no arguments)
        retries: Maximum number of retry attempts
        base_delay: Initial delay between retries (seconds)
        max_delay: Maximum delay between retries (seconds)
        
    Returns:
        Function result
        
    Raises:
        Exception: If all retries fail
    """
    for attempt in range(retries):
        try:
            return fn()
        except Exception as e:
            if attempt == retries - 1:
                raise
            
            delay = min(max_delay, base_delay * (2 ** attempt))
            jitter = delay * (1 + 0.25 * random.random())
            
            logger.warning(f"API call failed ({e}). Retrying in {jitter:.1f}s...")
            time.sleep(jitter)

def validate_extraction(annot: dict, page_number: int, pdf_name: str) -> tuple[bool, List[str]]:
    """
    Validate extracted annotation for common issues.
    
    Args:
        annot: Annotation dictionary
        page_number: Page number (1-indexed)
        pdf_name: PDF filename for logging
        
    Returns:
        Tuple of (is_valid, list_of_warnings)
    """
    warnings = []
    
    # Check if items exist
    if "items" not in annot:
        warnings.append(f"Missing 'items' field")
        return False, warnings
    
    items = annot["items"]
    
    # Check for empty pages (valid but worth noting)
    if len(items) == 0:
        warnings.append(f"Zero items extracted (possibly blank page)")
    
    # Check for suspiciously short items
    for idx, item in enumerate(items):
        text = item.get("item_text_raw", "")
        if len(text) < 3:
            warnings.append(f"Item {idx} has very short text ({len(text)} chars)")
    
    # Schema validation with Pydantic
    try:
        Stage1PageModel(**annot)
    except ValidationError as e:
        warnings.append(f"Schema validation failed: {e}")
        return False, warnings
    
    return True, warnings

In [None]:
# # %%
# # TEST CELL: Batch Response Structure Analysis
# # Run this once to understand what Mistral returns for multi-page batches

# def test_batch_response_structure(
#     pdf_path: Path,
#     test_pages: list[int] = [0, 1],  # Test with first 2 pages
#     model_name: str = "mistral-ocr-latest",
# ):
#     """
#     Test what Mistral returns when requesting multiple pages at once.
#     This helps determine if batching is viable for your workflow.
#     """
#     logger.info("="*60)
#     logger.info("BATCH RESPONSE TEST")
#     logger.info("="*60)
#     logger.info(f"Testing with: {pdf_path.name}")
#     logger.info(f"Pages requested: {test_pages}")
    
#     # Encode PDF
#     data_url = encode_file_to_data_url(pdf_path)
#     client = get_mistral_client()
    
#     # Make batch request
#     try:
#         resp = client.ocr.process(
#             model=model_name,
#             document={"type": "document_url", "document_url": data_url},
#             pages=test_pages,  # Request multiple pages
#             document_annotation_format=DOC_ANNOT_FMT,
#             include_image_base64=False,
#         )
#     except Exception as e:
#         logger.error(f"Batch request failed: {e}")
#         return None
    
#     # Analyze the response structure
#     logger.info("\n" + "="*60)
#     logger.info("RESPONSE STRUCTURE ANALYSIS")
#     logger.info("="*60)
    
#     # Check top-level structure
#     logger.info(f"\nResponse type: {type(resp)}")
#     logger.info(f"Response attributes: {dir(resp)}")
    
#     # Check document_annotation
#     doc_annot = getattr(resp, "document_annotation", None)
#     logger.info(f"\n--- document_annotation ---")
#     logger.info(f"Type: {type(doc_annot)}")
    
#     if isinstance(doc_annot, str):
#         try:
#             parsed = json.loads(doc_annot)
#             logger.info(f"Parsed type: {type(parsed)}")
#             logger.info(f"Parsed keys: {parsed.keys() if isinstance(parsed, dict) else 'N/A'}")
            
#             # Pretty print first 500 chars
#             preview = json.dumps(parsed, indent=2, ensure_ascii=False)[:500]
#             logger.info(f"\nFirst 500 chars:\n{preview}...")
            
#         except Exception as e:
#             logger.warning(f"Could not parse as JSON: {e}")
#             logger.info(f"Raw value (first 200 chars): {doc_annot[:200]}...")
#     elif isinstance(doc_annot, dict):
#         logger.info(f"Keys: {doc_annot.keys()}")
#         preview = json.dumps(doc_annot, indent=2, ensure_ascii=False)[:500]
#         logger.info(f"\nFirst 500 chars:\n{preview}...")
#     else:
#         logger.info(f"Value: {doc_annot}")
    
#     # Check pages array
#     pages_list = getattr(resp, "pages", None)
#     logger.info(f"\n--- pages array ---")
#     logger.info(f"Type: {type(pages_list)}")
    
#     if pages_list:
#         logger.info(f"Length: {len(pages_list)}")
        
#         # Check first page structure
#         if len(pages_list) > 0:
#             first_page = pages_list[0]
#             logger.info(f"\nFirst page type: {type(first_page)}")
#             logger.info(f"First page attributes: {dir(first_page)}")
            
#             page_doc_annot = getattr(first_page, "document_annotation", None)
#             logger.info(f"\nFirst page document_annotation type: {type(page_doc_annot)}")
            
#             if page_doc_annot:
#                 if isinstance(page_doc_annot, str):
#                     try:
#                         parsed_page = json.loads(page_doc_annot)
#                         logger.info(f"First page parsed keys: {parsed_page.keys() if isinstance(parsed_page, dict) else 'N/A'}")
#                     except:
#                         pass
#                 elif isinstance(page_doc_annot, dict):
#                     logger.info(f"First page keys: {page_doc_annot.keys()}")
    
#     # Save full response for manual inspection
#     output_test_file = PROJECT_ROOT / "test_batch_response.json"
#     try:
#         # Try to convert response to dict for saving
#         if hasattr(resp, 'model_dump'):
#             resp_dict = resp.model_dump()
#         elif hasattr(resp, 'dict'):
#             resp_dict = resp.dict()
#         else:
#             resp_dict = {"raw": str(resp)}
        
#         output_test_file.write_text(
#             json.dumps(resp_dict, indent=2, ensure_ascii=False),
#             encoding="utf-8"
#         )
#         logger.info(f"\n✓ Full response saved to: {output_test_file}")
#         logger.info(f"  Review this file to understand the exact structure")
#     except Exception as e:
#         logger.warning(f"Could not save response: {e}")
    
#     logger.info("\n" + "="*60)
#     logger.info("KEY QUESTIONS TO ANSWER:")
#     logger.info("="*60)
#     logger.info("1. Is document_annotation ONE merged annotation or ARRAY of annotations?")
#     logger.info("2. Does resp.pages contain separate annotations per page?")
#     logger.info("3. Are page boundaries preserved or merged?")
#     logger.info("4. Review the saved JSON file for the complete picture")
#     logger.info("="*60 + "\n")
    
#     return resp

# # %%
# # RUN THE TEST
# # Pick any PDF from your data directory
# test_pdf = next(SRC_ROOT.rglob("*.pdf"))  # Gets first PDF found

# logger.info(f"Running batch test on: {test_pdf.name}\n")
# test_response = test_batch_response_structure(
#     pdf_path=test_pdf,
#     test_pages=[0, 1]  # Test with first 2 pages
# )

# # %%
# # Optional: Test with more pages to see scaling behavior
# if test_response:
#     logger.info("\n\nTesting with 4 pages...\n")
#     test_response_4 = test_batch_response_structure(
#         pdf_path=test_pdf,
#         test_pages=[0, 1, 2, 3]
#     )

In [5]:
def parse_single_annotation(resp) -> dict:
    """
    Extract one page's annotation from a single-page request response.
    Prefer resp.document_annotation; fall back to resp.pages[0].document_annotation.
    """
    ann = getattr(resp, "document_annotation", None)
    if isinstance(ann, str):
        try:
            return json.loads(ann)
        except Exception:
            pass
    elif isinstance(ann, dict):
        return ann or {}

    pages = getattr(resp, "pages", None) or []
    if pages:
        raw = getattr(pages[0], "document_annotation", None)
        if isinstance(raw, str):
            try:
                return json.loads(raw)
            except Exception:
                return {}
        elif isinstance(raw, dict):
            return raw or {}
    return {}

def call_with_retry(fn, *, retries: int = 3, base_delay: float = 1.0, max_delay: float = 8.0):
    """Simple exponential backoff with jitter for transient API errors."""
    for attempt in range(retries):
        try:
            return fn()
        except Exception as e:
            if attempt == retries - 1:
                raise
            delay = min(max_delay, base_delay * (2 ** attempt)) * (1 + 0.25 * random.random())
            logger.warning("Call failed (%s). Retrying in %.1fs...", e, delay)
            time.sleep(delay)

def _prune_empty_fields(d: dict) -> dict:
    """Remove keys with None/empty values, but keep 'items' even if empty list."""
    if not isinstance(d, dict):
        return {}
    out = {}
    for k, v in d.items():
        if v is None:
            continue
        if isinstance(v, str) and v.strip() == "":
            continue
        if isinstance(v, list):
            out[k] = v  # keep list, even empty
        elif isinstance(v, dict):
            pruned = _prune_empty_fields(v)
            if pruned:
                out[k] = pruned
        else:
            out[k] = v
    return out

def annotate_pdf_per_page(
    pdf_path: Path,
    out_root: Path = DST_PAGES,
    model_name: str = "mistral-ocr-latest",
    overwrite: bool = OVERWRITE,
) -> int:
    """Call Document Annotation once per page and write one JSON per page.
    Only keep fields/text that the model explicitly extracts from the page.
    """
    n_pages = count_pages(pdf_path)
    if n_pages == 0:
        return 0

    rel_no_ext = pdf_path.relative_to(SRC_ROOT).with_suffix("")
    out_dir = out_root / rel_no_ext
    out_dir.mkdir(parents=True, exist_ok=True)

    data_url = encode_file_to_data_url(pdf_path)
    client = get_mistral_client()

    written = 0
    for page_idx in tqdm(range(n_pages), desc=f"Annotating (per-page) {pdf_path.name}", leave=False):
        out_json = out_dir / f"{pdf_path.stem}__page-{page_idx+1:0{ZERO_PAD}d}.json"
        if out_json.exists() and not overwrite:
            continue

        def _call():
            return client.ocr.process(
                model=model_name,
                document={"type": "document_url", "document_url": data_url},
                pages=[page_idx],
                document_annotation_format=DOC_ANNOT_FMT,
                include_image_base64=False,
            )

        try:
            resp = call_with_retry(_call)
        except Exception as e:
            logger.warning("Page %d of %s failed after retries: %s", page_idx + 1, pdf_path.name, e)
            continue

        annot = parse_single_annotation(resp) or {}
        # Ensure items key exists, but don't fabricate other fields
        if "items" not in annot:
            annot["items"] = []

        # Drop synthetic/default fields; keep only what's present
        annot = _prune_empty_fields(annot)

        try:
            out_json.write_text(json.dumps(annot, ensure_ascii=False, indent=2), encoding="utf-8")
            written += 1
        except Exception as e:
            logger.warning("Failed to write %s: %s", out_json, e)

    logger.info("Per-page annotated %s → %d/%d JSONs", pdf_path.name, written, n_pages)
    return written

def annotate_all_pdfs_per_page(src_root: Path = SRC_ROOT) -> int:
    total = 0
    for pdf in tqdm([p for p in src_root.rglob("*.pdf") if p.is_file()], desc="Annotating PDFs (per-page)"):
        total += annotate_pdf_per_page(pdf)
    logger.info("Total per-page annotated JSONs: %d", total)
    return total

In [6]:
# Run (set OVERWRITE=True first if you need to replace existing files)
# OVERWRITE = True
total_per_page = annotate_all_pdfs_per_page()
total_per_page

Annotating PDFs (per-page):   0%|          | 0/1 [00:00<?, ?it/s]2025-10-16 20:30:01,412 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:30:27,009 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:31:02,848 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:31:47,287 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:32:14,115 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:32:44,629 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:33:18,420 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:33:45,248 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:34:28,869 | INFO | HTTP Request: POST https://api.mistral.ai/v1/ocr "HTTP/1.1 200 OK"
2025-10-16 20:34:59,178 | INFO | HT

14

In [None]:
# # %%
# # =============================================================================
# # BATCH API IMPLEMENTATION (Alternative to per-page synchronous calls)
# # =============================================================================
# # This section uses Mistral's Batch API for 50% cost reduction
# # It processes all pages from all PDFs as a single asynchronous job

# import time
# from typing import List, Tuple

# def create_batch_requests_file(
#     src_root: Path = SRC_ROOT,
#     output_file: Path = PROJECT_ROOT / "batch_requests.jsonl",
#     overwrite: bool = False,
# ) -> Tuple[Path, int]:
#     """
#     Create a JSONL file with one OCR request per page for all PDFs.
    
#     Returns:
#         Tuple of (jsonl_path, total_requests)
#     """
#     if output_file.exists() and not overwrite:
#         logger.info("Batch file already exists: %s", output_file)
#         # Count lines to return request count
#         with output_file.open('r') as f:
#             count = sum(1 for _ in f)
#         return output_file, count
    
#     logger.info("Creating batch requests file...")
    
#     # Generate proper JSON schema from Pydantic model
#     # This is what the Batch API expects
#     doc_annot_format = {
#         "type": "json_schema",
#         "json_schema": {
#             "name": Stage1PageModel.__name__,
#             "schema": Stage1PageModel.model_json_schema(),
#             "strict": True
#         }
#     }
    
#     requests = []
    
#     for pdf in tqdm([p for p in src_root.rglob("*.pdf") if p.is_file()], desc="Preparing batch requests"):
#         n_pages = count_pages(pdf)
#         if n_pages == 0:
#             continue
            
#         data_url = encode_file_to_data_url(pdf)
        
#         for page_idx in range(n_pages):
#             # Create unique ID: pdf_name__page-XXX
#             rel_no_ext = pdf.relative_to(src_root).with_suffix("")
#             custom_id = f"{rel_no_ext}__page-{page_idx+1:0{ZERO_PAD}d}".replace("/", "__")
            
#             request = {
#                 "custom_id": custom_id,
#                 "body": {
#                     "document": {
#                         "type": "document_url",
#                         "document_url": data_url
#                     },
#                     "pages": [page_idx],
#                     "document_annotation_format": doc_annot_format,
#                     "include_image_base64": False
#                 }
#             }
#             requests.append(request)
    
#     # Write JSONL file
#     with output_file.open('w', encoding='utf-8') as f:
#         for req in requests:
#             f.write(json.dumps(req, ensure_ascii=False) + '\n')
    
#     logger.info("Created batch file with %d requests: %s", len(requests), output_file)
#     return output_file, len(requests)


# def submit_batch_job(
#     batch_file: Path,
#     model_name: str = "mistral-ocr-latest",
# ) -> str:
#     """
#     Upload batch file and create batch job.
    
#     Returns:
#         job_id
#     """
#     client = get_mistral_client()
    
#     logger.info("Uploading batch file: %s", batch_file)
#     batch_data = client.files.upload(
#         file={
#             "file_name": batch_file.name,
#             "content": batch_file.open("rb")
#         },
#         purpose="batch"
#     )
#     logger.info("File uploaded with ID: %s", batch_data.id)
    
#     logger.info("Creating batch job...")
#     created_job = client.batch.jobs.create(
#         input_files=[batch_data.id],
#         model=model_name,
#         endpoint="/v1/ocr",
#         metadata={"job_type": "stage1_ocr", "source": "notebook"}
#     )
    
#     logger.info("Batch job created!")
#     logger.info("  Job ID: %s", created_job.id)
#     logger.info("  Status: %s", created_job.status)
#     logger.info("  Total requests: %s", created_job.total_requests)
    
#     return created_job.id


# def monitor_batch_job(
#     job_id: str,
#     poll_interval: int = 10,
#     max_wait_minutes: int = 120,
# ) -> dict:
#     """
#     Monitor batch job until completion.
    
#     Returns:
#         Final job info dict
#     """
#     client = get_mistral_client()
#     start_time = time.time()
#     max_wait_seconds = max_wait_minutes * 60
    
#     logger.info("Monitoring batch job: %s", job_id)
#     logger.info("Will check every %d seconds (max wait: %d minutes)", poll_interval, max_wait_minutes)
    
#     while True:
#         elapsed = time.time() - start_time
#         if elapsed > max_wait_seconds:
#             raise TimeoutError(f"Job did not complete within {max_wait_minutes} minutes")
        
#         retrieved_job = client.batch.jobs.get(job_id=job_id)
        
#         status = retrieved_job.status
#         total = retrieved_job.total_requests or 0
#         succeeded = retrieved_job.succeeded_requests or 0
#         failed = retrieved_job.failed_requests or 0
        
#         if total > 0:
#             percent = round((succeeded + failed) / total * 100, 2)
#         else:
#             percent = 0
        
#         logger.info(
#             "Status: %s | Total: %d | Success: %d | Failed: %d | Complete: %s%%",
#             status, total, succeeded, failed, percent
#         )
        
#         if status in ["SUCCEEDED", "FAILED", "CANCELLED"]:
#             logger.info("Job finished with status: %s", status)
#             return {
#                 "job_id": job_id,
#                 "status": status,
#                 "total_requests": total,
#                 "succeeded_requests": succeeded,
#                 "failed_requests": failed,
#                 "output_file": retrieved_job.output_file,
#                 "error_file": getattr(retrieved_job, "error_file", None)
#             }
        
#         time.sleep(poll_interval)


# def download_and_parse_batch_results(
#     job_info: dict,
#     out_root: Path = DST_PAGES,
#     results_file: Path = PROJECT_ROOT / "batch_results.jsonl",
# ) -> int:
#     """
#     Download batch results and save as individual page JSONs.
    
#     Returns:
#         Number of pages written
#     """
#     client = get_mistral_client()
    
#     output_file_id = job_info.get("output_file")
#     if not output_file_id:
#         logger.error("No output file ID in job info")
#         return 0
    
#     logger.info("Downloading results from file ID: %s", output_file_id)
#     result_content = client.files.download(file_id=output_file_id)
    
#     # Save raw results for backup
#     results_file.write_text(result_content, encoding='utf-8')
#     logger.info("Raw results saved to: %s", results_file)
    
#     # Parse and save individual page JSONs
#     written = 0
#     failed = 0
    
#     for line in tqdm(result_content.strip().split('\n'), desc="Processing batch results"):
#         if not line.strip():
#             continue
            
#         try:
#             result = json.loads(line)
#             custom_id = result.get("custom_id", "")
#             response_body = result.get("response", {}).get("body", {})
            
#             # Parse the custom_id to get output path
#             # Format: path__to__pdf__page-001
#             parts = custom_id.rsplit("__page-", 1)
#             if len(parts) != 2:
#                 logger.warning("Could not parse custom_id: %s", custom_id)
#                 failed += 1
#                 continue
            
#             pdf_rel_path = parts[0].replace("__", "/")
#             page_num_str = parts[1]
            
#             # Create output directory
#             out_dir = out_root / pdf_rel_path
#             out_dir.mkdir(parents=True, exist_ok=True)
            
#             # Create output file path
#             out_json = out_dir / f"{Path(pdf_rel_path).name}__page-{page_num_str}.json"
            
#             # Extract annotation from response
#             # The response structure should match what we get from sync API
#             annot = {}
#             if "document_annotation" in response_body:
#                 doc_annot = response_body["document_annotation"]
#                 if isinstance(doc_annot, str):
#                     annot = json.loads(doc_annot)
#                 elif isinstance(doc_annot, dict):
#                     annot = doc_annot
            
#             # Ensure items key exists
#             if "items" not in annot:
#                 annot["items"] = []
            
#             # Prune empty fields
#             annot = _prune_empty_fields(annot)
            
#             # Write JSON file
#             out_json.write_text(json.dumps(annot, ensure_ascii=False, indent=2), encoding="utf-8")
#             written += 1
            
#         except Exception as e:
#             logger.warning("Failed to process result line: %s", e)
#             failed += 1
    
#     logger.info("Batch results processed: %d written, %d failed", written, failed)
#     return written


# def run_batch_ocr_pipeline(
#     src_root: Path = SRC_ROOT,
#     out_root: Path = DST_PAGES,
#     overwrite_batch_file: bool = False,
#     poll_interval: int = 10,
# ) -> int:
#     """
#     Complete pipeline: Create batch file, submit job, monitor, download results.
    
#     Returns:
#         Number of page JSONs written
#     """
#     logger.info("="*60)
#     logger.info("STARTING BATCH OCR PIPELINE")
#     logger.info("="*60)
    
#     # Step 1: Create batch requests file
#     batch_file, total_requests = create_batch_requests_file(
#         src_root=src_root,
#         overwrite=overwrite_batch_file
#     )
#     logger.info("Batch file ready with %d requests", total_requests)
    
#     # Step 2: Submit batch job
#     job_id = submit_batch_job(batch_file)
    
#     # Step 3: Monitor until completion
#     job_info = monitor_batch_job(job_id, poll_interval=poll_interval)
    
#     # Step 4: Download and parse results
#     if job_info["status"] == "SUCCEEDED":
#         written = download_and_parse_batch_results(job_info, out_root=out_root)
#         logger.info("="*60)
#         logger.info("BATCH OCR PIPELINE COMPLETE")
#         logger.info("Total pages written: %d", written)
#         logger.info("="*60)
#         return written
#     else:
#         logger.error("Batch job failed with status: %s", job_info["status"])
#         return 0

In [None]:
# # %%
# # =============================================================================
# # RUN BATCH OCR PIPELINE
# # =============================================================================
# # Uncomment the lines below to run the batch pipeline
# # This will process ALL PDFs in SRC_ROOT using the Batch API

# total_batch = run_batch_ocr_pipeline(
#     src_root=SRC_ROOT,
#     out_root=DST_PAGES,
#     overwrite_batch_file=False,  # Set True to regenerate batch file
#     poll_interval=10  # Check status every 10 seconds
# )
# total_batch

In [None]:
# # %%
# # =============================================================================
# # TEST BATCH WITH 4 PAGES ONLY
# # =============================================================================

# def create_small_test_batch(
#     pdf_path: Path,
#     num_pages: int = 2,
#     output_file: Path = PROJECT_ROOT / "batch_requests_test.jsonl",
# ) -> Tuple[Path, int]:
#     """
#     Create a test batch file with just a few pages.
#     """
#     logger.info(f"Creating test batch with {num_pages} pages from {pdf_path.name}")
    
#     # Generate proper JSON schema from Pydantic model
#     doc_annot_format = {
#         "type": "json_schema",
#         "json_schema": {
#             "name": Stage1PageModel.__name__,
#             "schema": Stage1PageModel.model_json_schema(),
#             "strict": True
#         }
#     }
    
#     n_pages = count_pages(pdf_path)
#     if n_pages == 0:
#         logger.error("PDF has no pages!")
#         return output_file, 0
    
#     # Limit to available pages
#     num_pages = min(num_pages, n_pages)
#     data_url = encode_file_to_data_url(pdf_path)
    
#     requests = []
#     for page_idx in range(num_pages):
#         rel_no_ext = pdf_path.relative_to(SRC_ROOT).with_suffix("")
#         custom_id = f"{rel_no_ext}__page-{page_idx+1:0{ZERO_PAD}d}".replace("/", "__")
        
#         request = {
#             "custom_id": custom_id,
#             "body": {
#                 "document": {
#                     "type": "document_url",
#                     "document_url": data_url
#                 },
#                 "pages": [page_idx],
#                 "document_annotation_format": doc_annot_format,
#                 "include_image_base64": False
#             }
#         }
#         requests.append(request)
    
#     # Write JSONL file
#     with output_file.open('w', encoding='utf-8') as f:
#         for req in requests:
#             f.write(json.dumps(req, ensure_ascii=False) + '\n')
    
#     logger.info(f"Created test batch file with {len(requests)} requests: {output_file}")
#     return output_file, len(requests)


# def run_test_batch(
#     pdf_path: Path,
#     num_pages: int = 2,
#     out_root: Path = DST_PAGES / "batch_test",
#     poll_interval: int = 5,
# ) -> int:
#     """
#     Test batch pipeline with just a few pages.
#     """
#     logger.info("="*60)
#     logger.info("STARTING TEST BATCH (2 PAGES)")
#     logger.info("="*60)
    
#     # Create test batch file
#     batch_file, total_requests = create_small_test_batch(
#         pdf_path=pdf_path,
#         num_pages=num_pages
#     )
#     logger.info(f"Test batch ready with {total_requests} requests")
    
#     # Submit batch job
#     try:
#         job_id = submit_batch_job(batch_file)
#     except Exception as e:
#         logger.error(f"Failed to submit batch job: {e}")
#         return 0
    
#     # Monitor until completion
#     job_info = monitor_batch_job(job_id, poll_interval=poll_interval)
    
#     # Download and parse results
#     if job_info["status"] == "SUCCEEDED":
#         written = download_and_parse_batch_results(job_info, out_root=out_root)
#         logger.info("="*60)
#         logger.info("TEST BATCH COMPLETE")
#         logger.info(f"Total pages written: {written}")
#         logger.info(f"Results saved to: {out_root}")
#         logger.info("="*60)
#         return written
#     else:
#         logger.error(f"Batch job failed with status: {job_info['status']}")
#         return 0

# # %%
# # RUN THE TEST
# test_pdf = next(SRC_ROOT.rglob("*.pdf"))  # Your La Plume PDF

# logger.info(f"Testing batch with first 2 pages of: {test_pdf.name}\n")
# test_result = run_test_batch(
#     pdf_path=test_pdf,
#     num_pages=4,
#     poll_interval=5  # Check every 5 seconds
# )
# test_result