Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export OPENAI_API_KEY='your-api-key-here'
Then simply run:

```bash
$ pai
pai
```

This will start an interactive session where you can chat with the AI model. Special commands available in interactive mode:
Expand All @@ -52,3 +52,11 @@ You can specify which model to use with the `--model` flag:
```bash
$ pai --model=openai:gpt-4 "What's the capital of France?"
```

### Usage with `uvx`

If you have [uv](https://docs.astral.sh/uv/) installed, the quickest way to run the CLI is with `uvx`:

```bash
uvx --from pydantic-ai pai
```
79 changes: 47 additions & 32 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,21 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma:
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument('prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode')
parser.add_argument(
arg = parser.add_argument(
'--model',
nargs='?',
help='Model to use, it should be "<provider>:<model>" e.g. "openai:gpt-4o". If omitted it will default to "openai:gpt-4o"',
default='openai:gpt-4o',
).completer = argcomplete.ChoicesCompleter(list(get_literal_values(KnownModelName))) # type: ignore[reportPrivateUsage]
)
# we don't want to autocomplete or list models that don't include the provider,
# e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
qualified_model_names = [n for n in get_literal_values(KnownModelName) if ':' in n]
arg.completer = argcomplete.ChoicesCompleter(qualified_model_names) # type: ignore[reportPrivateUsage]
parser.add_argument(
'--list-models',
action='store_true',
help='List all available models and exit',
)
parser.add_argument('--no-stream', action='store_true', help='Whether to stream responses from OpenAI')
parser.add_argument('--version', action='store_true', help='Show version and exit')

Expand All @@ -81,6 +90,11 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma:
console.print(f'pai - PydanticAI CLI v{__version__}', style='green bold', highlight=False)
if args.version:
return 0
if args.list_models:
console.print('Available models:', style='green bold')
for model in qualified_model_names:
console.print(f' {model}', highlight=False)
return 0

now_utc = datetime.now(timezone.utc)
tzname = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore
Expand Down Expand Up @@ -121,37 +135,38 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma:
continue

ident_prompt = text.lower().strip(' ').replace(' ', '-').lstrip(' ')
if ident_prompt == '/markdown':
try:
parts = messages[-1].parts
except IndexError:
console.print('[dim]No markdown output available.[/dim]')
continue
for part in parts:
if part.part_kind == 'text':
last_content = part.content
console.print('[dim]Last markdown output of last question:[/dim]\n')
console.print(Syntax(last_content, lexer='markdown', background_color='default'))

continue
if ident_prompt == '/multiline':
multiline = not multiline
if multiline:
console.print(
'Enabling multiline mode. '
'[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
)
if ident_prompt.startswith('/'):
if ident_prompt == '/markdown':
try:
parts = messages[-1].parts
except IndexError:
console.print('[dim]No markdown output available.[/dim]')
continue
for part in parts:
if part.part_kind == 'text':
last_content = part.content
console.print('[dim]Last markdown output of last question:[/dim]\n')
console.print(Syntax(last_content, lexer='markdown', background_color='default'))

elif ident_prompt == '/multiline':
multiline = not multiline
if multiline:
console.print(
'Enabling multiline mode. '
'[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
)
else:
console.print('Disabling multiline mode.')
elif ident_prompt == '/exit':
console.print('[dim]Exiting…[/dim]')
return 0
else:
console.print('Disabling multiline mode.')
continue
if ident_prompt == '/exit':
console.print('[dim]Exiting…[/dim]')
return 0

try:
messages = asyncio.run(ask_agent(agent, text, stream, console, messages))
except KeyboardInterrupt:
return 0
console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]')
else:
try:
messages = asyncio.run(ask_agent(agent, text, stream, console, messages))
except KeyboardInterrupt:
return 0


async def ask_agent(
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ Changelog = "https://github.com/pydantic/pydantic-ai/releases"
examples = ["pydantic-ai-examples==0.0.46"]
logfire = ["logfire>=3.11.0"]

[project.scripts]
pai = "pydantic_ai._cli:app"

[tool.uv.sources]
pydantic-ai-slim = { workspace = true }
pydantic-graph = { workspace = true }
Expand Down
25 changes: 24 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_cli_help(capfd: CaptureFixture[str]):

assert capfd.readouterr().out.splitlines() == snapshot(
[
'usage: pai [-h] [--model [MODEL]] [--no-stream] [--version] [prompt]',
'usage: pai [-h] [--model [MODEL]] [--list-models] [--no-stream] [--version] [prompt]',
'',
IsStr(),
'',
Expand All @@ -38,6 +38,7 @@ def test_cli_help(capfd: CaptureFixture[str]):
IsStr(),
' -h, --help show this help message and exit',
' --model [MODEL] Model to use, it should be "<provider>:<model>" e.g. "openai:gpt-4o". If omitted it will default to "openai:gpt-4o"',
' --list-models List all available models and exit',
' --no-stream Whether to stream responses from OpenAI',
' --version Show version and exit',
]
Expand All @@ -47,3 +48,25 @@ def test_cli_help(capfd: CaptureFixture[str]):
def test_invalid_model(capfd: CaptureFixture[str]):
assert cli(['--model', 'invalid_model']) == 1
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), 'Invalid model "invalid_model"'])


def test_list_models(capfd: CaptureFixture[str]):
assert cli(['--list-models']) == 0
output = capfd.readouterr().out.splitlines()
assert output[:2] == snapshot(['pai - PydanticAI CLI v0.0.46', 'Available models:'])

providers = (
'openai',
'anthropic',
'bedrock',
'google-vertex',
'google-gla',
'groq',
'mistral',
'cohere',
'deepseek',
)
models = {line.strip().split(' ')[0] for line in output[2:]}
for provider in providers:
models = models - {model for model in models if model.startswith(provider)}
assert models == set(), models