In [1]:
import os
import uuid
import json
import numpy as np
from PIL import Image
from pydantic import BaseModel
from sqlalchemy import create_engine, MetaData

from sentence_transformers import SentenceTransformer

from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain_ollama import OllamaEmbeddings
from langchain.schema import Document
from langchain.vectorstores import Qdrant

from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct, VectorParams, Distance

from dotenv import load_dotenv
load_dotenv(override=True)

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


## Introspect ด้วย SQLAlchemy

In [2]:
DATABASE_URI = os.getenv("POSTGRES_URI")

In [3]:
engine   = create_engine(DATABASE_URI)
engine.connect()

<sqlalchemy.engine.base.Connection at 0x7dbc91db9210>

In [4]:
metadata = MetaData()
metadata.reflect(bind=engine)   # อ่านโครงสร้างจากฐานข้อมูลตาม ERD ที่ใช้งานจริง

## Generate JSON Docs

In [5]:
# 2. วิ่งวนแต่ละตาราง สร้าง docs list
docs = []
for table in (metadata.sorted_tables):
    text = f"Table: {table.name}\n"
    for col in table.columns:
        pk_flag = " [PK]" if col.primary_key else ""
        text += f"- **{col.name}** ({col.type}){pk_flag}\n"
    # เติม sample rows ถ้าต้องการ
    docs.append({"text": text, "meta": {"table": table.name}})

# 3. เซฟออกไฟล์ JSON (ใช้สำหรับ embed ต่อ)
with open("schema_docs.json", "w") as f:
    json.dump(docs, f, indent=2)

In [6]:
# # 1. โหลด docs ที่สร้างไว้
# with open("schema_docs.json", "r") as f:
#     docs = json.load(f)

# Embed Database Metadata

In [7]:
# --------- 2. Build snippet from SQLAlchemy metadata ---------
def snippet_from_table(table, business_rules=None, snippet_type: str = "schema_overview"):
    """
    1. ไล่ทีละคอลัมน์ในตาราง SQLAlchemy และสร้าง snippet's content ที่มีรายละเอียดของคอลัมน์นั้น ๆ
    2. เพิ่ม business rules ถ้ามี
    3. สร้างตัวอย่าง query ที่ใช้ดึงข้อมูลจากตารางนั้น
    4. คืนค่าเป็น dict ที่มีข้อมูล snippet พร้อมใช้ เพื่อเก็บใน Qdrant หรือระบบอื่น ๆ
    """
    if business_rules is None:
        business_rules = []
    col_lines = []
    for col in table.columns:
        desc = "N/A" 
        example = "N/A"
        name = col.name
        col_lines.append(f" - {name}: ({str(col.type)}, nullable={col.nullable}, primary_key={col.primary_key}) description: {desc}; example: {example};")

    content = f"Table: {table.name}\n" + "\n".join(col_lines)
    if business_rules:
        content += "\nBusiness rules:\n" + "\n".join(f"* {r}" for r in business_rules)

    example_queries = [
        f"SELECT * FROM {table.name} ORDER BY {table.name[0]} DESC LIMIT 10;",
        f"SELECT COUNT(*) FROM {table.name};"
    ]
    
    snippet = {
        "id": f"schema::{table.name}",
        "type": snippet_type,
        "title": f"Schema of {table.name}",
        "content": content,
        "business_rules": business_rules,
        "example_query": example_queries,
    }
    return snippet


In [8]:
docs = []
for i, table in enumerate(metadata.sorted_tables):
    snippets = snippet_from_table(table, business_rules=["No duplicate entries allowed"], snippet_type="schema_overview")
    docs.append(snippets)

with open("snippet_schema.json", "w") as f:
    json.dump(docs, f, indent=2)

In [9]:
# Show a sample database table

for table in metadata.sorted_tables:
    print(f"Table: {table.name}")
    for field in table.columns:
        print(f"Column: {field.name}, Type: {field.type}, Primary Key: {field.primary_key}")
    print("\n---\n")


Table: auth_group
Column: id, Type: INTEGER, Primary Key: True
Column: name, Type: VARCHAR(150), Primary Key: False

---

Table: auth_user
Column: id, Type: INTEGER, Primary Key: True
Column: password, Type: VARCHAR(128), Primary Key: False
Column: last_login, Type: TIMESTAMP, Primary Key: False
Column: is_superuser, Type: BOOLEAN, Primary Key: False
Column: username, Type: VARCHAR(150), Primary Key: False
Column: first_name, Type: VARCHAR(150), Primary Key: False
Column: last_name, Type: VARCHAR(150), Primary Key: False
Column: email, Type: VARCHAR(254), Primary Key: False
Column: is_staff, Type: BOOLEAN, Primary Key: False
Column: is_active, Type: BOOLEAN, Primary Key: False
Column: date_joined, Type: TIMESTAMP, Primary Key: False

---

Table: company_companylimit
Column: id, Type: BIGINT, Primary Key: True
Column: deleted, Type: TIMESTAMP, Primary Key: False
Column: deleted_by_cascade, Type: BOOLEAN, Primary Key: False
Column: created_at, Type: TIMESTAMP, Primary Key: False
Column: 

In [10]:
# 2. เตรียม embedding function จาก Ollama
embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")

In [11]:
def prepare_point(docs, embed_vector: list):
    for doc in docs:
        point =  PointStruct(
            id=doc["id"],
            vector=embed_vector,
            payload={
                "type": doc["type"],
                "title": doc.get("title"),
                "content": doc.get("content"),
                "details": doc.get("details", {}),
                "example_query": doc.get("example_query"),
                "source": doc.get("source"),
                "last_updated": doc.get("last_updated")
            }
        )

In [13]:
# # prepare documents as Document objects
# docs_for_qdrant = [
#     Document(page_content=item["text"], metadata=item["meta"])
#     for item in docs
# ]

# docs_for_qdrant

In [14]:
# vectordb = Qdrant.from_documents(
#     force_recreate=True,
#     documents=docs_for_qdrant,
#     embedding=embeddings,           # your OllamaEmbeddings instance
#     url="http://localhost:6333",    # Qdrant REST URL
#     prefer_grpc=False,
#     collection_name="schema_docs"
# )

In [16]:
# retriever = vectordb.as_retriever()

In [17]:
# user_question = "What are the columns in the users table?" # "Show me all 'company_contractor'"
# query_embedding = embeddings(user_question)
# search_result = client.search(
#     collection_name="llm_grounding",
#     query_vector=query_embedding,
#     limit=5,  # ปรับตามความเหมาะสม
#     with_payload=True
# )


# Call LLM

In [18]:
# # ใช้โมเดล llama3 ที่รันบน Ollama
# llm = OllamaLLM(model="llama3.2")

# SQL Agent

In [19]:

# db    = SQLDatabase.from_uri("postgresql://postgres:swd12345@localhost:5432/postgres")
# db.get_usable_table_names()


In [20]:
# # Create an agent that can interact with the SQL database
# agent_executor = create_sql_agent(
#     llm=llm,
#     toolkit=SQLDatabaseToolkit(db=db, llm=llm),
#     verbose=True,
#     agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
#     top_k=100
# )

In [21]:
# # 4) Ask it!
# print(agent.run("Show me all 'company_contractor'"))

#  Hybrid Agent Wrapper

In [22]:
# class HybridDBAgent:
#     def __init__(self, retriever, sql_chain, rag_chain=None):
#         self.retriever = retriever
#         self.sql_chain = sql_chain
#         self.rag_chain = rag_chain

#     def ask(self, query: str) -> str:
#         # 1) ดึง context ที่เกี่ยวข้อง
#         docs = self.retriever.get_relevant_documents(query)
#         context = "\n\n".join(d.page_content for d in docs)

#         # 2) ถ้าอยากใช้ RAG fallback (narrative analytics) ให้เช็คคีย์เวิร์ด
#         if self.rag_chain and any(w in query.lower() for w in ["why", "trend", "recommend"]):
#             return self.rag_chain.run(context + "\n\n" + query)

#         # 3) ปกติให้รัน SQL chain
#         prompt = f"{context}\n\nUser: {query}"
#         return self.sql_chain.run(prompt)


# 1. Extract Schema with SQLAlchemy


In [23]:
from sqlalchemy import create_engine, MetaData, inspect

# 1) connect to your Postgres DB
engine = create_engine(os.getenv("POSTGRES_URI"))
metadata = MetaData()
metadata.reflect(bind=engine, schema="public")  # reflect all tables :contentReference[oaicite:0]{index=0}


In [24]:
metadata.sorted_tables[0]

Table('auth_group', MetaData(), Column('id', INTEGER(), table=<auth_group>, primary_key=True, nullable=False, server_default=Identity(start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1)), Column('name', VARCHAR(length=150), table=<auth_group>, nullable=False), schema='public')

In [25]:
# 2) or use Inspector for more control
insp = inspect(engine)
# tables = insp.get_table_names(schema="public")  # list all tables :contentReference[oaicite:1]{index=1}
tables = insp.get_table_names(schema="public")
len(tables)

186

# 2. Generate Natural-Language Descriptions


In [26]:
import random
from sqlalchemy import text

def describe_table(tbl_name, row_count):
    return f"{tbl_name} table – approx. {int(row_count):,} rows."


def describe_column(tbl_name, col_info):
    # build a safe SQL string—avoid shadowing `text`
    query = (
        f"SELECT {col_info['name']} "
        f"FROM public.{tbl_name} "
        "LIMIT 3"
    )
    try:
        with engine.connect() as conn:
            # Option A: wrap in TextClause
            # result = conn.execute(text(query))
            # Option B: or use the driver-level API
            result = conn.exec_driver_sql(query)
            sample_vals = result.fetchall()
    except:
        sample_vals = []
    samples = ", ".join(str(v[0]) for v in sample_vals)
    comment = col_info.get("comment") or ""
    return (
        f"{tbl_name}.{col_info['name']} ({col_info['type']}) "
        f"{comment} Examples: {samples}"
    )

def describe_fk(tbl_name, fk):
    src = f"{tbl_name}.{fk['constrained_columns'][0]}"
    dst = f"{fk['referred_table']}.{fk['referred_columns'][0]}"
    return f"{src} → {dst} (foreign-key relationship)"


# 3. Embed & Upsert into Qdrant

In [27]:
from qdrant_client import QdrantClient, models as qdrant
# initialize clients
qc  = QdrantClient(url=os.getenv("QDRANT_URL"))
emb = OllamaEmbeddings(model="nomic-embed-text:latest")

In [28]:
points = []
# tables
for i, tbl in enumerate(tables):
    row_est = metadata.tables[f"public.{tbl}"].count() if hasattr(metadata.tables[f"public.{tbl}"], 'count') else random.randint(1e3,1e5)
    text = describe_table(tbl, row_est)
    points.append(qdrant.PointStruct(
        id=i, vector=emb.embed_query(text),
        payload={"type":"table","table":tbl,"db":"CM"}))


In [29]:
# columns
for i, tbl in enumerate(tables):
    for col in insp.get_columns(tbl):
        text = describe_column(tbl, col)
        points.append(qdrant.PointStruct(
            id=i, vector=emb.embed_query(text),
            payload={"type":"column","table":tbl,"col":col['name']}))

In [46]:
for fk in insp.get_foreign_keys(table_name="auth_user"):
    print(fk)
    break

In [None]:
# foreign-keys
# TODO
for tbl in tables:
    for fk in insp.get_foreign_keys(tbl):
        text = describe_fk(tbl, fk)
        points.append(qdrant.PointStruct(
            id=i, vector=emb.embed_query(text),
            payload={"type":"fk","src":f"{tbl}.{fk['constrained_columns'][0]}",
                     "dst":f"{fk['referred_table']}.{fk['referred_columns'][0]}"}))
        i += 1

In [56]:
print(points[300])

id=12 vector=[0.031627126, 0.019274605, -0.22026151, -0.049132593, 0.0596554, -0.03596781, 0.056252, -0.068254806, -0.013091847, -0.009354297, -0.038297314, 0.017328575, 0.059369247, 0.044154536, 0.010272014, -0.054808497, -0.0010474337, -0.037958387, -0.05371522, 0.062408414, 0.049878586, -0.04402047, 0.0037765342, -0.034832776, 0.044711478, 0.06833008, 0.03840778, -0.025770921, -0.059868753, -0.028501475, 0.014887494, 0.073677965, -0.014575863, -0.029238503, 0.003487025, -0.028030664, 0.050771974, 0.027001034, -0.011195963, -0.037167057, 0.057461776, 0.0008591445, -0.010366242, -0.03010581, 0.015573404, -0.012781328, 0.030106664, 0.009957659, 0.062493738, -0.005977949, 0.007268761, -0.025798129, -0.013583751, 0.011243206, 0.0419708, 0.031815533, 0.00195205, 0.035100106, -0.045093335, -0.0070850556, 0.029917143, 0.053604383, 0.016207816, 0.044084534, 0.07166026, 0.002126486, -0.0007333492, 0.071544334, -0.0060309498, -0.0024364712, 0.028381728, -0.034101974, -0.0101115005, 0.001418837

In [74]:
COL_NAME = "schema_cm_db_knowledge"
if not qc.collection_exists(collection_name=COL_NAME): 
    qc.create_collection(
        collection_name=COL_NAME,
        vectors_config=VectorParams(size=768, distance=Distance.COSINE),
        # optional: tune shards, replication_factor, etc.
    )
else:
    print("Collection already created")

In [75]:
# 3. (Optional) Verify it exists
col_info = qc.get_collection(collection_name=COL_NAME)
col_info
# print(f"Collection '{col_info.name}' with {col_info.vectors_count} points ready.")

CollectionInfo(status=<CollectionStatus.GREEN: 'green'>, optimizer_status=<OptimizersStatusOneOf.OK: 'ok'>, vectors_count=None, indexed_vectors_count=0, points_count=0, segments_count=8, config=CollectionConfig(params=CollectionParams(vectors=VectorParams(size=768, distance=<Distance.COSINE: 'Cosine'>, hnsw_config=None, quantization_config=None, on_disk=None, datatype=None, multivector_config=None), shard_number=1, sharding_method=None, replication_factor=1, write_consistency_factor=1, read_fan_out_factor=None, on_disk_payload=True, sparse_vectors=None), hnsw_config=HnswConfig(m=16, ef_construct=100, full_scan_threshold=10000, max_indexing_threads=0, on_disk=False, payload_m=None), optimizer_config=OptimizersConfig(deleted_threshold=0.2, vacuum_min_vector_number=1000, default_segment_number=0, max_segment_size=None, memmap_threshold=None, indexing_threshold=10000, flush_interval_sec=5, max_optimization_threads=None), wal_config=WalConfig(wal_capacity_mb=32, wal_segments_ahead=0), quant

In [76]:
# upsert all at once
qc.upsert(collection_name=COL_NAME, points=points)

UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)

# 4. Query-Time Retrieval & Prompt Assembly

In [77]:
from qdrant_client.models import Filter, FieldCondition, MatchValue, SearchParams

def fetch_schema_slice(question, top_k=20):
    qvec = emb.embed_query(question)
    res = qc.search(
        collection_name="schema_catalog",
        query_vector=qvec,
        limit=top_k,
        params=SearchParams(sparse=True),  # hybrid dense + BM25 :contentReference[oaicite:10]{index=10}
        query_filter=Filter(
            must=[FieldCondition(
                key="type",
                match=MatchValue(any=["table","column","fk"])
            )]
        )
    )
    return res

# build prompt
points = fetch_schema_slice("latest orders and their customers")
schema_lines = []
for p in points:
    pl = p.payload
    if pl["type"] == "table":
        schema_lines.append(f"- {pl['table']} (table)")
    elif pl["type"] == "column":
        schema_lines.append(f"- {pl['table']}.{pl['col']}")
    else:  # fk
        schema_lines.append(f"- FK: {pl['src']} → {pl['dst']}")

prompt = "### Relevant schema\n" + "\n".join(schema_lines) + \
         "\n\n### Query:\nWrite valid SQL for: latest orders and their customers"


ValidationError: 1 validation error for SearchParams
sparse
  Extra inputs are not permitted [type=extra_forbidden, input_value=True, input_type=bool]
    For further information visit https://errors.pydantic.dev/2.11/v/extra_forbidden