Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/history_aware_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from langchain_mongodb import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from langchain_openai import OpenAIEmbeddings
import os
from langchain import hub
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from dotenv import load_dotenv
load_dotenv()

from flo_ai import FloSession
from flo_ai.retrievers.flo_retriever import FloRagBuilder

db_url = os.getenv("MONGO_DB_URL")

connection_timeout = 60000
mongo_client = MongoClient(db_url, connectTimeoutMS=connection_timeout, socketTimeoutMS=connection_timeout)
mongo_embedding_collection = (mongo_client
.get_database("dohabank")
.get_collection("products"))

store = MongoDBAtlasVectorSearch(
collection=mongo_embedding_collection,
embedding_key="embedding",
embedding=OpenAIEmbeddings(model="text-embedding-3-small"),
index_name="bank-products-index",
)


llm = ChatOpenAI(temperature=0, model_name='gpt-4o')
session = FloSession(llm)
rag_builder = FloRagBuilder(session, store.as_retriever())

import logging

logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

rag = rag_builder.with_multi_query().build()
print(rag.invoke({ "question": "Tell me about corporate loans" }))

35 changes: 35 additions & 0 deletions flo_ai/retrievers/flo_compression_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from langchain_core.embeddings import Embeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter

class FloCompressionPipeline():

def __init__(self, embeddings: Embeddings) -> None:
self.__embeddings = embeddings
self.__pipeline = []

def add_chuncking(self, chunk_size = 300, chunk_overlap = 0):
splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=". ")
self.__pipeline.append(splitter)

def add_embedding_reduntant_filter(self):
redundant_filter = EmbeddingsRedundantFilter(embeddings=self.__embeddings)
self.__pipeline.append(redundant_filter)

def add_embedding_relevant_filter(self, threshold: float = 0.76):
relevant_filter = EmbeddingsFilter(embeddings=self.__embeddings, similarity_threshold=threshold)
self.__pipeline.append(relevant_filter)

def add_flashrank_reranking(self, model_name="ms-marco-MultiBERT-L-12"):
from langchain.retrievers.document_compressors.flashrank_rerank import FlashrankRerank
compressor = FlashrankRerank(model=model_name)
self.__pipeline.append(compressor)

def add_cohere_reranking(self, model_name="rerank-english-v3.0"):
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
compressor = CohereRerank(model=model_name)
self.__pipeline.append(compressor)

def get(self):
return self.__pipeline
54 changes: 54 additions & 0 deletions flo_ai/retrievers/flo_multi_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List, Union

from langchain.chains.llm import LLMChain
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field
from flo_ai.state.flo_session import FloSession
from langchain.retrievers.multi_query import MultiQueryRetriever

class LineList(BaseModel):
lines: List[str] = Field(description="Lines of text")


class LineListOutputParser(PydanticOutputParser):
def __init__(self) -> None:
super().__init__(pydantic_object=LineList)

def parse(self, text: str) -> LineList:
lines = text.strip().split("\n")
return LineList(lines=lines)

class FloMultiQueryRetriever():
def __init__(self, retriever) -> None:
self.retriever = retriever

class FloMultiQueryRetriverBuilder():

def __init__(self,
session: FloSession,
retriver: VectorStoreRetriever,
query_prompt: Union[str, None] = None) -> None:
self.session = session
self.retriver = retriver
self.output_parser = LineListOutputParser()

self.prompt = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is to generate three
different versions of the given user question to retrieve relevant documents from a vector
database. By generating multiple perspectives on the user question, your goal is to help
the user overcome some of the limitations of the distance-based similarity search.
Provide these alternative questions separated by newlines.
Original question: {question}""" if query_prompt is None else query_prompt,
)

def build(self):
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=self.retriver,
llm=self.session.llm,
prompt=self.prompt
)
return FloMultiQueryRetriever(multi_query_retriever)

102 changes: 97 additions & 5 deletions flo_ai/retrievers/flo_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,98 @@
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.runnables import RunnableParallel
from flo_ai.state.flo_session import FloSession
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from flo_ai.retrievers.flo_multi_query import FloMultiQueryRetriverBuilder
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from flo_ai.retrievers.flo_compression_pipeline import FloCompressionPipeline

class FloRetriever():
def __init__(self, vector_store: VectorStore) -> None:
self.vector_store = vector_store
self.retriver = self.vector_store.as_retriever()
class FloRagBuilder():
def __init__(self, session: FloSession, retriever: VectorStoreRetriever) -> None:
self.session = session
self.retriever = retriever
self.default_prompt = ChatPromptTemplate.from_messages(
[
("system", """You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer concise."""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)

def with_prompt(self, prompt: ChatPromptTemplate):
self.default_prompt = prompt

def with_multi_query(self, prompt = None):
builder = FloMultiQueryRetriverBuilder(session=self.session,
retriver=self.retriever,
query_prompt=prompt)
multi_query_retriever = builder.build()
self.retriever = multi_query_retriever.retriever
return self

def with_compression(self, pipeline: FloCompressionPipeline):
pipeline_compressor = DocumentCompressorPipeline(
transformers=pipeline
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=self.retriever
)
self.retriever = compression_retriever
return self

def __create_history_aware_retriever(self):
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""

contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{question}"),
]
)
self.history_aware_retriever = contextualize_q_prompt | self.session.llm | StrOutputParser()
return self

def __get_retriever(self):
def __precontext_retriver(input_prompt: dict):
if input_prompt.get("chat_history"):
return self.history_aware_retriever
else:
return input_prompt["question"]
return __precontext_retriver | self.retriever

def __format_docs(self, docs):
return "\n\n".join(doc.page_content for doc in docs)

def __get_optional_chat_history(self, x):
return x["chat_history"] if "chat_history" in x else []

def __build_history_aware_rag(self):
self.history_aware_retriever = self.__create_history_aware_retriever()
rag_chain = (
RunnablePassthrough.assign(
context=(lambda x: x["context"]),
)
| self.default_prompt
| self.session.llm
)

rag_chain_with_source = RunnableParallel(
{
"context": self.__get_retriever() | self.__format_docs,
"question": RunnablePassthrough(),
"chat_history": lambda x: self.__get_optional_chat_history(x)
}
).assign(answer=rag_chain)
return rag_chain_with_source

def build(self):
return self.__build_history_aware_rag()
31 changes: 10 additions & 21 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,22 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10, <3.13"
langchain = "0.2.1"
langchain = "0.2.11"
langgraph = "0.0.55"
ipython = "8.24.0"
httpx = "0.27.0"
langchain-community = "0.2.0"
langchain-experimental = "0.0.59"
google-cloud-bigquery = "^3.23.1"
openai = "1.30.5"
langchain-openai = "^0.1.8"
pillow = "^10.3.0"
langchain-mongodb = "^0.1.5"
langchain-mistralai = "^0.1.7"
langchain-chroma = "^0.1.1"
python-dotenv = "^1.0.1"


[tool.poetry.group.dev.dependencies]
langchain-mongodb = "^0.1.5"
langchain-chroma = "^0.1.1"
langchainhub = "^0.1.17"
torch = "2.0.0"
sentence-transformers = "^3.0.0"
Expand Down