diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index 7c67b978c5..104914dc99 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -3,19 +3,20 @@ import argparse import asyncio import sys +from asyncio import CancelledError from collections.abc import Sequence from contextlib import ExitStack from datetime import datetime, timezone from importlib.metadata import version from pathlib import Path -from typing import cast +from typing import Any, cast from typing_inspection.introspection import get_literal_values from pydantic_ai.agent import Agent from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta -from pydantic_ai.models import KnownModelName +from pydantic_ai.models import KnownModelName, infer_model try: import argcomplete @@ -47,7 +48,7 @@ class SimpleCodeBlock(CodeBlock): This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix. """ - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover + def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: code = str(self.text).rstrip() yield Text(self.lexer_name, style='dim') yield Syntax(code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True) @@ -57,7 +58,7 @@ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderR class LeftHeading(Heading): """Customised headings in markdown to stop centering and prepend markdown style hashes.""" - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover + def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO yield Text(f'{"#" * int(self.tag[1:])} {self.text.plain}', style=Style(bold=True)) @@ -68,7 +69,21 @@ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderR ) -def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma: no cover +cli_agent = Agent() + + +@cli_agent.system_prompt +def cli_system_prompt() -> str: + now_utc = datetime.now(timezone.utc) + tzinfo = now_utc.astimezone().tzinfo + tzname = tzinfo.tzname(now_utc) if tzinfo else '' + return f"""\ +Help the user by responding to their request, the output should be concise and always written in markdown. +The current date and time is {datetime.now()} {tzname}. +The user is running {sys.platform}.""" + + +def cli(args_list: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser( prog='pai', description=f"""\ @@ -124,18 +139,10 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma: console.print(f' {model}', highlight=False) return 0 - now_utc = datetime.now(timezone.utc) - tzname = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore try: - agent = Agent( - model=args.model, - system_prompt=f"""\ - Help the user by responding to their request, the output should be concise and always written in markdown. - The current date and time is {datetime.now()} {tzname}. - The user is running {sys.platform}.""", - ) - except UserError: - console.print(f'[red]Invalid model "{args.model}"[/red]') + cli_agent.model = infer_model(args.model) + except UserError as e: + console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]') return 1 stream = not args.no_stream @@ -148,67 +155,44 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma: if prompt := cast(str, args.prompt): try: - asyncio.run(ask_agent(agent, prompt, stream, console, code_theme)) + asyncio.run(ask_agent(cli_agent, prompt, stream, console, code_theme)) except KeyboardInterrupt: pass return 0 history = Path.home() / '.pai-prompt-history.txt' - session = PromptSession(history=FileHistory(str(history))) # type: ignore + # doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests + session: PromptSession[Any] = PromptSession(history=FileHistory(str(history))) + try: + return asyncio.run(run_chat(session, stream, cli_agent, console, code_theme)) + except KeyboardInterrupt: # pragma: no cover + return 0 + + +async def run_chat(session: PromptSession[Any], stream: bool, agent: Agent, console: Console, code_theme: str) -> int: multiline = False messages: list[ModelMessage] = [] while True: try: auto_suggest = CustomAutoSuggest(['/markdown', '/multiline', '/exit']) - text = cast(str, session.prompt('pai ➤ ', auto_suggest=auto_suggest, multiline=multiline)) - except (KeyboardInterrupt, EOFError): + text = await session.prompt_async('pai ➤ ', auto_suggest=auto_suggest, multiline=multiline) + except (KeyboardInterrupt, EOFError): # pragma: no cover return 0 if not text.strip(): continue - ident_prompt = text.lower().strip(' ').replace(' ', '-').lstrip(' ') + ident_prompt = text.lower().strip().replace(' ', '-') if ident_prompt.startswith('/'): - if ident_prompt == '/markdown': - try: - parts = messages[-1].parts - except IndexError: - console.print('[dim]No markdown output available.[/dim]') - continue - console.print('[dim]Markdown output of last question:[/dim]\n') - for part in parts: - if part.part_kind == 'text': - console.print( - Syntax( - part.content, - lexer='markdown', - theme=code_theme, - word_wrap=True, - background_color='default', - ) - ) - - elif ident_prompt == '/multiline': - multiline = not multiline - if multiline: - console.print( - 'Enabling multiline mode. ' - '[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]' - ) - else: - console.print('Disabling multiline mode.') - elif ident_prompt == '/exit': - console.print('[dim]Exiting…[/dim]') - return 0 - else: - console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]') + exit_value, multiline = handle_slash_command(ident_prompt, messages, multiline, console, code_theme) + if exit_value is not None: + return exit_value else: try: - messages = asyncio.run(ask_agent(agent, text, stream, console, code_theme, messages)) - except KeyboardInterrupt: + messages = await ask_agent(agent, text, stream, console, code_theme, messages) + except CancelledError: # pragma: no cover console.print('[dim]Interrupted[/dim]') - messages = [] async def ask_agent( @@ -218,7 +202,7 @@ async def ask_agent( console: Console, code_theme: str, messages: list[ModelMessage] | None = None, -) -> list[ModelMessage]: # pragma: no cover +) -> list[ModelMessage]: status = Status('[dim]Working on it…[/dim]', console=console) if not stream: @@ -248,7 +232,7 @@ async def ask_agent( class CustomAutoSuggest(AutoSuggestFromHistory): - def __init__(self, special_suggestions: list[str] | None = None): # pragma: no cover + def __init__(self, special_suggestions: list[str] | None = None): super().__init__() self.special_suggestions = special_suggestions or [] @@ -264,5 +248,44 @@ def get_suggestion(self, buffer: Buffer, document: Document) -> Suggestion | Non return suggestion +def handle_slash_command( + ident_prompt: str, messages: list[ModelMessage], multiline: bool, console: Console, code_theme: str +) -> tuple[int | None, bool]: + if ident_prompt == '/markdown': + try: + parts = messages[-1].parts + except IndexError: + console.print('[dim]No markdown output available.[/dim]') + else: + console.print('[dim]Markdown output of last question:[/dim]\n') + for part in parts: + if part.part_kind == 'text': + console.print( + Syntax( + part.content, + lexer='markdown', + theme=code_theme, + word_wrap=True, + background_color='default', + ) + ) + + elif ident_prompt == '/multiline': + multiline = not multiline + if multiline: + console.print( + 'Enabling multiline mode. [dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]' + ) + else: + console.print('Disabling multiline mode.') + return None, multiline + elif ident_prompt == '/exit': + console.print('[dim]Exiting…[/dim]') + return 0, multiline + else: + console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]') + return None, multiline + + def app(): # pragma: no cover sys.exit(cli()) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 529b234539..e351a3d173 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -106,6 +106,7 @@ 'google-gla:gemini-2.0-flash', 'google-gla:gemini-2.0-flash-lite-preview-02-05', 'google-gla:gemini-2.0-pro-exp-02-05', + 'google-gla:gemini-2.5-pro-exp-03-25', 'google-vertex:gemini-1.0-pro', 'google-vertex:gemini-1.5-flash', 'google-vertex:gemini-1.5-flash-8b', @@ -116,6 +117,7 @@ 'google-vertex:gemini-2.0-flash', 'google-vertex:gemini-2.0-flash-lite-preview-02-05', 'google-vertex:gemini-2.0-pro-exp-02-05', + 'google-vertex:gemini-2.5-pro-exp-03-25', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo-0301', diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 1176536c5d..8e7662a6c9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -57,6 +57,7 @@ 'gemini-2.0-flash', 'gemini-2.0-flash-lite-preview-02-05', 'gemini-2.0-pro-exp-02-05', + 'gemini-2.5-pro-exp-03-25', ] """Latest Gemini models.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 12bf4072cd..663e1b0be9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,14 +1,26 @@ import sys +from io import StringIO +from typing import Any import pytest -from dirty_equals import IsStr +from dirty_equals import IsInstance, IsStr from inline_snapshot import snapshot from pytest import CaptureFixture +from pytest_mock import MockerFixture +from rich.console import Console -from .conftest import try_import +from pydantic_ai import Agent +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart +from pydantic_ai.models.test import TestModel + +from .conftest import TestEnv, try_import with try_import() as imports_successful: - from pydantic_ai._cli import cli + from prompt_toolkit.input import create_pipe_input + from prompt_toolkit.output import DummyOutput + from prompt_toolkit.shortcuts import PromptSession + + from pydantic_ai._cli import cli, cli_agent, handle_slash_command pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests') @@ -52,8 +64,10 @@ def test_cli_help(capfd: CaptureFixture[str]): def test_invalid_model(capfd: CaptureFixture[str]): - assert cli(['--model', 'invalid_model']) == 1 - assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), 'Invalid model "invalid_model"']) + assert cli(['--model', 'potato']) == 1 + assert capfd.readouterr().out.splitlines() == snapshot( + [IsStr(), 'Error initializing potato:', 'Unknown model: potato'] + ) def test_list_models(capfd: CaptureFixture[str]): @@ -76,3 +90,102 @@ def test_list_models(capfd: CaptureFixture[str]): for provider in providers: models = models - {model for model in models if model.startswith(provider)} assert models == set(), models + + +def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv): + env.set('OPENAI_API_KEY', 'test') + with cli_agent.override(model=TestModel(custom_result_text='# result\n\n```py\nx = 1\n```')): + assert cli(['hello']) == 0 + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py']) + assert cli(['--no-stream', 'hello']) == 0 + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py']) + + +def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv): + env.set('OPENAI_API_KEY', 'test') + with create_pipe_input() as inp: + inp.send_text('\n') + inp.send_text('hello\n') + inp.send_text('/markdown\n') + inp.send_text('/exit\n') + session = PromptSession[Any](input=inp, output=DummyOutput()) + m = mocker.patch('pydantic_ai._cli.PromptSession', return_value=session) + m.return_value = session + m = TestModel(custom_result_text='goodbye') + with cli_agent.override(model=m): + assert cli([]) == 0 + assert capfd.readouterr().out.splitlines() == snapshot( + [ + IsStr(), + IsStr(regex='goodbye *Markdown output of last question:'), + '', + 'goodbye', + 'Exiting…', + ] + ) + + +def test_handle_slash_command_markdown(): + io = StringIO() + assert handle_slash_command('/markdown', [], False, Console(file=io), 'default') == (None, False) + assert io.getvalue() == snapshot('No markdown output available.\n') + + messages: list[ModelMessage] = [ModelResponse(parts=[TextPart('[hello](# hello)'), ToolCallPart('foo', '{}')])] + io = StringIO() + assert handle_slash_command('/markdown', messages, True, Console(file=io), 'default') == (None, True) + assert io.getvalue() == snapshot("""\ +Markdown output of last question: + +[hello](# hello) +""") + + +def test_handle_slash_command_multiline(): + io = StringIO() + assert handle_slash_command('/multiline', [], False, Console(file=io), 'default') == (None, True) + assert io.getvalue() == snapshot( + 'Enabling multiline mode. Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.\n' + ) + + io = StringIO() + assert handle_slash_command('/multiline', [], True, Console(file=io), 'default') == (None, False) + assert io.getvalue() == snapshot('Disabling multiline mode.\n') + + +def test_handle_slash_command_exit(): + io = StringIO() + assert handle_slash_command('/exit', [], False, Console(file=io), 'default') == (0, False) + assert io.getvalue() == snapshot('Exiting…\n') + + +def test_handle_slash_command_other(): + io = StringIO() + assert handle_slash_command('/foobar', [], False, Console(file=io), 'default') == (None, False) + assert io.getvalue() == snapshot('Unknown command `/foobar`\n') + + +def test_code_theme_unset(mocker: MockerFixture, env: TestEnv): + env.set('OPENAI_API_KEY', 'test') + mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat') + cli([]) + mock_run_chat.assert_awaited_once_with( + IsInstance(PromptSession), True, IsInstance(Agent), IsInstance(Console), 'monokai' + ) + + +def test_code_theme_light(mocker: MockerFixture, env: TestEnv): + env.set('OPENAI_API_KEY', 'test') + mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat') + cli(['--code-theme=light']) + mock_run_chat.assert_awaited_once_with( + IsInstance(PromptSession), True, IsInstance(Agent), IsInstance(Console), 'default' + ) + + +def test_code_theme_dark(mocker: MockerFixture, env: TestEnv): + env.set('OPENAI_API_KEY', 'test') + mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat') + cli(['--code-theme=dark']) + mock_run_chat.assert_awaited_once_with( + IsInstance(PromptSession), True, IsInstance(Agent), IsInstance(Console), 'monokai' + )