# 

In [11]:
import os
import shutil

import nltk
from unstructured.partition.pdf import partition_pdf

import PyPDF2
import google.generativeai as genai

import chromadb
from IPython.display import Markdown

from dotenv import load_dotenv

In [None]:
load_dotenv()

In [None]:
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

In [148]:
RAG_SYSTEM_PROMPT = """You are a helpful assistant. You will receive a context and a question. Your task is to generate a complete answer based only on the provided context.

The context may contain image references in the format:
![image-description-here](image-path-here)
When generating your response, you must properly format images in Markdown like this:
![image-description-here](image-path-here)

Your response must be in Markdown to ensure images display correctly.

If the context does not have enough information to answer the question, respond with:
"The given documents do not contain the required information."

Examples:
Example 1:
Context:
"The Eiffel Tower is a famous landmark in Paris. ![A beautiful view of the Eiffel Tower](eiffel.jpg)."

Question:
"Where is the Eiffel Tower located?"

Response:
"![A beautiful view of the Eiffel Tower](eiffel.jpg) \n\nThe Eiffel Tower is located in Paris."

Example 2:
Context:
"Mars is known as the Red Planet due to its reddish appearance."

Question:
"What color is Mars?"

Response:
"Mars is known as the Red Planet due to its reddish appearance."

Example 3:
Context:
"An ear, nose, and throat doctor (ENT) specializes in everything having to do with those parts of the body."

Question:
"Who discovered gravity?"

Response:
"The given documents do not contain the required information."
"""

IMAGE_SYSTEM_PROMPT = "Given an image, you need to generate a summary that describes the image precisely. You need to ensure all details are covered and the summary is concise and clear."

In [154]:
class MultimodalRag:
    def __init__(self, api_key, collection_name, db_path="./chroma_db"):
        self.api_key = api_key
        self.client = chromadb.PersistentClient(path=db_path)
        self.collection_name = collection_name
        
        genai.configure(api_key=self.api_key)
        self.generation_config = {
            "temperature": 1,
            "top_p": 0.95,
            "top_k": 40,
            "max_output_tokens": 8192,
            "response_mime_type": "text/plain",
        }

        self.rag_model = genai.GenerativeModel(
            model_name="gemini-2.0-flash-exp",
            generation_config=self.generation_config,
            system_instruction=RAG_SYSTEM_PROMPT
        )

        self.image_model = genai.GenerativeModel(
            model_name="gemini-2.0-flash-exp",
            generation_config=self.generation_config,
            system_instruction=IMAGE_SYSTEM_PROMPT
        )

    def delete_collection(self):
        self.client.delete_collection(self.collection_name)
    
    def upload_to_gemini(self, path, mime_type=None):
        return genai.upload_file(path, mime_type=mime_type)
    
    def summarise_image(self, image_path):
        file = self.upload_to_gemini(image_path, mime_type="image/jpeg")
        chat_session = self.image_model.start_chat(history=[{"role": "user", "parts": [file]}])
        response = chat_session.send_message("Analyze the provided image and generate a concise, detailed summary.")
        return response.text

    def parse_pdf(self, pdf_path):
        parsed_pdf = partition_pdf(pdf_path, extract_images_in_pdf=True, infer_table_structure=True, max_characters=4000, new_after_n_chars=3800, combine_text_under_n_chars=2000)
        return parsed_pdf

    def replace_image_with_summary(self, parsed_pdf):
        data_to_embed = []

        for parsed_object in parsed_pdf:
            parsed_object = parsed_object.to_dict()
            if parsed_object['type'] == "Image":
                parsed_object['image_summary'] = self.summarise_image(parsed_object['metadata']['image_path'])
            data_to_embed.append(parsed_object)
        
        return data_to_embed

    def group_data_by_page(self, data_to_embed):
        data_by_page = [[]]
        cur_page_number = 1
        
        for data in data_to_embed:
            if data['type'] == 'Footer':
                continue
            if data['metadata']['page_number'] != cur_page_number:
                cur_page_number += 1
                data_by_page.append([])
            data_by_page[-1].append(data)
        
        return data_by_page
    
    def create_chunks(self, data_by_page):
        chunks = []
        for page in data_by_page:
            chunk_text = []
            for data in page:
                if data['type'] == "Image":
                    image_path = data['metadata']['image_path']
                    relative_image_path = image_path.replace(os.getcwd(), ".")
                    text = f"![{data['image_summary']}]({relative_image_path})"
                else:
                    text = data['text']
                chunk_text.append(text)
            chunks.append("\n".join(chunk_text))
        return chunks
        
    def process_pdf(self, pdf_path):
        print("0% processing done. Parsing PDF")
        parsed_pdf = self.parse_pdf(pdf_path)
        print("40% processing done. Replacing Images with Summaries")
        data_to_embed = self.replace_image_with_summary(parsed_pdf) 
        print("75% processing done. Creating Chunks")
        data_by_page = self.group_data_by_page(data_to_embed)
        chunks = self.create_chunks(data_by_page)
        return chunks
    
    def get_query_embedding(self, query):
        result = genai.embed_content(model="models/text-embedding-004", content=query)
        return result['embedding']

    def ingest_pdf(self, pdf_path):
        chunks = self.process_pdf(pdf_path)
        print("80% processing done. Creating Embeddings")
        collection = self.client.get_or_create_collection(name=self.collection_name)
        embeddings = [self.get_query_embedding(chunk) for chunk in chunks]
        collection.add(
            ids=[f"doc_{i}" for i in range(len(chunks))],
            documents=chunks,
            embeddings=embeddings,
            metadatas=[{"page_number": index+1} for index in range(len(chunks))]
        )
        print("100% processing done. PDF ingested successfully")
        
    
    def retrieve_similar_documents(self, query_text, top_k=3):
        collection = self.client.get_collection(name=self.collection_name)
        query_embedding = self.get_query_embedding(query_text)
        results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
        return [doc for doc in results['documents'][0]]
    
    def prompt_builder(self, context, question):
        return f"""Find the context below:
        
        Context:
        
        {context}
        
        Question: {question}
        """
    
    def invoke(self, question):
        context = self.retrieve_similar_documents(question)
        chat_session = self.rag_model.start_chat()
        prompt = self.prompt_builder(context, question)
        response = chat_session.send_message(prompt)
        return response.parts[0].text

In [155]:
gemini_rag = MultimodalRag(os.getenv('GEMINI_API_KEY'), "attention_sample_2")

In [None]:
gemini_rag.ingest_pdf("test.pdf")

In [157]:
response = gemini_rag.invoke("Explain Encoder decoder with images")

In [None]:
Markdown(response)