# Multimodal RAG

In [1]:
import os
import sys
import glob

current_dir = os.getcwd()
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))

sys.path.append(kit_dir)
sys.path.append(repo_dir)

from utils.sambanova_endpoint import SambaNovaEndpoint
from dotenv import load_dotenv
load_dotenv(os.path.join(repo_dir,'.env'))

import requests
import json
import base64
from pprint import pprint

## utils

In [2]:
def image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        image_binary = image_file.read()
        base64_image = base64.b64encode(image_binary).decode()
        return base64_image

## Multimodal call

### sambastudio llava call method

In [3]:
# sambastudio call
def llava_call(prompt, image_path):
    image=image_to_base64(image_path)
    endpoint_url = f"{os.environ.get('LVLM_BASE_URL')}/api/predict/generic/{os.environ.get('LVLM_PROJECT_ID')}/{os.environ.get('LVLM_ENDPOINT_ID')}"
    endpoint_key = os.environ.get('LVLM_API_KEY')
    # Define the data payload
    data = {
        "instances": [{
            "prompt": prompt,
            "image_content": f"{image}"
        }],
        "params": {
            "do_sample": {"type": "bool", "value": "false"},
            "max_tokens_to_generate": {"type": "int", "value": "512"},
            "temperature": {"type": "float", "value": "1"},
            "top_k": {"type": "int", "value": "50"},
            "top_logprobs": {"type": "int", "value": "0"},
            "top_p": {"type": "float", "value": "1"}
        }
    }
    # Define headers
    headers = {
        "Content-Type": "application/json",
        "key": endpoint_key
    }
    response = requests.post(endpoint_url, headers=headers, data=json.dumps(data))
    return response.json()["predictions"][0]['completion']

### QA Llava Call

In [4]:
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the humans question. USER: <image>\nhow many birds could you find at 4pm?. ASSISTANT:"
image_path = os.path.join(kit_dir,"data","sample_docs","sample.png")
llava_call(prompt, image_path)

'At 4 pm, you could find approximately 10 birds on the tree.'

### Summary Llava call

In [5]:
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the humans question. USER: <image>\nDescribe the image in detail. Be specific about graphs, such as bar plots, scatter plots, or others. ASSISTANT:"
llava_call(prompt, image_path)

'The image displays a graph showing the number of birds on a tree at different times of the day. The graph is a combination of a bar plot and a scatter plot, with the bar plot showing the number of birds at various times of the day, and the scatter plot showing the number of birds on a tree at a specific time.\n\nThe graph is divided into two main sections. The first section, which is the bar plot, shows the number of birds on a tree at different times of the day, with the bars extending from 10 am to 11 pm. The second section, which is the scatter plot, shows the number of birds on a tree at a specific time, with the x-axis representing the time and the y-axis representing the number of birds.\n\nThe graph is labeled with the time of the day, and the number of birds is represented by the number of orange dots on the graph. The dots are scattered throughout the graph, with some appearing closer to the bars and others appearing closer to the scatter plot.'

## Doc Extraction

### Unstructured PDF extraction

In [6]:
from unstructured.partition.pdf import partition_pdf

# Path to save images
file_path=os.path.join(kit_dir, "data", "sample_docs", "invoicesample.pdf")
output_path=os.path.splitext(file_path)[0]

# Get elements
raw_pdf_elements = partition_pdf(
    filename=file_path,
    extract_images_in_pdf=True,
    strategy='hi_res',
     hi_res_model_name="yolox",
    # Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
    # Titles are any sub-section of the document
    infer_table_structure=True,
    chunking_strategy="by_title",
    max_characters=1000,
    new_after_n_chars=800,
    combine_text_under_n_chars=500,
    extract_image_block_output_dir=output_path,
)

Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### View Elements

In [7]:
for i, element in enumerate(raw_pdf_elements):
    print(f"\033[95m ELEMENT {i}\033[00m")
    print(f"TYPE: {type(element)}")
    print(f"META: {element.metadata.to_dict()}")
    print(f"TEXT: {element.text}")
    print("\n\n##########\n")

[95m ELEMENT 0[00m
TYPE: <class 'unstructured.documents.elements.CompositeElement'>
META: {'filetype': 'application/pdf', 'languages': ['eng'], 'last_modified': '2024-05-08T08:57:57', 'page_number': 1, 'orig_elements': 'eJy9lk2P2zYQhv+KoPbQAqbEb5F7K5qi2EM2KdbpZbEQSGrkZVaWBIleZxP0v5eSnCAfToAYsAFd5tWMTb7PDMW7Dyk0sIU2lL5Kr5JUWF1r4AYRYwBxbCxSuCDIacI0NoWuqzpdJekWgqlMMLHmQ+q6bqh8awKMc9yY524Xygfwm4cQFcqYjjUHee+r8BBVIgWLat/5Nkx1d3da00ysEsJ5xu9XyceYEpmJKSYCFxk9IiwVUUnH5zHAdtrJa/8OmtveOEj/iy8qCOCC79rSNWYcy37obEzDGZVMyJhQ+wbCcw9z7euX6bzgdrMzm3lXdym0m/R+VsdQbrvK1x5mzyimHGGBsFpjdSWK+EzVfaws293WwjDtdlpEgHeTH+kf6/VfN+vrVzfJ+tWU+vF/1z4083K/xmJ4jakzGlErasQlkchKqBBIEx9FrLH8bFhIIbIiui5wRmbXl1gKkbEpVoRPmL6Ol/zToCgmv8PEbydbezOvNc3fjDCM+dtu2ECfv+jcbjJtzM34WPY723hXdvs2Nx6NwQwBBvToQ77dNcFHgqYpH9tu30AVf3KAMHh4giGfDMxHs+0bKKvOjblvnzrvYJHy2m92AyCCSPa235y1Tz7vjetp48d6w2IFzFQOGVuZ2BtOIEtib8i6sMS5ymHGztcbmMzNwChb4H8SpMzULDAuM31MmEtOHVqKxYWH9gW07XPy9641e9N+TubGDIMJ/gnWU+YRQlirigscx7Wo4/TyKh6qrrbIcaeU4oopflZCKiOrhM

In [8]:
# Create a dictionary to store counts of each type
category_counts = {}

for element in raw_pdf_elements:
    category = str(type(element))
    if category in category_counts:
        category_counts[category] += 1
    else:
        category_counts[category] = 1

# Unique_categories will have unique elements
# TableChunk if Table > max chars set above
unique_categories = set(category_counts.keys())
category_counts

{"<class 'unstructured.documents.elements.CompositeElement'>": 2,
 "<class 'unstructured.documents.elements.Table'>": 1}

In [26]:
from langchain.schema import Document


# Categorize by type
categorized_elements = []
for element in raw_pdf_elements:
    if "unstructured.documents.elements.Table" in str(type(element)):
        meta = element.metadata.to_dict()
        meta["type"] = "table"
        categorized_elements.append(Document(page_content=element.metadata.text_as_html, metadata=meta))
    elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
        meta = element.metadata.to_dict()
        meta["type"] = "text"
        categorized_elements.append(Document(page_content=str(element), metadata=meta))

# Tables
table_docs = [e for e in categorized_elements if e.metadata["type"] == "table"]
print(len(table_docs))

# Text
text_docs = [e for e in categorized_elements if e.metadata["type"] == "text"]
print(len(text_docs))

1
2


## Retrieval with raw text, raw tables and image summaries

### Text and table summaries

In [10]:
from langchain_community.llms.sambanova import Sambaverse
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import load_prompt

In [11]:
text_prompt = load_prompt(os.path.join(kit_dir, "prompts", "llama70b-text_summary.yaml"))
table_prompt = load_prompt(os.path.join(kit_dir, "prompts", "llama70b-table_summary.yaml"))

# Summary chain
model = Sambaverse(
    sambaverse_model_name="Meta/llama-2-70b-chat-hf",
    model_kwargs={
            "do_sample": True, 
            "max_tokens_to_generate": 256,
            "temperature": 0.01,
            "process_prompt": True,
            "select_expert": "llama-2-70b-chat-hf"
            #"stop_sequences": { "type":"str", "value":""},
            # "repetition_penalty": {"type": "float", "value": "1"},
            # "top_k": {"type": "int", "value": "50"},
            # "top_p": {"type": "float", "value": "1"}
        }
)
text_summarize_chain = {"element": lambda x: x} | text_prompt | model | StrOutputParser()
table_summarize_chain = {"element": lambda x: x} | table_prompt | model | StrOutputParser()

### Text Summaries

In [12]:
# Apply to text
texts = [i.page_content for i in text_docs if i.page_content != ""]
if texts:
    text_summaries = text_summarize_chain.batch(texts, {"max_concurrency": 1})

In [13]:
text_summaries

[" Sure! Here's a concise summary of the text chunk you provided:\n\nDenny Gunawan has an invoice for $39.60 with invoice number #20130304. The address is 221 Queen St Melbourne VIC 3000, and the phone number is (03) 1234 5678.",
 ' A fictitious receipt or invoice with a subtotal, a 10% GST charge, and a total that includes the GST. The total amount is $39.60.']

### Table summaries

In [14]:
# Apply to tables
tables = [i.page_content for i in table_docs]
if tables:
    table_summaries = table_summarize_chain.batch(tables, {"max_concurrency":1})

In [15]:
table_summaries

[' The table shows the prices and quantities of various fruits. The fruits included are Apple, Orange, Watermelon, Mango, and Peach. The prices range from $1.69 (Watermelon) to $9.56 (Mango). The quantities range from 1 (Apple, Peach) to 3 (Watermelon). The total cost of each fruit is also shown, which ranges from $5.00 (Apple) to $19.12 (Mango).']

### Image summary

In [16]:
image_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the humans question. USER: <image>\n, such as bar plots. ASSISTANT:"
image_prompt = load_prompt(os.path.join(kit_dir, "prompts", "llava.yaml"))
prompt = image_prompt.format(instruction = "Describe the image in detail. Be specific about graphs include name of axis, labels, legends and important numerical information")
image_paths = []
image_paths.extend(glob.glob(os.path.join(output_path, '*.jpg')))
image_paths.extend(glob.glob(os.path.join(output_path, '*.png')))

image_summaries = []
image_docs = []

for image_path in image_paths:
    result = llava_call(prompt, image_path)
    image_summaries.append(result)
    image_docs.append(Document(page_content=result, metadata={"type": "image", 'file_directory': image_path }))

In [17]:
image_summaries

["The image features a logo for Sunny Farm, a company that specializes in fresh produce. The logo is a gold and yellow color scheme, with a sun in the center, symbolizing the warmth and freshness of the products. The sun is surrounded by trees, which further emphasize the connection to nature and the source of the fresh produce.\n\nThe logo is placed on a white background, making it stand out and be easily recognizable. The sun in the logo is positioned at the top left corner, while the trees are located at the bottom right corner. The overall design of the logo is simple yet effective in conveying the company's message."]

### add to vectorstore

In [18]:
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryByteStore
from utils.sambanova_endpoint import SambaNovaEmbeddingModel
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document

# The vectorstore to use to index the child chunks
vectorstore = Chroma(
    collection_name="summaries", embedding_function=SambaNovaEmbeddingModel()
)

# The storage layer for the parent documents
store = InMemoryByteStore()  
id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
    search_kwargs={"k":2}
)

In [19]:
# Add texts
if texts:
    doc_ids = [str(uuid.uuid4()) for _ in text_docs]
    summary_texts = [
        Document(page_content=s, metadata={id_key: doc_ids[i]})
        for i, s in enumerate(text_summaries)
    ]
    retriever.vectorstore.add_documents(summary_texts)
    retriever.docstore.mset(list(zip(doc_ids, text_docs)))

# Add tables
if tables:
    table_ids = [str(uuid.uuid4()) for _ in table_docs]
    summary_tables = [
        Document(page_content=s, metadata={id_key: table_ids[i]})
        for i, s in enumerate(table_summaries)
    ]
    retriever.vectorstore.add_documents(summary_tables)
    retriever.docstore.mset(list(zip(table_ids, table_docs)))

# Add images
if image_summaries:
    img_ids = [str(uuid.uuid4()) for _ in image_summaries]
    summary_img = [
        Document(page_content=s, metadata={id_key: img_ids[i]})
        for i, s in enumerate(image_summaries)
    ]
    retriever.vectorstore.add_documents(summary_img)
    retriever.docstore.mset(
        list(zip(img_ids, image_docs))
    )  # Store the image summary as the raw document

In [23]:
retriever.invoke("what is the final price in the invoice?")

[Document(page_content='THANK YOU\n\n* Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam sodales dapibus fermentum. Nunc adipiscing, magna sed scelerisque cursus, erat lectus dapibus urna, sed facilisis leo dui et ipsum.\n\nSubtotal | Total\n\nSubtotal\n\nGST (10%)\n\nTotal\n\n$36.00\n\n$3.60\n\n$39.60', metadata={'filetype': 'application/pdf', 'languages': ['eng'], 'last_modified': '2024-05-08T08:57:57', 'page_number': 1, 'orig_elements': 'eJzNVVuL3DYY/SvCtNCWkdD9krc8tdB2G9jJQ1mWQdZlVuCxJ7bcZpvmv1f2eJKhmTabwmwHDEbHOrJ0zvk+3b2rQhN2oc2b5KsXoJJeqlALDqlkAvLIPKy1J5BwJxjVtY22rlag2oVsvc22cN5Vrut6n1qbwzCPG/vYjXnzENL2IReEMmYKZ4F/Tz4/FJRIwQq671KbJ97dnVSozCNGEMTvV2AZU6zUYawVQ/oMsDAKUg2PQw676SSv0tvQ3O6tC9X78sGHHFxOXbtxjR2Gzb7v6jINI2kYVmVCTE3Ij/swc1/9XM0bbrej3c6nuqtCu63uZ3TIm13nU0xh1oxiyiEWEOs11i+EKs/E3hfmph13dein006byOHtpEe1/uHlzY/g119eT/OOP12n3Mx7/bsnjkcRbXSwtiIUT6SFtSIGakato45iTOzFPBEUkSI5IRiZSfLj2EhEZwukQfwsMDP+myeGYk6f2ZPvwE9dH3Yg7YdxB3zXdD0YUga2yLoCrmuHsteQxx5Yn/ZpcKndgtCkjMDL

In [21]:
retriever.invoke("what is the logo of the company")

[Document(page_content="The image features a logo for Sunny Farm, a company that specializes in fresh produce. The logo is a gold and yellow color scheme, with a sun in the center, symbolizing the warmth and freshness of the products. The sun is surrounded by trees, which further emphasize the connection to nature and the source of the fresh produce.\n\nThe logo is placed on a white background, making it stand out and be easily recognizable. The sun in the logo is positioned at the top left corner, while the trees are located at the bottom right corner. The overall design of the logo is simple yet effective in conveying the company's message.", metadata={'type': 'image', 'file_directory': '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/multimodal_knowledge_retriever/data/sample_docs/invoicesample/figure-1-1.jpg'}),
 Document(page_content='ATTENTION TO\n\nDenny Gunawan\n\n221 Queen St Melbourne VIC 3000\n\n123 Somewhere St, Melbourne VIC 3000 (03) 1234 5678\n\n$39.60\n\nInvoice N

### Retrieval

In [22]:
from langchain.chains import RetrievalQA

prompt = load_prompt(os.path.join(kit_dir,"prompts","llama70b-knowledge_retriever_custom_qa_prompt.yaml"))

chain = RetrievalQA.from_llm(
    llm = model,
    retriever=retriever,
    return_source_documents=True,
    input_key="question",
    output_key="answer"
)
chain.combine_documents_chain.llm_chain.prompt=prompt


In [24]:
chain.invoke({"question": "what is the final price in the invoice?"})

{'question': 'what is the final price in the invoice?',
 'answer': ' Sure, I can help you with that! Based on the provided context, the final price in the invoice is $39.60. This can be found in the context labeled "Subtotal | Total" which lists the total price as $39.60, including a $3.60 GST charge.\n\nSo, the answer to your question is:\n\n$39.60',
 'source_documents': [Document(page_content='THANK YOU\n\n* Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam sodales dapibus fermentum. Nunc adipiscing, magna sed scelerisque cursus, erat lectus dapibus urna, sed facilisis leo dui et ipsum.\n\nSubtotal | Total\n\nSubtotal\n\nGST (10%)\n\nTotal\n\n$36.00\n\n$3.60\n\n$39.60', metadata={'filetype': 'application/pdf', 'languages': ['eng'], 'last_modified': '2024-05-08T08:57:57', 'page_number': 1, 'orig_elements': 'eJzNVVuL3DYY/SvCtNCWkdD9krc8tdB2G9jJQ1mWQdZlVuCxJ7bcZpvmv1f2eJKhmTabwmwHDEbHOrJ0zvk+3b2rQhN2oc2b5KsXoJJeqlALDqlkAvLIPKy1J5BwJxjVtY22rlag2oVsvc22cN5Vrut6n1qbwzCPG/vYj

In [25]:
chain.invoke("what is the logo of the company")

{'question': 'what is the logo of the company',
 'answer': ' Based on the provided context, the logo of the company is a gold and yellow logo featuring a sun in the center, surrounded by trees. The sun is positioned at the top left corner, while the trees are located at the bottom right corner. The logo is placed on a white background, making it stand out and easily recognizable.',
 'source_documents': [Document(page_content="The image features a logo for Sunny Farm, a company that specializes in fresh produce. The logo is a gold and yellow color scheme, with a sun in the center, symbolizing the warmth and freshness of the products. The sun is surrounded by trees, which further emphasize the connection to nature and the source of the fresh produce.\n\nThe logo is placed on a white background, making it stand out and be easily recognizable. The sun in the logo is positioned at the top left corner, while the trees are located at the bottom right corner. The overall design of the logo is 

## Retrieval with raw text, raw tables and raw images

In [28]:
#WIP