In [1]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri(
    'postgresql://postgres:postgres@localhost:5434/geogpt_db')
print(db.dialect)
print(db.get_usable_table_names())
# db.run("SELECT * FROM Artist LIMIT 10;")

postgresql
['layer', 'osm_buildings_polygons', 'osm_landuse_polygons', 'osm_natural_points', 'osm_natural_polygons', 'osm_places_of_worship_points', 'osm_places_of_worship_polygons', 'osm_places_points', 'osm_places_polygons', 'osm_points_of_interest_points', 'osm_points_of_interest_polygons', 'osm_railways_lines', 'osm_roads_lines', 'osm_traffic_points', 'osm_traffic_polygons', 'osm_transport_points', 'osm_transport_polygons', 'osm_water_polygons', 'osm_waterways_lines', 'spatial_ref_sys', 'topology']


  self._metadata.reflect(


In [2]:
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv

load_dotenv('../../.env')

llm = ChatOpenAI(model="gpt-3.5-turbo-1106", temperature=0)

In [3]:
from langsmith import Client

client = Client()
run = client.read_run("98b7a0ca-1996-4d2a-912a-5104834e1b8b")
print(run.url)

https://smith.langchain.com/o/1bca6b3d-1f4b-564f-b424-a1a872a82f1b/projects/p/b8fdcc72-9e48-4a81-88b1-715dca9c4bf7/r/98b7a0ca-1996-4d2a-912a-5104834e1b8b?trace_id=98b7a0ca-1996-4d2a-912a-5104834e1b8b&start_time=2024-04-15T10:23:22.952373


In [17]:
from typing import Optional, Type, List

from langchain_core.callbacks import (
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from langchain_community.tools.sql_database.tool import BaseSQLDatabaseTool
from langchain_core.pydantic_v1 import BaseModel, Field
from sqlalchemy.schema import CreateTable
from sqlalchemy.types import NullType
from sqlalchemy import text


def _get_srid(table_name):
    query = text(f"""SELECT auth_name, auth_srid, proj4text
            FROM spatial_ref_sys
            WHERE srid = (SELECT ST_SRID(geom) FROM public.{table_name} LIMIT 1);""")
    with db._engine.connect() as conn:
        result = conn.execute(query).fetchone()
    auth_name, auth_srid, proj4text = result
    return f"{auth_name}:{auth_srid} ({proj4text})"


def _get_column_distributions(table_name: str, column_names: List[str]):
    print(column_names)
    distributions = {}
    for column_name in column_names:
        query = text(f"""
        WITH RandomSubset AS (
            SELECT {column_name}
            FROM {table_name}
            ORDER BY RANDOM() 
            LIMIT 10000
        )
        SELECT {column_name}, (COUNT(*) * 100.0) / 10000 AS percent
        FROM RandomSubset
        GROUP BY {column_name}
        ORDER BY percent DESC
        LIMIT 10;
        """)
        with db._engine.connect() as conn:
            results = conn.execute(query).fetchall()
            print(len(results))

        distributions[column_name] = [
            {'value': value, 'percent': round(percent, 1)} for value, percent in results]

    return distributions


def _create_distribution_string(distributions):
    output_string = ''
    for column_name, distribution_data in distributions.items():
        print(column_name)
        output_string += f"Column: {column_name}\n"
        for data in distribution_data:
            output_string += f"    {data['value']}: {data['percent']}\n"
            output_string += '\n'
    return output_string


table_names = ['osm_buildings_polygons']
table_infos = []
for table_name in table_names:
    table = db._metadata.tables[table_name]

    for k, v in table.columns.items():
        if type(v.type) is NullType:
            table._columns.remove(v)

    create_table = str(CreateTable(
        table).compile(db._engine))
    srid = _get_srid(table_name)

    column_names = table.columns.keys()
    print(column_names)
    exclude_columns = ['id', 'code', 'osm_id',
                        'name', 'ref', 'layer', 'population', 'geom']
    column_distributions = _get_column_distributions(
        table_name=table_name,
        column_names=[
            cname for cname in column_names if not cname in exclude_columns]
    )
    distribution_string = _create_distribution_string(
        column_distributions)

    table_infos.append(
        f'{create_table}{srid}\n\n{distribution_string}\n')

print(f'\n{"-" * 100}\n'.join(table_infos))

['id', 'osm_id', 'code', 'fclass', 'name', 'type']
['fclass', 'type']


1
10
fclass
type

CREATE TABLE osm_buildings_polygons (
	id SERIAL NOT NULL, 
	osm_id VARCHAR(12), 
	code INTEGER, 
	fclass VARCHAR(28), 
	name VARCHAR(100), 
	type VARCHAR(20), 
	CONSTRAINT buildings_polygons_pkey PRIMARY KEY (id)
)

EPSG:4326 (+proj=longlat +datum=WGS84 +no_defs )

Column: fclass
    building: 100.0
Column: type
    house: 26.6
    garage: 21.3
    cabin: 11.0
    shed: 6.7
    barn: 5.4
    farm_auxiliary: 5.0
    farm: 3.5
    None: 3.4
    terrace: 3.3
    semidetached_house: 2.7




In [50]:
import ast
import re


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    # res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return res


proper_nouns = query_as_list(
    db, 'select distinct name from water_polygons order by name')
# proper_nouns += query_as_list(db, "SELECT Title FROM Album")
# proper_nouns += query_as_list(db, "SELECT Name FROM Genre")
len(proper_nouns)
proper_nouns, len(proper_nouns)

(['0.5 m',
  '1003',
  '1. dam',
  '20.1',
  '2. dam',
  '2. Stampe',
  '3. dam',
  '3. Stampe',
  '45,9',
  '49,0',
  '50,2',
  '5-pottingen',
  '832',
  '964',
  'Aaletajärvi / Alitjávri',
  'Aallejegaejsienjaevrie',
  'Aastjønna',
  'Aasvatnet',
  'Åbårtjønna',
  'Åbbårloken',
  'Åbbårtjønna',
  'Abborrtjärnen',
  'Abborrtjärnet',
  'Abborsjømyra',
  'Abbortjärn',
  'Abbortjenn',
  'Åbbortjenn',
  'Abbortjenna',
  'Abbortjennet',
  'Åbbortjennet',
  'Abbortjennmyra',
  'Abbortjern',
  'Abbortjernet',
  'Åbbortjernet',
  'Åbbortjønn',
  'Abbortjønna',
  'Abbortjønnan',
  'Åbbortjønnet',
  'Åbbortjønnloken',
  'Abelhølet',
  'Åbelvigtjønn',
  'Aberdalstjønn',
  'Ábijávri',
  'Abildsømyra',
  'Åborsjøen',
  'Åbortjenn',
  'Åbortjennet',
  'Abortjern',
  'Åbortjern',
  'Åbortjernet',
  'Åbortjønna',
  'Åbortjønnet',
  'Åbortjønnin',
  'Åborvatnet',
  'Abrahamsmyra',
  'Abrahamtjern',
  'Abrahamvatnet',
  'Åbydammen',
  'Åbydammen nedre',
  'Adalstjern',
  'Adamsvatnet',
  'Adelsbreen',


In [51]:
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 15})

In [55]:
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain

system = """You are a SQLite expert. Given an input question, create a syntactically \
correct SQLite query to run. Unless otherwise specificed, do not return more than \
{top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nHere is a non-exhaustive \
list of possible feature values. If filtering on a feature value make sure to check its spelling \
against this list first:\n\n{proper_nouns}"""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{input}")])

query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (
    itemgetter("question")
    | retriever
    | (lambda docs: "\n".join(doc.page_content for doc in docs))
)
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain

query = chain.invoke(
    {"question": "jonsvatnet trondheim"})
print(query)
# db.run(query)

```sql
SELECT * 
FROM natural_points 
WHERE name = 'Jonsvatnet' AND fclass = 'lake'
UNION
SELECT * 
FROM natural_polygons 
WHERE name = 'Jonsvatnet' AND fclass = 'lake';
```
