Skip to content

Commit

Permalink
fix llama2 in query engines (#6969)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich committed Jul 19, 2023
1 parent 576eba7 commit 1033303
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Added a `SentenceTransformerRerank` node post-processor for fast local re-ranking (#6934)
- Add numpy support for evaluating queries in pandas query engine (#6935)
- Add metadata filtering support for Postgres Vector Storage integration (#6968)
- Proper llama2 support for agents and query engines (#6969)

### Bug Fixes / Nits
- Added `model_name` to LLMMetadata (#6911)
Expand Down
4 changes: 1 addition & 3 deletions benchmarks/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def get_model(model: str) -> LLM:
llm = Replicate(
model=replicate_model,
temperature=0.01,
# override max tokens since it's interpreted
# as context window instead of max tokens
max_tokens=4096,
context_window=4096,
# override message representation for llama 2
messages_to_prompt=messages_to_prompt,
)
Expand Down
1 change: 1 addition & 0 deletions docs/core_modules/model_modules/llms/modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ maxdepth: 1
---
/examples/llm/llama_2.ipynb
/examples/llm/vicuna.ipynb
/examples/vector_stores/SimpleIndexDemoLlama2.ipynb
```

## LangChain
Expand Down
277 changes: 277 additions & 0 deletions docs/examples/vector_stores/SimpleIndexDemoLlama2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "9c48213d-6e6a-4c10-838a-2a7c710c3a05",
"metadata": {},
"source": [
"# Llama2 + VectorStoreIndex\n",
"\n",
"This notebook walks through the proper setup to use llama-2 with LlamaIndex. Specifically, we look at using a vector store index."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "91f09a23",
"metadata": {},
"source": [
"## Setup"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ba765302",
"metadata": {},
"source": [
"### Keys"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3d8cab38",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"OPENAI_API_KEY\"\n",
"os.environ[\"REPLICATE_API_TOKEN\"] = \"REPLICATE_API_TOKEN\"\n",
"\n",
"# currently needed for notebooks\n",
"import openai\n",
"\n",
"openai.api_key = os.environ[\"OPENAI_API_KEY\"]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "50d3b817-b70e-4667-be4f-d3a0fe4bd119",
"metadata": {},
"source": [
"### Load documents, build the VectorStoreIndex"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "690a6918-7c75-4f95-9ccc-d2c4a1fe00d7",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:numexpr.utils:Note: NumExpr detected 12 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
"Note: NumExpr detected 12 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
"INFO:numexpr.utils:NumExpr defaulting to 8 threads.\n",
"NumExpr defaulting to 8 threads.\n"
]
}
],
"source": [
"import logging\n",
"import sys\n",
"\n",
"logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
"logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n",
"\n",
"from llama_index import (\n",
" VectorStoreIndex,\n",
" SimpleDirectoryReader,\n",
")\n",
"\n",
"from IPython.display import Markdown, display"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "be92665d",
"metadata": {},
"outputs": [],
"source": [
"from llama_index.llms import Replicate\n",
"from llama_index import ServiceContext, set_global_service_context\n",
"from llama_index.llms.llama_utils import messages_to_prompt, completion_to_prompt\n",
"\n",
"# The replicate endpoint\n",
"LLAMA_13B_V2_CHAT = \"a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5\"\n",
"\n",
"# inject custom system prompt into llama-2\n",
"def custom_completion_to_prompt(completion: str) -> str:\n",
" return completion_to_prompt(\n",
" completion,\n",
" system_prompt=(\n",
" \"You are a Q&A assistant. Your goal is to answer questions as \"\n",
" \"accurately as possible is the instructions and context provided.\"\n",
" ),\n",
" )\n",
"\n",
"\n",
"llm = Replicate(\n",
" model=LLAMA_13B_V2_CHAT,\n",
" temperature=0.01,\n",
" # override max tokens since it's interpreted\n",
" # as context window instead of max tokens\n",
" context_window=4096,\n",
" # override completion representation for llama 2\n",
" completion_to_prompt=custom_completion_to_prompt,\n",
" # if using llama 2 for data agents, also override the message representation\n",
" messages_to_prompt=messages_to_prompt,\n",
")\n",
"\n",
"# set a global service context\n",
"ctx = ServiceContext.from_defaults(llm=llm)\n",
"set_global_service_context(ctx)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "03d1691e-544b-454f-825b-5ee12f7faa8a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# load documents\n",
"documents = SimpleDirectoryReader(\n",
" \"../../../examples/paul_graham_essay/data\"\n",
").load_data()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ad144ee7-96da-4dd6-be00-fd6cf0c78e58",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"index = VectorStoreIndex.from_documents(documents)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b6caf93b-6345-4c65-a346-a95b0f1746c4",
"metadata": {},
"source": [
"## Querying"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "85466fdf-93f3-4cb1-a5f9-0056a8245a6f",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"# set Logging to DEBUG for more detailed outputs\n",
"query_engine = index.as_query_engine()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bdda1b2c-ae46-47cf-91d7-3153e8d0473b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/markdown": [
"<b> Based on the context information provided, the author's activities growing up were:\n",
"1. Writing short stories, which were \"awful\" and lacked a strong plot.\n",
"2. Programming on an IBM 1401 computer in 9th grade, using an early version of Fortran.\n",
"3. Building a microcomputer with a friend, and writing simple games, a program to predict the height of model rockets, and a word processor.\n",
"4. Studying philosophy in college, but finding it boring and switching to AI.\n",
"5. Writing essays online, which became a turning point in their career.</b>"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"response = query_engine.query(\"What did the author do growing up?\")\n",
"display(Markdown(f\"<b>{response}</b>\"))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "24935a47",
"metadata": {},
"source": [
"### Streaming Support"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "446406f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Based on the context information provided, it appears that the author worked at Interleaf, a company that made software for creating and managing documents. The author mentions that Interleaf was \"on the way down\" and that the company's Release Engineering group was large compared to the group that actually wrote the software. It is inferred that Interleaf was experiencing financial difficulties and that the author was nervous about money. However, there is no explicit mention of what specifically happened at Interleaf."
]
}
],
"source": [
"query_engine = index.as_query_engine(streaming=True)\n",
"response = query_engine.query(\"What happened at interleaf?\")\n",
"for token in response.response_gen:\n",
" print(token, end=\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6cb2c4c0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
17 changes: 14 additions & 3 deletions llama_index/llms/llama_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Optional, Sequence

from llama_index.llms.base import ChatMessage, MessageRole

Expand All @@ -18,13 +18,15 @@
"""


def messages_to_prompt(messages: Sequence[ChatMessage]) -> str:
def messages_to_prompt(
messages: Sequence[ChatMessage], system_prompt: Optional[str] = None
) -> str:
string_messages = []
if messages[0].role == MessageRole.SYSTEM:
system_message_str = messages[0].content or ""
messages = messages[1:]
else:
system_message_str = DEFAULT_SYSTEM_PROMPT
system_message_str = system_prompt or DEFAULT_SYSTEM_PROMPT

system_message_str = B_SYS + system_message_str + E_SYS

Expand All @@ -49,3 +51,12 @@ def messages_to_prompt(messages: Sequence[ChatMessage]) -> str:
assert last_message.role == MessageRole.USER
string_messages.append(f"{B_INST} {last_message.content} {E_INST}")
return "".join(string_messages)


def completion_to_prompt(completion: str, system_prompt: Optional[str] = None) -> str:
system_prompt_str = system_prompt or DEFAULT_SYSTEM_PROMPT

return (
f"{BOS}{B_INST} {B_SYS}{system_prompt_str.strip()}{E_SYS}"
f"{completion.strip()} {E_INST}"
)
11 changes: 7 additions & 4 deletions llama_index/llms/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,36 @@ def __init__(
self,
model: str,
temperature: float = 0.75,
max_tokens: int = DEFAULT_NUM_OUTPUTS,
additional_kwargs: Optional[Dict[str, Any]] = None,
context_window: int = DEFAULT_CONTEXT_WINDOW,
prompt_key: str = "prompt",
messages_to_prompt: Optional[Callable] = None,
completion_to_prompt: Optional[Callable] = None,
) -> None:
self._model = model
self._context_window = context_window
self._prompt_key = prompt_key
self._messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
self._completion_to_prompt = completion_to_prompt or (lambda x: x)

# model kwargs
self._temperature = temperature
self._max_tokens = max_tokens
self._additional_kwargs = additional_kwargs or {}

@property
def metadata(self) -> LLMMetadata:
"""LLM metadata."""
return LLMMetadata(
context_window=self._context_window, num_output=self._max_tokens
context_window=self._context_window,
num_output=DEFAULT_NUM_OUTPUTS,
model_name=self._model,
)

@property
def _model_kwargs(self) -> Dict[str, Any]:
base_kwargs = {
"temperature": self._temperature,
"max_length": self._max_tokens,
"max_length": self._context_window,
}
model_kwargs = {
**base_kwargs,
Expand Down Expand Up @@ -88,6 +90,7 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
"Please install replicate with `pip install replicate`"
)

prompt = self._completion_to_prompt(prompt)
input_dict = self._get_input_dict(prompt, **kwargs)
response_iter = replicate.run(self._model, input=input_dict)

Expand Down

0 comments on commit 1033303

Please sign in to comment.