# Atopic Eczema VLM Extraction — RAG Pipeline (Gemma 3 27B IT)

This notebook prototypes the full pipeline:
1) Load RAG cards (fields, policies, abbrev, ranges, meds lexicon)  
2) Candidate extraction (get page tokens from image)  
3) RAG context assembly  
4) Build prompts and call **google/gemma-3-27b-it** (Hugging Face)  
5) Validate + compute confidences  
6) Merge into a patient JSON

> **Note:** You must accept the Gemma 3 license on Hugging Face and set `HF_TOKEN` in your environment to pull weights.


## 0. Environment & Installs

- Make sure you have a suitable GPU (27B benefits from A100 / H100, bf16).
- Install pinned `transformers` with Gemma 3 support and `accelerate`.
- Login to Hugging Face or set an access token.
- Accept the model license on its model card.

**References:**
- Hugging Face blog guide for Gemma 3 (inference & API usage).


In [None]:
# !pip install -U 'transformers==4.49.0' accelerate torch torchvision pillow\n# Optional: for better throughput / paged attention, also consider vLLM or TGI serving later\n\nimport os\nos.environ.get("HF_TOKEN")  # ensure your token is set, e.g., `export HF_TOKEN=hf_xxx` in your shell\n

## 1. Imports — Local Modules

In [None]:
from pathlib import Path
from rag_store import RAGPaths, RAGStore, ContextAssembler, fields_for_section
from candidate_extractor import CandidateExtractor


## 2. Load RAG Cards
Point to your cards directory (where we placed `field_cards.jsonl`, `policy/`, `abbr/`, `range/`, `lexicon/`).

In [None]:
CARDS_DIR = Path("/home/rijul/Gitlaboratory/Context_Engineering_LLM/cards")  # <-- update if different
store = RAGStore(RAGPaths.from_base(CARDS_DIR)).load()

print("Fields loaded:", len(store.fields_by_name))
print("Policies:", list(store.policy.keys()))
print("Abbr:", list(store.abbr.keys()))
print("Ranges:", list(store.ranges.keys()))
print("Lexicons:", list(store.lexicons.keys()))


## 3. Candidate Extraction (VLM-assisted, no OCR)
Given a form page image, ask the VLM to list headings/labels/short snippets likely to be variable names or medications.

In [None]:
# Replace with your actual page path(s)
PAGE_IMAGE = "/path/to/form_page1.png"

# Stub runner for demo. Replace with a real VLM call later.
extractor = CandidateExtractor()
page_tokens = extractor.extract_candidates(PAGE_IMAGE)
print("Candidate tokens:", page_tokens)


## 4. Assemble RAG Context
Select a section (e.g., `history`, `scorad`, `investigations`, `followups`) and build the compact context payload.

In [None]:
assembler = ContextAssembler(store)
target_fields = fields_for_section("history")  # change to other sections as needed

ctx = assembler.build_context(target_fields, page_tokens=page_tokens)
chunks = assembler.to_prompt_chunks(ctx)

for i, ch in enumerate(chunks, 1):
    print(f"=== Chunk {i} ===\n{ch[:600]}\n")


## 5. Load Gemma 3 27B IT (Hugging Face)

Use `transformers` **pipeline** for simple VLM calls (image + text → text).  
Make sure you've accepted the model license on the model card: `google/gemma-3-27b-it`.


In [None]:
from transformers import pipeline
import torch

# Use bfloat16 on GPU if available
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# The "image-text-to-text" pipeline supports interleaved image/text messages per HF blog
pipe = pipeline(
    task="image-text-to-text",
    model="google/gemma-3-27b-it",
    token=os.environ.get("HF_TOKEN"),
    device_map="auto",
    torch_dtype=dtype
)

# Minimal sanity query (uncomment when you have a real image path and HF auth)
# messages = [{
#     "role": "user",
#     "content": [
#         {"type": "image", "image": PAGE_IMAGE},
#         {"type": "text", "text": "List all headings visible on this page."}
#     ]
# }]
# out = pipe(text=messages, max_new_tokens=200)
# print(out[0]["generated_text"][-1]["content"])


## 6. Build the Field Extractor Prompt and Call the Model

We send:
- A **short system instruction**
- The **RAG context chunks**
- The **user instruction** describing what to extract
- The **page image**

We ask for **strict JSON** back.


In [None]:
SYSTEM = (
    "You are a medical data extractor. "
    "Use the provided field cards, policies, abbreviations, ranges, and meds lexicon to extract values. "
    "If a value is missing or illegible, return null and set a low confidence. "
    "Return only JSON."
)

USER_INSTR = (
    "Extract the requested fields from this page. "
    "For each field, return {value, confidence (0..1), provenance: short description}. "
    "Field set is in the context."
)

def build_messages(chunks, system_text, user_text, image_path):
    # Build interleaved messages for the VLM pipeline
    content = [{"type": "text", "text": system_text}]
    # Append context chunks
    for ch in chunks:
        content.append({"type": "text", "text": ch})
    # Append user instruction + image
    content.append({"type": "text", "text": user_text})
    content.append({"type": "image", "image": image_path})
    return [{"role": "user", "content": content}]

messages = build_messages(chunks, SYSTEM, USER_INSTR, PAGE_IMAGE)

# Example extraction call (uncomment to run with the actual model)
# resp = pipe(text=messages, max_new_tokens=800)
# raw = resp[0]["generated_text"][-1]["content"]
# print(raw)


## 7. Validate & Normalize
Apply ranges, unit/date normalization, and compute flags.

In [None]:
import json

def validate_record(raw_json_text: str, store: RAGStore):
    try:
        data = json.loads(raw_json_text)
    except Exception as e:
        return {"ok": False, "error": f"JSON parse failed: {e}", "flags": [], "data": None}

    flags = []
    # Simple examples
    scorad_range = store.ranges.get("range/scorad:v1", {}).get("ranges", {})
    if "scorad_final" in data:
        v = data["scorad_final"].get("value")
        if v is not None:
            lo, hi = scorad_range.get("scorad_total", [0, 103])
            if not (lo <= float(v) <= hi):
                flags.append({"field": "scorad_final", "reason": f"Out of range [{lo},{hi}]"})

    return {"ok": True, "flags": flags, "data": data}

# Example usage after a real model response:
# result = validate_record(raw, store)
# result


## 8. (Optional) Multi-page Merge
If your PDF has multiple pages, run steps 3–7 per page and merge field-wise by confidence and page provenance.

In [None]:
def merge_records(records: list[dict]) -> dict:
    merged = {}
    for rec in records:
        for k, v in rec.items():
            if k not in merged:
                merged[k] = v
            else:
                # keep the value with higher confidence
                if v.get("confidence", 0) > merged[k].get("confidence", 0):
                    merged[k] = v
    return merged


---

### Next
- Replace the **stub candidate extractor** with a real Gemma call (Section 5) to list headings/blocks.  
- Tune **Section selection** (batch fields in 10–15 chunks).  
- Expand the **validator** with tighter clinical rules & ontology checks.  
- Add a **resolver** step to re-query flagged fields with more focused crops.


## 9. PDF → Image Conversion (in-order)

In [None]:

# Requires: pip install pdf2image pillow  AND  poppler-utils (apt)
from pathlib import Path
from typing import Iterator, Tuple
from pdf2image import convert_from_path

def pdfs_to_images_in_series(pdf_dir: str, out_dir: str, dpi: int = 200) -> Iterator[Tuple[str, int, str]]:
    """
    Convert all PDFs in pdf_dir to images in deterministic order.
    Yields (pdf_file_path, page_index_1based, image_path).
    """
    pdf_dir = Path(pdf_dir)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    for pdf_file in sorted(pdf_dir.glob("*.pdf")):
        patient_id = pdf_file.stem
        pages = convert_from_path(str(pdf_file), dpi=dpi)
        for i, page in enumerate(pages, start=1):
            out_path = out_dir / f"{patient_id}_page{i}.png"
            page.save(out_path, "PNG")
            yield str(pdf_file), i, str(out_path)
            
print("PDF→Image helper ready. Set your input/output paths below.")            


## 10. End-to-End Batch Loop (All PDFs → Extraction)

In [None]:

# Configure your folders
PDF_INPUT_DIR = "/home/rijul/Academic/Atopic Eczema/cropped"     # source PDFs
IMG_OUTPUT_DIR = "/home/rijul/Academic/Atopic Eczema/images"     # where rendered PNGs will go

# Choose which section to extract in this pass (you can run multiple passes for other sections)
SECTION = "history"  # options: "history", "scorad", "investigations", "followups"

# Instantiate helpers
extractor = CandidateExtractor()
assembler = ContextAssembler(store)
target_fields = fields_for_section(SECTION)

# Optional: real Gemma pipeline (uncomment if configured)
# from transformers import pipeline
# import torch, os
# dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# pipe = pipeline(
#     task="image-text-to-text",
#     model="google/gemma-3-27b-it",
#     token=os.environ.get("HF_TOKEN"),
#     device_map="auto",
#     torch_dtype=dtype
# )

def build_messages(chunks, system_text, user_text, image_path):
    content = [{"type": "text", "text": system_text}]
    for ch in chunks: content.append({"type": "text", "text": ch})
    content.append({"type": "text", "text": user_text})
    content.append({"type": "image", "image": image_path})
    return [{"role": "user", "content": content}]

SYSTEM = ("You are a medical data extractor. Use the provided field cards, policies, abbreviations, ranges, "
          "and meds lexicon to extract values. If a value is missing or illegible, return null and set a low confidence. "
          "Return only JSON with keys matching canonical_name.")

USER_INSTR = ("Extract the requested fields from this page. For each field, return "
              "{value, confidence (0..1), provenance: short description}. Field set is in the context.")

results = []  # collect per-page outputs (replace with writing to disk if you prefer)

for pdf_path, page_idx, img_path in pdfs_to_images_in_series(PDF_INPUT_DIR, IMG_OUTPUT_DIR, dpi=200):
    # 1) Candidate tokens from the page
    page_tokens = extractor.extract_candidates(img_path)

    # 2) Assemble context for the chosen section
    ctx = assembler.build_context(target_fields, page_tokens=page_tokens)
    chunks = assembler.to_prompt_chunks(ctx)

    # 3) Build messages for the VLM
    messages = build_messages(chunks, SYSTEM, USER_INSTR, img_path)

    # 4) Call the model (stub shown; uncomment for real call)
    # resp = pipe(text=messages, max_new_tokens=800)
    # raw = resp[0]["generated_text"][-1]["content"]
    # For now, use a placeholder dict so the loop runs:
    raw = '{"duration": {"value": "6 months", "confidence": 0.8, "provenance": "upper right"}, "symptoms": {"value": ["itching"], "confidence": 0.9, "provenance": "middle"}}'

    # 5) Validate & store
    vr = validate_record(raw, store)
    results.append({
        "pdf": pdf_path,
        "page": page_idx,
        "image": img_path,
        "raw": raw,
        "validated": vr
    })

# Example summary print
print(f"Processed pages: {len(results)}")
print("Sample record:", results[0]["pdf"], results[0]["page"], results[0]["validated"]["ok"] if results else "N/A")
