# MultiModal RAG

The below implementation is developed based on this [medium article](https://medium.com/artificial-corner/multimodal-retrieval-augmented-generation-for-sustainable-finance-with-code-5a910f3b666c).

The notebook example of this complete article can be found [here](./references/00_multimodal-rag-esg-main/notebooks/ESG_Multimodal_RAG_v2.ipynb) that contains multiple modalities such as video to audio transcriptions, images, tables and text. However, most of our enterprise usecases deals with images, tables and text within the pdf, hence, we have created this version of notebook to simplify the implementation for further reference.

Also, we shall use pgvector as a vector database instead of weaviate which was used in the article.

This is a minimalistic example that can be used as getting started to understand the fundamentals. However, this does not consider the complete contents of the PDF during the chunking process.

Refer to [README.md](./README.md) for more references. 

## Using Unstructured Library

In [1]:
!pip install unstructured[pdf]

In [None]:
!pip install matplotlib

### 1. Parse PDF

In [1]:
esg_report_path = "./data/Global_ESG_Q1_2024_Flows_Report.pdf"

In [None]:
from unstructured.partition.pdf import partition_pdf

esg_report_raw_data =partition_pdf(
    filename=esg_report_path,
    strategy="hi_res",
    extract_images_in_pdf=True,
    extract_image_block_to_payload=False,
    extract_image_block_output_dir="./data/images/"
    )

In [None]:
esg_report_raw_data

### 2. Extract Textual Component

This may not extract the entire content since some of the elements generated from unstructured falls into other categories such as `ListItem`, `Title`, etc..

If we want to consider the whole text, we can go with other types of parsers such as `pymupdf4llm`.

In [22]:
from unstructured.documents.elements import NarrativeText

In [23]:
def extract_text_with_metadata(esg_report, source_document):

    text_data = []
    paragraph_counters = {}

    for element in esg_report:
        if isinstance(element, NarrativeText):
            page_number = element.metadata.page_number

            if page_number not in paragraph_counters:
                paragraph_counters[page_number] = 1
            else:
                paragraph_counters[page_number] += 1

            paragraph_number = paragraph_counters[page_number]

            text_content = element.text
            text_data.append({
                "source_document": source_document,
                "page_number": page_number,
                "paragraph_number": paragraph_number,
                "text": text_content
            })

    return text_data

In [24]:
extracted_data = extract_text_with_metadata(esg_report_raw_data, esg_report_path)

In [None]:
extracted_data

### 3. Extract Image components

In [26]:
from unstructured.documents.elements import Image

In [27]:
def extract_image_metadata(esg_report, source_document):
    image_data = []

    for element in esg_report:
        if isinstance(element, Image):
            page_number = element.metadata.page_number
            image_path = element.metadata.image_path if hasattr(element.metadata, 'image_path') else None

            image_data.append({
                "source_document": source_document,
                "page_number": page_number,
                "image_path": image_path
            })

    return image_data

In [28]:
extracted_image_data = extract_image_metadata(esg_report_raw_data, esg_report_path)

In [29]:
import matplotlib.pyplot as plt
from PIL import Image
import math

In [30]:
def display_images_from_metadata(extracted_image_data, images_per_row=4):
    valid_images = [img for img in extracted_image_data if img['image_path']]
    if not valid_images:
        print("No valid image data available.")
        return

    num_images = len(valid_images)
    num_rows = math.ceil(num_images / images_per_row)

    fig, axes = plt.subplots(num_rows, images_per_row, figsize=(20, 5*num_rows))
    axes = axes.flatten() if num_rows > 1 else [axes]

    for ax, img_data in zip(axes, valid_images):
        try:
            img = Image.open(img_data['image_path'])
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f"Page {img_data['page_number']}", fontsize=10)
        except Exception as e:
            print(f"Error loading image {img_data['image_path']}: {str(e)}")
            ax.text(0.5, 0.5, f"Error loading image\n{str(e)}", ha='center', va='center')
            ax.axis('off')

    for ax in axes[num_images:]:
        fig.delaxes(ax)

    plt.tight_layout()
    plt.show()

In [None]:
display_images_from_metadata(extracted_image_data)

### 4. Extract Table Components

In [35]:
from unstructured.documents.elements import Table

In [36]:
def extract_table_metadata(esg_report, source_document):
    table_data = []

    for element in esg_report:
        if isinstance(element, Table):
            page_number = element.metadata.page_number

            # Extract table content as a string
            table_content = str(element)

            table_data.append({
                "source_document": source_document,
                "page_number": page_number,
                "table_content": table_content
            })

    return table_data

In [37]:
extracted_table_data = extract_table_metadata(esg_report_raw_data, esg_report_path)

### 5. Image and Table Summarization

Images and Tables will be described in a way that make them undertandable in a few sentences.

For both image and table, we get a description first using the corresponding prompts.

#### 1. Table summarization

In [None]:
!pip install langchain-core
!pip install langchain-openai
!pip install python-dotenv

In [39]:
tables_summarizer_prompt = """
As an ESG analyst for emerging markets investments, provide a concise and exact summary of the table contents.
Focus on key ESG metrics (Environmental, Social, Governance) and their relevance to emerging markets.
Highlight significant trends, comparisons, or outliers in the data. Identify any potential impacts on investment strategies or risk assessments.
Avoid bullet points; instead, deliver a coherent, factual summary that captures the essence of the table for ESG investment decision-making.

Table: {table_content}

Limit your summary to 3-4 sentences, ensuring it's precise and informative for ESG analysis in emerging markets."""

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai.chat_models import AzureChatOpenAI
from dotenv import load_dotenv

load_dotenv()

In [41]:
model_id = "gpt-4o-mini"

In [42]:
description_model = AzureChatOpenAI(model=model_id)

In [43]:
def extract_table_metadata_with_summary(esg_report,
                                        source_document,
                                        tables_summarizer_prompt):

    table_data = []
    prompt = ChatPromptTemplate.from_template(tables_summarizer_prompt)

    for element in esg_report:
        if isinstance(element, Table):
            page_number = element.metadata.page_number

            # Extract table content as a string
            table_content = str(element)

            # Generate summary using the OpenAI model
            messages = prompt.format_messages(table_content=table_content)
            description = description_model.invoke(messages).content

            table_data.append({
                "source_document": source_document,
                "page_number": page_number,
                "table_content": table_content,
                "description": description
            })

    return table_data

In [None]:
extracted_table_data_with_summary = extract_table_metadata_with_summary(esg_report_raw_data,
                                                                        esg_report_path,
                                                                        tables_summarizer_prompt)

In [None]:
extracted_table_data_with_summary

In [None]:
# Get the first key-value pair in the dictionary
first_table_details = extracted_table_data_with_summary[0]

# Extract the transcription from the first item
first_description = first_table_details

first_description

In [None]:
first_description['description']

#### 2. Image Summarization

In [47]:
from PIL import Image as PILImage
from langchain_core.messages import HumanMessage
import base64
import os

In [48]:
images_summarizer_prompt = """
As an ESG analyst for emerging markets investments, please provide a clear interpretation of data or information that see describe from the image.
Focus on ESG-relevant content (Environmental, Social, Governance) and any emerging market context. Describe the type of visual (e.g., chart, photograph, infographic) and its key elements.
Highlight significant data points or trends that are relevant to investment analysis. Avoid bullet points; instead, deliver a coherent, factual summary that captures the essence of the image for ESG investment decision-making.

Ground your response based on the provided image and do not hallucinate.

Limit your description to 3-4 sentences, ensuring it's precise and informative for ESG analysis."""

In [49]:
def extract_image_metadata_with_summary(esg_report_raw_data,
                                        esg_report_path,
                                        images_summarizer_prompt):

    image_data = []

    # Create ChatPromptTemplate instance
    prompt = ChatPromptTemplate.from_template(images_summarizer_prompt)

    # Create ChatOpenAI instance
    description_model = AzureChatOpenAI(model=model_id)

    for element in esg_report_raw_data:
        if "Image" in str(type(element)):
            page_number = element.metadata.page_number if hasattr(element.metadata, 'page_number') else None
            image_path = element.metadata.image_path if hasattr(element.metadata, 'image_path') else None

            # Read the image file and encode it to base64
            with open(image_path, "rb") as image_file:
                encoded_string = base64.b64encode(image_file.read()).decode('utf-8')


            if image_path and os.path.exists(image_path):
                # Generate description using the OpenAI model
                messages = HumanMessage(
                    content=[
                        {"type": "text", "text": images_summarizer_prompt},
                        {
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{encoded_string}"},
                        },
                    ],
                )
                description = description_model.invoke([messages]).content

                
                image_data.append({
                    "source_document": esg_report_path,
                    "page_number": page_number,
                    "image_path": image_path,
                    "description": description,
                    "base64_encoding": encoded_string
                })
            else:
                print(f"Warning: Image file not found or path not available for image on page {page_number}")

    return image_data

In [50]:
extracted_image_data = extract_image_metadata_with_summary(esg_report_raw_data,
                                                           esg_report_path,
                                                           images_summarizer_prompt)

In [None]:
extracted_image_data

In [None]:
# Get the first key-value pair in the dictionary
sixth_image_details = extracted_image_data[5]

sixth_image_details

In [None]:
sixth_image_details['description']

### Data Upload - Pgvector

In [None]:
!pip install langchain-postgres

#### Test Connection

In [54]:
import psycopg

In [None]:
# Database connection parameters
db_params = {
    "dbname": "pgvector-exploration",
    "user": "admin",
    "password": "admin",
    "host": "172.31.60.199",  # Use the appropriate host
    "port": "15432"        # Default PostgreSQL port
}

# Connect to the PostgreSQL database
with psycopg.connect(**db_params) as conn:
    print("Postgresql Test connection successful.")

#### Vectorstore Implementation

In [56]:
connection = "postgresql+psycopg://admin:admin@172.31.60.199:15432/pgvector-exploration" 
collection_name = "esg_reports"

In [68]:
from langchain_openai.embeddings import AzureOpenAIEmbeddings

embeddings = AzureOpenAIEmbeddings( model="text-embedding-3-small", api_version="2024-02-01")

In [69]:
# Function to get embeddings
def get_embedding(text):
    response = embeddings.embed_query(text)
    return response

In [82]:
from langchain_postgres.vectorstores import PGVector

vectorstore = PGVector(
    embeddings=embeddings,
    collection_name=collection_name,
    connection=connection,
    use_jsonb=True,
)

To drop the tables created by the vectorstore (e.g., updating the embedding to a different dimension or just updating the embedding provider)

In [81]:
vectorstore.drop_tables()

Add Documents

In [60]:
from langchain_core.documents import Document
from tqdm import tqdm
import uuid

In [83]:
metadata_template = {
    "id": None,
    "source_document": None,
    "page_number": None,
    "paragraph_number": None,
    "image_path": None,
    "base64_encoding": None,
    "table_content": None,
    "content_type": None
}

In [84]:
def ingest_text_data(text_data):
    docs = []
    for text in tqdm(text_data, desc="Ingesting text data"):
        metadata = metadata_template.copy()
        metadata["id"] = str(uuid.uuid4())
        metadata["source_document"] = text['source_document']
        metadata["page_number"] = text['page_number']
        metadata["paragraph_number"] = text['paragraph_number']
        metadata["content_type"] = "text"
        
        # Instantiate Document Object and append to the list
        docs.append(Document(page_content=text['text'], metadata=metadata))
    
    vectorstore.add_documents(docs, ids=[doc.metadata['id'] for doc in docs])


def ingest_image_data(image_data):
    docs = []
    for image in tqdm(image_data, desc="Ingesting image data"):
        metadata = metadata_template.copy()
        metadata["id"] = str(uuid.uuid4())
        metadata["source_document"] = image['source_document']
        metadata["page_number"] = image['page_number']
        metadata["image_path"] = image['image_path']
        metadata["base64_encoding"] = image['base64_encoding']
        metadata["content_type"] = "image"
        
        # Instantiate Document Object and append to the list
        docs.append(Document(page_content=image['description'], metadata=metadata))
    
    vectorstore.add_documents(docs, ids=[doc.metadata['id'] for doc in docs])

def ingest_table_data(table_data):
    docs = []
    for table in tqdm(table_data, desc="Ingesting table data"):
        metadata = metadata_template.copy()
        metadata["id"] = str(uuid.uuid4())
        metadata["source_document"] = table['source_document']
        metadata["page_number"] = table['page_number']
        metadata["table_content"] = table['table_content']
        metadata["content_type"] = "table"
        
        # Instantiate Document Object and append to the list
        docs.append(Document(page_content=table['description'], metadata=metadata))
    
    vectorstore.add_documents(docs, ids=[doc.metadata['id'] for doc in docs])

def ingest_all_data(text_data, image_data, table_data):
    ingest_text_data(text_data)
    ingest_image_data(image_data)
    ingest_table_data(table_data)
    print("All objects imported successfully")

#### Start Data Ingestion

In [None]:
ingest_all_data(text_data=extracted_data,
                image_data=extracted_image_data,
                table_data=extracted_table_data_with_summary
            )

### Query PgVector for Most Relevant Data

In [102]:
def search_multimodal(query: str, limit: int = 3):
    retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": limit})
    return retriever.invoke(query)

def search_multimodal_with_score(query: str, limit: int = 3):
    docs_with_score = vectorstore.similarity_search_with_score(query, k=limit)
    return docs_with_score

In [103]:
def search_and_print_results(query, limit=3):

    search_results = search_multimodal(query, limit)

    print(f"Search Results for query: '{query}'")
    for item in search_results:
        print(f"Type: {item.metadata['content_type']}")
        if item.metadata['content_type'] == 'text':
            print(f"Source: {item.metadata['source_document']}, Page: {item.metadata['page_number']}")
            print(f"Paragraph {item.metadata['paragraph_number']}")
            print(f"Text: {item.page_content[:100]}...")
        elif item.metadata['content_type'] == 'image':
            print(f"Source: {item.metadata['source_document']}, Page: {item.metadata['page_number']}")
            print(f"Image Source: {item.metadata['image_path']}, Page: {item.metadata['page_number']}")
            print(f"Description: {item.page_content}")
        elif item.metadata['content_type'] == 'table':
            print(f"Source: {item.metadata['source_document']}, Page: {item.metadata['page_number']}")
            print(f"Description: {item.page_content}")
        print("---")

In [110]:
def search_and_print_results_with_score(query, limit=3):

    search_results = search_multimodal_with_score(query, limit)

    print(f"Search Results for query: '{query}'")
    for item, score in search_results:
        print(f"Type: {item.metadata['content_type']}")
        print(f"Cosine Similarity: {1-score}")
        if item.metadata['content_type'] == 'text':
            print(f"Source: {item.metadata['source_document']}, Page: {item.metadata['page_number']}")
            print(f"Paragraph {item.metadata['paragraph_number']}")
            print(f"Text: {item.page_content[:100]}...")
        elif item.metadata['content_type'] == 'image':
            print(f"Source: {item.metadata['source_document']}, Page: {item.metadata['page_number']}")
            print(f"Image Source: {item.metadata['image_path']}, Page: {item.metadata['page_number']}")
            print(f"Description: {item.page_content}")
        elif item.metadata['content_type'] == 'table':
            print(f"Source: {item.metadata['source_document']}, Page: {item.metadata['page_number']}")
            print(f"Description: {item.page_content}")
        print("---")

In [None]:
query = "What are the main environmental challenges in renewable energy?"
search_and_print_results_with_score(query)

### Multimodal RAG for ESG

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai.chat_models import AzureChatOpenAI
from dotenv import load_dotenv

load_dotenv()

In [114]:
model_id = "gpt-4o-mini"

In [115]:
chat_model = AzureChatOpenAI(model=model_id)

In [116]:
def generate_response(query: str, context: str) -> str:
    prompt = f"""
    You are an AI assistant specializing in ESG (Environmental, Social, and Governance) analysis for emerging markets.
    Use the following pieces of information to answer the user's question.
    If you cannot answer the question based on the provided information, say that you don't have enough information to answer accurately.

    Context:
    {context}

    User Question: {query}

    Please provide a detailed and accurate answer based on the given context:
    """

    prompt = ChatPromptTemplate.from_template(prompt)

    messages = prompt.format_messages(query=query, context=context)

    response = chat_model.invoke(messages)

    return response.content

In [124]:
def esg_analysis(user_query: str):

    # Step 1: Retrieve relevant information
    search_results = search_multimodal_with_score(user_query)

    # Step 2: Prepare context for RAG
    context = ""
    for item, score in search_results:
        if item.metadata['content_type'] == 'text':
            context += f"Text from {item.metadata['source_document']} (Page {item.metadata['page_number']}, Paragraph {item.metadata['paragraph_number']}): {item.page_content}\n\n"
        elif item.metadata['content_type'] == 'image':
            context += f"Image Description from {item.metadata['source_document']} (Page {item.metadata['page_number']}, Path: {item.metadata['image_path']}): {item.page_content}\n\n"
        elif item.metadata['content_type'] == 'table':
            context += f"Table Description from {item.metadata['source_document']} (Page {item.metadata['page_number']}): {item.page_content}\n\n"

    # Step 3: Generate response using RAG
    response = generate_response(user_query, context)

    # Step 4: Format and return the final output
    sources = []
    for item, score in search_results:
        source = {
            "type": item.metadata["content_type"],
            "distance": score
        }
        if item.metadata["content_type"] == 'text':
            source.update({
                "document": item.metadata["source_document"],
                "page": item.metadata["page_number"],
                "paragraph": item.metadata["paragraph_number"]
            })
        elif item.metadata["content_type"] == 'image':
            source.update({
                "document": item.metadata["source_document"],
                "page": item.metadata["page_number"],
                "image_path": item.metadata["image_path"]
            })
        elif item.metadata["content_type"] == 'table':
            source.update({
                "document": item.metadata["source_document"],
                "page": item.metadata["page_number"]
            })
        
        sources.append(source)

    # Sort sources by distance (ascending order)
    sources.sort(key=lambda x: x['distance'])

    final_output = {
        "user_query": user_query,
        "ai_response": response,
        "sources": sources
    }

    return final_output

In [118]:
import textwrap

def wrap_text(text, width=120):
    wrapped_text = textwrap.fill(text, width=width)
    return wrapped_text

In [119]:
def analyze_and_print_esg_results(user_question):
    result = esg_analysis(user_question)

    print("User Query:", result["user_query"])
    print("\nAI Response:", wrap_text(result["ai_response"]))
    print("\nSources (sorted by relevance):")
    for source in result["sources"]:
        print(f"- Type: {source['type']}, Distance: {source['distance']:.3f}")
        if source['type'] == 'text':
            print(f"  Document: {source['document']}, Page: {source['page']}, Paragraph: {source['paragraph']}")
        elif source['type'] == 'image':
            print(f"  Document: {source['document']}, Page: {source['page']}, Image Path: {source['image_path']}")
        elif source['type'] == 'table':
            print(f"  Document: {source['document']}, Page: {source['page']}")
        print("---")

In [None]:
user_question = "Is ESG investment a fraud?"
analyze_and_print_esg_results(user_question)