In [None]:
#import required libraries
import os
import re
import io
import json
from typing import List 
from collections import defaultdict

import fitz  # PyMuPDF
import pytesseract
from PIL import Image
import base64
import pickle

from dotenv import load_dotenv
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from openai import OpenAI


In [None]:
#load configurations for later
#get openAI API key
load_dotenv('.env')
API_KEY = os.getenv("OPENAI_API_KEY")

#The 2 code of practice pdf documents
pdf_files = [
    "./data/Code-of-Practice-on-Surface-Water-Drainage.pdf",
    "./data/Code of Practice on Sewerage and Sanitary Works 3rd Edition  Mar 2025.pdf"
]
#name of local vector store
db_dir = "chroma_parent_child"
#name of directory for extracted figures
OUTPUT_IMAGE_DIR = "auto_figures"
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)

### Step 1: Extract text from PDF and Create Parent and Child Chunks

In [None]:
#regular expression to identify part of the document that are annexes so that after retrival, 
#if the annex was mentioned in the reference chunk, it will be retrieved to use as additional context 
ref_regex = re.compile(
    r"\b(?:annex|appendix)\s+[a-z\d]+\b"
    r"|\bdrawing\s+no\.?\s*\d+\b"
    r"|\bfigure\s+\d+(?:\.\d+)*\b",
    re.IGNORECASE
)

# Normalization and extraction function
# references in documents come in many different forms and need to be normalized. e.g. annex K, AnnexK, Annex  K -> annex k
def normalize_label(label: str) -> str: 
    label = label.lower().replace("\n", " ")
    label = re.sub(r"\s+", " ", label.strip())
    label = re.sub(r"(\bno)\s*(\d)", r"\1 \2", label)  
    return label

def extract_annex_refs(text: str):
    matches = ref_regex.findall(text)
    return list(set([normalize_label(m) for m in matches]))

#split the document and obtain chunks
all_chunks = []
child_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)

for path in pdf_files:
    loader = PyPDFLoader(path)
    page_docs = loader.load()  # one Document per PDF page as parent
    basename = os.path.basename(path)

    for page_doc in page_docs:
        #get page number and document name to cite as source later
        page_num = page_doc.metadata.get("page", None) 
        page_doc.metadata["source"] = basename

        # split each page into child chunks; carry forward metadata
        chunks = child_splitter.split_documents([page_doc])
        for c in chunks:
            c.metadata["source"] = basename
            if page_num is not None:
                c.metadata["page"] = int(page_num)
            c.metadata["parent_preview"] = page_doc.page_content[:500]
            #extract any mention of references and store as metadata
            c.metadata["annex_refs"] = extract_annex_refs(c.page_content) 
            all_chunks.append(c)

### Step 2: Parse the document again to extract references with images and describe them with words

In [None]:
#Optical character recognition to convert pages that are scanned drawings.
def ocr_page(pdf_path, page_num):
    doc = fitz.open(pdf_path)
    pix = doc[page_num - 1].get_pixmap(dpi=300)
    img = Image.open(io.BytesIO(pix.tobytes("png")))
    return pytesseract.image_to_string(img)

# as reference diagrams e.g. figure 1 may be mentioned multiple times, for the purpose of this project
#we assume that the last mention of the figure will the page where the image/drawing is.
#thats where it will be detected and extracted.
def detect_last_occurrences(pdf_path):
    doc = fitz.open(pdf_path)
    last_seen = {}  # {norm_label: {canonical, page}}
    unlabelled_pages = set()

    for page_index, page in enumerate(doc):
        page_num = page_index + 1
        text_layer = page.get_text("text")
        ocr_text = ""
        if not text_layer.strip() and page.get_images(full=True):
            ocr_text = ocr_page(pdf_path, page_num)

        merged_text = (text_layer or "") + "\n" + ocr_text
        matches = ref_regex.findall(merged_text)

        if matches:
            for match in matches:
                norm = normalize_label(match)
                # Only overwrite if this is later in the doc
                if norm not in last_seen or page_num > last_seen[norm]["page"]:
                    last_seen[norm] = {
                        "canonical": match.strip(),
                        "page": page_num
                    }
        elif page.get_images(full=True):
            unlabelled_pages.add(page_num)

    return last_seen, unlabelled_pages

#Extract image in pdf in a format for the vision models to describe
def extract_images_from_page(pdf_path, page_number, output_dir=OUTPUT_IMAGE_DIR):
    pdf_doc = fitz.open(pdf_path)
    page = pdf_doc[page_number - 1]
    image_paths = []
    for img_index, img in enumerate(page.get_images(full=True)):
        xref = img[0]
        pix = fitz.Pixmap(pdf_doc, xref)
        if pix.n < 5:  # RGB
            img_path = os.path.join(output_dir, f"{os.path.basename(pdf_path)}_p{page_number}_{img_index}.png")
            pix.save(img_path)
        else:  # CMYK → RGB
            pix = fitz.Pixmap(fitz.csRGB, pix)
            img_path = os.path.join(output_dir, f"{os.path.basename(pdf_path)}_p{page_number}_{img_index}.png")
            pix.save(img_path)
        image_paths.append(img_path)
        pix = None
    return image_paths

#both caption and understanding of diagram/words on images will be used as text to describe that image
# use blip to get overall image captioning
def blip_describe_image(image_path):
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    output = model.generate(**inputs)
    return processor.decode(output[0], skip_special_tokens=True)
# use gpt vision to convert understanding of any figures etc into words
def gpt_vision_describe_image(image_path):
    with open(image_path, "rb") as image_file:
        image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
    client = OpenAI(api_key=API_KEY)
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Describe this engineering diagram in detail."},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
                ]
            }
        ],
        max_tokens=750
    )
    return response.choices[0].message.content


# bring the fuctions together to label and describe the references and annexes
def build_figure_docs(pdf_path):
    label_map = {} #map reference name to diagram/annex/figure description. when annex is detected in retrieved chunk, 
    #the related reference will be retrieved using this label_map
    figure_docs = []
    #identify last seen pages
    last_seen, unlabelled_pages = detect_last_occurrences(pdf_path)

    # Labelled diagrams
    for norm_label, info in last_seen.items():
        label_map[norm_label] = info["canonical"]
        img_paths = extract_images_from_page(pdf_path, info["page"])
        for img_path in img_paths: #use image models to describe image
            blip_caption = blip_describe_image(img_path)
            gpt_caption = gpt_vision_describe_image(img_path)
            caption_combined = f"{info['canonical']}\n[BLIP] {blip_caption}\n[GPT-4V] {gpt_caption}"
            #append information to figure_docs
            figure_docs.append(Document(
                page_content=f"[DIAGRAM CAPTION]\n{caption_combined}",
                metadata={
                    "source": os.path.basename(pdf_path),
                    "is_diagram": True,
                    "figure_label": info["canonical"],
                    "figure_label_norm": norm_label,
                    "page_number": info["page"],
                    "image_path": img_path
                }
            ))

    # Do the same for unlabelled diagrams
    for page_num in unlabelled_pages:
        img_paths = extract_images_from_page(pdf_path, page_num)
        for img_path in img_paths:
            blip_caption = blip_describe_image(img_path)
            gpt_caption = gpt_vision_describe_image(img_path)
            label = f"UNLABELLED_DIAGRAM_p{page_num}"
            caption_combined = f"{label}\n[BLIP] {blip_caption}\n[GPT-4V] {gpt_caption}"
            figure_docs.append(Document(
                page_content=f"[DIAGRAM CAPTION]\n{caption_combined}",
                metadata={
                    "source": os.path.basename(pdf_path),
                    "is_diagram": True,
                    "figure_label": label,
                    "figure_label_norm": label.lower(),
                    "page_number": page_num,
                    "image_path": img_path
                }
            ))

    return figure_docs, label_map

# Build figure docs & label map (skip if pickle exists) 
all_figure_docs = []
global_label_map = {}
pickle_file = "figure_docs.pkl"

if os.path.exists(pickle_file):
    print(f"✅ Skipping figure processing — loaded from {pickle_file}")
    with open(pickle_file, "rb") as f:
        all_figure_docs = pickle.load(f)
    if os.path.exists("label_map.json"):
        with open("label_map.json", "r") as f:
            global_label_map = json.load(f)
else:
    for pdf_path in pdf_files:
        figs, label_map = build_figure_docs(pdf_path)
        all_figure_docs.extend(figs)
        global_label_map.update(label_map)
    with open(pickle_file, "wb") as f:
        pickle.dump(all_figure_docs, f)
    with open("label_map.json", "w") as f:
        json.dump(global_label_map, f, indent=2)


### Step 3 Processing of data and creation of vectorDB

In [None]:
def _stringify_primitive(v):
    # allow primitives only
    return v if isinstance(v, (str, int, float, bool)) or v is None else str(v)

def filter_complex_metadata(docs: List[Document]) -> List[Document]:
    keep_keys = {
        "source", "page", "parent_preview",
        "annex_refs", "is_diagram", "figure_label", "figure_label_norm",
        "page_number", "image_path"
    }
    #from typing imported list at the top
    filtered: List[Document] = []
    for d in docs:
        md = d.metadata or {}
        md_simple = {k: md[k] for k in keep_keys if k in md} # keep required keys

        # normalize 'page' (figures may use 'page_number')
        if "page" not in md_simple and "page_number" in md_simple:
            try:
                md_simple["page"] = int(md_simple["page_number"])
            except Exception:
                pass

        #Chroma disallows list metadata. Convert annex_refs (list) → CSV string; drop if empty.
        if "annex_refs" in md_simple:
            v = md_simple["annex_refs"]
            if isinstance(v, (list, tuple, set)):
                v = [str(x) for x in v if str(x).strip()]
                if v:
                    md_simple["annex_refs_csv"] = "; ".join(v)  # new scalar field
                # remove the list field entirely
                md_simple.pop("annex_refs", None)

        #cap preview to keep metadata small
        if "parent_preview" in md_simple and isinstance(md_simple["parent_preview"], str):
            md_simple["parent_preview"] = md_simple["parent_preview"][:500]

        # stringify everything to primitives
        md_simple = {k: _stringify_primitive(v) for k, v in md_simple.items()}

        filtered.append(Document(page_content=d.page_content, metadata=md_simple))
    return filtered

In [None]:
#Create vector db
# Merge text chunks with figure captions and filter metadata
all_chunks = filter_complex_metadata(all_chunks + all_figure_docs)

#use openai's embeddings
embedding = OpenAIEmbeddings(openai_api_key=API_KEY)

#create vector db using chroma
vectordb = Chroma.from_documents(
    documents=all_chunks,
    embedding=embedding,
    persist_directory=db_dir
)
vectordb.persist()