In [1]:
!docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres

The files belonging to this database system will be owned by user "postgres".
This user must also own the server process.

The database cluster will be initialized with locale "en_US.utf8".
The default database encoding has accordingly been set to "UTF8".
The default text search configuration will be set to "english".

Data page checksums are disabled.

fixing permissions on existing directory /var/lib/postgresql/data ... ok
creating subdirectories ... ok
selecting dynamic shared memory implementation ... posix
selecting default "max_connections" ... 100
selecting default "shared_buffers" ... 128MB
selecting default time zone ... Etc/UTC
creating configuration files ... ok
running bootstrap script ... ok
performing post-bootstrap initialization ... ok
initdb: hint: You can change this by editing pg_hba.conf or using the option -A, or --auth-local and --auth-host, the next time you run initdb.
ok


Success. You can now start the database server using:

    pg_ctl -D /var/lib/postgresql/

In [10]:
!poetry add psycopg2-binary

Using version [39;1m^2.9.10[39;22m for [36mpsycopg2-binary[39m

[34mUpdating dependencies[39m
[2K[34mResolving dependencies...[39m [39;2m(0.7s)[39;22m

[39;1mPackage operations[39;22m: [34m1[39m install, [34m0[39m updates, [34m0[39m removals

  [34;1m-[39;22m [39mInstalling [39m[36mpsycopg2-binary[39m[39m ([39m[39;1m2.9.10[39;22m[39m)[39m: [34mPending...[39m
[1A[0J  [34;1m-[39;22m [39mInstalling [39m[36mpsycopg2-binary[39m[39m ([39m[39;1m2.9.10[39;22m[39m)[39m: [34mDownloading...[39m [39;1m0%[39;22m
[1A[0J  [34;1m-[39;22m [39mInstalling [39m[36mpsycopg2-binary[39m[39m ([39m[39;1m2.9.10[39;22m[39m)[39m: [34mDownloading...[39m [39;1m10%[39;22m
[1A[0J  [34;1m-[39;22m [39mInstalling [39m[36mpsycopg2-binary[39m[39m ([39m[39;1m2.9.10[39;22m[39m)[39m: [34mDownloading...[39m [39;1m20%[39;22m
[1A[0J  [34;1m-[39;22m [39mInstalling [39m[36mpsycopg2-binary[39m[39m ([39m[39;1m2.9.10[39;22m[39m)[39m:

In [48]:
from dataclasses import dataclass
import psycopg2
from psycopg2 import sql
from typing import Any, Annotated, Union
from annotated_types import MinLen
from typing_extensions import TypeAlias
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext, ModelRetry
from pydantic_ai.format_as_xml import format_as_xml
from datetime import date

In [51]:
DB_SCHEMA="""
CREATE TABLE records (
    created_at timestamptz,
    start_timestamp timestamptz,
    end_timestamp timestamptz,
    trace_id text,
    span_id text,
    parent_span_id text,
    level log_level,
    span_name text,
    message text,
    attributes_json_schema text,
    attributes jsonb,
    tags text[],
    is_exception boolean,
    otel_status_message text,
    service_name text
);
"""

In [52]:
SQL_EXAMPLES = [
    {
        'request': 'show me records where foobar is false',
        'response': "SELECT * FROM records WHERE attributes->>'foobar' = false",
    },
    {
        'request': 'show me records where attributes include the key "foobar"',
        'response': "SELECT * FROM records WHERE attributes ? 'foobar'",
    },
    {
        'request': 'show me records from yesterday',
        'response': "SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'",
    },
    {
        'request': 'show me error records with the tag "foobar"',
        'response': "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)",
    },
]

In [53]:
@dataclass
class Deps:
    conn: Any
    
class Success(BaseModel):
    """Response when SQL could be successfully generated."""
    sql_query: Annotated[str, MinLen(1)]
    explanation: str = Field(
        '', description=('Explanation of the SQL query, as markdown')
    )
    
class InvalidRequest(BaseModel):
    """Response the user input didn't include enough information to generate SQL."""
    error_msg: str

In [102]:
Response: TypeAlias = Union[Success, InvalidRequest]

agent: Agent[Deps, Response] = Agent(
    'gemini-1.5-flash',
    result_type=Response,
    deps_type=Deps,
    retries=4
)

In [None]:
@agent.system_prompt
def system_prompt()-> str:
    return f"""\
        Given the following PostgreSQL table of records, your job is to
        write a SQL query that suits the user's request.

        Database schema:

        {DB_SCHEMA}

        today's date = {date.today()}

        {format_as_xml(SQL_EXAMPLES)}
    """
    
@agent.result_validator
def validate_results(ctx: RunContext[Deps], result: Response) -> Response:
    if isinstance(result, InvalidRequest):
        return result
    
    result.sql_query = result.sql_query.replace("\\","")
    if not result.sql_query.upper().startswith("SELECT"):
        raise ModelRetry("Please create a SQL Query.")
    try:
        with ctx.deps.conn.cursor() as cur:
            try:
                query = sql.SQL("EXPLAIN {}").format(sql.SQL(result.sql_query))
                cur.execute(query)
            finally:
                cur.execute("ROLLBACK")
    except psycopg2.Error as e:
        raise ModelRetry(f"Invalid Query: {e}") from e
    else:
        return result

In [64]:
def database_connect(server_dsn:str, database: str)-> Any:
    conn = psycopg2.connect(server_dsn)
    conn.autocommit = True
    try:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT 1 from pg_database WHERE datname = %s", (database,)
            )
            db_exists = cur.fetchone() is not None
            
            if not db_exists:
                cur.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(database)))
    finally:
        conn.close()
        
    conn = psycopg2.connect(f"{server_dsn}/{database}")
    try:
        with conn:
            with conn.cursor() as cur:
                if not db_exists:
                    cur.execute("""
                        DO $$ 
                        BEGIN
                            IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'log_level') THEN
                                CREATE TYPE log_level AS ENUM ('debug', 'info', 'warning', 'error', 'critical');
                            END IF;
                        END $$;
                    """)
                    cur.execute(DB_SCHEMA)
        print("Schema setup completed successfully.")
        return conn
    except Exception as e:
       raise e

In [94]:
conn = database_connect(
        'postgresql://postgres:postgres@localhost:54320', 'pydantic_ai_sql_gen'
    )

Schema setup completed successfully.


In [104]:
deps = Deps(conn)

prompt = 'show me logs from yesterday, with level "error"'

result = await agent.run(prompt, deps=deps)

In [105]:
result.data

Success(sql_query="SELECT * FROM records WHERE level = 'error' AND created_at::date = CURRENT_DATE - INTERVAL '1 day'", explanation="This query selects all records with log level 'error' from yesterday.")