In [1]:
from typing import List, Literal, Union, Dict, Any
from pydantic import BaseModel
import psycopg2
import pandas as pd
import json

In [2]:
# ---------- 1. Pydantic Schema：白名单表 & 列 ----------

class MetricFilter(BaseModel):
    # 只允许这些列名
    column: Literal["id", "store_id", "product_sku", "date", "price", "promo_flag", "units_sold"]
    # 简化，只用等号就够了（也可以加 > < >= <=）
    op: Literal["="] = "="
    value: Union[int, float, str]


class MetricQuery(BaseModel):
    # 只允许查这一张表
    table: Literal["Store_Sales_Price_Elasticity_Promotions_Data"]

    # 要 select 的列（可以是 "*" 或具体列）
    select: List[Literal["*", "id", "store_id", "product_sku", "date", "price", "promo_flag", "units_sold"]] = ["*"]

    # WHERE 条件列表（AND 连接）
    filters: List[MetricFilter] = []

    # 限制返回行数
    limit: int = 100


# ---------- 2. SQL Builder：参数化 SQL ----------

def metric_query_to_sql(mq: MetricQuery) -> tuple[str, Dict[str, Any]]:
    # SELECT 子句
    select_cols = ", ".join(mq.select)
    table_name = mq.table

    where_clauses = []
    params: Dict[str, Any] = {}

    for i, f in enumerate(mq.filters):
        param_name = f"p{i}"
        # 列名用双引号包一下，防止关键字冲突
        where_clauses.append(f'"{f.column}" {f.op} %({param_name})s')
        params[param_name] = f.value

    where_sql = ""
    if where_clauses:
        where_sql = "WHERE " + " AND ".join(where_clauses)

    sql = f'''
    SELECT {select_cols}
    FROM "{table_name}"
    {where_sql}
    LIMIT %(limit)s;
    '''
    params["limit"] = mq.limit

    return sql, params


# ---------- 3. Postgres 执行：拿 rows ----------

def run_query(metric_json_str: str):
    # 3.1 LLM/前端传来的 JSON → Pydantic 校验
    mq = MetricQuery.model_validate_json(metric_json_str)
    print(mq)
    # 3.2 构造 SQL + 参数
    sql, params = metric_query_to_sql(mq)
    print("SQL:\n", sql)
    print("Params:", params)

    # 3.3 连接 PostgreSQL（根据你本地配置改一下）
    # conn = psycopg2.connect(
    #     dbname="mydb",
    #     user="postgres",
    #     password="your_password",  # TODO: 改成你自己的
    #     host="localhost",
    #     port=5432,
    # )
    conn = psycopg2.connect("dbname=mydb")

    # 3.4 用 pandas 读出来（也可以用 cursor）
    df = pd.read_sql(sql, conn, params=params)
    print("\nResult rows:\n", df)

    conn.close()

In [3]:
if __name__ == "__main__":
    # 模拟 LLM 生成的 MetricQuery JSON
    # 需求：查出 store_id = 1320 的所有行
    metric_json = json.dumps({
        "table": "Store_Sales_Price_Elasticity_Promotions_Data",
        "select": ["*"],
        "filters": [
            {"column": "store_id", "op": "=", "value": 1320}
        ],
        "limit": 100
    })

    run_query(metric_json)

table='Store_Sales_Price_Elasticity_Promotions_Data' select=['*'] filters=[MetricFilter(column='store_id', op='=', value=1320)] limit=100
SQL:
 
    SELECT *
    FROM "Store_Sales_Price_Elasticity_Promotions_Data"
    WHERE "store_id" = %(p0)s
    LIMIT %(limit)s;
    
Params: {'p0': 1320, 'limit': 100}

Result rows:
    id  store_id product_sku        date  price  promo_flag  units_sold
0   1      1320         A12  2021-11-01   4.99           0          32
1   2      1320         A12  2021-11-01   4.49           1          65
2   3      1320         A12  2021-11-01   4.99           0          28


  df = pd.read_sql(sql, conn, params=params)
