diff --git a/examples/history_aware_rag.py b/examples/history_aware_rag.py new file mode 100644 index 00000000..455fdab1 --- /dev/null +++ b/examples/history_aware_rag.py @@ -0,0 +1,41 @@ +from langchain_mongodb import MongoDBAtlasVectorSearch +from pymongo import MongoClient +from langchain_openai import OpenAIEmbeddings +import os +from langchain import hub +from langchain_openai import ChatOpenAI +from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder +from dotenv import load_dotenv +load_dotenv() + +from flo_ai import FloSession +from flo_ai.retrievers.flo_retriever import FloRagBuilder + +db_url = os.getenv("MONGO_DB_URL") + +connection_timeout = 60000 +mongo_client = MongoClient(db_url, connectTimeoutMS=connection_timeout, socketTimeoutMS=connection_timeout) +mongo_embedding_collection = (mongo_client + .get_database("dohabank") + .get_collection("products")) + +store = MongoDBAtlasVectorSearch( + collection=mongo_embedding_collection, + embedding_key="embedding", + embedding=OpenAIEmbeddings(model="text-embedding-3-small"), + index_name="bank-products-index", +) + + +llm = ChatOpenAI(temperature=0, model_name='gpt-4o') +session = FloSession(llm) +rag_builder = FloRagBuilder(session, store.as_retriever()) + +import logging + +logging.basicConfig() +logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO) + +rag = rag_builder.with_multi_query().build() +print(rag.invoke({ "question": "Tell me about corporate loans" })) + diff --git a/flo_ai/retrievers/flo_compression_pipeline.py b/flo_ai/retrievers/flo_compression_pipeline.py new file mode 100644 index 00000000..bfc31b78 --- /dev/null +++ b/flo_ai/retrievers/flo_compression_pipeline.py @@ -0,0 +1,35 @@ +from langchain_core.embeddings import Embeddings +from langchain_text_splitters import CharacterTextSplitter +from langchain_community.document_transformers import EmbeddingsRedundantFilter +from langchain.retrievers.document_compressors import EmbeddingsFilter + +class FloCompressionPipeline(): + + def __init__(self, embeddings: Embeddings) -> None: + self.__embeddings = embeddings + self.__pipeline = [] + + def add_chuncking(self, chunk_size = 300, chunk_overlap = 0): + splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=". ") + self.__pipeline.append(splitter) + + def add_embedding_reduntant_filter(self): + redundant_filter = EmbeddingsRedundantFilter(embeddings=self.__embeddings) + self.__pipeline.append(redundant_filter) + + def add_embedding_relevant_filter(self, threshold: float = 0.76): + relevant_filter = EmbeddingsFilter(embeddings=self.__embeddings, similarity_threshold=threshold) + self.__pipeline.append(relevant_filter) + + def add_flashrank_reranking(self, model_name="ms-marco-MultiBERT-L-12"): + from langchain.retrievers.document_compressors.flashrank_rerank import FlashrankRerank + compressor = FlashrankRerank(model=model_name) + self.__pipeline.append(compressor) + + def add_cohere_reranking(self, model_name="rerank-english-v3.0"): + from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank + compressor = CohereRerank(model=model_name) + self.__pipeline.append(compressor) + + def get(self): + return self.__pipeline diff --git a/flo_ai/retrievers/flo_multi_query.py b/flo_ai/retrievers/flo_multi_query.py new file mode 100644 index 00000000..89881ad2 --- /dev/null +++ b/flo_ai/retrievers/flo_multi_query.py @@ -0,0 +1,54 @@ +from typing import List, Union + +from langchain.chains.llm import LLMChain +from langchain_core.vectorstores import VectorStoreRetriever +from langchain.output_parsers import PydanticOutputParser +from langchain_core.prompts import PromptTemplate +from pydantic import BaseModel, Field +from flo_ai.state.flo_session import FloSession +from langchain.retrievers.multi_query import MultiQueryRetriever + +class LineList(BaseModel): + lines: List[str] = Field(description="Lines of text") + + +class LineListOutputParser(PydanticOutputParser): + def __init__(self) -> None: + super().__init__(pydantic_object=LineList) + + def parse(self, text: str) -> LineList: + lines = text.strip().split("\n") + return LineList(lines=lines) + +class FloMultiQueryRetriever(): + def __init__(self, retriever) -> None: + self.retriever = retriever + +class FloMultiQueryRetriverBuilder(): + + def __init__(self, + session: FloSession, + retriver: VectorStoreRetriever, + query_prompt: Union[str, None] = None) -> None: + self.session = session + self.retriver = retriver + self.output_parser = LineListOutputParser() + + self.prompt = PromptTemplate( + input_variables=["question"], + template="""You are an AI language model assistant. Your task is to generate three + different versions of the given user question to retrieve relevant documents from a vector + database. By generating multiple perspectives on the user question, your goal is to help + the user overcome some of the limitations of the distance-based similarity search. + Provide these alternative questions separated by newlines. + Original question: {question}""" if query_prompt is None else query_prompt, + ) + + def build(self): + multi_query_retriever = MultiQueryRetriever.from_llm( + retriever=self.retriver, + llm=self.session.llm, + prompt=self.prompt + ) + return FloMultiQueryRetriever(multi_query_retriever) + diff --git a/flo_ai/retrievers/flo_retriever.py b/flo_ai/retrievers/flo_retriever.py index 4b3d9d7c..4f6264b6 100644 --- a/flo_ai/retrievers/flo_retriever.py +++ b/flo_ai/retrievers/flo_retriever.py @@ -1,6 +1,98 @@ -from langchain_core.vectorstores import VectorStore +from langchain_core.vectorstores import VectorStoreRetriever +from langchain_core.runnables import RunnableParallel +from flo_ai.state.flo_session import FloSession +from langchain.schema.output_parser import StrOutputParser +from langchain.schema.runnable import RunnablePassthrough +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from flo_ai.retrievers.flo_multi_query import FloMultiQueryRetriverBuilder +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import DocumentCompressorPipeline +from flo_ai.retrievers.flo_compression_pipeline import FloCompressionPipeline -class FloRetriever(): - def __init__(self, vector_store: VectorStore) -> None: - self.vector_store = vector_store - self.retriver = self.vector_store.as_retriever() \ No newline at end of file +class FloRagBuilder(): + def __init__(self, session: FloSession, retriever: VectorStoreRetriever) -> None: + self.session = session + self.retriever = retriever + self.default_prompt = ChatPromptTemplate.from_messages( + [ + ("system", """You are an assistant for question-answering tasks. + Use the following pieces of retrieved context to answer the question. + If you don't know the answer, just say that you don't know. + Use three sentences maximum and keep the answer concise."""), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}"), + ] + ) + + def with_prompt(self, prompt: ChatPromptTemplate): + self.default_prompt = prompt + + def with_multi_query(self, prompt = None): + builder = FloMultiQueryRetriverBuilder(session=self.session, + retriver=self.retriever, + query_prompt=prompt) + multi_query_retriever = builder.build() + self.retriever = multi_query_retriever.retriever + return self + + def with_compression(self, pipeline: FloCompressionPipeline): + pipeline_compressor = DocumentCompressorPipeline( + transformers=pipeline + ) + compression_retriever = ContextualCompressionRetriever( + base_compressor=pipeline_compressor, base_retriever=self.retriever + ) + self.retriever = compression_retriever + return self + + def __create_history_aware_retriever(self): + contextualize_q_system_prompt = """Given a chat history and the latest user question \ + which might reference context in the chat history, formulate a standalone question \ + which can be understood without the chat history. Do NOT answer the question, \ + just reformulate it if needed and otherwise return it as is.""" + + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{question}"), + ] + ) + self.history_aware_retriever = contextualize_q_prompt | self.session.llm | StrOutputParser() + return self + + def __get_retriever(self): + def __precontext_retriver(input_prompt: dict): + if input_prompt.get("chat_history"): + return self.history_aware_retriever + else: + return input_prompt["question"] + return __precontext_retriver | self.retriever + + def __format_docs(self, docs): + return "\n\n".join(doc.page_content for doc in docs) + + def __get_optional_chat_history(self, x): + return x["chat_history"] if "chat_history" in x else [] + + def __build_history_aware_rag(self): + self.history_aware_retriever = self.__create_history_aware_retriever() + rag_chain = ( + RunnablePassthrough.assign( + context=(lambda x: x["context"]), + ) + | self.default_prompt + | self.session.llm + ) + + rag_chain_with_source = RunnableParallel( + { + "context": self.__get_retriever() | self.__format_docs, + "question": RunnablePassthrough(), + "chat_history": lambda x: self.__get_optional_chat_history(x) + } + ).assign(answer=rag_chain) + return rag_chain_with_source + + def build(self): + return self.__build_history_aware_rag() diff --git a/poetry.lock b/poetry.lock index 440f7df6..b9ee21bd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1775,41 +1775,30 @@ adal = ["adal (>=1.0.2)"] [[package]] name = "langchain" -version = "0.2.1" +version = "0.2.11" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain-0.2.1-py3-none-any.whl", hash = "sha256:3e13bf97c5717bce2c281f5117e8778823e8ccf62d949e73d3869448962b1c97"}, - {file = "langchain-0.2.1.tar.gz", hash = "sha256:5758a315e1ac92eb26dafec5ad0fafa03cafa686aba197d5bb0b1dd28cc03ebe"}, + {file = "langchain-0.2.11-py3-none-any.whl", hash = "sha256:5a7a8b4918f3d3bebce9b4f23b92d050699e6f7fb97591e8941177cf07a260a2"}, + {file = "langchain-0.2.11.tar.gz", hash = "sha256:d7a9e4165f02dca0bd78addbc2319d5b9286b5d37c51d784124102b57e9fd297"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} -langchain-core = ">=0.2.0,<0.3.0" +langchain-core = ">=0.2.23,<0.3.0" langchain-text-splitters = ">=0.2.0,<0.3.0" langsmith = ">=0.1.17,<0.2.0" -numpy = ">=1,<2" +numpy = [ + {version = ">=1,<2", markers = "python_version < \"3.12\""}, + {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, +] pydantic = ">=1,<3" PyYAML = ">=5.3" requests = ">=2,<3" SQLAlchemy = ">=1.4,<3" -tenacity = ">=8.1.0,<9.0.0" - -[package.extras] -azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"] -clarifai = ["clarifai (>=9.1.0)"] -cli = ["typer (>=0.9.0,<0.10.0)"] -cohere = ["cohere (>=4,<6)"] -docarray = ["docarray[hnswlib] (>=0.32.0,<0.33.0)"] -embeddings = ["sentence-transformers (>=2,<3)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<6)", "couchbase (>=4.1.9,<5.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "langchain-openai (>=0.1,<0.2)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] -javascript = ["esprima (>=4.0.1,<5.0.0)"] -llms = ["clarifai (>=9.1.0)", "cohere (>=4,<6)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"] -openai = ["openai (<2)", "tiktoken (>=0.7,<1.0)"] -qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"] -text-helpers = ["chardet (>=5.1.0,<6.0.0)"] +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-chroma" @@ -5312,4 +5301,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10, <3.13" -content-hash = "0b28e33eee19ad9a3ae742d93df7b4cdc9c161e99340a3fee7b5c201ff287fd6" +content-hash = "3bf0df13c61e484f49eefa2e16e5b895e2cee7359b9afd4eb5ea49f61b4a25a4" diff --git a/pyproject.toml b/pyproject.toml index 190b8a65..602f0973 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,23 +8,22 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.10, <3.13" -langchain = "0.2.1" +langchain = "0.2.11" langgraph = "0.0.55" ipython = "8.24.0" httpx = "0.27.0" langchain-community = "0.2.0" langchain-experimental = "0.0.59" -google-cloud-bigquery = "^3.23.1" openai = "1.30.5" langchain-openai = "^0.1.8" pillow = "^10.3.0" -langchain-mongodb = "^0.1.5" langchain-mistralai = "^0.1.7" -langchain-chroma = "^0.1.1" python-dotenv = "^1.0.1" [tool.poetry.group.dev.dependencies] +langchain-mongodb = "^0.1.5" +langchain-chroma = "^0.1.1" langchainhub = "^0.1.17" torch = "2.0.0" sentence-transformers = "^3.0.0"