In [None]:
# !brew install poppler tesseract libmagic
#install globally
#brew install tesseract poppler libmagic
# echo 'export PATH="/opt/homebrew/bin:$PATH"' >> ~/.zshrc
# source ~/.zshrc

In [None]:
import os
import uuid
from dotenv import load_dotenv
from PIL import Image
from io import BytesIO
import base64

#Langchain
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.embeddings import OllamaEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.documents import Document
from langchain_community.chat_models.ollama import ChatOllama
from unstructured.partition.pdf import partition_pdf
from unstructured.documents.elements import Table, CompositeElement
from langchain_core.messages import SystemMessage, HumanMessage

#Langfuse
!pip install langfuse
from langfuse import Langfuse
from langfuse.langchain import CallbackHandler
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

load_dotenv()

### Setup langfuse for tracing

In [None]:
print('os.environ.get("LANGFUSE_PUBLIC_KEY")', os.environ.get("LANGFUSE_PUBLIC_KEY"))
print('os.environ.get("LANGFUSE_SECRET_KEY")', os.environ.get("LANGFUSE_SECRET_KEY"))
print('os.environ.get("LANGFUSE_HOST")', os.environ.get("LANGFUSE_HOST"))

langfuse = Langfuse(
    public_key=os.environ.get("LANGFUSE_PUBLIC_KEY"),
    secret_key=os.environ.get("LANGFUSE_SECRET_KEY"),
    host=os.environ.get("LANGFUSE_HOST"),
)

langfuse_handler = CallbackHandler()

### Simple test makesure langfuse is working

In [None]:
chat_model = ChatOllama(
    model="llama3.1:8b",
    base_url="http://localhost:11434",
    temperature=0.7,
)

# Create a simple chat prompt template
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful AI assistant. Answer questions clearly and concisely.",
        ),
        ("user", "{question}"),
    ]
)

# Create the chain
chain = prompt | chat_model | StrOutputParser()

# Test with one question
question = "Is job market bad currently?"

print("Testing ChatOllama with Langfuse tracing...")
print(f"Question: {question}")
print("-" * 40)

try:
    # Invoke the chain with Langfuse callback
    response = chain.invoke(
        {"question": question}, config={"callbacks": [langfuse_handler]}
    )

    print(f"Response: {response}")

except Exception as e:
    print(f"Error: {e}")
    print("Make sure Ollama is running and llama3.1:8b is available")

print("\nCheck your Langfuse dashboard at http://localhost:3000 to see the trace!")

## setup some path

In [None]:
import os

os.environ["PATH"] += os.pathsep + "/opt/homebrew/bin"

In [None]:
import subprocess
import sys

# Check if tesseract is accessible
try:
    result = subprocess.run(["tesseract", "--version"], capture_output=True, text=True)
    print("Tesseract version:", result.stdout)
except FileNotFoundError:
    print("Tesseract not found in PATH")

# Check PATH
import os

print("Current PATH:", os.environ.get("PATH", ""))

### Chunk PDF by title

In [None]:
import os
from unstructured.partition.pdf import partition_pdf

# Your original settings
content_folder = "./content2/"

# Get all PDF files in content folder
pdf_files = [f for f in os.listdir(content_folder) if f.endswith(".pdf")]

print(f"Found {len(pdf_files)} PDF files to process")
chunks = []
# Process each PDF file
for pdf_file in pdf_files:
    file_path = os.path.join(content_folder, pdf_file)
    print(f"Processing: {pdf_file}")

    try:
        # Your original chunking code
        each_chunks = partition_pdf(
            filename=file_path,
            infer_table_structure=True,
            include_page_breaks=True,
            strategy="hi_res",
            extract_image_block_types=["Image"],
            extract_image_block_to_payload=True,
            chunking_strategy="by_title",
            max_characters=10000,
            combine_text_under_n_chars=2000,
            new_after_n_chars=6000,
        )

        # Add filename to metadata for each chunk
        for chunk in each_chunks:
            chunk.metadata.filename = pdf_file  # Add filename here

        chunks.extend(each_chunks)
    except Exception as e:
        print(f"✗ Error processing {pdf_file}: {e}")

print("Done!")

In [None]:
chunks[0].to_dict()

### extract and differentiate tables and text

In [None]:
from unstructured.documents.elements import Table, CompositeElement

# === Extract Content ===
tables, texts = [], []

count = 0
for chunk in chunks:
    count = count = 1
    if isinstance(chunk, Table):   ## Actually this line like not use.
        print("chunk" + str(count))
        tables.append(chunk)
    elif isinstance(chunk, CompositeElement): 
        texts.append(chunk)
        for el in getattr(chunk.metadata, "orig_elements", []):
            if isinstance(el, Table):
                print("chunk" + str(count))
                # Also add filename to nested tables
                el.metadata.filename = chunk.metadata.filename
                tables.append(el)

In [None]:
tables

In [None]:
texts

### Prepare function to try to filter away logo images - using OCR and match the OCR text

In [None]:
!pip install pytesseract
from PIL import Image
import pytesseract
import base64
import io
import cv2
import numpy as np


def is_likely_logo(image_base64):
    """
    Check if image is likely a logo based on OCR text detection and size

    Args:
        image_base64: Base64 encoded image string
        logo_keywords: List of keywords that indicate a logo
        size_threshold: Tuple of (width, height) - images smaller than this are likely logos

    Returns:
        bool: True if likely a logo, False otherwise
    """

    # Customize these keywords based on your company logo text
    logo_keywords = [
        "logo",
        "company",
        "inc",
        "ltd",
        "corp",
        "llc",
        "trademark",
        "®",
        "©",
        "copyright",
        "MINISTRY OF MANPOWER",
        "MINISTRY OF",
        "ACCENTURE",
    ]

    try:
        # Decode base64 image
        image_bytes = base64.b64decode(image_base64)
        image = Image.open(io.BytesIO(image_bytes))
        print(f"==>> image: {image}")

        # Check image size first (quick filter)
        width, height = image.size
        print(f"==>> height: {height}")
        print(f"==>> width: {width}")
        # if width < size_threshold[0] and height < size_threshold[1]:
        #     return True  # Small images are likely logos

        # Convert to grayscale for better OCR
        if image.mode != "L":
            image = image.convert("L")

        # Enhance image for better OCR (optional)
        # Convert PIL to numpy array for OpenCV processing
        img_array = np.array(image)
        print(f"==>> img_array: {img_array}")

        # Apply some preprocessing to improve OCR accuracy
        # Increase contrast
        img_array = cv2.convertScaleAbs(img_array, alpha=1.5, beta=0)
        print(f"==>> img_array: {img_array}")

        # Convert back to PIL
        enhanced_image = Image.fromarray(img_array)
        print(f"==>> enhanced_image: {enhanced_image}")

        # Extract text using OCR
        text = (
            pytesseract.image_to_string(enhanced_image, config="--psm 6")
            .strip()
            .lower()
        )
        text = ' '.join(text.split())

        print(f"==>> OCR text: '{text}'")
        print(f"==>> Logo keywords: {logo_keywords}")

        # # Check for logo keywords
        for keyword in logo_keywords:
            keyword_lower = keyword.lower()
            if keyword_lower in text:
                print(f"==>> MATCH FOUND: '{keyword_lower}' in '{text}' - FILTERING OUT")
                return True

        return False

    except Exception as e:
        print(f"Error processing image: {e}")
        # If we can't process the image, keep it to be safe
        return False


### old version just return base64

In [None]:
# def get_images_base64_filtered(chunks, filter_logos=True):
#     """
#     Extract images from chunks with optional logo filtering

#     Args:
#         chunks: List of chunks from partition_pdf
#         filter_logos: Whether to filter out logos
#         logo_keywords: List of keywords that indicate a logo
#         size_threshold: Tuple of (width, height) for size-based filtering

#     Returns:
#         List of base64 encoded images (logos filtered out if enabled)
#     """
#     images_b64 = []
#     filtered_count = 0

#     for chunk in chunks:
#         if "CompositeElement" in str(type(chunk)):
#             chunk_els = chunk.metadata.orig_elements
#             for el in chunk_els:
#                 if "Image" in str(type(el)):
#                     image_base64 = el.metadata.image_base64

#                     if filter_logos:
#                         if is_likely_logo(image_base64):
#                             filtered_count += 1
#                             print(
#                                 f"Filtered out likely logo image (total filtered: {filtered_count})"
#                             )
#                             continue

#                     images_b64.append(image_base64)

#     print(
#         f"Total images extracted: {len(images_b64)}, Logos filtered: {filtered_count}"
#     )
#     return images_b64

### new version return filename and base64 for processing

In [None]:
def get_images_base64_filtered(chunks, filter_logos=True):
    """
    Extract images from chunks with optional logo filtering

    Args:
        chunks: List of chunks from partition_pdf
        filter_logos: Whether to filter out logos

    Returns:
        Tuple of (images_b64, image_filenames)
    """
    images_b64 = []
    image_filenames = []
    filtered_count = 0

    for chunk in chunks:
        if "CompositeElement" in str(type(chunk)):
            chunk_els = chunk.metadata.orig_elements
            for el in chunk_els:
                if "Image" in str(type(el)):
                    image_base64 = el.metadata.image_base64

                    if filter_logos:
                        if is_likely_logo(image_base64):
                            filtered_count += 1
                            print(
                                f"Filtered out likely logo image (total filtered: {filtered_count})"
                            )
                            continue

                    images_b64.append(image_base64)
                    # Get filename from chunk metadata
                    source_filename = getattr(chunk.metadata, "filename", "unknown.pdf")
                    image_filenames.append(source_filename)

    print(
        f"Total images extracted: {len(images_b64)}, Logos filtered: {filtered_count}"
    )
    return images_b64, image_filenames

### old version

In [None]:
# images = get_images_base64_filtered(
#     chunks,
#     filter_logos=True,
# )

### new version

In [None]:
images, image_source_filenames = get_images_base64_filtered(
    chunks,
    filter_logos=True,
)

In [None]:
images

In [None]:
image_source_filenames

In [None]:
from langchain_community.chat_models.ollama import ChatOllama

# === LLM for Text + Table Summarization ===
text_model = ChatOllama(model="llama3:8b", temperature=0.1)

In [None]:
texts[0].to_dict()

### Old solution unable to see input trace in langfuse using {"element": lambda x: x}  langfuse confuse dont know what to extract

In [None]:
# from langchain_core.prompts import ChatPromptTemplate
# from langchain_core.output_parsers import StrOutputParser
# from langfuse.langchain import CallbackHandler

# langfuse_handler = CallbackHandler()

# prompt_text = """
# You are an assistant tasked with summarizing tables and text.
# Give a concise summary of the table or text.
# Respond only with the summary and do not start with any introduction like here is the concise summary.
# Table or text chunk: {text}
# """
# text_prompt = ChatPromptTemplate.from_template(prompt_text)
# summarize_chain = (
#     {"text": lambda x: x} | text_prompt | text_model | StrOutputParser()
# )


# # Convert to proper input format
# text_inputs = [{"text": text} for text in texts]
# text_summaries = summarize_chain.batch(
#     text_inputs, config={"callbacks": [langfuse_handler]}
# )

# text_summaries = summarize_chain.batch(texts, config={"callbacks": [langfuse_handler]})
# table_summaries = summarize_chain.batch([t.metadata.text_as_html for t in tables])


### New way to write for summarising Text and Tables. Explicitly extract text upfront. This allow langfuse to detect the input, good for tracing

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langfuse.langchain import CallbackHandler

langfuse_handler = CallbackHandler()

prompt_text = """
You are an assistant tasked with summarizing tables and text.
Give a concise summary of the table or text.
Respond only with the summary and do not start with any introduction like here is the concise summary.
Table or text chunk: {element}
"""
text_prompt = ChatPromptTemplate.from_template(prompt_text)
summarize_chain = text_prompt | text_model | StrOutputParser()

text_summaries = summarize_chain.batch(
    [{"element": text.text} for text in texts], config={"callbacks": [langfuse_handler]}
)

table_summaries = summarize_chain.batch(
    [{"element": t.metadata.text_as_html} for t in tables],
    config={"callbacks": [langfuse_handler]},
)

In [None]:
text_summaries

In [None]:
table_summaries

In [None]:
VISION_MODEL = "gemma3:12b"  # For image analysis (alternatives: llava:7b, bakllava)
vision_model = ChatOllama(
    model=VISION_MODEL, temperature=0.1, base_url="http://localhost:11434"
)

### Ask LLM to summarise image with base64 image text

In [None]:
from langchain_core.messages import SystemMessage, HumanMessage

def analyze_image_with_ollama(image_base64: str) -> str:
    """
    Analyze image using Ollama vision model
    Note: This approach works with models like llava that support vision
    """
    prompt_template = """Describe this image in detail. For context, 
    the image is part of a Singapore Ministry of Manpower workpass system. Be specific about images, such as, diagrams, flowchart, screenshot and any text visible in the image. Do not respond with any introduction words like Here\'s a detailed description of the image. """

    # Create message with image
    messages = [
        HumanMessage(
            content=[
                {"type": "text", "text": prompt_template},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
                },
            ]
        )
    ]

    try:
        response = vision_model.invoke(
            messages, config={"callbacks": [langfuse_handler]}
        )
        print(f"==>> response: {response}")
        return response.content
    except Exception as e:
        print(f"Error analyzing image: {e}")
        return f"Error analyzing image: Unable to process with {VISION_MODEL}"

In [None]:
image_summaries = []
for i, img_b64 in enumerate(images):
    print(f"Processing image {i+1}/{len(images)}")
    summary = analyze_image_with_ollama(img_b64)
    image_summaries.append(summary)

In [None]:
image_summaries

### Init embedding model for Vector DB

In [None]:
from langchain.embeddings import OllamaEmbeddings

EMBEDDING_MODEL = "nomic-embed-text"
embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL, base_url="http://localhost:11434")

### Init vector db, cuurently using PG vector and bytestore for localstore storage

In [None]:
!pip install langchain_postgres
!pip install psycopg_binary
from utils.store import PostgresByteStore
from database import COLLECTION_NAME, CONNECTION_STRING
from langchain_postgres import PGVector

# vectorstore = Chroma(
#     collection_name="multi_modal_rag_ollama",
#     embedding_function=embeddings,
#     persist_directory="./chroma_db_8",  # Separate directory for Ollama version
# )
vectorstore = PGVector(
    embeddings=embeddings,
    collection_name=COLLECTION_NAME,
    connection=CONNECTION_STRING,
    use_jsonb=True,
)

# Storage setup (unchanged)
# store = InMemoryStore()
store = PostgresByteStore(CONNECTION_STRING, COLLECTION_NAME)
# store = LocalFileStore("./document_store_ollama")  # Alternative persistent storage
id_key = "doc_id"

In [None]:
from langchain.retrievers.multi_vector import MultiVectorRetriever

retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

In [None]:
retriever

### Define filename to map record for text

In [None]:
filenames_for_text = []
for i, text in enumerate(texts):
    if hasattr(text, "metadata") and hasattr(text.metadata, "filename"):
        filename = text.metadata.filename
    elif (
        hasattr(text, "metadata")
        and isinstance(text.metadata, dict)
        and "filename" in text.metadata
    ):
        filename = text.metadata["filename"]
    else:
        filename = "NO_FILENAME"

    print(filename)
    filenames_for_text.append(filename)

### Old way - Add both vector(PGVector) and docstore to postgres for texts
### Langfuse cant detect the structure since it is not a document

In [None]:
# import uuid
# from langchain_core.documents import Document

# print("Adding texts to retriever...")
# doc_ids = [str(uuid.uuid4()) for _ in texts]
# summary_texts = [
#     Document(page_content=summary, metadata={id_key: doc_ids[i]})
#     for i, summary in enumerate(text_summaries)
# ]
# retriever.vectorstore.add_documents(summary_texts)
# retriever.docstore.mset(list(zip(doc_ids, texts, filenames_for_text)))

### New way using Document type

In [None]:
print("Adding texts to retriever...")
doc_ids = [str(uuid.uuid4()) for _ in texts]

# Create summary documents for vectorstore
summary_texts = [
    Document(
        page_content=summary,
        metadata={
            id_key: doc_ids[i],
            "doc_type": "text",
            "filename": (
                filenames_for_text[i] if i < len(filenames_for_text) else f"text_{i}"
            ),
            "content_type": "text_summary",
        },
    )
    for i, summary in enumerate(text_summaries)
]

# Create full documents for docstore
full_text_docs = [
    Document(
        page_content=text if isinstance(text, str) else str(text),
        metadata={
            id_key: doc_ids[i],
            "doc_type": "text",
            "filename": (
                filenames_for_text[i] if i < len(filenames_for_text) else f"text_{i}"
            ),
            "content_type": "text_full",
        },
    )
    for i, text in enumerate(texts)
]

# Add summaries to vectorstore for searching
retriever.vectorstore.add_documents(summary_texts)

# Add full documents to docstore for retrieval - FIXED FORMAT
retriever.docstore.mset(list(zip(doc_ids, full_text_docs, filenames_for_text)))

In [None]:
summary_texts

In [None]:
texts[0].to_dict()

### Define filename to map record for tables

In [None]:
filenames_for_tables = []
for i, text in enumerate(tables):
    if hasattr(text, "metadata") and hasattr(text.metadata, "filename"):
        filename = text.metadata.filename
    elif (
        hasattr(text, "metadata")
        and isinstance(text.metadata, dict)
        and "filename" in text.metadata
    ):
        filename = text.metadata["filename"]
    else:
        filename = "NO_FILENAME"

    print(filename)
    filenames_for_tables.append(filename)

In [None]:
len(tables)

In [None]:
tables

### Old way - Add both vector(PGVector) and docstore to postgres for Tables
### Langfuse cant detect the structure since it is not a document

In [None]:
# print("Adding tables to retriever...")
# table_ids = [str(uuid.uuid4()) for _ in tables]
# summary_tables = [
#     Document(page_content=summary, metadata={id_key: table_ids[i]})
#     for i, summary in enumerate(table_summaries)
# ]
# retriever.vectorstore.add_documents(summary_tables)
# retriever.docstore.mset(list(zip(table_ids, tables, filenames_for_tables)))

### New way using Document type

In [None]:
print("Adding tables to retriever...")
table_ids = [str(uuid.uuid4()) for _ in tables]

# Create summary documents for vectorstore (these get searched)
summary_tables = [
    Document(
        page_content=summary,
        metadata={
            id_key: table_ids[i],
            "doc_type": "table",
            "content_type": "table_summary",
        },
    )
    for i, summary in enumerate(table_summaries)
]

# Create full documents for docstore (these get returned)
full_table_docs = [
    Document(
        page_content=table if isinstance(table, str) else str(table),
        metadata={
            id_key: table_ids[i],
            "doc_type": "table",
            "filename": ( ##Not needed i think but leave it first
                filenames_for_tables[i]
                if i < len(filenames_for_tables)
                else f"table_{i}"
            ),
            "content_type": "table_full",
        },
    )
    for i, table in enumerate(tables)
]

# Add summaries to vectorstore for searching
retriever.vectorstore.add_documents(summary_tables)

# Add full documents to docstore for retrieval - FIXED FORMAT
retriever.docstore.mset(list(zip(table_ids, full_table_docs, filenames_for_tables)))

### Define filename to map record for Images

In [None]:
filenames_for_images = []
for i, source_filename in enumerate(image_source_filenames):
    # Create unique filename for each image
    image_filename = f"{source_filename}_image_{i}"
    filenames_for_images.append(image_filename)

In [None]:
filenames_for_images 

### Old way - langfuse cannot recognise
### Add both vector(PGVector) and docstore to postgres for Images

In [None]:
# print("Adding images to retriever...")
# img_ids = [str(uuid.uuid4()) for _ in images]
# summary_img = [
#     Document(page_content=summary, metadata={id_key: img_ids[i]})
#     for i, summary in enumerate(image_summaries)
# ]
# retriever.vectorstore.add_documents(summary_img)
# retriever.docstore.mset(list(zip(img_ids, images, filenames_for_images)))

### New way using Document type

In [None]:
image_summaries

In [None]:
images

In [None]:
print("Adding images to retriever...")
img_ids = [str(uuid.uuid4()) for _ in filenames_for_images]

# Create summary documents for vectorstore (these get searched)
summary_img = [
    Document(
        page_content=summary,
        metadata={
            id_key: img_ids[i],
            "doc_type": "image",
            "content_type": "image_summary",
        },
    )
    for i, summary in enumerate(image_summaries)
]

# Create full documents for docstore (these get returned)
full_image_docs = [
    Document(
        page_content=image if isinstance(image, str) else str(image),
        metadata={
            id_key: img_ids[i],
            "doc_type": "image",
            "filename": (  ##Not needed i think but leave it first
                filenames_for_images[i]
                if i < len(filenames_for_images)
                else f"image_{i}"
            ),
            "content_type": "image_full",
        },
    )
    for i, image in enumerate(images)
]

# Add summaries to vectorstore for searching
retriever.vectorstore.add_documents(summary_img)

# Add full documents to docstore for retrieval - FIXED FORMAT
retriever.docstore.mset(list(zip(img_ids, full_image_docs, filenames_for_images)))

### Check In memory store data -
#### can check DB also `select * from public.bytestore;`

In [None]:
# Get all the keys currently in the store
all_doc_ids = store.yield_keys()

# Loop through and fetch each document by its ID
for doc_id in all_doc_ids:
    docs = store.mget([doc_id])  # Returns a list with the document(s)
    print(f"Document ID: {doc_id}")
    for doc in docs:
        print(doc)  # `doc` is a Document object

### check Chroma document - 15 data - obselete

In [None]:
# all_docs = vectorstore.get()

# index = 0
# for doc in all_docs["documents"]:
#     print("index is :", index)
#     print(doc)
#     index = index + 1

### list top 1000 records in postgres vector DB(PG Vector)
#### can check DB also `select * from langchain_pg_embedding`

In [None]:
docs = vectorstore.similarity_search(" ", k=1000)  # or any number large enough

for index, doc in enumerate(docs):
    print("index is:", index)
    print(doc.page_content)

In [None]:
print("Multi-modal RAG setup complete!")
print(f"Processed: {len(texts)} texts, {len(tables)} tables, {len(images)} images")

### Not required search_kwargs unless, need to specifically retrieve top few results 

In [None]:
# retriever.search_kwargs = {"k":4}

In [None]:
docs = retriever.invoke("what are the nationalities allowed for S-pass holders")

In [None]:
docs

In [None]:
# docs[0]
docs[1]

In [None]:
docs = retriever.vectorstore.similarity_search(
    "what are the nationalities allowed for S-pass holders"
)

In [None]:
docs

### some sameples using similarity_search can get the doc id but .invoke cannot

In [None]:
# See what the retriever is actually doing
docs = retriever.invoke("what are the nationalities allowed for S-pass holders")

# The retriever first searches vector store for summaries
query = "what are the nationalities allowed for S-pass holders"
relevant_summaries = retriever.vectorstore.similarity_search(query, k=4)

print("=== RELEVANT SUMMARIES FROM VECTOR STORE ===")
for summary in relevant_summaries:
    print(f"Summary: {summary.page_content[:100]}...")
    print(f"Metadata: {summary.metadata}")
    doc_id = summary.metadata.get("doc_id")
    if doc_id:
        print(f"Will retrieve full doc with ID: {doc_id}")
    print()

# Then retrieves full docs from byte store
if relevant_summaries:
    doc_ids = [
        summary.metadata["doc_id"]
        for summary in relevant_summaries
        if "doc_id" in summary.metadata
    ]
    print(f"Looking up doc_ids: {doc_ids}")
    full_docs = retriever.docstore.mget(doc_ids)
    print(f"Retrieved {len(full_docs)} full documents")

In [None]:
docs[0].to_dict()

### Print the formatted result

In [None]:
for doc in docs:
    print(str(doc) + "\n\n" + "-" * 80)

In [None]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_models import ChatOllama
from base64 import b64decode
import base64

### Define model for Query and Vision

In [None]:
# Configuration - Choose your preferred model
RAG_MODEL = (
    "llama3.1:8b"  # Recommended alternatives: "llama3.2:3b", "mistral:7b", "qwen2:7b"
)
VISION_MODEL = "gemma3:12b"  # For handling images in RAG

print(f"Using RAG model: {RAG_MODEL}")
print(f"Using Vision model: {VISION_MODEL}")

### For detecting images vs text, it will be pass to LLM after vectorstore retrieval

In [None]:
def parse_docs(docs):
    """
    Split retrieved documents into base64-encoded images and text content

    Args:
        docs: List of retrieved documents from the vector store

    Returns:
        dict: Contains separated 'images' and 'texts' lists
    """
    print(f"Processing {len(docs)} retrieved documents")

    b64_images = []
    text_docs = []

    for doc in docs:
        print(f"==>> doc: {doc}")
        print(f"==>> doc type: {type(doc).__name__}")
        # Extract content from Document object
        # Think all is already page_content images, text, tables, should be save to remove this
        if hasattr(doc, "page_content"):
            content = doc.page_content
            print(f"==>> content preview: {content[:100]}...")
        else:
            content = str(doc)
            print(f"==>> raw content: {content[:100]}...")

        # Check if document content is base64 encoded (likely an image)
        try:
            # Try to decode as base64
            # decoded = b64decode(doc.page_content)
            print("testing123123")
            clean_content = content.strip().replace('\n', '').replace('\r', '').replace(' ', '')
            print("is same content as clean_content", content == clean_content)
            b64decode(clean_content)
            # If successful, it's likely base64 encoded image data
            if hasattr(doc, "page_content"):
                doc.page_content = clean_content

            # Append the document object (not just content)
            b64_images.append(b64decode(clean_content))
            print(f"Found base64 image document")
        except Exception as e:
            # If decoding fails, treat as text
            text_docs.append(doc)
            print(f"Found text document: {doc}...")

    return {"images": b64_images, "texts": text_docs}

### constuct text only, using llama3 non vision model - can look to improve the prompt perhaps

In [None]:
def build_prompt_text_only(kwargs):
    """
    Build prompt for text-only RAG (when no images are present)
    Uses the main RAG model for faster processing
    """
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    # Combine all text content
    context_text = ""
    if len(docs_by_type["texts"]) > 0:
        for text_doc in docs_by_type["texts"]:
            # Handle both string content and Document objects
            if hasattr(text_doc, "page_content"):
                context_text += text_doc.page_content + "\n\n"
            else:
                context_text += str(text_doc) + "\n\n"

    # Simple text-based prompt template
    prompt_template = f"""You are a helpful assistant answering questions based on the provided context.

Context:
{context_text.strip()}

Question: {user_question}

Instructions:
- Answer based only on the provided context
- If the context doesn't contain relevant information, say "I don't have enough information to answer this question based on the provided context"
- Be concise and accurate
- If referencing specific data or facts, mention them clearly

Answer:"""

    return ChatPromptTemplate.from_template(prompt_template)

### Construct prompt using a vision model (Gemma) for images and text

In [None]:
def build_prompt_with_vision(kwargs):
    """
    Build prompt for multi-modal RAG (when images are present)
    Uses the vision model to handle both text and images
    """
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    # Combine text content
    context_text = ""
    if len(docs_by_type["texts"]) > 0:
        for text_doc in docs_by_type["texts"]:
            if hasattr(text_doc, "page_content"):
                context_text += text_doc.page_content + "\n\n"
            else:
                context_text += str(text_doc) + "\n\n"

    # Base prompt text
    prompt_text = f"""You are a helpful assistant answering questions based on the provided context, which includes both text and images.

Text Context:
{context_text.strip()}

Question: {user_question}

Instructions:
- Answer based on both the text context and the images provided
- If analyzing images, describe what you see that's relevant to the question
- Be specific about information from images (charts, diagrams, etc.)
- If the context doesn't contain relevant information, say so clearly

Answer:"""

    # Build content list starting with text
    prompt_content = [{"type": "text", "text": prompt_text}]

    # Add images if present
    if len(docs_by_type["images"]) > 0:
        print(f"Adding {len(docs_by_type['images'])} images to prompt")
        for i, image_b64 in enumerate(docs_by_type["images"]):
            prompt_content.append(
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
                }
            )

    return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)])

### Dynamically choose between Llama and Gemma, actually dont need it can use Gemma for all but this will improve the speed

In [None]:
def choose_model_and_prompt(kwargs):
    """
    Dynamically choose between text-only and vision model based on content
    """
    docs_by_type = kwargs["context"]
    print(f"==>> docs_by_type: {docs_by_type}")

    if len(docs_by_type["images"]) > 0:
        # Use vision model for multi-modal content
        print("Using vision model for multi-modal RAG")
        prompt = build_prompt_with_vision(kwargs)
        model = ChatOllama(
            model=VISION_MODEL, temperature=0.1, base_url="http://localhost:11434"
        )
    else:
        # Use text model for text-only content (faster)
        print("Using text model for text-only RAG")
        prompt = build_prompt_text_only(kwargs)
        model = ChatOllama(
            model=RAG_MODEL, temperature=0.1, base_url="http://localhost:11434"
        )

    return prompt | model | StrOutputParser()

### Setting up 4 types of chain

In [None]:
# Main RAG Chain
print("Setting up RAG chain...")
chain = {
    "context": retriever | RunnableLambda(parse_docs),
    "question": RunnablePassthrough(),
} | RunnableLambda(choose_model_and_prompt)

# Alternative: Simple chain that always uses text model (faster but no vision)
simple_text_chain = (
    {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(build_prompt_text_only)
    | ChatOllama(model=RAG_MODEL, temperature=0.1, base_url="http://localhost:11434")
    | StrOutputParser()
)

# Chain with sources (returns both context and response)
chain_with_sources = {
    "context": retriever | RunnableLambda(parse_docs),
    "question": RunnablePassthrough(),
} | RunnablePassthrough().assign(response=RunnableLambda(choose_model_and_prompt))

# Alternative: Always use vision model (slower but handles all content types)
vision_chain = (
    {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(build_prompt_with_vision)
    | ChatOllama(model=VISION_MODEL, temperature=0.1, base_url="http://localhost:11434")
    | StrOutputParser()
)

vision_chain_with_sources = {
    "context": retriever | RunnableLambda(parse_docs),
    "question": RunnablePassthrough(),
} | RunnablePassthrough().assign(
    response=(
        RunnableLambda(build_prompt_with_vision)
        | ChatOllama(
            model=VISION_MODEL, temperature=0.1, base_url="http://localhost:11434"
        )
        | StrOutputParser()
    )
)

In [None]:
tables

### testing chain

In [None]:
# Simple question
response = chain.invoke("How is fin created?", config={"callbacks": [langfuse_handler]})
print(f"==>> response: {response}")

In [None]:
import base64
from IPython.display import Image, display


def display_base64_image(base64_code):
    # Decode the base64 string to binary
    image_data = base64.b64decode(base64_code)
    # Display the image
    display(Image(data=image_data))

display_base64_image(images[0])

In [None]:
# With sources
# response = vision_chain_with_sources.invoke(
#     "if a company has 2 CSN account, would a levy default bl applied to both CSN if the company failed to pay the levy for one of the CSN, is there a table base on what type of CSN they are holding?",
#     config={"callbacks": [langfuse_handler]},
# )
response = chain_with_sources.invoke(
    "Is security bond (SB) needed for all work permit (WP) holders? How about for S-Pass and Employment Pass (EP)?",
    config={"callbacks": [langfuse_handler]},
)
print(f"==>> response: {response}")
print("Response:", response["response"])
print("Context used:", response["context"])
for text in response["context"]["texts"]:
    print(text.text)
    print("Page number: ", text.metadata.page_number)
    print("\n" + "-" * 50 + "\n")
for image in response["context"]["images"]:
    display_base64_image(image)