# Multimodal RAG

In [46]:
import os
import sys

current_dir = os.getcwd()
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))

sys.path.append(kit_dir)
sys.path.append(repo_dir)

from utils.sambanova_endpoint import SambaNovaEndpoint
from dotenv import load_dotenv
load_dotenv(os.path.join(repo_dir,'.env'))

import requests
import json
import base64
from pprint import pprint

## utils

In [47]:
def image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        image_binary = image_file.read()
        base64_image = base64.b64encode(image_binary).decode()
        return base64_image

## Multimodal call

In [71]:
# sambastudio call
# Define the endpoint URL and key
def llava_call(prompt, image_path):
    image=image_to_base64(image_path)
    endpoint_url = f"{os.environ.get('BASE_URL')}/api/predict/generic/{os.environ.get('PROJECT_ID')}/{os.environ.get('ENDPOINT_ID')}"
    endpoint_key = os.environ.get('API_KEY')
    # Define the data payload
    data = {
        "instances": [{
            "prompt": prompt,
            "image_content": f"{image}"
        }],
        "params": {
            "do_sample": {"type": "bool", "value": "false"},
            "max_tokens_to_generate": {"type": "int", "value": "100"},
            "temperature": {"type": "float", "value": "1"},
            "top_k": {"type": "int", "value": "50"},
            "top_logprobs": {"type": "int", "value": "0"},
            "top_p": {"type": "float", "value": "1"}
        }
    }
    # Define headers
    headers = {
        "Content-Type": "application/json",
        "key": endpoint_key
    }
    response = requests.post(endpoint_url, headers=headers, data=json.dumps(data))
    return response.json()["predictions"][0]['completion']

prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the humans question. USER: <image>\nAre you allowed to swim here?. ASSISTANT:"
image_path = os.path.join(kit_dir,"data/view.jpg")
llava_call(prompt, image_path)

'Yes, you are allowed to swim in the lake. The image shows a pier extending over the water, and there are no visible signs or barriers that prohibit swimming. The serene environment and the presence of a pier suggest that it is a suitable location for swimming and enjoying the water.'

In [78]:
#replicate  usage
import replicate

output = replicate.run(
    "yorickvp/llava-13b:b5f6212d032508382d61ff00469ddda3e32fd8a0e75dc39d8a4191bb742157fb",
    input={
        "image": f"data:application/octet-stream;base64,{image_to_base64(os.path.join(kit_dir,'data/view.jpg'))}",
        "top_p": 1,
        "prompt": "Are you allowed to swim here?",
        "max_tokens": 1024,
        "temperature": 0.2
    }
)

# The yorickvp/llava-13b model can stream output as it's running.
# The predict method returns an iterator, and you can iterate over that output.
for item in output:
    # https://replicate.com/yorickvp/llava-13b/api#output-schema
    print(item, end="")


Yes, you are allowed to swim in the lake near the pier. The image shows a pier extending out into the water, which suggests that it is a popular spot for swimming and other water-related activities.

## PDF extraction

In [97]:
from typing import Any

from pydantic import BaseModel
from unstructured.partition.pdf import partition_pdf

# Path to save images
path =  os.path.join(kit_dir, "data/")

# Get elements
raw_pdf_elements = partition_pdf(
    filename=path + "SambaNova_Suite_Solution_Brief_06-21-23.pdf",
    # Using pdf format to find embedded image blocks
    extract_images_in_pdf=True,
    # Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
    # Titles are any sub-section of the document
    infer_table_structure=True,
    # Post processing to aggregate text once we have the title
    chunking_strategy="by_title",
    # Chunking params to aggregate text blocks
    # Attempt to create a new chunk 3800 chars
    # Attempt to keep chunks > 2000 chars
    # Hard max on chunks
    max_characters=1000,
    new_after_n_chars=800,
    combine_text_under_n_chars=500,
    image_output_dir_path=path,
)

This function will be deprecated in a future release and `unstructured` will simply use the DEFAULT_MODEL from `unstructured_inference.model.base` to set default model name


In [93]:
raw_pdf_elements

[<unstructured.documents.elements.CompositeElement at 0x2da034550>,
 <unstructured.documents.elements.CompositeElement at 0x2da034520>,
 <unstructured.documents.elements.CompositeElement at 0x2da034160>,
 <unstructured.documents.elements.CompositeElement at 0x2a46b7310>,
 <unstructured.documents.elements.CompositeElement at 0x2da0341f0>,
 <unstructured.documents.elements.CompositeElement at 0x2d9f886d0>,
 <unstructured.documents.elements.CompositeElement at 0x2d9f89420>,
 <unstructured.documents.elements.CompositeElement at 0x2d9f89210>,
 <unstructured.documents.elements.CompositeElement at 0x2d9f8b9a0>,
 <unstructured.documents.elements.CompositeElement at 0x2d9f880a0>,
 <unstructured.documents.elements.CompositeElement at 0x2d9f891e0>,
 <unstructured.documents.elements.CompositeElement at 0x2d9f88dc0>,
 <unstructured.documents.elements.CompositeElement at 0x2b339bee0>]

In [98]:
# Create a dictionary to store counts of each type
category_counts = {}

for element in raw_pdf_elements:
    category = str(type(element))
    if category in category_counts:
        category_counts[category] += 1
    else:
        category_counts[category] = 1

# Unique_categories will have unique elements
# TableChunk if Table > max chars set above
unique_categories = set(category_counts.keys())
category_counts

{"<class 'unstructured.documents.elements.CompositeElement'>": 13}

In [99]:
class Element(BaseModel):
    type: str
    text: Any


# Categorize by type
categorized_elements = []
for element in raw_pdf_elements:
    if "unstructured.documents.elements.Table" in str(type(element)):
        categorized_elements.append(Element(type="table", text=str(element)))
    elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
        categorized_elements.append(Element(type="text", text=str(element)))

# Tables
table_elements = [e for e in categorized_elements if e.type == "table"]
print(len(table_elements))

# Text
text_elements = [e for e in categorized_elements if e.type == "text"]
print(len(text_elements))

0
13
