In [6]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser
from constants import gemini_api_key, tavily_api_key
from langchain_community.tools.tavily_search import TavilySearchResults
import os
import numpy as np

# Retrieval Augmented Generation
Providing context to the LLM models to better understand the question, hence allowing LLMs to better answer the questions.

The RAG happens by passing the question to a Vector Database and performing similarity comparison between the embeddings of the question phrase and the documents present in the vector database. 

The most similar ones are fetched and fed into the model (as context for the model) along with the question.

## Testing the embedding model with test queries

In [2]:
model = ChatGoogleGenerativeAI(google_api_key = gemini_api_key, model="gemini-pro")
embeddings = GoogleGenerativeAIEmbeddings(google_api_key = gemini_api_key, model = "models/embedding-001")

In [3]:
queries = [
    "Today is a sunny day", 
    "Today is april fools day", 
    "Today is a snowy day", 
    "Robert Downey is the Iron Man"
]

In [4]:
# Gives the embedding vector for the following queries
vectors = [embeddings.embed_query(query) for query in queries]

In [7]:
np.array(vectors).shape

(4, 768)

In [8]:
# Verifying the similarities between these documents
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

print("1st and 2nd query: ", cosine_similarity([vectors[0]], [vectors[1]]))
print("1st and 3nd query: ", cosine_similarity([vectors[0]], [vectors[2]]))
print("2nd and 3rd query: ", cosine_similarity([vectors[1]], [vectors[2]]))
print("1st and 4th query: ", cosine_similarity([vectors[0]], [vectors[3]]))
print("2nd and 4th query: ", cosine_similarity([vectors[1]], [vectors[3]]))
print("3rd and 4th query: ", cosine_similarity([vectors[2]], [vectors[3]]))

1st and 2nd query:  [[0.87064364]]
1st and 3nd query:  [[0.93163611]]
2nd and 3rd query:  [[0.87995205]]
1st and 4th query:  [[0.77332271]]
2nd and 4th query:  [[0.80163847]]
3rd and 4th query:  [[0.7808282]]


## Using GoogleEmbeddings in Langchain to build RAG system

### 1. Testing with in memory vector database

In [9]:
# pip install docarray
from langchain_community.vectorstores import DocArrayInMemorySearch
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings

In [30]:
vectorstore = DocArrayInMemorySearch.from_texts(
    [
        "Tony Stark was kidnapped by thugs for making a lethal weapon in a cave",
        "Tony Stark designed a miniature arc reactor in a cave"
    ],
    embedding = embeddings
)

In [31]:
retriever = vectorstore.as_retriever()

In [33]:
template = """Answer the questions based only on the following context:
{context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)
output_parser = StrOutputParser()

setup_and_retrieval = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
)
chain = setup_and_retrieval | prompt | model | output_parser

chain.invoke("Why was Tony Stark kidnapped?")

'The provided context does not specify why Tony Stark was kidnapped.'

In [35]:
(setup_and_retrieval | prompt | model).invoke("Who is Tony Stark")

AIMessage(content='The provided context does not mention who Tony Stark is.', response_metadata={'prompt_feedback': {'safety_ratings': [{'category': 9, 'probability': 1, 'blocked': False}, {'category': 8, 'probability': 1, 'blocked': False}, {'category': 7, 'probability': 1, 'blocked': False}, {'category': 10, 'probability': 1, 'blocked': False}], 'block_reason': 0}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]})