From 1a21313d73230a96a4458051fc76c6a214ebc625 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 3 Nov 2024 21:35:31 +0000 Subject: [PATCH 1/5] working on chap app example --- .gitignore | 1 + pydantic_ai/models/openai.py | 10 +- pydantic_ai_examples/README.md | 26 +++++- pydantic_ai_examples/__main__.py | 5 +- pydantic_ai_examples/chat_app.html | 143 +++++++++++++++++++++++++++++ pydantic_ai_examples/chat_app.py | 71 ++++++++++++++ pydantic_ai_examples/sql_gen.py | 7 +- pyproject.toml | 7 +- tests/test_agent.py | 58 ++++++++++++ uv.lock | 56 +++++++++++ 10 files changed, 374 insertions(+), 10 deletions(-) create mode 100644 pydantic_ai_examples/chat_app.html create mode 100644 pydantic_ai_examples/chat_app.py diff --git a/.gitignore b/.gitignore index 11f2837c1b..99f4125d24 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ env*/ /TODO.md /postgres-data/ .DS_Store +/pydantic_ai_examples/.chat_app_messages.json diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 4c4b423d7b..7fb41a7b5b 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -6,7 +6,7 @@ from typing import Literal from httpx import AsyncClient as AsyncHTTPClient -from openai import AsyncOpenAI +from openai import NOT_GIVEN, AsyncOpenAI from openai.types import ChatModel, chat from typing_extensions import assert_never @@ -104,7 +104,7 @@ def process_response(response: chat.ChatCompletion) -> LLMMessage: async def completions_create(self, messages: list[Message]) -> chat.ChatCompletion: # standalone function to make it easier to override if not self.tools: - tool_choice: Literal['none', 'required', 'auto'] = 'none' + tool_choice: Literal['none', 'required', 'auto'] | None = None elif not self.allow_text_result: tool_choice = 'required' else: @@ -115,9 +115,9 @@ async def completions_create(self, messages: list[Message]) -> chat.ChatCompleti model=self.model_name, messages=openai_messages, n=1, - parallel_tool_calls=True, - tools=self.tools, - tool_choice=tool_choice, + parallel_tool_calls=True if self.tools else NOT_GIVEN, + tools=self.tools or NOT_GIVEN, + tool_choice=tool_choice or NOT_GIVEN, ) @staticmethod diff --git a/pydantic_ai_examples/README.md b/pydantic_ai_examples/README.md index 6a106f34b1..96335acfa9 100644 --- a/pydantic_ai_examples/README.md +++ b/pydantic_ai_examples/README.md @@ -35,6 +35,21 @@ But you'll probably want to edit examples as well as just running them, so you c python -m pydantic_ai_examples --copy-to examples/ ``` +### Setting model environment variables + +All these examples will need to set either: + +* `OPENAI_API_KEY` to use OpenAI models, go to [platform.openai.com](https://platform.openai.com/) and follow your nose until you find how to generate an API key +* `GEMINI_API_KEY` to use Gemini/Google models, go to [aistudio.google.com](https://aistudio.google.com/) and do the same to generate an API key + +Then set the API key as an env variable with: + +```bash +export OPENAI_API_KEY=your-api-key +# or +export GEMINI_API_KEY=your-api-key +``` + ## Examples ### `pydantic_model.py` @@ -62,6 +77,15 @@ PYDANTIC_AI_MODEL=gemini-1.5-pro (uv run/python) -m pydantic_ai_examples.pydanti Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. +The resulting SQL is validated by running it as an `EXPLAIN` query on postgres. To run the example, you first need to run postgres, e.g. via Docker: + +```bash +docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres +``` +_(we run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running)_ + +Then to run the code + ```bash (uv run/python) -m pydantic_ai_examples.sql_gen ``` @@ -116,7 +140,7 @@ mkdir postgres-data docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 -v `pwd`/postgres-data:/var/lib/postgresql/data pgvector/pgvector:pg17 ``` -We run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running. +As above, we run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running. We also mount the postgresql `data` directory locally to persist the data if you need to stop and restart the container. With that running, we can build the search database with (**WARNING**: this requires the `OPENAI_API_KEY` env variable and will calling the OpenAI embedding API around 300 times to generate embeddings for each section of the documentation): diff --git a/pydantic_ai_examples/__main__.py b/pydantic_ai_examples/__main__.py index fe64846468..49a37a3258 100644 --- a/pydantic_ai_examples/__main__.py +++ b/pydantic_ai_examples/__main__.py @@ -2,6 +2,7 @@ See README.md for more information. """ + import argparse import sys from pathlib import Path @@ -10,7 +11,9 @@ def cli(): this_dir = Path(__file__).parent - parser = argparse.ArgumentParser(prog='pydantic_ai_examples', description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + parser = argparse.ArgumentParser( + prog='pydantic_ai_examples', description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) parser.add_argument('-v', '--version', action='store_true', help='show the version and exit') parser.add_argument('--copy-to', dest='DEST', help='Copy all examples to a new directory') diff --git a/pydantic_ai_examples/chat_app.html b/pydantic_ai_examples/chat_app.html new file mode 100644 index 0000000000..7b9e53c817 --- /dev/null +++ b/pydantic_ai_examples/chat_app.html @@ -0,0 +1,143 @@ + + + + + + Chat App + + + +
+

Chat App

+

Ask me anything...

+
+
+
+
+
+ + +
+
+ Error occurred, check the console for more information. +
+
+ + + diff --git a/pydantic_ai_examples/chat_app.py b/pydantic_ai_examples/chat_app.py new file mode 100644 index 0000000000..67046b09c4 --- /dev/null +++ b/pydantic_ai_examples/chat_app.py @@ -0,0 +1,71 @@ +from collections.abc import Iterator +from dataclasses import dataclass +from pathlib import Path +from typing import Annotated + +import fastapi +from fastapi.responses import HTMLResponse +from pydantic import Field, TypeAdapter + +from pydantic_ai import Agent +from pydantic_ai.messages import Message, MessagesTypeAdapter + +agent = Agent('openai:gpt-4o', deps=None) + +app = fastapi.FastAPI() + + +@app.get('/') +async def index() -> HTMLResponse: + return HTMLResponse((THIS_DIR / 'chat_app.html').read_bytes()) + + +@app.get('/chat/') +async def get_chat() -> fastapi.Response: + messages = list(database.get_messages()) + messages = MessagesTypeAdapter.dump_json(messages) + return fastapi.Response(content=messages, media_type='application/json') + + +@app.post('/chat/') +async def post_chat(prompt: Annotated[str, fastapi.Form()]) -> fastapi.Response: + messages = list(database.get_messages()) + response = await agent.run(prompt, message_history=messages) + response_messages: list[Message] = [] + for message in response.message_history: + if message.role != 'system': + database.add_message(message) + response_messages.append(message) + messages = MessagesTypeAdapter.dump_json(response_messages) + return fastapi.Response(content=messages, media_type='application/json') + + +THIS_DIR = Path(__file__).parent +MessageTypeAdapter: TypeAdapter[Message] = TypeAdapter(Annotated[Message, Field(discriminator='role')]) + + +@dataclass +class Database: + """Very rudimentary database to store chat messages in a JSON lines file.""" + + file: Path = THIS_DIR / '.chat_app_messages.json' + + def add_message(self, message: Message): + with self.file.open('ab') as f: + f.write(MessageTypeAdapter.dump_json(message) + b'\n') + + def get_messages(self) -> Iterator[Message]: + if self.file.exists(): + with self.file.open('rb') as f: + for line in f: + if line: + yield MessageTypeAdapter.validate_json(line) + + +database = Database() + + +if __name__ == '__main__': + import uvicorn + + uvicorn.run('pydantic_ai_examples.chat_app:app', reload=True, reload_dirs=[str(THIS_DIR)]) diff --git a/pydantic_ai_examples/sql_gen.py b/pydantic_ai_examples/sql_gen.py index 624f67bbe2..c9bb8dac8b 100644 --- a/pydantic_ai_examples/sql_gen.py +++ b/pydantic_ai_examples/sql_gen.py @@ -1,5 +1,10 @@ """Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. +Run postgres with: + + mkdir postgres-data + docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres + Run with: uv run -m pydantic_ai_examples.sql_gen "show me logs from yesterday, with level 'error'" @@ -121,7 +126,7 @@ async def main(): else: prompt = sys.argv[1] - async with database_connect('postgresql://postgres@localhost', 'pydantic_ai_sql_gen') as conn: + async with database_connect('postgresql://postgres:postgres@localhost:54320', 'pydantic_ai_sql_gen') as conn: deps = Deps(conn) result = await agent.run(prompt, deps=deps) debug(result.response) diff --git a/pyproject.toml b/pyproject.toml index abd61157c3..a845726b1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,10 @@ logfire = [ ] examples = [ "asyncpg>=0.30.0", + "fastapi>=0.115.4", "logfire[asyncpg]>=1.2.0", + "python-multipart>=0.0.17", + "uvicorn>=0.32.0", ] [dependency-groups] @@ -77,7 +80,7 @@ target-version = "py39" include = [ "pydantic_ai/**/*.py", "tests/**/*.py", - "examples/**/*.py", + "pydantic_ai_examples/**/*.py", ] [tool.ruff.lint] @@ -110,7 +113,7 @@ quote-style = "single" [tool.ruff.lint.per-file-ignores] "tests/**.py" = ["D"] -"examples/**.py" = ["D103"] +"pydantic_ai_examples/**.py" = ["D103"] [tool.pyright] typeCheckingMode = "strict" diff --git a/tests/test_agent.py b/tests/test_agent.py index 552e4797bc..e2ef59c8cb 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -9,16 +9,20 @@ from pydantic_ai import Agent, CallContext, ModelRetry from pydantic_ai.messages import ( ArgsJson, + ArgsObject, LLMMessage, LLMResponse, LLMToolCalls, Message, RetryPrompt, + SystemPrompt, ToolCall, + ToolReturn, UserPrompt, ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.shared import Cost, RunResult from tests.conftest import IsNow @@ -348,3 +352,57 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any: result = agent.run_sync('Hello', model=TestModel(seed=1)) assert result.response == mod.Bar(b='b') assert got_tool_call_name == snapshot('final_result_Bar') + + +def test_run_with_history(): + m = TestModel() + + agent = Agent(m, deps=None, system_prompt='Foobar') + + @agent.retriever_plain + async def ret_a(x: str) -> str: + return f'{x}-apple' + + result = agent.run_sync('Hello') + assert result == snapshot( + RunResult( + response='{"ret_a":"a-apple"}', + message_history=[ + SystemPrompt(content='Foobar'), + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), + LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), + ], + cost=Cost(), + ) + ) + + result = agent.run_sync('Hello again', message_history=result.message_history) + assert result == snapshot( + RunResult( + response='{"ret_a":"a-apple"}', + message_history=[ + SystemPrompt(content='Foobar'), + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), + LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), + # second call, notice no repeated system prompt + UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), + LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), + ], + cost=Cost(), + ) + ) diff --git a/uv.lock b/uv.lock index d73c5301f8..70f4674ce2 100644 --- a/uv.lock +++ b/uv.lock @@ -398,6 +398,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/fd/afcd0496feca3276f509df3dbd5dae726fcc756f1a08d9e25abe1733f962/executing-2.1.0-py2.py3-none-any.whl", hash = "sha256:8d63781349375b5ebccc3142f4b30350c0cd9c79f921cde38be2be4637e98eaf", size = 25805 }, ] +[[package]] +name = "fastapi" +version = "0.115.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/db/5781f19bd30745885e0737ff3fdd4e63e7bc691710f9da691128bb0dc73b/fastapi-0.115.4.tar.gz", hash = "sha256:db653475586b091cb8b2fec2ac54a680ac6a158e07406e1abae31679e8826349", size = 300737 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/f6/af0d1f58f86002be0cf1e2665cdd6f7a4a71cdc8a7a9438cdc9e3b5375fe/fastapi-0.115.4-py3-none-any.whl", hash = "sha256:0b504a063ffb3cf96a5e27dc1bc32c80ca743a2528574f9cdc77daa2d31b4742", size = 94732 }, +] + [[package]] name = "googleapis-common-protos" version = "1.65.0" @@ -899,7 +913,10 @@ dependencies = [ [package.optional-dependencies] examples = [ { name = "asyncpg" }, + { name = "fastapi" }, { name = "logfire", extra = ["asyncpg"] }, + { name = "python-multipart" }, + { name = "uvicorn" }, ] logfire = [ { name = "logfire" }, @@ -923,6 +940,7 @@ dev = [ requires-dist = [ { name = "asyncpg", marker = "extra == 'examples'", specifier = ">=0.30.0" }, { name = "eval-type-backport", specifier = ">=0.2.0" }, + { name = "fastapi", marker = "extra == 'examples'", specifier = ">=0.115.4" }, { name = "griffe", specifier = ">=1.3.2" }, { name = "httpx", specifier = ">=0.27.2" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=1.2.0" }, @@ -930,6 +948,8 @@ requires-dist = [ { name = "logfire-api", specifier = ">=1.2.0" }, { name = "openai", specifier = ">=1.51.2" }, { name = "pydantic", specifier = ">=2.9.2" }, + { name = "python-multipart", marker = "extra == 'examples'", specifier = ">=0.0.17" }, + { name = "uvicorn", marker = "extra == 'examples'", specifier = ">=0.32.0" }, ] [package.metadata.requires-dev] @@ -1085,6 +1105,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bf/fe/d44d391312c1b8abee2af58ee70fabb1c00b6577ac4e0bdf25b70c1caffb/pytest_pretty-1.2.0-py3-none-any.whl", hash = "sha256:6f79122bf53864ae2951b6c9e94d7a06a87ef753476acd4588aeac018f062036", size = 6180 }, ] +[[package]] +name = "python-multipart" +version = "0.0.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/40/22/edea41c2d4a22e666c0c7db7acdcbf7bc8c1c1f7d3b3ca246ec982fec612/python_multipart-0.0.17.tar.gz", hash = "sha256:41330d831cae6e2f22902704ead2826ea038d0419530eadff3ea80175aec5538", size = 36452 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/fb/275137a799169392f1fa88fff2be92f16eee38e982720a8aaadefc4a36b2/python_multipart-0.0.17-py3-none-any.whl", hash = "sha256:15dc4f487e0a9476cc1201261188ee0940165cffc94429b6fc565c4d3045cb5d", size = 24453 }, +] + [[package]] name = "requests" version = "2.32.3" @@ -1166,6 +1195,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, ] +[[package]] +name = "starlette" +version = "0.41.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/da/1fb4bdb72ae12b834becd7e1e7e47001d32f91ec0ce8d7bc1b618d9f0bd9/starlette-0.41.2.tar.gz", hash = "sha256:9834fd799d1a87fd346deb76158668cfa0b0d56f85caefe8268e2d97c3468b62", size = 2573867 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/43/f185bfd0ca1d213beb4293bed51d92254df23d8ceaf6c0e17146d508a776/starlette-0.41.2-py3-none-any.whl", hash = "sha256:fbc189474b4731cf30fcef52f18a8d070e3f3b46c6a04c97579e85e6ffca942d", size = 73259 }, +] + [[package]] name = "toml" version = "0.10.2" @@ -1223,6 +1265,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", size = 126338 }, ] +[[package]] +name = "uvicorn" +version = "0.32.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/fc/1d785078eefd6945f3e5bab5c076e4230698046231eb0f3747bc5c8fa992/uvicorn-0.32.0.tar.gz", hash = "sha256:f78b36b143c16f54ccdb8190d0a26b5f1901fe5a3c777e1ab29f26391af8551e", size = 77564 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/14/78bd0e95dd2444b6caacbca2b730671d4295ccb628ef58b81bee903629df/uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82", size = 63723 }, +] + [[package]] name = "wrapt" version = "1.16.0" From dd890f91757a8e740586a10747a0dce909b97564 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 3 Nov 2024 21:40:34 +0000 Subject: [PATCH 2/5] hyperlint suggestions Co-authored-by: hyperlint-ai[bot] <154288675+hyperlint-ai[bot]@users.noreply.github.com> --- pydantic_ai_examples/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_examples/README.md b/pydantic_ai_examples/README.md index 96335acfa9..f2ac0974a0 100644 --- a/pydantic_ai_examples/README.md +++ b/pydantic_ai_examples/README.md @@ -42,7 +42,7 @@ All these examples will need to set either: * `OPENAI_API_KEY` to use OpenAI models, go to [platform.openai.com](https://platform.openai.com/) and follow your nose until you find how to generate an API key * `GEMINI_API_KEY` to use Gemini/Google models, go to [aistudio.google.com](https://aistudio.google.com/) and do the same to generate an API key -Then set the API key as an env variable with: +Then set the API key as an environment variable with: ```bash export OPENAI_API_KEY=your-api-key @@ -77,7 +77,7 @@ PYDANTIC_AI_MODEL=gemini-1.5-pro (uv run/python) -m pydantic_ai_examples.pydanti Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. -The resulting SQL is validated by running it as an `EXPLAIN` query on postgres. To run the example, you first need to run postgres, e.g. via Docker: +The resulting SQL is validated by running it as an `EXPLAIN` query on PostgreSQL. To run the example, you first need to run PostgreSQL, e.g. via Docker: ```bash docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres From 10362a33176dafd3297a2a2d4f2bb7ab1cb6fcc6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 4 Nov 2024 12:08:52 +0000 Subject: [PATCH 3/5] adding new_messages methods --- .gitignore | 2 +- preview.README.md | 4 +-- pydantic_ai/agent.py | 8 +++-- pydantic_ai/shared.py | 23 +++++++++++--- pydantic_ai_examples/chat_app.html | 47 ++++++++++++++++++++--------- pydantic_ai_examples/chat_app.py | 44 +++++++++++++++------------ tests/models/test_gemini.py | 6 ++-- tests/models/test_model_function.py | 10 +++--- tests/models/test_model_test.py | 2 +- tests/models/test_openai.py | 4 +-- tests/test_agent.py | 18 +++++++---- 11 files changed, 107 insertions(+), 61 deletions(-) diff --git a/.gitignore b/.gitignore index 99f4125d24..1fdae50508 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ env*/ /TODO.md /postgres-data/ .DS_Store -/pydantic_ai_examples/.chat_app_messages.json +/pydantic_ai_examples/.chat_app_messages.jsonl diff --git a/preview.README.md b/preview.README.md index bc4d3f806f..3d27bf5760 100644 --- a/preview.README.md +++ b/preview.README.md @@ -60,7 +60,7 @@ result = weather_agent.run_sync('What is the weather like in West London and in print(result.response) # > 'The weather in West London is raining, while in Wiltshire it is sunny.' -# `result.message_history` details of messages exchanged, useful if you want to continue +# `result.all_messages` includes details of messages exchanged, useful if you want to continue # the conversation later, via the `message_history` argument of `run_sync`. -print(result.message_history) +print(result.all_messages) ``` diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 55850f7ebd..cb52475fb1 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -97,12 +97,16 @@ async def run( if deps is None: deps = self._default_deps - if message_history is not None: + # if message history includes system prompts, we don't want to regenerate them + if message_history and any(m.role == 'system' for m in message_history): # shallow copy messages messages = message_history.copy() else: messages = await self._init_messages(deps) + if message_history: + messages += message_history + new_message_index = len(messages) messages.append(_messages.UserPrompt(user_prompt)) result_tools = list(self._result_schema.tools.values()) if self._result_schema else None @@ -135,7 +139,7 @@ async def run( run_span.set_attribute('cost', cost) handle_span.set_attribute('result', left.value) handle_span.message = 'handle model response -> final result' - return shared.RunResult(left.value, messages, cost=cost) + return shared.RunResult(left.value, messages, new_message_index, cost) else: tool_responses = either.right handle_span.set_attribute('tool_responses', tool_responses) diff --git a/pydantic_ai/shared.py b/pydantic_ai/shared.py index f0ada51847..f3bb1bc61b 100644 --- a/pydantic_ai/shared.py +++ b/pydantic_ai/shared.py @@ -6,8 +6,9 @@ from pydantic import ValidationError +from . import messages + if TYPE_CHECKING: - from . import messages from .models import Model __all__ = ( @@ -57,12 +58,24 @@ class RunResult(Generic[ResultData]): """Result of a run.""" response: ResultData - message_history: list[messages.Message] + all_messages: list[messages.Message] + new_message_index: int cost: Cost - def message_history_json(self) -> str: - """Return the history of messages as a JSON string.""" - return messages.MessagesTypeAdapter.dump_json(self.message_history).decode() + def all_messages_json(self) -> bytes: + """Return the history of messages as JSON bytes.""" + return messages.MessagesTypeAdapter.dump_json(self.all_messages) + + def new_messages(self) -> list[messages.Message]: + """Return new messages associated with this run. + + System prompts and any messages from older runs are excluded. + """ + return self.all_messages[self.new_message_index :] + + def new_messages_json(self) -> bytes: + """Return new messages from [new_messages][] as JSON bytes.""" + return messages.MessagesTypeAdapter.dump_json(self.new_messages()) @dataclass diff --git a/pydantic_ai_examples/chat_app.html b/pydantic_ai_examples/chat_app.html index 7b9e53c817..544f75fb65 100644 --- a/pydantic_ai_examples/chat_app.html +++ b/pydantic_ai_examples/chat_app.html @@ -19,6 +19,7 @@ margin-top: 10px; border: 1px solid #ccc; border-radius: 4px; + font-size: 1rem; } button { margin: 10px 0 0 auto; @@ -66,7 +67,6 @@ #spinner.active { opacity: 1; } - @keyframes rotation { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } @@ -85,7 +85,7 @@

Chat App

-
+
@@ -98,11 +98,14 @@

Chat App

diff --git a/pydantic_ai_examples/chat_app.py b/pydantic_ai_examples/chat_app.py index 67046b09c4..ac772bcdbf 100644 --- a/pydantic_ai_examples/chat_app.py +++ b/pydantic_ai_examples/chat_app.py @@ -4,11 +4,11 @@ from typing import Annotated import fastapi -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, Response, StreamingResponse from pydantic import Field, TypeAdapter from pydantic_ai import Agent -from pydantic_ai.messages import Message, MessagesTypeAdapter +from pydantic_ai.messages import Message, MessagesTypeAdapter, UserPrompt agent = Agent('openai:gpt-4o', deps=None) @@ -21,23 +21,27 @@ async def index() -> HTMLResponse: @app.get('/chat/') -async def get_chat() -> fastapi.Response: - messages = list(database.get_messages()) - messages = MessagesTypeAdapter.dump_json(messages) - return fastapi.Response(content=messages, media_type='application/json') +async def get_chat() -> Response: + msgs = database.get_messages() + return Response(b'\n'.join(MessageTypeAdapter.dump_json(m) for m in msgs), media_type='text/plain') @app.post('/chat/') -async def post_chat(prompt: Annotated[str, fastapi.Form()]) -> fastapi.Response: - messages = list(database.get_messages()) - response = await agent.run(prompt, message_history=messages) - response_messages: list[Message] = [] - for message in response.message_history: - if message.role != 'system': - database.add_message(message) - response_messages.append(message) - messages = MessagesTypeAdapter.dump_json(response_messages) - return fastapi.Response(content=messages, media_type='application/json') +async def post_chat(prompt: Annotated[str, fastapi.Form()]) -> StreamingResponse: + async def stream_messages(): + """Streams new line delimited JSON `Message`s to the client.""" + # stream the user prompt so that can be displayed straight away + yield MessageTypeAdapter.dump_json(UserPrompt(content=prompt)) + b'\n' + # get the chat history so far to pass as context to the agent + messages = list(database.get_messages()) + response = await agent.run(prompt, message_history=messages) + # add new messages (e.g. the user prompt and the agent response in this case) to the database + database.add_messages(response.new_messages_json()) + # stream the last message which will be the agent response, we can't just yield `new_messages_json()` + # since we already stream the user prompt + yield MessageTypeAdapter.dump_json(response.all_messages[-1]) + b'\n' + + return StreamingResponse(stream_messages(), media_type='text/plain') THIS_DIR = Path(__file__).parent @@ -48,18 +52,18 @@ async def post_chat(prompt: Annotated[str, fastapi.Form()]) -> fastapi.Response: class Database: """Very rudimentary database to store chat messages in a JSON lines file.""" - file: Path = THIS_DIR / '.chat_app_messages.json' + file: Path = THIS_DIR / '.chat_app_messages.jsonl' - def add_message(self, message: Message): + def add_messages(self, messages: bytes): with self.file.open('ab') as f: - f.write(MessageTypeAdapter.dump_json(message) + b'\n') + f.write(messages + b'\n') def get_messages(self) -> Iterator[Message]: if self.file.exists(): with self.file.open('rb') as f: for line in f: if line: - yield MessageTypeAdapter.validate_json(line) + yield from MessagesTypeAdapter.validate_json(line) database = Database() diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 6d41ee4ae3..2ce8351619 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -360,7 +360,7 @@ async def test_request_simple_success(get_gemini_client: GetGeminiClient): result = await agent.run('Hello') assert result.response == 'Hello world' - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)), @@ -381,7 +381,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): result = await agent.run('Hello') assert result.response == [1, 2, 123] - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -424,7 +424,7 @@ async def get_location(loc_name: str) -> str: result = await agent.run('Hello') assert result.response == 'final response' - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ SystemPrompt(content='this is the system prompt'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 1e0aed15b4..028bec3ec8 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -37,7 +37,7 @@ def test_simple(): agent = Agent(FunctionModel(return_last), deps=None) result = agent.run_sync('Hello') assert result.response == snapshot("content='Hello' role='user' message_count=1") - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt( content='Hello', @@ -52,9 +52,9 @@ def test_simple(): ] ) - result2 = agent.run_sync('World', message_history=result.message_history) + result2 = agent.run_sync('World', message_history=result.all_messages) assert result2.response == snapshot("content='World' role='user' message_count=3") - assert result2.message_history == snapshot( + assert result2.all_messages == snapshot( [ UserPrompt( content='Hello', @@ -127,7 +127,7 @@ async def get_weather(_: CallContext[None], lat: int, lng: int): def test_weather(): result = weather_agent.run_sync('London') assert result.response == 'Raining in London' - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt( content='London', @@ -322,7 +322,7 @@ def f(messages: list[Message], info: AgentInfo) -> LLMMessage: def test_call_all(): result = agent_all.run_sync('Hello', model=TestModel()) assert result.response == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}') - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ SystemPrompt(content='foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 42fee6c34e..2736de6bd7 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -82,7 +82,7 @@ async def my_ret(x: int) -> str: result = agent.run_sync('Hello', model=TestModel()) assert call_count == 2 assert result.response == snapshot('{"my_ret":"1"}') - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index f0d162fa2a..e19c32dc15 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -113,7 +113,7 @@ async def test_request_structured_response(): result = await agent.run('Hello') assert result.response == [1, 2, 123] - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)), LLMToolCalls( @@ -185,7 +185,7 @@ async def get_location(loc_name: str) -> str: result = await agent.run('Hello') assert result.response == 'final response' - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ SystemPrompt(content='this is the system prompt'), UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)), diff --git a/tests/test_agent.py b/tests/test_agent.py index e2ef59c8cb..724dcb09dc 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -70,7 +70,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> LLMMessage: result = agent.run_sync('Hello') assert isinstance(result.response, Foo) assert result.response.model_dump() == {'a': 42, 'b': 'foo'} - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -95,6 +95,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> LLMMessage: ), ] ) + assert result.all_messages_json().startswith(b'[{"content":"Hello"') def test_result_validator(): @@ -119,7 +120,7 @@ def validate_result(ctx: CallContext[None], r: Foo) -> Foo: result = agent.run_sync('Hello') assert isinstance(result.response, Foo) assert result.response.model_dump() == {'a': 42, 'b': 'foo'} - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -152,7 +153,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage: result = agent.run_sync('Hello') assert result.response == ('foo', 'bar') assert call_index == 2 - assert result.message_history == snapshot( + assert result.all_messages == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='hello', timestamp=IsNow(tz=timezone.utc)), @@ -367,7 +368,7 @@ async def ret_a(x: str) -> str: assert result == snapshot( RunResult( response='{"ret_a":"a-apple"}', - message_history=[ + all_messages=[ SystemPrompt(content='Foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -377,15 +378,16 @@ async def ret_a(x: str) -> str: ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), ], + new_message_index=1, cost=Cost(), ) ) - result = agent.run_sync('Hello again', message_history=result.message_history) + result = agent.run_sync('Hello again', message_history=result.all_messages) assert result == snapshot( RunResult( response='{"ret_a":"a-apple"}', - message_history=[ + all_messages=[ SystemPrompt(content='Foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -403,6 +405,10 @@ async def ret_a(x: str) -> str: ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), ], + new_message_index=5, cost=Cost(), ) ) + new_msg_roles = [msg.role for msg in result.new_messages()] + assert new_msg_roles == snapshot(['user', 'llm-tool-calls', 'tool-return', 'llm-response']) + assert result.new_messages_json().startswith(b'[{"content":"Hello again",') From 3c0f4d70cb0e9476d0e5b307396119e88e1da437 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 4 Nov 2024 12:16:24 +0000 Subject: [PATCH 4/5] make all_messages() a method --- preview.README.md | 2 +- pydantic_ai/agent.py | 2 +- pydantic_ai/shared.py | 13 ++++--- pydantic_ai_examples/chat_app.py | 2 +- tests/models/test_gemini.py | 6 ++-- tests/models/test_model_function.py | 10 +++--- tests/models/test_model_test.py | 2 +- tests/models/test_openai.py | 4 +-- tests/test_agent.py | 55 +++++++++++++++++++++-------- 9 files changed, 63 insertions(+), 33 deletions(-) diff --git a/preview.README.md b/preview.README.md index 3d27bf5760..6f453ddbb5 100644 --- a/preview.README.md +++ b/preview.README.md @@ -62,5 +62,5 @@ print(result.response) # `result.all_messages` includes details of messages exchanged, useful if you want to continue # the conversation later, via the `message_history` argument of `run_sync`. -print(result.all_messages) +print(result.all_messages()) ``` diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index cb52475fb1..eaffc780fb 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -139,7 +139,7 @@ async def run( run_span.set_attribute('cost', cost) handle_span.set_attribute('result', left.value) handle_span.message = 'handle model response -> final result' - return shared.RunResult(left.value, messages, new_message_index, cost) + return shared.RunResult(left.value, cost, messages, new_message_index) else: tool_responses = either.right handle_span.set_attribute('tool_responses', tool_responses) diff --git a/pydantic_ai/shared.py b/pydantic_ai/shared.py index f3bb1bc61b..73aa9c7ee3 100644 --- a/pydantic_ai/shared.py +++ b/pydantic_ai/shared.py @@ -58,20 +58,25 @@ class RunResult(Generic[ResultData]): """Result of a run.""" response: ResultData - all_messages: list[messages.Message] - new_message_index: int cost: Cost + _all_messages: list[messages.Message] + _new_message_index: int + + def all_messages(self) -> list[messages.Message]: + """Return the history of messages.""" + # this is a method to be consistent with the other methods + return self._all_messages def all_messages_json(self) -> bytes: """Return the history of messages as JSON bytes.""" - return messages.MessagesTypeAdapter.dump_json(self.all_messages) + return messages.MessagesTypeAdapter.dump_json(self.all_messages()) def new_messages(self) -> list[messages.Message]: """Return new messages associated with this run. System prompts and any messages from older runs are excluded. """ - return self.all_messages[self.new_message_index :] + return self.all_messages()[self._new_message_index :] def new_messages_json(self) -> bytes: """Return new messages from [new_messages][] as JSON bytes.""" diff --git a/pydantic_ai_examples/chat_app.py b/pydantic_ai_examples/chat_app.py index ac772bcdbf..b8c69d9551 100644 --- a/pydantic_ai_examples/chat_app.py +++ b/pydantic_ai_examples/chat_app.py @@ -39,7 +39,7 @@ async def stream_messages(): database.add_messages(response.new_messages_json()) # stream the last message which will be the agent response, we can't just yield `new_messages_json()` # since we already stream the user prompt - yield MessageTypeAdapter.dump_json(response.all_messages[-1]) + b'\n' + yield MessageTypeAdapter.dump_json(response.all_messages()[-1]) + b'\n' return StreamingResponse(stream_messages(), media_type='text/plain') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 2ce8351619..53de14d921 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -360,7 +360,7 @@ async def test_request_simple_success(get_gemini_client: GetGeminiClient): result = await agent.run('Hello') assert result.response == 'Hello world' - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)), @@ -381,7 +381,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): result = await agent.run('Hello') assert result.response == [1, 2, 123] - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -424,7 +424,7 @@ async def get_location(loc_name: str) -> str: result = await agent.run('Hello') assert result.response == 'final response' - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ SystemPrompt(content='this is the system prompt'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 028bec3ec8..36d141f785 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -37,7 +37,7 @@ def test_simple(): agent = Agent(FunctionModel(return_last), deps=None) result = agent.run_sync('Hello') assert result.response == snapshot("content='Hello' role='user' message_count=1") - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt( content='Hello', @@ -52,9 +52,9 @@ def test_simple(): ] ) - result2 = agent.run_sync('World', message_history=result.all_messages) + result2 = agent.run_sync('World', message_history=result.all_messages()) assert result2.response == snapshot("content='World' role='user' message_count=3") - assert result2.all_messages == snapshot( + assert result2.all_messages() == snapshot( [ UserPrompt( content='Hello', @@ -127,7 +127,7 @@ async def get_weather(_: CallContext[None], lat: int, lng: int): def test_weather(): result = weather_agent.run_sync('London') assert result.response == 'Raining in London' - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt( content='London', @@ -322,7 +322,7 @@ def f(messages: list[Message], info: AgentInfo) -> LLMMessage: def test_call_all(): result = agent_all.run_sync('Hello', model=TestModel()) assert result.response == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}') - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ SystemPrompt(content='foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 2736de6bd7..3150d8ad7c 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -82,7 +82,7 @@ async def my_ret(x: int) -> str: result = agent.run_sync('Hello', model=TestModel()) assert call_count == 2 assert result.response == snapshot('{"my_ret":"1"}') - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index e19c32dc15..60e9b9a92d 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -113,7 +113,7 @@ async def test_request_structured_response(): result = await agent.run('Hello') assert result.response == [1, 2, 123] - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)), LLMToolCalls( @@ -185,7 +185,7 @@ async def get_location(loc_name: str) -> str: result = await agent.run('Hello') assert result.response == 'final response' - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ SystemPrompt(content='this is the system prompt'), UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)), diff --git a/tests/test_agent.py b/tests/test_agent.py index 724dcb09dc..b225f2493f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -70,7 +70,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> LLMMessage: result = agent.run_sync('Hello') assert isinstance(result.response, Foo) assert result.response.model_dump() == {'a': 42, 'b': 'foo'} - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -120,7 +120,7 @@ def validate_result(ctx: CallContext[None], r: Foo) -> Foo: result = agent.run_sync('Hello') assert isinstance(result.response, Foo) assert result.response.model_dump() == {'a': 42, 'b': 'foo'} - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -153,7 +153,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage: result = agent.run_sync('Hello') assert result.response == ('foo', 'bar') assert call_index == 2 - assert result.all_messages == snapshot( + assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='hello', timestamp=IsNow(tz=timezone.utc)), @@ -355,7 +355,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any: assert got_tool_call_name == snapshot('final_result_Bar') -def test_run_with_history(): +def test_run_with_history_new(): m = TestModel() agent = Agent(m, deps=None, system_prompt='Foobar') @@ -364,11 +364,25 @@ def test_run_with_history(): async def ret_a(x: str) -> str: return f'{x}-apple' - result = agent.run_sync('Hello') - assert result == snapshot( + result1 = agent.run_sync('Hello') + assert result1.new_messages() == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), + LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), + ] + ) + + # if we pass new_messages, system prompt is inserted before the message_history messages + result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) + assert result2 == snapshot( RunResult( response='{"ret_a":"a-apple"}', - all_messages=[ + _all_messages=[ SystemPrompt(content='Foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -377,17 +391,31 @@ async def ret_a(x: str) -> str: ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), + # second call, notice no repeated system prompt + UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), + LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), ], - new_message_index=1, + _new_message_index=5, cost=Cost(), ) ) + new_msg_roles = [msg.role for msg in result2.new_messages()] + assert new_msg_roles == snapshot(['user', 'llm-tool-calls', 'tool-return', 'llm-response']) + assert result2.new_messages_json().startswith(b'[{"content":"Hello again",') - result = agent.run_sync('Hello again', message_history=result.all_messages) - assert result == snapshot( + # if we pass all_messages, system prompt is NOT inserted before the message_history messages, + # so only one system prompt + result3 = agent.run_sync('Hello again', message_history=result1.all_messages()) + # same as result2 except for datetimes + assert result3 == snapshot( RunResult( response='{"ret_a":"a-apple"}', - all_messages=[ + _all_messages=[ SystemPrompt(content='Foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), LLMToolCalls( @@ -405,10 +433,7 @@ async def ret_a(x: str) -> str: ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), ], - new_message_index=5, + _new_message_index=5, cost=Cost(), ) ) - new_msg_roles = [msg.role for msg in result.new_messages()] - assert new_msg_roles == snapshot(['user', 'llm-tool-calls', 'tool-return', 'llm-response']) - assert result.new_messages_json().startswith(b'[{"content":"Hello again",') From b539e6205a1f51bc0ac96f3f41d453565b3fa3fb Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 4 Nov 2024 12:49:34 +0000 Subject: [PATCH 5/5] fix Gemini, add docs --- pydantic_ai/models/__init__.py | 2 +- pydantic_ai/models/gemini.py | 9 ++-- pydantic_ai_examples/README.md | 20 +++++++++ pydantic_ai_examples/chat_app.html | 70 ++++++++---------------------- pydantic_ai_examples/chat_app.py | 7 +++ 5 files changed, 52 insertions(+), 56 deletions(-) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 81ae41033c..212c3f4648 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -99,4 +99,4 @@ def cached_async_http_client() -> AsyncHTTPClient: described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing examples, it's very useful not to, this allows multiple Agents to use a single client. """ - return AsyncHTTPClient() + return AsyncHTTPClient(timeout=30) diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index 638db390e7..5ae51499a6 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -346,14 +346,17 @@ class _GeminiResponse: @dataclass class _GeminiCandidates: + """See .""" + content: _GeminiContent finish_reason: Annotated[Literal['STOP'], Field(alias='finishReason')] """ - See https://ai.google.dev/api/generate-content#FinishReason, lots of other values are possible, + See , lots of other values are possible, but let's wait until we see them and know what they mean to add them here. """ - index: int - safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')] + avg_log_probs: Annotated[float, Field(alias='avgLogProbs')] | None = None + index: int | None = None + safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')] | None = None @dataclass diff --git a/pydantic_ai_examples/README.md b/pydantic_ai_examples/README.md index f2ac0974a0..01ed2efc63 100644 --- a/pydantic_ai_examples/README.md +++ b/pydantic_ai_examples/README.md @@ -156,3 +156,23 @@ You can then ask the agent a question with: ```bash (uv run/python) -m pydantic_ai_examples.rag search "How do I configure logfire to work with FastAPI?" ``` + +### `chat_app.py` + +(Demonstrates: reusing chat history, serializing messages) + +**TODO**: stream responses + +Simple chat app example build with FastAPI. + +This demonstrates storing chat history between requests and using it to give the model context for new responses. + +Most of the complex logic here is in `chat_app.html` which includes the page layout and JavaScript to handle the chat. + +Run the app with: + +```bash +(uv run/python) -m pydantic_ai_examples.chat_app +``` + +Then open the app at [localhost:8000](http://localhost:8000). diff --git a/pydantic_ai_examples/chat_app.html b/pydantic_ai_examples/chat_app.html index 544f75fb65..25d8cf056f 100644 --- a/pydantic_ai_examples/chat_app.html +++ b/pydantic_ai_examples/chat_app.html @@ -4,39 +4,10 @@ Chat App + -
+

Chat App

Ask me anything...

-
-
+
+
- - + +
+ +
-
+
Error occurred, check the console for more information.
@@ -103,7 +66,7 @@

Chat App

const parent = document.getElementById('conversation'); for (const message of messages) { let msgDiv = document.createElement('div'); - msgDiv.classList.add(message.role); + msgDiv.classList.add('border-top', 'pt-2', message.role); msgDiv.innerHTML = marked.parse(message.content); parent.appendChild(msgDiv); } @@ -111,7 +74,8 @@

Chat App

function onError(error) { console.error(error); - document.getElementById('error').style.display = 'block'; + document.getElementById('error').classList.remove('d-none'); + document.getElementById('spinner').classList.remove('active'); } async function fetchResponse(response) { @@ -145,16 +109,18 @@

Chat App

e.preventDefault(); const spinner = document.getElementById('spinner'); spinner.classList.add('active'); + const body = new FormData(e.target); let input = document.getElementById('prompt-input') input.value = ''; input.disabled = true; - const response = await fetch('/chat/', {method: 'POST', body: new FormData(e.target)}); + const response = await fetch('/chat/', {method: 'POST', body}); await fetchResponse(response); spinner.classList.remove('active'); } + // call onSubmit when form is submitted (e.g. user clicks the send button or hits Enter) document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError)); // load messages on page load diff --git a/pydantic_ai_examples/chat_app.py b/pydantic_ai_examples/chat_app.py index b8c69d9551..9e2dfb107c 100644 --- a/pydantic_ai_examples/chat_app.py +++ b/pydantic_ai_examples/chat_app.py @@ -1,3 +1,10 @@ +"""Simple chat app example build with FastAPI. + +Run with: + + uv run -m pydantic_ai_examples.chat_app +""" + from collections.abc import Iterator from dataclasses import dataclass from pathlib import Path