# Retrieval Augmented Generation (RAG) with Gemini 2.0

## Get started

### Install Dependencies


- `google-genai`:  Google Gen AI python library
- `PyPDF2`: To read PDFs

In [None]:
%%capture

%pip install --upgrade --quiet google-genai PyPDF2 gradio

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
import os

PROJECT_ID = "[your-project-id]"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}
if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

### Import libraries

In [None]:
# For asynchronous operations
import asyncio

# For data processing
import glob
from typing import Any

from IPython.display import Audio, Markdown, display
import PyPDF2

# For GenerativeAI
from google import genai
from google.genai import types
from google.genai.types import LiveConnectConfig
import numpy as np
import pandas as pd

# For similarity score
from sklearn.metrics.pairwise import cosine_similarity

# For retry mechanism
from tenacity import retry, stop_after_attempt, wait_random_exponential

#### Initialize Gen AI client

- Client for calling the Gemini API in Vertex AI
- `vertexai=True`, indicates the client should communicate with the Vertex AI API endpoints.

In [None]:
# Vertex AI API
client = genai.Client(
    vertexai=True,
    project=PROJECT_ID,
    location=LOCATION,
)

### Initialize model

In [None]:
MODEL_ID = "gemini-2.0-flash-exp"  # @param {type:"string", isTemplate: true}
MODEL = (
    f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}"
)

text_embedding_model = "text-embedding-005"  # @param {type:"string", isTemplate: true}

#### Context Documents

- Download the documents from Google Cloud Storage bucket
- These documents are specific to `Cymbal Bikes` store
  - [`Cymbal Bikes Return Policy`](https://storage.googleapis.com/github-repo/generative-ai/gemini2/use-cases/retail_rag/documents/CymbalBikesReturnPolicy.pdf): Contains information about return policy
  - [`Cymbal Bikes Services`](https://storage.googleapis.com/github-repo/generative-ai/gemini2/use-cases/retail_rag/documents/CymbalBikesServices.pdf): Contains information about services provided by Cymbal Bikes

In [None]:
!gsutil cp "gs://github-repo/generative-ai/gemini2/use-cases/retail_rag/documents/CymbalBikesReturnPolicy.pdf" "documents/CymbalBikesReturnPolicy.pdf"
!gsutil cp "gs://github-repo/generative-ai/gemini2/use-cases/retail_rag/documents/CymbalBikesServices.pdf" "documents/CymbalBikesServices.pdf"

### Test

In [None]:
query = "What is the price of a basic tune-up at Cymbal Bikes?"

response = client.models.generate_content(
    model=MODEL_ID,
    contents=query,
)

display(Markdown(response.text))

### Text

For text generation, you need to set the `response_modalities` to `TEXT`

In [None]:
async def generate_content(query: str) -> str:
    """Function to generate text content using Gemini live API.

    Args:
      query: The query to generate content for.

    Returns:
      The generated content.
    """
    config = LiveConnectConfig(response_modalities=["TEXT"])

    async with client.aio.live.connect(model=MODEL, config=config) as session:

        await session.send(input=query, end_of_turn=True)

        response = []
        async for message in session.receive():
            try:
                if message.text:
                    response.append(message.text)
            except AttributeError:
                pass

            if message.server_content.turn_complete:
                response = "".join(str(x) for x in response)
                return response

- Try a specific query

In [None]:
query = "What is the price of a basic tune-up at Cymbal Bikes?"

response = await generate_content(query)
display(Markdown(response))

## Enhancing LLM Accuracy with RAG

### Context Documents

In [None]:
documents = glob.glob("documents/*")
documents

#### Document Embedding and Indexing

In [None]:
@retry(wait=wait_random_exponential(multiplier=1, max=120), stop=stop_after_attempt(4))
def get_embeddings(
    embedding_client: Any, embedding_model: str, text: str, output_dim: int = 768
) -> list[float]:
    """
    Generate embeddings for text with retry logic for API quota management.

    Args:
        embedding_client: The client object used to generate embeddings.
        embedding_model: The name of the embedding model to use.
        text: The text for which to generate embeddings.
        output_dim: The desired dimensionality of the output embeddings (default is 768).

    Returns:
        A list of floats representing the generated embeddings. Returns None if a "RESOURCE_EXHAUSTED" error occurs.

    Raises:
        Exception: Any exception encountered during embedding generation, excluding "RESOURCE_EXHAUSTED" errors.
    """
    try:
        #!TBD IN WORKSHOP
        pass
    
    except Exception as e:
        if "RESOURCE_EXHAUSTED" in str(e):
            return None
        print(f"Error generating embeddings: {str(e)}")
        raise

- The code block executes the following steps:

  - Extracts text from PDF documents and segments it into smaller chunks for processing.
  - Employs a Vertex AI model to transform each text chunk into a numerical embedding vector, facilitating semantic representation and search.
  - Constructs a Pandas DataFrame to store the embeddings, enriched with metadata such as document name and page number, effectively creating a searchable index for efficient retrieval.


In [None]:
def build_index(
    document_paths: list[str],
    embedding_client: Any,
    embedding_model: str,
    chunk_size: int = 512,
) -> pd.DataFrame:
    """
    Build searchable index from a list of PDF documents with page-wise processing.

    Args:
        document_paths: A list of file paths to PDF documents.
        embedding_client: The client object used to generate embeddings.
        embedding_model: The name of the embedding model to use.
        chunk_size: The maximum size (in characters) of each text chunk.  Defaults to 512.

    Returns:
        A Pandas DataFrame where each row represents a text chunk.  The DataFrame includes columns for:
            - 'document_name': The path to the source PDF document.
            - 'page_number': The page number within the document.
            - 'page_text': The full text of the page.
            - 'chunk_number': The chunk number within the page.
            - 'chunk_text': The text content of the chunk.
            - 'embeddings': The embedding vector for the chunk.

    Raises:
        ValueError: If no chunks are created from the input documents.
        Exception:  Any exceptions encountered during file processing are printed to the console and the function continues to the next document.
    """
    all_chunks = []

    for doc_path in document_paths:
        try:
            with open(doc_path, "rb") as file:
                pdf_reader = PyPDF2.PdfReader(file)

                for page_num in range(len(pdf_reader.pages)):
                    page = pdf_reader.pages[page_num]
                    page_text = page.extract_text()

                    #!TBD IN WORKSHOP
                    pass

        except Exception as e:
            print(f"Error processing document {doc_path}: {str(e)}")
            continue

    if not all_chunks:
        raise ValueError("No chunks were created from the documents")

    return pd.DataFrame(all_chunks)

Let's create embeddings and an index using the provided documents

In [None]:
vector_db_mini_vertex = build_index(
    documents, embedding_client=client, embedding_model=text_embedding_model
)
vector_db_mini_vertex

In [None]:
# Index size
vector_db_mini_vertex.shape

In [None]:
# Example of how a chunk looks like
vector_db_mini_vertex.loc[0, "chunk_text"]

#### Retrieval

In [None]:
def get_relevant_chunks(
    query: str,
    vector_db: pd.DataFrame,
    embedding_client: Any,
    embedding_model: str,
    top_k: int = 3,
) -> str:
    """
    Retrieve the most relevant document chunks for a query using similarity search.

    Args:
        query: The search query string.
        vector_db: A pandas DataFrame containing the vectorized document chunks.
                     It must contain columns named 'embeddings', 'document_name',
                     'page_number', and 'chunk_text'.
                     The 'embeddings' column should contain lists or numpy arrays
                     representing the embeddings.
        embedding_client: The client object used to generate embeddings.
        embedding_model: The name of the embedding model to use.
        top_k: The number of most similar chunks to retrieve. Defaults to 3.

    Returns:
        A formatted string containing the top_k most relevant chunks.  Each chunk is
        presented with its page number and chunk number. Returns an error message if
        the query processing fails or if an error occurs during chunk retrieval.

    Raises:
        Exception: If any error occurs during the process (e.g., issues with the embedding client,
                   data format problems in the vector database).
                   The specific error is printed to the console.
    """
    try:
        
        #!TBD IN WORKSHOP
        pass

    except Exception as e:
        print(f"Error getting relevant chunks: {str(e)}")
        return "Error retrieving relevant chunks"

Let's test out our retrieval component

- Let's try the same query for which the model was not able to answer earlier, due to lack of context

In [None]:
query = "What is the price of a basic tune-up at Cymbal Bikes?"
relevant_context = get_relevant_chunks(
    query, vector_db_mini_vertex, client, text_embedding_model, top_k=3
)
relevant_context

### Generation

In [None]:
@retry(wait=wait_random_exponential(multiplier=1, max=120), stop=stop_after_attempt(4))
async def generate_answer(
    query: str, context: str, llm_client: Any, modality: str = "text"
) -> str:
    """
    Generate answer using LLM with retry logic for API quota management.

    Args:
        query: User query.
        context: Relevant text providing context for the query.
        llm_client: Client for accessing LLM API.
        modality: Output modality (text or audio).

    Returns:
        Generated answer.

    Raises:
        Exception: If an unexpected error occurs during the LLM call (after retry attempts are exhausted).
    """
    try:
        # If context indicates earlier quota issues, return early
        if context in [
            "Could not process query due to quota issues",
            "Error retrieving relevant chunks",
        ]:
            return "Can't Process, Quota Issues"

        #!TBD IN WORKSHOP
        pass

    except Exception as e:
        if "RESOURCE_EXHAUSTED" in str(e):
            return "Can't Process, Quota Issues"
        print(f"Error generating answer: {str(e)}")
        return "Error generating answer"

In [None]:
query = "What is the price of a basic tune-up at Cymbal Bikes?"

generated_answer = await generate_answer(
    query, relevant_context, client, modality="text"
)
display(Markdown(generated_answer))

### Pipeline

In [None]:
async def rag(
    question: str,
    vector_db: pd.DataFrame,
    embedding_client: Any,
    embedding_model: str,
    llm_client: Any,
    top_k: int,
    llm_model: str,
    modality: str = "text",
) -> str | None:
    """
    RAG Pipeline.

    Args:
        question: User query.
        vector_db: DataFrame containing document chunks and embeddings.
        embedding_client: Client for accessing embedding API.
        embedding_model: Name of the embedding model.
        llm_client: Client for accessing LLM API.
        top_k: The number of top relevant chunks to retrieve from the vector database.
        llm_model: Name of the LLM model.
        modality: Output modality (text or audio).

    Returns:
        For text modality, generated answer.
        For audio modality, audio playback widget.

    Raises:
        Exception:  Catches and prints any exceptions during processing. Returns an error message.
    """

    try:
        #!TBD IN WORKSHOP
        pass

    except Exception as e:
        print(f"Error processing question '{question}': {str(e)}")
        return {"question": question, "generated_answer": "Error processing question"}

In [None]:
question_set = [
    {
        "question": "What is the price of a basic tune-up at Cymbal Bikes?",
        "answer": "A basic tune-up costs $100.",
    },
    {
        "question": "How much does it cost to replace a tire at Cymbal Bikes?",
        "answer": "Replacing a tire at Cymbal Bikes costs $50 per tire.",
    },
    {
        "question": "What does gear repair at Cymbal Bikes include?",
        "answer": "Gear repair includes inspection and repair of the gears, including replacement of chainrings, cogs, and cables as needed.",
    },
    {
        "question": "What is the cost of replacing a tube at Cymbal Bikes?",
        "answer": "Replacing a tube at Cymbal Bikes costs $20.",
    },
    {
        "question": "Can I return clothing items to Cymbal Bikes?",
        "answer": "Clothing can only be returned if it is unworn and in the original packaging.",
    },
    {
        "question": "What is the time frame for returning items to Cymbal Bikes?",
        "answer": "Cymbal Bikes offers a 30-day return policy on all items.",
    },
    {
        "question": "Can I return edible items like energy gels?",
        "answer": "No, edible items are not returnable.",
    },
    {
        "question": "How can I return an item purchased online from Cymbal Bikes?",
        "answer": "Items purchased online can be returned to any Cymbal Bikes store or mailed back.",
    },
    {
        "question": "What should I include when returning an item to Cymbal Bikes?",
        "answer": "Please include the original receipt and a copy of your shipping confirmation when returning an item.",
    },
    {
        "question": "Does Cymbal Bikes offer refunds for shipping charges?",
        "answer": "Cymbal Bikes does not offer refunds for shipping charges, except for defective items.",
    },
    {
        "question": "How do I process a return for a defective item at Cymbal Bikes?",
        "answer": "To process a return for a defective item, please contact Cymbal Bikes first.",
    },
]

In [None]:
question_set[0]

In [None]:
response = await rag(
    question=question_set[0]["question"],
    vector_db=vector_db_mini_vertex,
    embedding_client=client,  # For embedding generation
    embedding_model=text_embedding_model,  # For embedding model
    llm_client=client,  # For answer generation,
    top_k=3,
    llm_model=MODEL,
    modality="text",
)
display(Markdown(response))

# Use a ChatUI for a better inface and a continuous chat

In [None]:
import time
import gradio as gr

async def get_agent_response(message, history):
    time.sleep(0.05)
    yield await rag(
            question=message,
            vector_db=vector_db_mini_vertex,
            embedding_client=client,  # For embedding generation
            embedding_model=text_embedding_model,  # For embedding model
            llm_client=client,  # For answer generation,
            top_k=3,
            llm_model=MODEL,
            modality="text",
        )

demo = gr.ChatInterface(
    get_agent_response,
    type="messages",
    flagging_mode="manual",
    flagging_options=["Like", "Spam", "Inappropriate", "Other"],
    save_history=True,
)

demo.launch()