Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ env*/
/TODO.md
/postgres-data/
.DS_Store
/pydantic_ai_examples/.chat_app_messages.jsonl
4 changes: 2 additions & 2 deletions preview.README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ result = weather_agent.run_sync('What is the weather like in West London and in
print(result.response)
# > 'The weather in West London is raining, while in Wiltshire it is sunny.'

# `result.message_history` details of messages exchanged, useful if you want to continue
# `result.all_messages` includes details of messages exchanged, useful if you want to continue
# the conversation later, via the `message_history` argument of `run_sync`.
print(result.message_history)
print(result.all_messages())
```
8 changes: 6 additions & 2 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,16 @@ async def run(
if deps is None:
deps = self._default_deps

if message_history is not None:
# if message history includes system prompts, we don't want to regenerate them
if message_history and any(m.role == 'system' for m in message_history):
# shallow copy messages
messages = message_history.copy()
else:
messages = await self._init_messages(deps)
if message_history:
messages += message_history

new_message_index = len(messages)
messages.append(_messages.UserPrompt(user_prompt))

result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
Expand Down Expand Up @@ -135,7 +139,7 @@ async def run(
run_span.set_attribute('cost', cost)
handle_span.set_attribute('result', left.value)
handle_span.message = 'handle model response -> final result'
return shared.RunResult(left.value, messages, cost=cost)
return shared.RunResult(left.value, cost, messages, new_message_index)
else:
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ def cached_async_http_client() -> AsyncHTTPClient:
described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing
examples, it's very useful not to, this allows multiple Agents to use a single client.
"""
return AsyncHTTPClient()
return AsyncHTTPClient(timeout=30)
9 changes: 6 additions & 3 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,17 @@ class _GeminiResponse:

@dataclass
class _GeminiCandidates:
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""

content: _GeminiContent
finish_reason: Annotated[Literal['STOP'], Field(alias='finishReason')]
"""
See https://ai.google.dev/api/generate-content#FinishReason, lots of other values are possible,
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
but let's wait until we see them and know what they mean to add them here.
"""
index: int
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
avg_log_probs: Annotated[float, Field(alias='avgLogProbs')] | None = None
index: int | None = None
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')] | None = None


@dataclass
Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Literal

from httpx import AsyncClient as AsyncHTTPClient
from openai import AsyncOpenAI
from openai import NOT_GIVEN, AsyncOpenAI
from openai.types import ChatModel, chat
from typing_extensions import assert_never

Expand Down Expand Up @@ -104,7 +104,7 @@ def process_response(response: chat.ChatCompletion) -> LLMMessage:
async def completions_create(self, messages: list[Message]) -> chat.ChatCompletion:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] = 'none'
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
Expand All @@ -115,9 +115,9 @@ async def completions_create(self, messages: list[Message]) -> chat.ChatCompleti
model=self.model_name,
messages=openai_messages,
n=1,
parallel_tool_calls=True,
tools=self.tools,
tool_choice=tool_choice,
parallel_tool_calls=True if self.tools else NOT_GIVEN,
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
)

@staticmethod
Expand Down
28 changes: 23 additions & 5 deletions pydantic_ai/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from pydantic import ValidationError

from . import messages

if TYPE_CHECKING:
from . import messages
from .models import Model

__all__ = (
Expand Down Expand Up @@ -57,12 +58,29 @@ class RunResult(Generic[ResultData]):
"""Result of a run."""

response: ResultData
message_history: list[messages.Message]
cost: Cost
_all_messages: list[messages.Message]
_new_message_index: int

def all_messages(self) -> list[messages.Message]:
"""Return the history of messages."""
# this is a method to be consistent with the other methods
return self._all_messages

def all_messages_json(self) -> bytes:
"""Return the history of messages as JSON bytes."""
return messages.MessagesTypeAdapter.dump_json(self.all_messages())

def new_messages(self) -> list[messages.Message]:
"""Return new messages associated with this run.

System prompts and any messages from older runs are excluded.
"""
return self.all_messages()[self._new_message_index :]

def message_history_json(self) -> str:
"""Return the history of messages as a JSON string."""
return messages.MessagesTypeAdapter.dump_json(self.message_history).decode()
def new_messages_json(self) -> bytes:
"""Return new messages from [new_messages][] as JSON bytes."""
return messages.MessagesTypeAdapter.dump_json(self.new_messages())


@dataclass
Expand Down
46 changes: 45 additions & 1 deletion pydantic_ai_examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ But you'll probably want to edit examples as well as just running them, so you c
python -m pydantic_ai_examples --copy-to examples/
```

### Setting model environment variables

All these examples will need to set either:

* `OPENAI_API_KEY` to use OpenAI models, go to [platform.openai.com](https://platform.openai.com/) and follow your nose until you find how to generate an API key
* `GEMINI_API_KEY` to use Gemini/Google models, go to [aistudio.google.com](https://aistudio.google.com/) and do the same to generate an API key

Then set the API key as an environment variable with:

```bash
export OPENAI_API_KEY=your-api-key
# or
export GEMINI_API_KEY=your-api-key
```

## Examples

### `pydantic_model.py`
Expand Down Expand Up @@ -62,6 +77,15 @@ PYDANTIC_AI_MODEL=gemini-1.5-pro (uv run/python) -m pydantic_ai_examples.pydanti

Example demonstrating how to use Pydantic AI to generate SQL queries based on user input.

The resulting SQL is validated by running it as an `EXPLAIN` query on PostgreSQL. To run the example, you first need to run PostgreSQL, e.g. via Docker:

```bash
docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres
```
_(we run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running)_

Then to run the code

```bash
(uv run/python) -m pydantic_ai_examples.sql_gen
```
Expand Down Expand Up @@ -116,7 +140,7 @@ mkdir postgres-data
docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 -v `pwd`/postgres-data:/var/lib/postgresql/data pgvector/pgvector:pg17
```

We run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running.
As above, we run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running.
We also mount the postgresql `data` directory locally to persist the data if you need to stop and restart the container.

With that running, we can build the search database with (**WARNING**: this requires the `OPENAI_API_KEY` env variable and will calling the OpenAI embedding API around 300 times to generate embeddings for each section of the documentation):
Expand All @@ -132,3 +156,23 @@ You can then ask the agent a question with:
```bash
(uv run/python) -m pydantic_ai_examples.rag search "How do I configure logfire to work with FastAPI?"
```

### `chat_app.py`

(Demonstrates: reusing chat history, serializing messages)

**TODO**: stream responses

Simple chat app example build with FastAPI.

This demonstrates storing chat history between requests and using it to give the model context for new responses.

Most of the complex logic here is in `chat_app.html` which includes the page layout and JavaScript to handle the chat.

Run the app with:

```bash
(uv run/python) -m pydantic_ai_examples.chat_app
```

Then open the app at [localhost:8000](http://localhost:8000).
5 changes: 4 additions & 1 deletion pydantic_ai_examples/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

See README.md for more information.
"""

import argparse
import sys
from pathlib import Path
Expand All @@ -10,7 +11,9 @@
def cli():
this_dir = Path(__file__).parent

parser = argparse.ArgumentParser(prog='pydantic_ai_examples', description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
parser = argparse.ArgumentParser(
prog='pydantic_ai_examples', description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument('-v', '--version', action='store_true', help='show the version and exit')
parser.add_argument('--copy-to', dest='DEST', help='Copy all examples to a new directory')

Expand Down
128 changes: 128 additions & 0 deletions pydantic_ai_examples/chat_app.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Chat App</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/css/bootstrap.min.css" rel="stylesheet">
<style>
main {
max-width: 600px;
}
#conversation .user::before {
content: 'You asked: ';
font-weight: bold;
display: block;
}
#conversation .llm-response::before {
content: 'AI Response: ';
font-weight: bold;
display: block;
}
#spinner {
opacity: 0;
transition: opacity 500ms ease-in;
width: 30px;
height: 30px;
border: 3px solid #222;
border-bottom-color: transparent;
border-radius: 50%;
animation: rotation 1s linear infinite;
}
@keyframes rotation {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
#spinner.active {
opacity: 1;
}
</style>
</head>
<body>
<main class="border rounded mx-auto my-5 p-4">
<h1>Chat App</h1>
<p>Ask me anything...</p>
<div id="conversation" class="px-2"></div>
<div class="d-flex justify-content-center mb-3">
<div id="spinner"></div>
</div>
<form method="post">
<input id="prompt-input" name="prompt" class="form-control"/>
<div class="d-flex justify-content-end">
<button class="btn btn-primary mt-2">Send</button>
</div>
</form>
<div id="error" class="d-none text-danger">
Error occurred, check the console for more information.
</div>
</main>
</body>
</html>
<script type="module">
import { marked } from 'https://cdn.jsdelivr.net/npm/marked/lib/marked.esm.js';

function addMessages(lines) {
const messages = lines.filter(line => line.length > 1).map((line) => JSON.parse(line))
const parent = document.getElementById('conversation');
for (const message of messages) {
let msgDiv = document.createElement('div');
msgDiv.classList.add('border-top', 'pt-2', message.role);
msgDiv.innerHTML = marked.parse(message.content);
parent.appendChild(msgDiv);
}
}

function onError(error) {
console.error(error);
document.getElementById('error').classList.remove('d-none');
document.getElementById('spinner').classList.remove('active');
}

async function fetchResponse(response) {
let text = '';
if (response.ok) {
const reader = response.body.getReader();
while (true) {
const {done, value} = await reader.read();
if (done) {
break;
}
text += new TextDecoder().decode(value);
const lines = text.split('\n');
if (lines.length > 1) {
addMessages(lines.slice(0, -1));
text = lines[lines.length - 1];
}
}
addMessages(text.split('\n'));
let input = document.getElementById('prompt-input')
input.disabled = false;
input.focus();
} else {
const text = await response.text();
console.error(`Unexpected response: ${response.status}`, {response, text});
throw new Error(`Unexpected response: ${response.status}`);
}
}

async function onSubmit(e) {
e.preventDefault();
const spinner = document.getElementById('spinner');
spinner.classList.add('active');
const body = new FormData(e.target);

let input = document.getElementById('prompt-input')
input.value = '';
input.disabled = true;

const response = await fetch('/chat/', {method: 'POST', body});
await fetchResponse(response);
spinner.classList.remove('active');
}

// call onSubmit when form is submitted (e.g. user clicks the send button or hits Enter)
document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError));

// load messages on page load
fetch('/chat/').then(fetchResponse).catch(onError);
</script>
Loading
Loading