In [1]:
import os
import uuid
import json
import numpy as np
from PIL import Image
from pydantic import BaseModel
from sqlalchemy import create_engine, MetaData

from sentence_transformers import SentenceTransformer

from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain_ollama import OllamaEmbeddings
from langchain.schema import Document
from langchain.vectorstores import Qdrant

from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct, VectorParams, Distance

from dotenv import load_dotenv
load_dotenv(override=True)

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


## Introspect ด้วย SQLAlchemy

In [2]:
engine   = create_engine("postgresql://postgres:swd12345@localhost:5432/postgres")
metadata = MetaData()
metadata.reflect(bind=engine)   # อ่านโครงสร้างจากฐานข้อมูลตาม ERD ที่ใช้งานจริง

## Generate JSON Docs

In [None]:
# # 2. วิ่งวนแต่ละตาราง สร้าง docs list
# docs = []
# for table in (metadata.sorted_tables):
#     text = f"Table: {table.name}\n"
#     for col in table.columns:
#         pk_flag = " [PK]" if col.primary_key else ""
#         text += f"- **{col.name}** ({col.type}){pk_flag}\n"
#     # เติม sample rows ถ้าต้องการ
#     docs.append({"text": text, "meta": {"table": table.name}})

# # 3. เซฟออกไฟล์ JSON (ใช้สำหรับ embed ต่อ)
# with open("schema_docs.json", "w") as f:
#     json.dump(docs, f, indent=2)

In [None]:
# # 1. โหลด docs ที่สร้างไว้
# with open("schema_docs.json", "r") as f:
#     docs = json.load(f)

# Embed Database Metadata

In [73]:
# --------- 2. Build snippet from SQLAlchemy metadata ---------
def snippet_from_table(table, business_rules=None, snippet_type: str = "schema_overview"):
    """
    1. ไล่ทีละคอลัมน์ในตาราง SQLAlchemy และสร้าง snippet's content ที่มีรายละเอียดของคอลัมน์นั้น ๆ
    2. เพิ่ม business rules ถ้ามี
    3. สร้างตัวอย่าง query ที่ใช้ดึงข้อมูลจากตารางนั้น
    4. คืนค่าเป็น dict ที่มีข้อมูล snippet พร้อมใช้ เพื่อเก็บใน Qdrant หรือระบบอื่น ๆ
    """
    if business_rules is None:
        business_rules = []
    col_lines = []
    for col in table.columns:
        desc = "N/A" 
        example = "N/A"
        name = col.name
        col_lines.append(f" - {name}: ({str(col.type)}, nullable={col.nullable}, primary_key={col.primary_key}) description: {desc}; example: {example};")

    content = f"Table: {table.name}\n" + "\n".join(col_lines)
    if business_rules:
        content += "\nBusiness rules:\n" + "\n".join(f"* {r}" for r in business_rules)

    example_queries = [
        f"SELECT * FROM {table.name} ORDER BY {table.name[0]} DESC LIMIT 10;",
        f"SELECT COUNT(*) FROM {table.name};"
    ]
    
    snippet = {
        "id": f"schema::{table.name}",
        "type": snippet_type,
        "title": f"Schema of {table.name}",
        "content": content,
        "business_rules": business_rules,
        "example_query": example_queries,
    }
    return snippet


In [74]:
docs = []
for i, table in enumerate(metadata.sorted_tables):
    snippets = snippet_from_table(table, business_rules=["No duplicate entries allowed"], snippet_type="schema_overview")
    docs.append(snippets)

with open("snippet_schema.json", "w") as f:
    json.dump(docs, f, indent=2)

In [72]:
# Show a sample database table

for table in metadata.sorted_tables:
    print(f"Table: {table.name}")
    for field in table.columns:
        print(f"Column: {field.name}, Type: {field.type}, Primary Key: {field.primary_key}")
    print("\n---\n")


Table: auth_group
Column: id, Type: INTEGER, Primary Key: True
Column: name, Type: VARCHAR(150), Primary Key: False

---

Table: auth_user
Column: id, Type: INTEGER, Primary Key: True
Column: password, Type: VARCHAR(128), Primary Key: False
Column: last_login, Type: TIMESTAMP, Primary Key: False
Column: is_superuser, Type: BOOLEAN, Primary Key: False
Column: username, Type: VARCHAR(150), Primary Key: False
Column: first_name, Type: VARCHAR(150), Primary Key: False
Column: last_name, Type: VARCHAR(150), Primary Key: False
Column: email, Type: VARCHAR(254), Primary Key: False
Column: is_staff, Type: BOOLEAN, Primary Key: False
Column: is_active, Type: BOOLEAN, Primary Key: False
Column: date_joined, Type: TIMESTAMP, Primary Key: False

---

Table: company_companylimit
Column: id, Type: BIGINT, Primary Key: True
Column: deleted, Type: TIMESTAMP, Primary Key: False
Column: deleted_by_cascade, Type: BOOLEAN, Primary Key: False
Column: created_at, Type: TIMESTAMP, Primary Key: False
Column: 

In [30]:
# 2. เตรียม embedding function จาก Ollama
embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")

In [None]:
def prepare_point(docs, embed_vector: list):
    for doc in docs:
        point =  PointStruct(
            id=doc["id"],
            vector=embed_vector,
            payload={
                "type": doc["type"],
                "title": doc.get("title"),
                "content": doc.get("content"),
                "details": doc.get("details", {}),
                "example_query": doc.get("example_query"),
                "source": doc.get("source"),
                "last_updated": doc.get("last_updated")
            }
        )

In [32]:
# prepare documents as Document objects
docs_for_qdrant = [
    Document(page_content=item["text"], metadata=item["meta"])
    for item in docs
]

vectordb = Qdrant.from_documents(
    force_recreate=True,
    documents=docs_for_qdrant,
    embedding=embeddings,           # your OllamaEmbeddings instance
    url="http://localhost:6333",    # Qdrant REST URL
    prefer_grpc=False,
    collection_name="schema_docs"
)

In [33]:
retriever = vectordb.as_retriever()

In [None]:
user_question = "What are the columns in the users table?" # "Show me all 'company_contractor'"
query_embedding = embeddings(user_question)
search_result = client.search(
    collection_name="llm_grounding",
    query_vector=query_embedding,
    limit=5,  # ปรับตามความเหมาะสม
    with_payload=True
)


# Call LLM

In [None]:
# ใช้โมเดล llama3 ที่รันบน Ollama
llm = OllamaLLM(model="llama3.2")

# SQL Agent

In [None]:

db    = SQLDatabase.from_uri("postgresql://postgres:swd12345@localhost:5432/postgres")
db.get_usable_table_names()


['auth_group',
 'auth_group_permissions',
 'auth_permission',
 'auth_user',
 'auth_user_groups',
 'auth_user_user_permissions',
 'authentication_userdevice',
 'company_company',
 'company_companylimit',
 'company_companyuser',
 'company_contractor',
 'company_contractor_tag_list',
 'company_discipline',
 'company_disciplinegroup',
 'company_formcenter',
 'company_formcenter_response',
 'company_formcenterresponse',
 'company_group',
 'company_locationgroup',
 'company_locationphase',
 'company_locationtype',
 'company_permission',
 'company_permissiongroup',
 'company_process',
 'company_profile',
 'company_projectcolumnsetting',
 'company_projecttype',
 'company_tag',
 'company_tag_tool_list',
 'company_tagprojectgroup',
 'company_toollabels',
 'company_usersetting',
 'company_usersignatures',
 'company_userstamps',
 'company_usertoolnotisetting',
 'customers_chppmilestone',
 'customers_chppprojectdashboard',
 'customers_chpptask',
 'django_admin_log',
 'django_celery_beat_clockedsche

In [35]:
# Create an agent that can interact with the SQL database
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=SQLDatabaseToolkit(db=db, llm=llm),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    top_k=100
)

In [36]:
# 4) Ask it!
print(agent.run("Show me all 'company_contractor'"))

The database does not contain any tables with the column name 'company_contractor'. 

I don't know


#  Hybrid Agent Wrapper

In [27]:
class HybridDBAgent:
    def __init__(self, retriever, sql_chain, rag_chain=None):
        self.retriever = retriever
        self.sql_chain = sql_chain
        self.rag_chain = rag_chain

    def ask(self, query: str) -> str:
        # 1) ดึง context ที่เกี่ยวข้อง
        docs = self.retriever.get_relevant_documents(query)
        context = "\n\n".join(d.page_content for d in docs)

        # 2) ถ้าอยากใช้ RAG fallback (narrative analytics) ให้เช็คคีย์เวิร์ด
        if self.rag_chain and any(w in query.lower() for w in ["why", "trend", "recommend"]):
            return self.rag_chain.run(context + "\n\n" + query)

        # 3) ปกติให้รัน SQL chain
        prompt = f"{context}\n\nUser: {query}"
        return self.sql_chain.run(prompt)
