In [1]:
from IPython.display import Markdown, display
def print_large(text):
    display(Markdown(f"### {text}"))

In [2]:
import duckdb

conn = duckdb.connect()

conn.execute("""
INSTALL httpfs;
INSTALL json;
INSTALL spatial;
""")
conn.execute(f"""
LOAD httpfs;
LOAD json;
LOAD spatial;
""")

<duckdb.duckdb.DuckDBPyConnection at 0x7f992df5b1b0>

In [3]:
admin_geojson_url = "https://github.com/nvkelso/natural-earth-vector/raw/master/geojson/ne_110m_admin_0_countries.geojson"
conn.execute(f"CREATE TABLE countries AS SELECT * FROM ST_Read('{admin_geojson_url}')")

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

<duckdb.duckdb.DuckDBPyConnection at 0x7f992df5b1b0>

In [4]:
summary_of_tables = ""

# SHOWによってテーブル一覧を取得
show_result = conn.execute("SHOW").fetchall()
tables = [row[2] for row in show_result]

for table in tables:
    summary_of_tables += f"Table: {table}\n"
    # DESCRIBE TABLEの結果を文字列に変換
    describe_result = conn.execute(f"DESCRIBE TABLE {table}").fetchall()
    for row in describe_result:
        field_name = row[0]
        field_type = row[1]
        summary_of_tables += f"  Field: {field_name}, {field_type}\n"
print(summary_of_tables)

Table: countries
  Field: featurecla, VARCHAR
  Field: scalerank, INTEGER
  Field: LABELRANK, INTEGER
  Field: SOVEREIGNT, VARCHAR
  Field: SOV_A3, VARCHAR
  Field: ADM0_DIF, INTEGER
  Field: LEVEL, INTEGER
  Field: TYPE, VARCHAR
  Field: TLC, VARCHAR
  Field: ADMIN, VARCHAR
  Field: ADM0_A3, VARCHAR
  Field: GEOU_DIF, INTEGER
  Field: GEOUNIT, VARCHAR
  Field: GU_A3, VARCHAR
  Field: SU_DIF, INTEGER
  Field: SUBUNIT, VARCHAR
  Field: SU_A3, VARCHAR
  Field: BRK_DIFF, INTEGER
  Field: NAME, VARCHAR
  Field: NAME_LONG, VARCHAR
  Field: BRK_A3, VARCHAR
  Field: BRK_NAME, VARCHAR
  Field: BRK_GROUP, VARCHAR
  Field: ABBREV, VARCHAR
  Field: POSTAL, VARCHAR
  Field: FORMAL_EN, VARCHAR
  Field: FORMAL_FR, VARCHAR
  Field: NAME_CIAWF, VARCHAR
  Field: NOTE_ADM0, VARCHAR
  Field: NOTE_BRK, VARCHAR
  Field: NAME_SORT, VARCHAR
  Field: NAME_ALT, VARCHAR
  Field: MAPCOLOR7, INTEGER
  Field: MAPCOLOR8, INTEGER
  Field: MAPCOLOR9, INTEGER
  Field: MAPCOLOR13, INTEGER
  Field: POP_EST, DOUBLE
  Fie

In [5]:
input_text = "日本よりも広い国は世界で何ヶ国ありますか？"

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI

# モデルの準備
model = ChatGoogleGenerativeAI(model="gemini-exp-1206", temperature=0)

# プロンプトの準備
template = """You are an expert of PostgreSQL and PostGIS.
You output the best PostgreSQL query based on given table schema and input text.

You will always reply according to the following rules:
- Output valid PostgreSQL query.
- The query MUST be line delimited and surrounded by just three backquote to indicate that it is a code block.

** Table Schema: **
{table_schema}

User Input:
{input}
"""
prompt = ChatPromptTemplate.from_template(template)

chain = prompt | model

res = chain.invoke({"input": input_text, "table_schema": summary_of_tables})
result = res.content.strip()

import re
match = re.search(r"```[^\n]*\n(.*?)```", result, re.DOTALL)
query = match.group(1).strip()
print(query)

SELECT
  COUNT(*)
FROM countries
WHERE
  ST_Area(geom) > (
    SELECT
      ST_Area(geom)
    FROM countries
    WHERE
      NAME = 'Japan'
  );


In [7]:
duckdb_result = conn.execute(query).fetchall()
duckdb_result

[(59,)]

In [8]:
prompt = (
    "Given the following user question, corresponding SQL query, "
    "and SQL result, answer the user question.\n\n"
    f'Question: {input_text}\n'
    f'SQL Query: {query}\n'
    f'SQL Result: {duckdb_result}\n'
)

In [9]:
answer = model.invoke(prompt).content.strip()
print_large(answer)

### 日本よりも広い国は世界で59ヶ国あります。