In [1]:
!pip install PyMuPDF pdfminer.six tiktoken langchain==1.0 langchain-community langchain-google-genai langchain-huggingface langchain-pinecone pinecone pinecone-text rank_bm25 tools transformers pillow

Collecting PyMuPDF
  Downloading pymupdf-1.26.6-cp310-abi3-macosx_10_9_x86_64.whl.metadata (3.4 kB)
Collecting pdfminer.six
  Downloading pdfminer_six-20251107-py3-none-any.whl.metadata (4.2 kB)
Collecting tiktoken
  Downloading tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl.metadata (6.7 kB)
Collecting langchain==1.0
  Downloading langchain-1.0.0-py3-none-any.whl.metadata (4.6 kB)
Collecting langchain-community
  Downloading langchain_community-0.4.1-py3-none-any.whl.metadata (3.0 kB)
Collecting langchain-google-genai
  Downloading langchain_google_genai-3.1.0-py3-none-any.whl.metadata (2.7 kB)
Collecting langchain-huggingface
  Downloading langchain_huggingface-1.0.1-py3-none-any.whl.metadata (2.1 kB)
Collecting langchain-pinecone
  Downloading langchain_pinecone-0.2.13-py3-none-any.whl.metadata (8.6 kB)
Collecting pinecone
  Downloading pinecone-8.0.0-py3-none-any.whl.metadata (11 kB)
Collecting pinecone-text
  Downloading pinecone_text-0.11.0-py3-none-any.whl.metadata (10 kB)


In [None]:
import os
from dotenv import load_dotenv
load_dotenv()
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", "")
GOOGLE_API_KEY = os.environ.get("GEMINI_API_KEY", "")
TEXT_INDEX = "who-text-index"
IMAGE_INDEX = "who-image-index"
NAMESPACE = "who-pdf"

DENSE_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
IMAGE_MODEL = "openai/clip-vit-base-patch32"

## Data Processing & Storage



In [2]:
import fitz  # PyMuPDF
import re
import json
import csv
from pathlib import Path
from datetime import datetime

def extract_figures_from_pdf(pdf_path, output_folder="extracted_figures"):
    """
    Extract figures from PDF by detecting figure captions and cropping regions.

    Args:
        pdf_path: Path to the PDF file
        output_folder: Folder to save extracted figures
    """
    # Create output folder
    Path(output_folder).mkdir(exist_ok=True)

    # Open PDF
    doc = fitz.open(pdf_path)
    figures_extracted = []
    metadata = []

    for page_num in range(len(doc)):
        page = doc[page_num]

        # Get all text blocks with formatting info
        blocks = page.get_text("dict")["blocks"]

        figure_regions = []

        for block in blocks:
            if "lines" not in block:
                continue

            for line in block["lines"]:
                for span in line["spans"]:
                    text = span["text"].strip()
                    text = text.split('\n')
                    if len(text)> 1 and text[0]=="":
                        text = text[1]
                    else:
                        text = text[0]

                    # Check if this looks like a figure caption
                    # Pattern: "Figure X.X" or "Figure X"
                    if re.match(r'^Figure\s+\d+\.?\d*', text, re.IGNORECASE):
                        # Get the bounding box of this text
                        bbox = span["bbox"]

                        # Check for bold and underline flags
                        flags = span["flags"]
                        is_bold = flags & 2**4  # Bold flag

                        # Store figure info
                        figure_regions.append({
                            "caption": text,
                            "y_position": bbox[1],  # Top y-coordinate
                            "bbox": bbox,
                            "page": page_num
                        })

        # Process detected figures
        for i, fig_info in enumerate(figure_regions):
            # Determine the crop region
            # Start from caption position
            y_start = fig_info["y_position"]

            # End at next figure or page bottom
            if i + 1 < len(figure_regions):
                y_end = figure_regions[i + 1]["y_position"]
            else:
                y_end = page.rect.height

            # Add some margin above caption (to include it)
            margin_top = 20
            margin_bottom = 20
            margin_sides = 40

            # Create crop rectangle
            crop_rect = fitz.Rect(
                margin_sides,  # left
                max(0, y_start - margin_top),  # top
                page.rect.width - margin_sides,  # right
                min(page.rect.height, y_end + margin_bottom)  # bottom
            )

            # Extract the figure region
            mat = fitz.Matrix(2, 2)  # 2x scale for better quality
            pix = page.get_pixmap(matrix=mat, clip=crop_rect)

            # Extract figure number from caption
            figure_number_match = re.search(r'Figure\s+(\d+\.?\d*)', fig_info["caption"], re.IGNORECASE)
            figure_number = figure_number_match.group(1) if figure_number_match else "unknown"

            # Get full caption text (may extend beyond first line)
            caption_rect = fitz.Rect(fig_info["bbox"][0], fig_info["bbox"][1],
                                    page.rect.width - margin_sides, fig_info["bbox"][3] + 50)
            full_caption = page.get_textbox(caption_rect).strip()

            # Generate filename
            caption_clean = re.sub(r'[^\w\s-]', '', fig_info["caption"])
            caption_clean = re.sub(r'\s+', '_', caption_clean)
            filename = f"page_{page_num + 1}_{caption_clean}.png"
            filepath = f"{output_folder}/{filename}"

            # Save image
            pix.save(filepath)
            figures_extracted.append(filepath)
            # full_caption = full_caption.split("\n")
            # full_caption = full_caption[0] if full_caption[0] != "" else full_caption[1]
            # Create metadata entry
            cleaned = re.sub(r'^Figure\s+\d+\.\d+\s*\n?', '', full_caption)
            full_caption = cleaned.split('\n')[0]
            metadata_entry = {
                "figure_number": figure_number,
                "caption": full_caption,
                "filename": filename,
                "filepath": filepath,
                "page_number": page_num + 1,
                "extraction_date": datetime.now().isoformat()
            }
            metadata.append(metadata_entry)

            print(f"Extracted: Figure {figure_number} -> {filename}")

    doc.close()

    # Save metadata
    save_metadata(metadata, output_folder)

    print(f"\nTotal figures extracted: {len(figures_extracted)}")
    print(f"Metadata saved in: {output_folder}/")
    return figures_extracted, metadata


def extract_figures_alternative(pdf_path, output_folder="extracted_figures_alt"):
    """
    Alternative method: Use text search to find figure captions.
    More robust for different PDF encodings.
    """
    Path(output_folder).mkdir(exist_ok=True)

    doc = fitz.open(pdf_path)
    all_figures = []
    metadata = []

    for page_num in range(len(doc)):
        page = doc[page_num]

        # Search for "Figure" text
        figure_instances = page.search_for("Figure", quads=False)

        for rect in figure_instances:
            # Get text near this position to get full caption
            expanded_rect = fitz.Rect(rect.x0, rect.y0, page.rect.width, rect.y1 + 100)
            caption_text = page.get_textbox(expanded_rect).strip().split('\n')[0]

            # Verify it's actually a figure caption
            if re.match(r'^Figure\s+\d+', caption_text, re.IGNORECASE):
                # Find next figure or end of page
                y_start = rect.y0

                # Look for next "Figure" occurrence
                next_figures = [r for r in figure_instances if r.y0 > rect.y0 + 50]
                if next_figures:
                    y_end = next_figures[0].y0
                else:
                    y_end = page.rect.height

                # Crop region
                crop_rect = fitz.Rect(
                    30,
                    max(0, y_start - 15),
                    page.rect.width - 30,
                    min(page.rect.height, y_end + 10)
                )

                # Extract
                mat = fitz.Matrix(2, 2)
                pix = page.get_pixmap(matrix=mat, clip=crop_rect)

                # Extract figure number
                figure_number_match = re.search(r'Figure\s+(\d+\.?\d*)', caption_text, re.IGNORECASE)
                figure_number = figure_number_match.group(1) if figure_number_match else "unknown"
                # Filename
                caption_clean = re.sub(r'[^\w\s-]', '', caption_text)
                caption_clean = re.sub(r'\s+', '_', caption_clean)[:50]
                filename = f"page_{page_num + 1}_{caption_clean}.png"
                filepath = f"{output_folder}/{filename}"

                pix.save(filepath)
                all_figures.append(filepath)
                cleaned = re.sub(r'^Figure\s+\d+\.\d+\s*\n?', '', text)
                caption_text = cleaned.split('\n')[0]

                # Create metadata entry
                metadata_entry = {
                    "figure_number": figure_number,
                    "caption": caption_text,
                    "filename": filename,
                    "filepath": filepath,
                    "page_number": page_num + 1,
                    "extraction_date": datetime.now().isoformat()
                }
                metadata.append(metadata_entry)

                print(f"Extracted: Figure {figure_number} -> {filename}")

    doc.close()

    # Save metadata
    save_metadata(metadata, output_folder)

    print(f"\nTotal figures extracted: {len(all_figures)}")
    print(f"Metadata saved in: {output_folder}/")
    return all_figures, metadata


def save_metadata(metadata, output_folder):
    """
    Save metadata in multiple formats: JSON, CSV, and TXT.

    Args:
        metadata: List of metadata dictionaries
        output_folder: Folder to save metadata files
    """
    if not metadata:
        print("No metadata to save.")
        return

    # Save as JSON
    json_path = f"{output_folder}/figures_metadata.json"
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    print(f"JSON metadata saved: {json_path}")

    # # Save as CSV
    # csv_path = f"{output_folder}/figures_metadata.csv"
    # with open(csv_path, 'w', newline='', encoding='utf-8') as f:
    #     writer = csv.DictWriter(f, fieldnames=metadata[0].keys())
    #     writer.writeheader()
    #     writer.writerows(metadata)
    # print(f"CSV metadata saved: {csv_path}")

    # # Save as human-readable text
    # txt_path = f"{output_folder}/figures_metadata.txt"
    # with open(txt_path, 'w', encoding='utf-8') as f:
    #     f.write("=" * 80 + "\n")
    #     f.write("EXTRACTED FIGURES METADATA\n")
    #     f.write("=" * 80 + "\n\n")

    #     for entry in metadata:
    #         f.write(f"Figure Number: {entry['figure_number']}\n")
    #         f.write(f"Caption: {entry['caption']}\n")
    #         f.write(f"Filename: {entry['filename']}\n")
    #         f.write(f"Page Number: {entry['page_number']}\n")
    #         f.write(f"Extraction Date: {entry['extraction_date']}\n")
    #         f.write("-" * 80 + "\n\n")
    # print(f"TXT metadata saved: {txt_path}")


def load_metadata(metadata_path):
    """
    Load metadata from JSON file.

    Args:
        metadata_path: Path to JSON metadata file

    Returns:
        List of metadata dictionaries
    """
    with open(metadata_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def search_figures(metadata, search_term):
    """
    Search figures by caption or figure number.

    Args:
        metadata: List of metadata dictionaries
        search_term: Term to search for

    Returns:
        List of matching metadata entries
    """
    results = []
    search_lower = search_term.lower()

    for entry in metadata:
        if (search_lower in entry['caption'].lower() or
            search_lower in entry['figure_number'].lower()):
            results.append(entry)

    return results


# Usage example
if __name__ == "__main__":
    pdf_path = "WHO_document.pdf"  # Change this to your PDF path

    print("Method 1: Using text formatting detection")
    print("=" * 50)
    try:
        figures1, metadata1 = extract_figures_from_pdf(pdf_path, "figures_method1")

    except Exception as e:
        print(f"Method 1 failed: {e}")

    # print("\n\nMethod 2: Using text search")
    # print("=" * 50)
    # try:
    #     figures2, metadata2 = extract_figures_alternative(pdf_path, "figures_method2")

    #     # Example: Load metadata later
    #     print("\n--- Reload Metadata Example ---")
    #     loaded_metadata = load_metadata("figures_method2/figures_metadata.json")
    #     print(f"Loaded {len(loaded_metadata)} figure metadata entries")

    # except Exception as e:
    #     print(f"Method 2 failed: {e}")

Method 1: Using text formatting detection
Extracted: Figure 1.1 -> page_1_Figure_11.png
Extracted: Figure 1.2 -> page_2_Figure_12.png
Extracted: Figure 1.3 -> page_3_Figure_13.png
Extracted: Figure 1.4 -> page_4_Figure_14.png
Extracted: Figure 1.5 -> page_5_Figure_15.png
Extracted: Figure 1.6 -> page_6_Figure_16.png
Extracted: Figure 1.7 -> page_7_Figure_17.png
Extracted: Figure 1.8 -> page_8_Figure_18.png
Extracted: Figure 1.9 -> page_9_Figure_19.png
Extracted: Figure 1.10 -> page_10_Figure_110.png
Extracted: Figure 1.11 -> page_11_Figure_111.png
Extracted: Figure 1.12 -> page_13_Figure_112.png
Extracted: Figure 1.13 -> page_15_Figure_113.png
Extracted: Figure 2.1 -> page_19_Figure_21.png
Extracted: Figure 2.2 -> page_20_Figure_22.png
Extracted: Figure 2.3 -> page_22_Figure_23_shows_the_coverage_of_immunization_with_malaria_and_other_vaccines_in_the_MVIP_areas_in_the_three.png
Extracted: Figure 2.3 -> page_22_Figure_23.png
Extracted: Figure 2.4 -> page_24_Figure_24.png
Extracted: Figu

In [None]:
# pip install pdfminer.six tiktoken
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTTextContainer
import re, tiktoken

def extract_paragraph_docs(pdf_path, max_tokens=1200, overlap_tokens=120):
    """
    Extract paragraphs from a PDF and return a list of dicts like
    LangChain's create_documents output:
        {"page_content": <text>, "metadata": {"page": <n>, "para": <m>}}
    """
    enc = tiktoken.get_encoding("cl100k_base")

    def normalize_ws(s: str) -> str:
        s = re.sub(r"[ \t]+", " ", s)
        s = re.sub(r" *\n+", "\n", s)
        return s.strip()

    def merge_lines_to_paragraphs(lines):
        paras, buf = [], []
        for ln in lines:
            t = ln.rstrip()
            if not t:
                if buf:
                    paras.append(" ".join(buf).strip())
                    buf = []
                continue
            if t.endswith("-"):
                buf.append(t[:-1])  # de-hyphenate
            else:
                buf.append(t)
        if buf:
            paras.append(" ".join(buf).strip())
        return [p for p in paras if p]

    docs = []
    for page_no, layout in enumerate(extract_pages(pdf_path), 1):
        lines = []
        for element in layout:
            if isinstance(element, LTTextContainer):
                lines.extend(element.get_text().splitlines())
        lines.append("")  # ensure paragraph break at page end

        paras = merge_lines_to_paragraphs([normalize_ws(x) for x in lines])

        # token-aware sub-chunking
        for para_idx, para in enumerate(paras, 1):
            ids = enc.encode(para)
            if len(ids) <= max_tokens:
                docs.append({
                    "page_content": para,
                    "metadata": {"page": page_no, "para": para_idx}
                })
            else:
                start = 0
                sub_idx = 0
                while start < len(ids):
                    end = min(start + max_tokens, len(ids))
                    sub_text = enc.decode(ids[start:end])
                    docs.append({
                        "page_content": sub_text,
                        "metadata": {"page": page_no,
                                     "para": para_idx,
                                     "sub": sub_idx}
                    })
                    sub_idx += 1
                    if end == len(ids): break
                    start = max(end - overlap_tokens, end)

    return docs


In [None]:
docs = extract_paragraph_docs("WHO_document.pdf")

In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

def build_text_vectors(docs):
    vectors = []
    for i, doc in enumerate(docs):
        # dense = dense_text_embed(doc['page_content'])
        metadata = doc['metadata'].copy()
        # metadata["context"] = doc['page_content'] # Add the text content under 'context' key

        vectors.append({
            "id": f"text-{i}",
            "chunk_text": doc['page_content'],
        })
    return vectors

# def build_image_vectors(images):
#     vectors = []
#     for i, img in enumerate(images):
#         dense = dense_image_embed(img["bytes"])
#         vectors.append({
#             "id": f"{img['figure_number']}",
#             "values": dense,
#             "metadata": {"page": img["page"],
#                          "ext": img["ext"],
#                          "caption": img["context"]}
#         })

def build_image_caption_vectors(images):
    vectors = []
    for i, img in enumerate(images):
      # dense = dense_text_embed(img["context"])
      vectors.append({"_id": f"{img['figure_number']}",
                      "chunk_text": img['context'],
                      "filename": img['filename'],
                      "caption": img["context"]})
    return vectors


In [None]:
import os
from PIL import Image
import io
import json

image_files = []
image_dir = "./figures_method1"
image_metadata = {}
with open(os.path.join(image_dir, "figures_metadata.json"), "r") as fp:
  json_data = json.load(fp)
for entry in json_data:
  image_metadata[entry['filename']] = entry

for filename in os.listdir(image_dir):
    if filename.endswith(".png"):
        filepath = os.path.join(image_dir, filename)
        with open(filepath, "rb") as f:
            image_bytes = f.read()

        # Extract page number from filename (e.g., page_1_figure_1.png -> page 1)
        try:
            page_num_str = filename.split('_')[1]
            page_num = int(page_num_str) - 1 # Adjust to 0-indexed page
        except (IndexError, ValueError):
            page_num = -1 # Default if page number cannot be extracted

        # table_path = os.path.join("extracted_tables", os.path.basename(filename).replace(".png", ".txt"))
        # table_text = ""
        # try:
        #     with open(table_path, "r", encoding="utf-8") as f:
        #         table_text = f.readlines()
        # except FileNotFoundError:
        #     print(f"No corresponding table found for {filename}")
        try:
          im_info = image_metadata[filename]
          page_num = im_info["page_number"]
          caption = im_info["caption"]
          figure_number = im_info["figure_number"]
        except:
          continue
        image_files.append({
            "page": page_num,
            "bytes": image_bytes,
            "ext": "png",
            "context": f"Figure {figure_number}: {caption}",
            "figure_number": figure_number,
            "filename": filename
        })

images = image_files
print(f"Loaded {len(images)} images from extracted_charts folder.")

Loaded 31 images from extracted_charts folder.


In [None]:
print(images[-1]["context"])

Figure 1.2: Trends in life expectancy and HALE at birth, by sex and by WHO region, 2000–2021


In [None]:


# docs = split_text(pages)

text_vectors = build_text_vectors(docs)
# image_vectors = build_image_vectors(images)
image_caption_vectors = build_image_caption_vectors(images)



In [None]:
print(f"image_caption_vectors: {len(image_caption_vectors)}")
# print(image_caption_vectors[0])

image_caption_vectors: 31


##Store Embeddings in Pinecone

In [None]:
from pinecone import Pinecone
pc = Pinecone(api_key=PINECONE_API_KEY)

In [None]:
# Run one time only
if TEXT_INDEX in pc.list_indexes().names():
    pc.delete_index(TEXT_INDEX)
if IMAGE_INDEX in pc.list_indexes().names():
    pc.delete_index(IMAGE_INDEX)

In [None]:
# Check and create TEXT_INDEX
if TEXT_INDEX not in pc.list_indexes().names():
    print(f"Creating Pinecone index: {TEXT_INDEX}...")
    # Create a dense index with integrated embedding
    index_name = TEXT_INDEX
    if not pc.has_index(index_name):
        pc.create_index_for_model(
            name=index_name,
            cloud="aws",
            region="us-east-1",
            embed={
                "model":"llama-text-embed-v2",
                "field_map":{"text": "chunk_text"}
            }
        )
    print(f"Index {TEXT_INDEX} created.")
else:
    print(f"Pinecone index {TEXT_INDEX} already exists.")

SPARSE_TEXT_INDEX = TEXT_INDEX+"-sparse"
if SPARSE_TEXT_INDEX not in pc.list_indexes().names():
  print(f"Creating Pinecone index: {SPARSE_TEXT_INDEX}...")
  # Create a dense index with integrated embedding
  index_name = SPARSE_TEXT_INDEX
  if not pc.has_index(index_name):
    pc.create_index_for_model(
        name=index_name,
        cloud="aws",
        region="us-east-1",
        embed={
            "model":"pinecone-sparse-english-v0",
            "field_map":{"text": "chunk_text"}
        }
    )
    print(f"Index {TEXT_INDEX} created.")
  else:
    print(f"Pinecone index {TEXT_INDEX} already exists.")

# Check and create IMAGE_INDEX
if IMAGE_INDEX not in pc.list_indexes().names():
    print(f"Creating Pinecone index: {IMAGE_INDEX}...")
    # Create a dense index with integrated embedding
    index_name = IMAGE_INDEX
    if not pc.has_index(index_name):
        pc.create_index_for_model(
            name=index_name,
            cloud="aws",
            region="us-east-1",
            embed={
                "model":"llama-text-embed-v2",
                "field_map":{"text": "chunk_text"}
            }
        )

# Initialize Pinecone Index objects
text_idx = pc.Index(TEXT_INDEX)
sp_text_ix = pc.Index(SPARSE_TEXT_INDEX)
image_idx = pc.Index(IMAGE_INDEX)

print("Pinecone index objects initialized for text_idx and image_idx.")

def upsert_text(vectors):
    text_idx.upsert_records("WHO-doc", vectors)
    sp_text_ix.upsert_records("WHO-doc", vectors)

def upsert_images(vectors):
    image_idx.upsert_records("WHO-doc", vectors)


Pinecone index who-text-index already exists.
Pinecone index objects initialized for text_idx and image_idx.


In [None]:
upsert_text(text_vectors)
# upsert_images(image_vectors)
upsert_images(image_caption_vectors)

##Retrieval System with Context Awareness

In [None]:
import os
from google.colab import userdata
from langchain_google_genai import ChatGoogleGenerativeAI
from google import genai
LLM_MODEL_NAME = "gemini-2.5-flash"

if "GOOGLE_API_KEY" not in os.environ:
    GOOGLE_API_KEY = userdata.get("GEMINI_API_KEY")
    os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

client = genai.Client()



REFORM_PROMPT = """
You are a concise search-query reformulator. Given conversation history and a user query,
produce a single-line standalone search query that preserves intent and context.

History:
{history}

User query:
{query}

Standalone query:
"""

def reformulate(query: str, history: list[str]) -> str:
    history_text = "\n".join(history[-6:])  # last N turns
    prompt = REFORM_PROMPT.format(history=history_text, query=query)
    response = client.models.generate_content(
    model=LLM_MODEL_NAME,
    contents=[prompt]
    )
    return response.text.strip()

In [None]:
from langchain_pinecone import PineconeRerank

from langchain_community.retrievers import (
    PineconeHybridSearchRetriever,
)

def retrieve_context(query: str, history: list[str] = None,
                     top_k_text=10, top_k_rerank=5, top_k_images=3, img_threshold=0.5):

    history = history or []
    reform_q = reformulate(query, history)
    # Search the dense index

    dense_results = text_idx.search(
        namespace="WHO-doc",
        query={
            "top_k": 5,
            "inputs": {
                'text': reform_q
            }
        }
    )
    sparse_results = sp_text_ix.search(
        namespace="WHO-doc",
        query={
            "top_k": 5,
            "inputs": {
                'text': reform_q
            }
        }
    )
    # print(dense_results)
    # print(sparse_results)
    dense_documents = dense_results['result']['hits']
    sp_documents = sparse_results['result']['hits']

    docs = []

    for doc in dense_documents:
      docs.append({'_id': doc['_id'], 'chunk_text': doc['fields']['chunk_text']})
    for doc in sp_documents:
      docs.append({'_id': doc['_id'], 'chunk_text': doc['fields']['chunk_text']})

    # print(docs)
    ranked_results = pc.inference.rerank(
      model="bge-reranker-v2-m3",
      query=reform_q,
      documents= docs,
      top_n=3,
      rank_fields=["chunk_text"],
      return_documents=True,
      parameters={
          "truncate": "END"
      }
    )
    # print(ranked_results)
    top_docs = ranked_results.data
    # top_docs = ranked_results['documents']
    # print(f"REFORMULATED_QUERY: {reform_q}")
    # result_text = text_retriever.invoke(reform_q)
    # top3_docs = reranker_text.compress_documents(result_text, reform_q)
    # # print("Top-3 Result:")
    # # for doc in top3_docs:
    # #   print(f"Score: {doc.metadata['relevance_score']:.4f} | Content: {doc.page_content}")
    image_results = image_idx.search(
        namespace="WHO-doc",
        query={
            "top_k": 3,
            "inputs": {
                'text': reform_q
            }
        }
    )

    image_documents = image_results['result']['hits']
    # print(image_documents)
    # filtered_images = []
    # for img in result_images['matches']:
    #   if img['score'] >= img_threshold:
    #     filtered_images.append(img)
    # # result_images['matches'] = filtered_images
    return top_docs, image_documents

In [None]:
question = 'What were the top causes of death globally in 2021?'
docs, images = retrieve_context(question, history=None)

In [None]:
# print(docs)

[{
    index=0,
    score=0.9981969,
    document={
        _id='text-8',
        chunk_text='This distribution of the leading 10 causes of death by broad cause group at global level remained unchanged from 2019 in 2020 and 2021; however, COVID-19 emerged as the third and second leading causes, respectively, claiming 4.1 million and 8.8 million lives globally. In all but two WHO regions (the African and Western Pacific regions), COVID-19 ranked among the top five causes of deaths in 2020 and 2021, responsible for the largest number of deaths in both years in the Region of the Americas, in 2021 in the South-East Asia Region and the second largest number of deaths in both years in the European and the Eastern Mediterranean regions. In the African Region, the disease only moved up from 12th to sixth in 2021. While in the Western Pacific Region it remained out of the top 10, it rose from being 50th in 2020 to 19th in 2021 (1). Figure 1.9 Top 10 causes of death, by World Bank income group, 

In [None]:
from typing import List, Any, Dict
def serialize_top_docs(docs: List[Any], max_chars: int = 5000) -> str:
    """
    Turn top docs into a short serialized block for the LLM.
    Each doc includes a short id, page/content snippet and source metadata.
    """
    pieces = []
    for i, doc in enumerate(docs, start=1):
        mid = doc.document.get("_id")
        # snip = (doc.page_content or "")[:max_chars].replace("\n", " ")
        snip = doc.document.get("chunk_text").replace("\n", " ")
        pieces.append(f"score={doc.score:.4f}| text={snip}")
    return "\n\n".join(pieces)

In [None]:
from PIL import Image
from google import genai
from google.genai import types
import csv

ANSWER_PROMPT = """

Retrieved evidence:
{retrieved_block}

User question:
{question}

"""
SYSTEM_INSTRUCTION = """
You are an evidence-based assistant. Use ONLY the retrieved evidence and the image provided to \
answer the user's question. Only provide factual answers.

GUIDELINES:
1) Produce a concise direct answer (1-3 sentences).
Maintain conversation context across multiple turns
2) Reference previous messages when relevant.
3) Give more importance to texts that have high scores.
4) Provide natural, engaging responses (not robotic)
5) Appropriately cite figures/tables (e.g., "As shown in Figure 5...")
6) Do NOT hallucinate facts not present in the retrieved texts.
"""


def generate_answer_from_retrieval(question: str, doc_texts:str, artifacts: List[str], image_dir:str) -> Dict[str, Any]:
    retrieved_block = doc_texts
    print(retrieved_block)
    top_image = None
    if len(artifacts) > 0:
        top_image = artifacts[0]
        # retrieved_block += f"\n\nFigure {top_image['id']}: {top_image['metadata']['caption']}"
        image_filename = top_image['fields'].get('filename', '')
        if image_filename.endswith('.png'):
          im_data = Image.open(os.path.join(image_dir, image_filename))
    prompt = ANSWER_PROMPT.format(retrieved_block=retrieved_block, question=question)
    # print(prompt)
    # call your LLM
    response = client.models.generate_content(
      model=LLM_MODEL_NAME,
      config=types.GenerateContentConfig(
        system_instruction=SYSTEM_INSTRUCTION),
      contents=[prompt])
    if top_image:
      return response.text.strip(), top_image['_id']
    return response.text.strip(), None

In [None]:
import time
import pandas as pd
def build_submission(input_csv: str = "266_lab2_questions.csv",
                     output_csv: str = "submission.csv"):
    df = pd.read_csv(input_csv, sep=None, engine="python")  # auto-detect sep
    results = []

    # keep per-conversation history
    conv_hist = {}
    try:
      for ii, row in df.iterrows():
          cid = row["conversation_id"]
          qid = row["id"]
          qnum = row["question_id"]
          question = row["question"]
          print(f"Conversation id: {cid} Qid: {qid} Qnum: {qnum} Question: {question}")
          history = conv_hist.get(cid, [])
          # retrieve
          result_text, images = retrieve_context(question, history, img_threshold=0.0)
          # print(f"{images}")
          serialized_text = serialize_top_docs(result_text)
          # generate answer using intelligent prompting
          answer, fig_refs = generate_answer_from_retrieval(question, serialized_text, images, image_dir="./figures_method1")

          # answer = parsed.get("answer", "").replace("\n", " ").strip()
          # fig_refs = parsed.get("figure_references", "0")
          print(f"Answer: {answer}")
          # ensure '0' if empty
          if not fig_refs:
              fig_refs = "0"
          else:
              fig_refs = f"Figure {fig_refs}"

          results.append({
              "id": qid,
              "conversation_id": cid,
              "question_id": qnum,
              "answer": answer,
              "figure_references": fig_refs
          })


          # update history (you might also include the model answer if you want)
          conv_hist.setdefault(cid, []).append(question)

          # Use Rate-limiter 30s per question
          print("Waiting 20s for next question")
          if ii < len(df) - 1: 
            time.sleep(20)
    except Exception as e:
      print(f"Exception: {e}")

    finally:
      # write CSV

      with open(output_csv, "w", newline="", encoding="utf-8") as f:
          writer = csv.DictWriter(f, fieldnames=["id", "conversation_id", "question_id", "answer", "figure_references"])
          writer.writeheader()
          writer.writerows(results)

      print(f"Wrote {output_csv} ({len(results)} rows)")

In [None]:
build_submission("Questions.csv", "Answers.csv")

Conversation id: 1 Qid: 1 Qnum: 1 Question: What happened to global life expectancy during the COVID-19 pandemic?
score=0.9973| text=Life expectancy, healthy life expectancy and burden of disease in the light of the COVID-19 pandemic However, the COVID-19 pandemic reversed this trend and wiped out the progress that was made in nearly a decade within just two years. Global life expectancy at birth dropped by 0.7 years to 72.5 (UI: 71.9–73.1) years in 2020 (back to the level of 2016), and by a further 1.1 years to 71.4 (UI: 70.8–72.0) years in 2021 (back to the level of 2012). Similarly, global HALE dropped to 62.8 (UI: 62.0– 63.7) years in 2020 (back to the level of 2016) and 61.9 (UI: 61.1–62.8) years in 2021 (back to the level of 2012) (1). The life expectancy at birth for both men and women dropped by about 1.7 years between 2019 and 2021. However, the decline for men was relatively more evenly split in 2020 (by 0.8 years) and 2021 (by 0.9 years), while the decline for women was conc