diff --git a/docs/examples/retrievers/ensemble_retrieval.ipynb b/docs/examples/retrievers/ensemble_retrieval.ipynb index 073d6d3b53d0a..c1e1c238260cd 100644 --- a/docs/examples/retrievers/ensemble_retrieval.ipynb +++ b/docs/examples/retrievers/ensemble_retrieval.ipynb @@ -102,7 +102,9 @@ "source": [ "# try loading great gatsby\n", "\n", - "documents = SimpleDirectoryReader(input_files=[\"../../../examples/gatsby/gatsby_full.txt\"]).load_data()" + "documents = SimpleDirectoryReader(\n", + " input_files=[\"../../../examples/gatsby/gatsby_full.txt\"]\n", + ").load_data()" ] }, { @@ -133,23 +135,23 @@ "vector_indices = []\n", "query_engines = []\n", "for chunk_size in chunk_sizes:\n", - " print(f'Chunk Size: {chunk_size}')\n", + " print(f\"Chunk Size: {chunk_size}\")\n", " service_context = ServiceContext.from_defaults(chunk_size=chunk_size, llm=llm)\n", " service_contexts.append(service_context)\n", " nodes = service_context.node_parser.get_nodes_from_documents(documents)\n", - " \n", + "\n", " # add chunk size to nodes to track later\n", " for node in nodes:\n", " node.metadata[\"chunk_size\"] = chunk_size\n", " node.excluded_embed_metadata_keys = [\"chunk_size\"]\n", " node.excluded_llm_metadata_keys = [\"chunk_size\"]\n", - " \n", + "\n", " nodes_list.append(nodes)\n", - " \n", + "\n", " # build vector index\n", " vector_index = VectorStoreIndex(nodes)\n", " vector_indices.append(vector_index)\n", - " \n", + "\n", " # query engines\n", " query_engines.append(vector_index.as_query_engine())" ] @@ -173,7 +175,7 @@ " retriever=vector_index.as_retriever(),\n", " description=f\"Retrieves relevant context from the Great Gatsby (chunk size {chunk_size})\",\n", " )\n", - " retriever_tools.append(retriever_tool)\n" + " retriever_tools.append(retriever_tool)" ] }, { @@ -185,15 +187,13 @@ }, "outputs": [], "source": [ - "from llama_index.selectors.pydantic_selectors import (\n", - " PydanticMultiSelector\n", - ")\n", + "from llama_index.selectors.pydantic_selectors import PydanticMultiSelector\n", "from llama_index.retrievers import RouterRetriever\n", "\n", "\n", "retriever = RouterRetriever(\n", " selector=PydanticMultiSelector.from_defaults(llm=llm, max_outputs=4),\n", - " retriever_tools=retriever_tools\n", + " retriever_tools=retriever_tools,\n", ")" ] }, @@ -221,7 +221,9 @@ } ], "source": [ - "nodes = await retriever.aretrieve(\"Describe and summarize the interactions between Gatsby and Daisy\")" + "nodes = await retriever.aretrieve(\n", + " \"Describe and summarize the interactions between Gatsby and Daisy\"\n", + ")" ] }, { @@ -560,7 +562,12 @@ "outputs": [], "source": [ "# define reranker\n", - "from llama_index.indices.postprocessor import LLMRerank, SentenceTransformerRerank, CohereRerank\n", + "from llama_index.indices.postprocessor import (\n", + " LLMRerank,\n", + " SentenceTransformerRerank,\n", + " CohereRerank,\n", + ")\n", + "\n", "# reranker = LLMRerank()\n", "# reranker = SentenceTransformerRerank(top_n=10)\n", "reranker = CohereRerank(top_n=10)" @@ -578,10 +585,7 @@ "# define RetrieverQueryEngine\n", "from llama_index.query_engine import RetrieverQueryEngine\n", "\n", - "query_engine = RetrieverQueryEngine(\n", - " retriever,\n", - " node_postprocessors=[reranker]\n", - ")" + "query_engine = RetrieverQueryEngine(retriever, node_postprocessors=[reranker])" ] }, { @@ -604,7 +608,9 @@ } ], "source": [ - "response = query_engine.query(\"Describe and summarize the interactions between Gatsby and Daisy\")" + "response = query_engine.query(\n", + " \"Describe and summarize the interactions between Gatsby and Daisy\"\n", + ")" ] }, { @@ -1003,7 +1009,9 @@ } ], "source": [ - "display_response(response, show_source=True, source_length=500, show_source_metadata=True)" + "display_response(\n", + " response, show_source=True, source_length=500, show_source_metadata=True\n", + ")" ] }, { @@ -1019,6 +1027,7 @@ "from collections import defaultdict\n", "import pandas as pd\n", "\n", + "\n", "def mrr_all(metadata_values, metadata_key, source_nodes):\n", " # source nodes is a ranked list\n", " # go through each value, find out positioning in source_nodes\n", @@ -1027,18 +1036,17 @@ " mrr = 0\n", " for idx, source_node in enumerate(source_nodes):\n", " if source_node.node.metadata[metadata_key] == metadata_value:\n", - " mrr = 1 / (idx+1)\n", + " mrr = 1 / (idx + 1)\n", " break\n", " else:\n", " continue\n", - " \n", + "\n", " # normalize AP, set in dict\n", " value_to_mrr_dict[metadata_value] = mrr\n", - " \n", + "\n", " df = pd.DataFrame(value_to_mrr_dict, index=[\"MRR\"])\n", " df.style.set_caption(\"Mean Reciprocal Rank\")\n", - " return df\n", - " " + " return df" ] }, { @@ -1108,7 +1116,7 @@ "source": [ "# Compute the Mean Reciprocal Rank for each chunk size (higher is better)\n", "# we can see that chunk size of 256 has the highest ranked results.\n", - "print('Mean Reciprocal Rank for each Chunk Size')\n", + "print(\"Mean Reciprocal Rank for each Chunk Size\")\n", "mrr_all(chunk_sizes, \"chunk_size\", response.source_nodes)" ] }, @@ -1143,7 +1151,9 @@ }, "outputs": [], "source": [ - "response_1024 = query_engine_1024.query(\"Describe and summarize the interactions between Gatsby and Daisy\")" + "response_1024 = query_engine_1024.query(\n", + " \"Describe and summarize the interactions between Gatsby and Daisy\"\n", + ")" ] }, {