In [None]:
# !brew install poppler tesseract libmagic
#install globally
#brew install tesseract poppler libmagic
# echo 'export PATH="/opt/homebrew/bin:$PATH"' >> ~/.zshrc
# source ~/.zshrc

In [None]:
import os

from dotenv import load_dotenv
load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = os.getenv("LANGCHAIN_TRACING_V2")

### Partition PDF tables, text, and images

In [None]:

import os

os.environ["PATH"] += os.pathsep + "/opt/homebrew/bin"

In [None]:
import subprocess
import sys

# Check if tesseract is accessible
try:
    result = subprocess.run(["tesseract", "--version"], capture_output=True, text=True)
    print("Tesseract version:", result.stdout)
except FileNotFoundError:
    print("Tesseract not found in PATH")

# Check PATH
import os

print("Current PATH:", os.environ.get("PATH", ""))

In [None]:
from unstructured.partition.pdf import partition_pdf

output_path = "./content/"
file_path = output_path + "attention.pdf"

# Reference: https://docs.unstructured.io/open-source/core-functionality/chunking
chunks = partition_pdf(
    filename=file_path,
    infer_table_structure=True,  # extract tables
    strategy="hi_res",  # mandatory to infer tables
    extract_image_block_types=[
        "Image"
    ],  # Add 'Table' to list to extract image of tables
    # image_output_dir_path=output_path,   # if None, images and tables will saved in base64
    extract_image_block_to_payload=True,  # if true, will extract base64 for API usage
    chunking_strategy="by_title",  # or 'basic'
    max_characters=10000,  # defaults to 500
    combine_text_under_n_chars=2000,  # defaults to 0
    new_after_n_chars=6000,
    # extract_images_in_pdf=True,          # deprecated
)

In [None]:
set([str(type(el)) for el in chunks])

In [None]:
chunks

### an image example

In [None]:
elements = chunks[3].metadata.orig_elements
chunk_images = [el for el in elements if "Image" in str(type(el))]
chunk_images[0].to_dict()

In [None]:
from unstructured.documents.elements import Table, CompositeElement, Text, Title

tables = []
texts = []
images = []

for chunk in chunks:
    # print(chunk)
    # print(type(chunk))
    # Case 1: Direct Table element (not wrapped in CompositeElement)
    if isinstance(chunk, Table):
        tables.append(chunk)

    # Case 2: CompositeElement (e.g., Title + Text + Table grouped)
    elif isinstance(chunk, CompositeElement):
        texts.append(chunk)
        orig_elements = getattr(chunk.metadata, "orig_elements", [])
        for el in orig_elements:
            if isinstance(el, Table):
                tables.append(el)

In [None]:
tables

In [None]:
texts

In [None]:
# Get the images from the CompositeElement objects
def get_images_base64(chunks):
    images_b64 = []
    for chunk in chunks:
        if "CompositeElement" in str(type(chunk)):
            chunk_els = chunk.metadata.orig_elements
            for el in chunk_els:
                if "Image" in str(type(el)):
                    images_b64.append(el.metadata.image_base64)
    return images_b64


images = get_images_base64(chunks)

In [None]:
images

### Text and Table Summary

In [None]:
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

In [None]:
# Prompt
prompt_text = """
You are an assistant tasked with summarizing tables and text.
Give a concise summary of the table or text.

Respond only with the summary, no additionnal comment.
Do not start your message by saying "Here is a summary" or anything like that.
Just give the summary as it is.

Table or text chunk: {element}

"""
prompt = ChatPromptTemplate.from_template(prompt_text)

# Summary chain
model = ChatGroq(temperature=0.5, model="llama-3.1-8b-instant")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
summarize_chain

In [None]:
# Summarize text
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})

# Summarize tables
tables_html = [table.metadata.text_as_html for table in tables]
table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 1})

In [None]:
text_summaries

In [None]:
tables_html

### Image Summary

In [None]:
from langchain_openai import ChatOpenAI

prompt_template = """Describe the image in detail. For context,
                  the image is part of a research paper explaining the transformers
                  architecture. Be specific about graphs, such as bar plots."""

# prompt_template = """Describe the image in detail. For context,
#                   it should be a flow diagram"""
messages = [
    (
        "user",
        [
            {"type": "text", "text": prompt_template},
            {
                "type": "image_url",
                "image_url": {"url": "data:image/jpeg;base64,{image}"},
            },
        ],
    )
]

prompt = ChatPromptTemplate.from_messages(messages)

chain = prompt | ChatOpenAI(model="gpt-4o-mini") | StrOutputParser()


image_summaries = chain.batch(images)

In [None]:
image_summaries

### Create a vector store

In [None]:
import uuid
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore, LocalFileStore
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever

# The vectorstore to use to index the child chunks
vectorstore = Chroma(
    collection_name="multi_modal_rag",
    embedding_function=OpenAIEmbeddings(),
    persist_directory="./chroma_db",  # This will create a folder on your disk
)


# The storage layer for the parent documents
store = InMemoryStore()
# store = LocalFileStore("./document_store")  # This will create a folder for documents
id_key = "doc_id"


# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

In [None]:
retriever

### Loading/inserting into vector store

In [None]:
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
    Document(page_content=summary, metadata={id_key: doc_ids[i]})
    for i, summary in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=summary, metadata={id_key: table_ids[i]})
    for i, summary in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

# Add image summaries
img_ids = [str(uuid.uuid4()) for _ in images]
summary_img = [
    Document(page_content=summary, metadata={id_key: img_ids[i]})
    for i, summary in enumerate(image_summaries)
]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(list(zip(img_ids, images)))

### Check In memory store data - 15 data

In [None]:
# Get all the keys currently in the store
all_doc_ids = store.yield_keys()

# Loop through and fetch each document by its ID
for doc_id in all_doc_ids:
    docs = store.mget([doc_id])  # Returns a list with the document(s)
    print(f"Document ID: {doc_id}")
    for doc in docs:
        print(doc)  # `doc` is a Document object

### check Chroma document - 15 data

In [None]:
all_docs = vectorstore.get()
index = 0
for doc in all_docs["documents"]:
    print("index is :", index)
    print(doc)
    index = index + 1

In [None]:
# # Add texts
# import pickle

# doc_ids = [str(uuid.uuid4()) for _ in texts]
# summary_texts = [
#     Document(page_content=summary, metadata={id_key: doc_ids[i]})
#     for i, summary in enumerate(text_summaries)
# ]
# retriever.vectorstore.add_documents(summary_texts)
# # retriever.docstore.mset(
# #     [(doc_ids[i], pickle.dumps(texts[i])) for i in range(len(texts))]
# # )
# retriever.docstore.mset(list(zip(doc_ids, texts)))

# # Add tables
# table_ids = [str(uuid.uuid4()) for _ in tables]
# summary_tables = [
#     Document(page_content=summary, metadata={id_key: table_ids[i]})
#     for i, summary in enumerate(table_summaries)
# ]
# retriever.vectorstore.add_documents(summary_tables)
# retriever.docstore.mset(list(zip(table_ids, tables)))
# # retriever.docstore.mset(
# #     [(table_ids[i], pickle.dumps(tables[i])) for i in range(len(tables))]
# # )

# # Add image summaries
# img_ids = [str(uuid.uuid4()) for _ in images]
# summary_img = [
#     Document(page_content=summary, metadata={id_key: img_ids[i]})
#     for i, summary in enumerate(image_summaries)
# ]
# retriever.vectorstore.add_documents(summary_img)
# retriever.docstore.mset(list(zip(img_ids, images)))
# # retriever.docstore.mset(
# #     [(img_ids[i], pickle.dumps(images[i])) for i in range(len(images))]
# # )

In [None]:
summary_texts

In [None]:
summary_tables

In [None]:
summary_img

### Set number of results to be return, say the top K base on relevance

In [None]:
# retriever.search_kwargs = {"k": 4}


In [None]:
retriever.search_kwargs

### Retriever 
 1) embed query
 2) similarity search over summaries, 
 3) match actual doc store result
 4) Return the Doc store result

In [None]:
# docs = retriever.invoke("who are the authors of the paper?")
docs = retriever.invoke("what is multihead attention?")

In [None]:
docs

In [None]:
# docs[0].to_dict()

### Print the formatted result

In [None]:
for doc in docs:
    print(str(doc) + "\n\n" + "-" * 80)

### RAG PIPELINE

In [None]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI
from base64 import b64decode

def parse_docs(docs):
    # print('docs is:',docs)
    # print('len of doc is: ',len(docs))
    """Split base64-encoded images and texts"""
    b64 = []
    text = []
    for doc in docs:
        try:
            b64decode(doc)
            b64.append(doc)
        except Exception as e:
            text.append(doc)
    return {"images": b64, "texts": text}


def build_prompt(kwargs):
    # print('kwargs is: ',kwargs)
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    context_text = ""
    if len(docs_by_type["texts"]) > 0:
        # print('docs_by_type["texts"]', docs_by_type["texts"])
        for text_element in docs_by_type["texts"]:
            # print('text_element', text_element)
            context_text += text_element.text

    # print("context_text",context_text)
    # construct prompt with context (including images)
    prompt_template = f"""
    Answer the question based only on the following context, which can include text, tables, and the below image.
    Context: {context_text}
    Question: {user_question}
    """

    prompt_content = [{"type": "text", "text": prompt_template}]

    if len(docs_by_type["images"]) > 0:
        for image in docs_by_type["images"]:
            prompt_content.append(
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{image}"},
                }
            )

    return ChatPromptTemplate.from_messages(
        [
            HumanMessage(content=prompt_content),
        ]
    )


chain = (
    {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(build_prompt)
    | ChatOpenAI(model="gpt-4o-mini")
    | StrOutputParser()
) 

chain_with_sources = {
    "context": retriever | RunnableLambda(parse_docs),
    "question": RunnablePassthrough(),
} | RunnablePassthrough().assign(
    response=(
        RunnableLambda(build_prompt)
        | ChatOpenAI(model="gpt-4o-mini")
        | StrOutputParser()
    )
)

### Check image function, just for verification

In [None]:
response = chain.invoke("What do the authors mean by 'attention'?")

print(response)

In [None]:
response = chain_with_sources.invoke("What do the authors mean by 'attention'?")

print("Response:", response["response"])

print("\n\nContext:")
for text in response["context"]["texts"]:
    print(text.text)
    print("Page number: ", text.metadata.page_number)
    print("\n" + "-" * 50 + "\n")
for image in response["context"]["images"]:
    display_base64_image(image)

### Image helper using raw base64

In [None]:
import base64
from IPython.display import Image, display


def display_base64_image(base64_code):
    # Decode the base64 string to binary
    image_data = base64.b64decode(base64_code)
    # Display the image
    display(Image(data=image_data))


display_base64_image(images[0])

In [None]:
print(image_summaries[1])