In [1]:
%%capture
!pip install chromadb
!pip install datasets
!pip install pyarrow
!pip install open-clip-torch
!pip install sentence-transformers
!pip install langchain_core langchain_openai

In [2]:
import IPython
from IPython.display import HTML, display, Image, Markdown, Video, Audio
from typing import Optional, Sequence, List, Dict, Union

import soundfile as sf

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from google.colab import userdata

from sentence_transformers import SentenceTransformer
from transformers import ClapModel, ClapProcessor
from datasets import load_dataset

import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
from chromadb.api.types import Document, Embedding, EmbeddingFunction, URI, DataLoader

import numpy as np
import torchaudio
import base64
import torch
import json
import cv2
import os

In [3]:
def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [4]:
path = "mm_vdb"
client = chromadb.PersistentClient(path=path)

## Load the Dataset

In [5]:
ds = load_dataset("ashraq/esc50")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/345 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


dataset_infos.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

(…)-00000-of-00002-2f1ab7b824ec751f.parquet:   0%|          | 0.00/387M [00:00<?, ?B/s]

(…)-00001-of-00002-27425e5c1846b494.parquet:   0%|          | 0.00/387M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2000 [00:00<?, ? examples/s]

## Saving Audio Files Locally

Saving the audio files locally will allow us to display them later, as Chroma will not store the actual files for modalities outside of text.

In [6]:
# Define the directory to save audio files
path = "esc50"
os.makedirs(path, exist_ok=True)

# Process and save audio files
for item in ds['train']:
    audio_array = item['audio']['array']
    sample_rate = item['audio']['sampling_rate']
    file_name = item['filename']
    target_path = os.path.join(path, file_name)

    # Write the audio file to the new directory
    sf.write(target_path, audio_array, sample_rate)

print("All audio files have been processed and saved.")

All audio files have been processed and saved.


In [7]:
class AudioLoader(DataLoader[List[Optional[Dict[str, any]]]]):
    def __init__(self, target_sample_rate: int = 48000) -> None:
        self.target_sample_rate = target_sample_rate

    def _load_audio(self, uri: Optional[URI]) -> Optional[Dict[str, any]]:
        if uri is None:
            return None

        try:
            waveform, sample_rate = torchaudio.load(uri)

            # Resample if necessary
            if sample_rate != self.target_sample_rate:
                resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
                waveform = resampler(waveform)

            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            return {"waveform": waveform.squeeze(), "uri": uri}
        except Exception as e:
            print(f"Error loading audio file {uri}: {str(e)}")
            return None

    def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Dict[str, any]]]:
        return [self._load_audio(uri) for uri in uris]

In [8]:
class CLAPEmbeddingFunction(EmbeddingFunction[Union[Document, Dict[str, any]]]):
    def __init__(
        self,
        model_name: str = "laion/larger_clap_general",
        device: str = None
    ) -> None:
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = ClapModel.from_pretrained(model_name).to(device)
        self.processor = ClapProcessor.from_pretrained(model_name)
        self.device = device

    def _encode_audio(self, audio: torch.Tensor) -> Embedding:
        inputs = self.processor(audios=audio.numpy(), sampling_rate=48000, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            audio_embedding = self.model.get_audio_features(**inputs)
        return audio_embedding.squeeze().cpu().numpy().tolist()

    def _encode_text(self, text: Document) -> Embedding:
        inputs = self.processor(text=text, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            text_embedding = self.model.get_text_features(**inputs)
        return text_embedding.squeeze().cpu().numpy().tolist()

    def __call__(self, input: Union[List[Document], List[Optional[Dict[str, any]]]]) -> List[Optional[Embedding]]:
        embeddings = []
        for item in input:
            if isinstance(item, dict) and 'waveform' in item:
                embeddings.append(self._encode_audio(item['waveform']))
            elif isinstance(item, str):
                embeddings.append(self._encode_text(item))
            elif item is None:
                embeddings.append(None)
            else:
                raise ValueError(f"Unsupported input type: {type(item)}")
        return embeddings

In [9]:
audio_collection = client.get_or_create_collection(
    name='audio_collection',
    embedding_function=CLAPEmbeddingFunction(),
    data_loader=AudioLoader()
)

config.json:   0%|          | 0.00/643 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/776M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/776M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

## Adding Audio Files to Collection

Iterating through our audio files and adding them to the audio collection.

In [10]:
# Takes a couple mins with GPU
def add_audio(audio_collection, folder_path):
    # List to store IDs and URIs
    ids = []
    uris = []

    # Iterate through all files in the folder
    for filename in os.listdir(folder_path):
        if filename.endswith('.wav'):
            file_id = os.path.splitext(filename)[0]
            file_uri = os.path.join(folder_path, filename)

            ids.append(file_id)
            uris.append(file_uri)

    # Add files to the collection
    audio_collection.add(ids=ids, uris=uris)

# Running it
folder_path = 'esc50'
add_audio(audio_collection, folder_path)


## Testing Function for Audio Retrieval

Now that they're in there, we can test out retrieval!

In [11]:
def display_audio_files(query_text, max_distance=None, debug=False):
    # Query the audio collection with the specified text
    results = audio_collection.query(
        query_texts=[query_text],
        n_results=5,
        include=['uris', 'distances']
    )

    # Extract uris and distances from the result
    uris = results['uris'][0]
    distances = results['distances'][0]

    # Display the audio files that meet the distance criteria
    for uri, distance in zip(uris, distances):
        # Check if a max_distance filter is applied and the distance is within the allowed range
        if max_distance is None or distance <= max_distance:
            if debug:
              print(f"URI: {uri} - Distance: {distance}")
            display(Audio(uri))
        else:
            if debug:
              print(f"URI: {uri} - Distance: {distance} (Filtered out)")

# Running it
display_audio_files("dog", max_distance=1.5, debug=True)


URI: esc50/1-30226-A-0.wav - Distance: 1.2870548963546753


URI: esc50/2-118072-A-0.wav - Distance: 1.3339946269989014
URI: esc50/1-110389-A-0.wav - Distance: 1.3397095203399658


URI: esc50/5-231762-A-0.wav - Distance: 1.3429148197174072


URI: esc50/4-183992-A-0.wav - Distance: 1.3598130941390991


---

# Modality 2: Image Retrieval

Image retrieval will work similarly to how our Audio Retrieval setup works, except this time its already built into ChromaDB.



In [12]:
ds = load_dataset("KoalaAI/StockImages-CC0")

README.md:   0%|          | 0.00/2.79k [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/422M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/467M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3999 [00:00<?, ? examples/s]

## Couple Corrupted Files, Cutting them Out

When processing this dataset, I noticed a few corrupted files. These lines remove those specifically so that no hiccups happen during processing.

In [13]:
# Indices to remove
indices_to_remove = {586, 1002}

# Generate a list of indices excluding the problematic ones
all_indices = set(range(len(ds['train'])))
indices_to_keep = list(all_indices - indices_to_remove)
# Select the remaining entries in the dataset
ds['train'] = ds['train'].select(indices_to_keep)

# Verification
print(ds['train'])

Dataset({
    features: ['image', 'tags'],
    num_rows: 3997
})


## Save Images Locally

We are once again saving these images locally as the multimodal feature requires pathing to object rather than saving it in the database.

In [14]:
output_folder = "StockImages-cc0"
os.makedirs(output_folder, exist_ok=True)

def process_and_save_image(idx, item):
    try:
        # Since the image is already a PIL image, just save it directly
        image = item['image']
        image.save(os.path.join(output_folder, f"image_{idx}.jpg"))
    except Exception:
        pass

def process_images(dataset):
    for idx, item in enumerate(dataset['train']):
        process_and_save_image(idx, item)

# Running it
process_images(ds)

In [15]:
# Instantiate the Image Loader
image_loader = ImageLoader()
# Instantiate CLIP embeddings
CLIP = OpenCLIPEmbeddingFunction()

# Create the image collection
image_collection = client.get_or_create_collection(name="image_collection",
                                                   embedding_function = CLIP,
                                                   data_loader = image_loader)

open_clip_model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

## Adding Images to Collection

Iterating through our images and adding them to our image collection

In [None]:
# Initialize lists for ids and uris
ids = []
uris = []

dataset_folder="StockImages-cc0"

# Iterate over each file in the dataset folder
for i, filename in enumerate(sorted(os.listdir(dataset_folder))):
    if filename.endswith('.jpg'):
        file_path = os.path.join(dataset_folder, filename)

        # Append id and uri to respective lists
        ids.append(str(i))
        uris.append(file_path)

# Add to image collection
image_collection.add(
    ids=ids,
    uris=uris
)

print("Images added to the database.")


## Testing Function for Image Retrieval

Testing out image retrieval!

In [None]:
def display_images(query_text, max_distance=None, debug=False):
    # Query the image collection with the specified text
    results = image_collection.query(
        query_texts=[query_text],
        n_results=5,
        include=['uris', 'distances']
    )

    # Extract uris and distances from the result
    uris = results['uris'][0]
    distances = results['distances'][0]

    # Display the images that meet the distance criteria
    for uri, distance in zip(uris, distances):
        # Check if a max_distance filter is applied and the distance is within the allowed range
        if max_distance is None or distance <= max_distance:
            if debug:
              print(f"URI: {uri} - Distance: {distance}")
            display(Image(uri, width=300))
        else:
            if debug:
              print(f"URI: {uri} - Distance: {distance} (Filtered out)")

# Running it
display_images("dog", max_distance=1.5, debug=True)


---

# Modality 3: Text Retrieval

This is the classic modality that is used in most RAG setups. Text embeddings with text retrieval.

In [None]:
ds = load_dataset("TopicNavi/Wikipedia-example-data")

## Pre-processing the data to Speed Things Up

Since the way ChromaDB's loader and embedding functions work, it would take quite a while to embed all 25,000 rows of text. In this pre-step we're loading the embedding model [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) ourselves and batch embedding with the help of some GPU power (courtesy of Colab or your own environment). This step is not necessary but saves some time.

In [None]:
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load the model
model = SentenceTransformer('all-MiniLM-L6-v2').to(device)

# Prepare the data
documents = [entry['text'] for entry in ds['train']]
metadatas = [{"url": entry['url'], "wiki_id": entry['wiki_id']} for entry in ds['train']]
ids = [entry['title'] for entry in ds['train']]

# Generate embeddings
embeddings = []
batch_size = 128  # Adjustable based GPU memory

for i in range(0, len(documents), batch_size):
    batch = documents[i:i+batch_size]
    batch_embeddings = model.encode(batch, convert_to_tensor=True, device=device)
    embeddings.extend(batch_embeddings.cpu().numpy())

# Convert embeddings to list for JSON serialization
embeddings = [emb.tolist() for emb in embeddings]

# Prepare the data for export
export_data = {
    "documents": documents,
    "embeddings": embeddings,
    "metadatas": metadatas,
    "ids": ids
}

# Export
with open('wikipedia_embeddings.json', 'w') as f:
    json.dump(export_data, f)

print("Data exported to wikipedia_embeddings.json")

## Loading From our Saved File to Chroma

Since we saved the batched embeddings to a JSON file, let's load them back up into variables here to plug into our text collection

In [None]:
def load_wiki(file_path):
    # Load the JSON data
    with open(file_path, 'r') as f:
        data = json.load(f)

    # Extract the components
    documents = data['documents']
    embeddings = data['embeddings']
    metadatas = data['metadatas']
    ids = data['ids']

    print(f"Loaded data from {file_path}")
    print(f"Number of entries: {len(documents)}")
    print(f"Embedding dimension: {len(embeddings[0])}")

    return documents, embeddings, metadatas, ids

docs, embs, metas, ids = load_wiki('wikipedia_embeddings.json')

In [None]:
text_collection = client.get_or_create_collection(name="text_collection")

## Adding In Text Data to Collection in Batches

ChromaDB has a limit on how many rows of data you can add at once, and won't allow us to shove in all 25000 entries in one go. This function batches the entries into manageable chunks and inserts them all into the text collection

In [None]:
# Batch and add data to the collection, respecting max batch size
def batch_add_to_collection(collection, documents, embeddings, metadatas, ids, batch_size=5461):
    for i in range(0, len(documents), batch_size):
        # Slice the data into batches
        doc_batch = documents[i:i + batch_size]
        emb_batch = embeddings[i:i + batch_size]
        meta_batch = metadatas[i:i + batch_size]
        id_batch = ids[i:i + batch_size]

        # Add the batch to the collection
        collection.add(
            documents=doc_batch,
            embeddings=emb_batch,
            metadatas=meta_batch,
            ids=id_batch
        )
        print(f"Batch {i // batch_size + 1} added to the collection successfully.")

batch_add_to_collection(text_collection, docs, embs, metas, ids)


## Testing Function for Text Retrieval

Now that all the wikipedia text is in there, we can test out our text retrieval!

In [None]:
def display_text_documents(query_text, max_distance=None, debug=False):
    # Query the text collection with the specified text
    results = text_collection.query(
        query_texts=[query_text],
        n_results=5,
        include=['documents', 'distances', 'metadatas']
    )

    documents = results['documents'][0]
    distances = results['distances'][0]
    metadatas = results['metadatas'][0]
    titles = results['ids'][0]

    # Display the text documents that meet the distance criteria or are filtered out
    for title, doc, distance, metadata in zip(titles, documents, distances, metadatas):
        url = metadata.get('url')

        if max_distance is None or distance <= max_distance:
            print(f"Title: {title.replace('_', ' ')}")
            if debug:
              print(f"Distance: {distance}")
            print(f"URL: {url}")
            print(f"Text: {doc}\n")
        else:
            # Output filtered out documents with their title and distance
            if debug:
              print(f"Title: {title.replace('_', ' ')} - Distance: {distance} (Filtered out)")

# Running it
display_text_documents("dog", max_distance=1.3, debug=True)


---
# Modality 4: Video

Our final modality will be handling video files. The unfortunate truth is that there's no (easily usable!) model like a text embedding, CLAP, or CLIP model that can take raw video input, but with a few tricks we can still make it work relatively seamlessly!

It all hinges on the idea that Videos are essentially a collection of frames, which are just images. If we have multiple frames from a video and process them all with image retrieval techniques that link back to the original video file, this should be able to find the right video given our query as its indexxed multiple frames from different times across the video!

Of course, this is not perfect and does not perfectly capture the temporal understanding of the video, but it still works decently.

In [None]:
%%capture
!unzip StockVideos-CC0.zip

In [None]:
def extract_frames(video_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for video_filename in os.listdir(video_folder):
        if video_filename.endswith('.mp4'):
            video_path = os.path.join(video_folder, video_filename)
            video_capture = cv2.VideoCapture(video_path)
            fps = video_capture.get(cv2.CAP_PROP_FPS)
            frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
            duration = frame_count / fps

            output_subfolder = os.path.join(output_folder, os.path.splitext(video_filename)[0])
            if not os.path.exists(output_subfolder):
                os.makedirs(output_subfolder)

            success, image = video_capture.read()
            frame_number = 0
            while success:
                if frame_number == 0 or frame_number % int(fps * 5) == 0 or frame_number == frame_count - 1:
                    frame_time = frame_number / fps
                    output_frame_filename = os.path.join(output_subfolder, f'frame_{int(frame_time)}.jpg')
                    cv2.imwrite(output_frame_filename, image)

                success, image = video_capture.read()
                frame_number += 1

            video_capture.release()

video_folder_path = 'StockVideos-CC0'
output_folder_path = 'StockVideos-CC0-frames'

extract_frames(video_folder_path, output_folder_path)

## Creating Video Collection

Since we're technically processing images, we'll use the Image Loader and Clip Embedding Function from before for our video collection.

In [None]:
video_collection = client.get_or_create_collection(
    name='video_collection',
    embedding_function=CLIP,
    data_loader=image_loader
)

## Adding Video Frames to Collection

We iterate over the frame folders and embed them into the database, with specific metadata that links back to the video file that the frame comes from

In [None]:
def add_frames_to_chromadb(video_dir, frames_dir):
    # Dictionary to hold video titles and their corresponding frames
    video_frames = {}

    # Process each video and associate its frames
    for video_file in os.listdir(video_dir):
        if video_file.endswith('.mp4'):
            video_title = video_file[:-4]
            frame_folder = os.path.join(frames_dir, video_title)
            if os.path.exists(frame_folder):
                # List all jpg files in the folder
                video_frames[video_title] = [f for f in os.listdir(frame_folder) if f.endswith('.jpg')]

    # Prepare ids, uris and metadatas
    ids = []
    uris = []
    metadatas = []

    for video_title, frames in video_frames.items():
        video_path = os.path.join(video_dir, f"{video_title}.mp4")
        for frame in frames:
            frame_id = f"{frame[:-4]}_{video_title}"
            frame_path = os.path.join(frames_dir, video_title, frame)
            ids.append(frame_id)
            uris.append(frame_path)
            metadatas.append({'video_uri': video_path})

    video_collection.add(ids=ids, uris=uris, metadatas=metadatas)

# Running it
video_dir = 'StockVideos-CC0'
frames_dir = 'StockVideos-CC0-frames'

add_frames_to_chromadb(video_dir, frames_dir)


## Video retrieval testing Function

Now that all of the frames of every video are embedded into the collection, and point back to their respective video file, we can test out video retrieval!

**Note:** Loading videos is really slow if you're running this on Colab as the Colab file manager is not great

In [None]:
def display_videos(query_text, max_distance=None, max_results=5, debug=False):
    # Deduplication set
    displayed_videos = set()

    # Query the video collection with the specified text
    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,  # Adjust the number of results if needed
        include=['uris', 'distances', 'metadatas']
    )

    # Extract URIs, distances, and metadatas from the result
    uris = results['uris'][0]
    distances = results['distances'][0]
    metadatas = results['metadatas'][0]

    # Display the videos that meet the distance criteria
    for uri, distance, metadata in zip(uris, distances, metadatas):
        video_uri = metadata['video_uri']

        # Check if a max_distance filter is applied and the distance is within the allowed range
        if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
            if debug:
              print(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
            display(Video(video_uri, embed=True, width=300))
            displayed_videos.add(video_uri)  # Add to the set to prevent duplication
        else:
            if debug:
              print(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")

# Running it
display_videos("trees", max_distance=1.55, debug=True)

----
# Full Multimodal Retrieval: Putting it All Together

Now that we have our audio, images, text, and videos embedded in their collections, let's see what it looks like when we run the same query retrieval over all modalities!

In [None]:
query = "San Francisco"

display(Markdown("# Text(s) Retrieved: \n"))
display_text_documents(query, max_distance=1.0)

display(Markdown("# Audio(s) Retrieved: \n"))
display_audio_files(query, max_distance=1.2)

display(Markdown("# Image(s) Retrieved: \n"))
display_images(query, max_distance=1.4)

display(Markdown("# Video(s) Retrieved: \n"))
display_videos(query, max_distance=1.55)

## Image Retrieval

Takes in the query, returns the paths to the relevant images.

In [None]:
def image_uris(query_text, max_distance=None, max_results=5):
    results = image_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_uris.append(uri)

    return filtered_uris

# Example usage:
images = image_uris("water droplet", max_distance=1.5)
print(images)

## Text Retrieval

Takes in the query, returns the retrieved texts.

In [None]:
def text_uris(query_text, max_distance=None, max_results=5):
    results = text_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['documents', 'distances']
    )

    filtered_texts = []
    for doc, distance in zip(results['documents'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_texts.append(doc)

    return filtered_texts

# Example usage:
texts = text_uris("water", max_distance=1.3)
print(texts)

## Video Retrieval

Takes in the query, returns the retrieved frames.

In [None]:
def frame_uris(query_text, max_distance=None, max_results=5):
    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    seen_folders = set()

    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            folder = os.path.dirname(uri)
            if folder not in seen_folders:
                filtered_uris.append(uri)
                seen_folders.add(folder)

        if len(filtered_uris) == max_results:
            break

    return filtered_uris

# Example usage:
vid_uris = frame_uris("Trees", max_distance=1.55)
print(vid_uris)

## LLM Setup

To process our text and images, we'll be using [GPT-4o](https://openai.com/index/hello-gpt-4o/) for easy and fast SOTA text & vision processing. We'll be using LangChain's framework to set this up easily, defined below:

In [None]:
api_key = userdata.get('OPENAI_API_KEY')

# Instantiate the LLM
gpt4o = ChatOpenAI(model="gpt-4o", temperature = 0.0, api_key=api_key)

# Instantiate the Output Parser
parser = StrOutputParser()

# Define the Prompt
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are document retrieval assistant that neatly synthesizes and explains the text and images provided by the user from the query {query}"),
        (
            "user",
            [
                {
                    "type": "text",
                    "text": "{texts}"
                },
                {
                    "type": "image_url",
                    "image_url": {'url': "data:image/jpeg;base64,{image_data_1}"}
                },
                {
                    "type": "text",
                    "text": "This is a frame from a video, refer to it as a video:"
                },
                {
                    "type": "image_url",
                    "image_url": {'url': "data:image/jpeg;base64,{image_data_2}"}
                },

            ],
        ),
    ]
)

chain = prompt | gpt4o | parser

## Prompt Setup

The below function will take our query, run them through our new retrieval functions, and format our prompt input, which is expecting a dictionary like:

```python
{
  "query": "the user query",
  "texts": "the retrieved texts",
  "image_data_1": "The retrieved image, base64 encoded",
  "image_data_2": "The retrieved frame, base64 encoded",
}
```

Note that for the sake of token consumption, context window, and cost we'll only be passing in two images (the image and a single relevant frame) and the text to the model.

In [None]:
def format_prompt_inputs(user_query):

    frame = frame_uris(user_query, max_distance=1.55)[0]
    image = image_uris(user_query, max_distance=1.5)[0]
    text = text_uris(user_query, max_distance=1.3)

    inputs = {}

    # save the user query
    inputs['query'] = user_query

    # Insert Text
    inputs['texts'] = text

    # Encode the first image
    with open(image, 'rb') as image_file:
        image_data_1 = image_file.read()
    inputs['image_data_1'] = base64.b64encode(image_data_1).decode('utf-8')

    # Encode the Frame
    with open(frame, 'rb') as image_file:
        image_data_2 = image_file.read()
    inputs['image_data_2'] = base64.b64encode(image_data_2).decode('utf-8')

    return frame, image, inputs

---
# Full Multimodal RAG

Image, Video, Audio, and Text retrieval and LLM processing

In [None]:
query = "San Francisco"
frame, image, inputs = format_prompt_inputs(query)
response = chain.invoke(inputs)

display(Markdown("## Image\n"))
display(Image(image, width=300))
display(Markdown("---"))

display(Markdown("## Video\n"))
video = f"StockVideos-CC0/{frame.split('/')[1]}.mp4"
display(Video(video, embed=True, width=300))
display(Markdown("---"))

display(Markdown("## Audio\n"))
display_audio_files(query, max_distance=1.2)
display(Markdown("---"))

display(Markdown("## AI Response\n"))
display(Markdown(response))
display(Markdown("---"))