In [11]:
import os
import shutil

import nltk
from unstructured.partition.pdf import partition_pdf

import PyPDF2
import google.generativeai as genai

import chromadb
from IPython.display import Markdown
from dotenv import load_dotenv

In [None]:
load_dotenv()

In [None]:
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

In [13]:
SAMPLE_PDF = "test.pdf"

In [14]:
GEMINI_API_KEY=os.getenv('GEMINI_API_KEY')
genai.configure(api_key=GEMINI_API_KEY)

In [95]:
generation_config = {
        "temperature": 0.7,
        "top_p": 0.95,
        "top_k": 40,
        "max_output_tokens": 8192,
        "response_mime_type": "text/plain",
    }

RAG_SYSTEM_PROMPT = """You are a helpful assistant. You will receive a context and a question. Your task is to generate a complete answer based only on the provided context.

The context may contain image references in the format:
![image-description-here](image-path-here)
When generating your response, you must properly format images in Markdown like this:
![image-description-here](image-path-here)

Your response must be in Markdown to ensure images display correctly.

If the context does not have enough information to answer the question, respond with:
"The given documents do not contain the required information."

Examples:
Example 1:
Context:
"The Eiffel Tower is a famous landmark in Paris. ![A beautiful view of the Eiffel Tower](eiffel.jpg)."

Question:
"Where is the Eiffel Tower located?"

Response:
"![A beautiful view of the Eiffel Tower](eiffel.jpg) \n\nThe Eiffel Tower is located in Paris."

Example 2:
Context:
"Mars is known as the Red Planet due to its reddish appearance."

Question:
"What color is Mars?"

Response:
"Mars is known as the Red Planet due to its reddish appearance."

Example 3:
Context:
"An ear, nose, and throat doctor (ENT) specializes in everything having to do with those parts of the body."

Question:
"Who discovered gravity?"

Response:
"The given documents do not contain the required information."
"""

rag_model = genai.GenerativeModel(
  model_name="gemini-2.0-flash-exp",
  generation_config=generation_config,
  system_instruction=RAG_SYSTEM_PROMPT,
)

image_model = genai.GenerativeModel(
  model_name="gemini-2.0-flash-exp",
  generation_config=generation_config,
  system_instruction="Given an image, you need to generate a summary that describes the image precisely. You need to ensure all details are covered and the summary is concise and clear.",
)

In [16]:
client = chromadb.PersistentClient(path="./chroma_db")

In [17]:
def upload_to_gemini(path, mime_type=None):
        """Uploads the given file to Gemini."""
        file = genai.upload_file(path, mime_type=mime_type)
        return file
    
def analyze_image(image_path):
    """Analyzes the given image and returns a detailed summary."""

    file = upload_to_gemini(image_path, mime_type="image/jpeg")

    chat_session = image_model.start_chat(
        history=[{
            "role": "user",
            "parts": [file],
        }]
    )

    response = chat_session.send_message("Analyze the provided image and generate a concise, detailed summary that captures the key elements and context.")
    return response.text

In [None]:
parsed_pdf = partition_pdf(SAMPLE_PDF,
                           extract_images_in_pdf=True,
                           infer_table_structure=True,
                           max_characters=4000,
                           new_after_n_chars=3800,
                           combine_text_under_n_chars=2000
                          )

In [19]:
# Images in the pdf are replaced with their description generated by gemini
data_to_embed = []

for parsed_object in parsed_pdf:
  parsed_object = parsed_object.to_dict()

  if parsed_object['type'] != "Image":
    data_to_embed.append(parsed_object)
    continue

  image_path = parsed_object['metadata']['image_path']
  image_summary = analyze_image(image_path)

  parsed_object['image_summary'] = image_summary
  data_to_embed.append(parsed_object)

In [20]:
# We now group all the objects in the parsed pdf by page
data_by_page = [[]]
cur_page_number = 1

for data in data_to_embed:
  if data['type'] == 'Footer':
    continue

  if data['metadata']['page_number'] != cur_page_number:
    cur_page_number += 1
    data_by_page.append([])

  data_by_page[-1].append(data)

In [81]:
# Here we concat all the text in one page into a single chunk for embedding
chunks = []

for page in data_by_page:
  chunk_text = []
  for data in page:
    if data['type'] == "Image":
      image_path = data['metadata']['image_path']
      current_directory = os.getcwd()
      relative_image_path = image_path.replace(current_directory, ".")
      # We add image path between <<<>>> to the chunk with image so that we replace it when the chunk is retrieved and we can display it
      text = f"![{data['image_summary']}]({relative_image_path})"
    else:
      text = data['text']

    chunk_text.append(text)

  chunks.append("\n".join(chunk_text))

In [69]:
def get_query_embedding(query):
    result = genai.embed_content(
        model="models/text-embedding-004",
        content=query
    )
    return result['embedding']

In [70]:
def ingest_chunks(chunks, collection_name):
  collection = client.get_or_create_collection(name=collection_name)

  embeddings = [get_query_embedding(chunk) for chunk in chunks]
  res = collection.add(
    ids=[f"doc_{i}" for i in range(len(chunks))],
    documents=chunks,
    embeddings=embeddings,
    metadatas=[{"page_number": index+1} for index in range(len(chunks))]
  )

  return res

In [71]:
CHROMA_COLLECTION_NAME = "attention_sample"

In [None]:
collections = client.list_collections()
collections

In [83]:
ingest_chunks(chunks, CHROMA_COLLECTION_NAME)

In [84]:
def retrieve_similar_documents(collection_name, query_text, top_k=3):
    collection = client.get_collection(name=collection_name)

    query_embedding = get_query_embedding(query_text)
    
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )

    response = []
    for doc, score in zip(results['documents'][0], results['distances'][0]):
        response.append(doc)

    return response

In [85]:
def prompt_builder(context, question):
    prompt = """Find the context below:
    
    Context:
    
    {context}

    Question: {question}
    """.format(context=context, question=question)
    
    return prompt

In [90]:
def rag(question):
    context = retrieve_similar_documents(CHROMA_COLLECTION_NAME, question)
    chat_session = rag_model.start_chat()

    prompt = prompt_builder(context, question)
    response = chat_session.send_message(prompt)

    return response

In [100]:
response = rag("Explain multihead attention")

In [None]:
Markdown(response.parts[0].text)