In [73]:
import os
from functools import lru_cache
import json
import json5
from datetime import datetime
from typing import List

from langchain_core.messages.tool import ToolCall
from langchain_core.messages import (
    AnyMessage,
    SystemMessage,
    HumanMessage,
    AIMessage,
)
from langgraph.graph import MessagesState

from langchain_openai import ChatOpenAI
from langchain_community.embeddings import InfinityEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams

from dotenv import load_dotenv

load_dotenv()

True

# Setup components

In [2]:
LLM_BASE_URL=os.getenv("LLM_BASE_URL")
LLM_MODEL=os.getenv("LLM_MODEL")
LLM_API_KEY=os.getenv("LLM_API_KEY")

EMBED_BASE_URL=os.getenv("EMBED_BASE_URL")
EMBED_MODEL=os.getenv("EMBED_MODEL")


@lru_cache()
def get_llm_model():
    return ChatOpenAI(
        model=LLM_MODEL,
        base_url=LLM_BASE_URL,
        api_key=LLM_API_KEY,
        temperature=0.7,
        top_p=0.8,
        presence_penalty=1,
        extra_body = {
            'chat_template_kwargs': {'enable_thinking': False},
            "top_k": 20,
            "mip_p": 0,
        },
    )

@lru_cache()
def get_thinking_llm_model():
    return ChatOpenAI(
        model=LLM_MODEL,
        base_url=LLM_BASE_URL,
        api_key=LLM_API_KEY,
        temperature=0.6,
        top_p=0.95,
        presence_penalty=1,
        extra_body = {
            'chat_template_kwargs': {'enable_thinking': True},
            "top_k": 20,
            "mip_p": 0,
        },
    )

@lru_cache()
def get_embedding_model():
    return InfinityEmbeddings(
        model=EMBED_MODEL,
        infinity_api_url=EMBED_BASE_URL,
    )


@lru_cache()
def get_vector_store():
    client = QdrantClient(
        url="http://localhost",
        grpc_port=6334,
        prefer_grpc=True,
    )
    embedding_model = get_embedding_model()
    client.create_collection(
        collection_name="demo",
        vectors_config=VectorParams(
            size=len(embedding_model.embed_query("Hello")), 
            distance=Distance.COSINE
        ),
    )
    return QdrantVectorStore(
        client=client,
        collection_name="demo",
        embedding=embedding_model,
    )


# @lru_cache()
# def get_sqlite_db():
#     return SQLDatabase.from_uri("sqlite:////Users/vinhnguyen/Projects/ext-chatbot/resources/database/batdongsan.db")


# Process data

## Excel

In [None]:
import glob
from src.tools.table import create_sqlite_table

In [2]:
tables = []
for filepath in glob.glob("/Users/vinhnguyen/Projects/ext-chatbot/resources/processed_data/batdongsan_1/*.json"):
    table_name = ".".join(filepath.split("/")[-1].split(".")[:-1])
    with open(filepath, "r") as f:
        table = json.load(f)
        table["pydantic_schema"]["title"] = table_name
        if len(table["transformed_data"]) > 100:
            tables.append(table)

print([table["pydantic_schema"]["title"] for table in tables])
print(2)

['BĐS Cho thuê 500', 'BĐS Bán 500']
2


In [3]:
for table in tables:
    create_sqlite_table(
        schema=table["pydantic_schema"],
        column_groups=table["column_groups"],
        data=table["transformed_data"],
        db_path="/Users/vinhnguyen/Projects/ext-chatbot/resources/database/batdongsan.db",
    )

In [38]:
# Request more files from PO to test

# Utils

In [3]:
def extract_fn(text: str) -> tuple[str, str]:
    """Extract function name and arguments from tool call text."""
    fn_name, fn_args = '', ''
    fn_name_s = '"name": "'
    fn_name_e = '", "'
    fn_args_s = '"arguments": '
    
    i = text.find(fn_name_s)
    k = text.find(fn_args_s)
    
    if i > 0:
        _text = text[i + len(fn_name_s):]
        j = _text.find(fn_name_e)
        if j > -1:
            fn_name = _text[:j]
    
    if k > 0:
        fn_args = text[k + len(fn_args_s):]
    
    fn_args = fn_args.strip()
    if len(fn_args) > 2:
        fn_args = fn_args[:-1]
    else:
        fn_args = ''
    
    return fn_name, fn_args


def postprocess_ai_message(
    ai_message: AIMessage,
) -> List[AIMessage]:
    """
    Convert AIMessage with <tool_call> tags to proper LangChain message with tool calls and leave it in a list to integrate with MessagesState.
    Assumes all content is text (no multimodal).
    """
    tool_id = 1
    
    content: str = ai_message.content if isinstance(ai_message.content, str) else str(ai_message.content)
    
    # Handle <think> tags - skip tool call parsing inside thinking
    if '<think>' in content:
        if '</think>' not in content:
            # Incomplete thinking, add as regular message
            return [ai_message]
        
        # Split thinking from rest of content
        parts = content.split('</think>')
        content = parts[-1]
        
    
    # Find tool calls in content
    if '<tool_call>' not in content:
        # No tool calls, add as regular message
        return [AIMessage(content=content.strip())]
    
    # Split content by tool calls
    tool_call_list = content.split('<tool_call>')
    pre_text = tool_call_list[0].strip()
    tool_calls: List[ToolCall] = []
    
    # Process each tool call
    for txt in tool_call_list[1:]:
        if not txt.strip():
            continue
        
        # Handle incomplete tool calls (no closing tag)
        if '</tool_call>' not in txt:
            fn_name, fn_args = extract_fn(txt)
            if fn_name:
                tool_calls.append(
                    ToolCall(
                        name=fn_name,
                        args=json.loads(fn_args) if fn_args else {},
                        id=str(tool_id),
                    )
                )
                tool_id += 1
                # new_messages.append(AIMessage(content='', tool_calls=tool_calls))
            continue
        
        # Handle complete tool calls
        one_tool_call_txt = txt.split('</tool_call>')[0].strip()
        
        try:
            # Try to parse as JSON
            fn = json5.loads(one_tool_call_txt)
            if 'name' in fn and 'arguments' in fn:
                tool_calls.append(
                    ToolCall(
                        name=fn['name'],
                        args=fn['arguments'],
                        id=str(tool_id),
                    )
                )
                tool_id += 1
                # new_messages.append(AIMessage(content='', tool_calls=tool_calls))
        except Exception:
            # Fallback to manual extraction
            fn_name, fn_args = extract_fn(one_tool_call_txt)
            if fn_name:
                tool_calls.append(
                    ToolCall(
                        name=fn_name,
                        args=json.loads(fn_args) if fn_args else {},
                        id=str(tool_id),
                    )
                )
                tool_id += 1
                # new_messages.append(AIMessage(content='', tool_calls=tool_calls))
        
    if tool_calls:
        return [AIMessage(content=pre_text, tool_calls=tool_calls)]
    elif pre_text:
        return [AIMessage(content=pre_text)]
    else:
        return [AIMessage(content=content)]

In [4]:
def get_today_date() -> str:
    """Get today's date formatted for system message."""
    today = datetime.today()
    day_names = [
        "Monday",
        "Tuesday",
        "Wednesday",
        "Thursday",
        "Friday",
        "Saturday",
        "Sunday",
    ]
    day_of_week = day_names[today.weekday()]
    month_name_full = today.strftime("%B")
    if today.day % 10 == 1 and today.day != 11:
        day_suffix = "st"
    elif today.day % 10 == 2 and today.day != 12:
        day_suffix = "nd"
    elif today.day % 10 == 3 and today.day != 13:
        day_suffix = "rd"
    else:
        day_suffix = "th"
    return f"{day_of_week}, {month_name_full} {today.day}{day_suffix}, {today.year}"


def preprocess_messages(
    state: MessagesState,
    system_prompt: str,
) -> List[AnyMessage]:
    """
    Convert LangChain messages with tool calls to plaintext format with <tool_call> tags.
    Converts ToolMessages to <tool_response> tags.
    Assumes all content is text (no multimodal).
    """
    messages: List[AnyMessage] = state["messages"]
    new_messages = []

    if messages[0].type == "system":
        new_messages.append(messages[0])
    else:
        date_info = "Hôm nay là {date}.\n".format(date=get_today_date())
        new_messages.append(SystemMessage(
            content=date_info + system_prompt
        ))

    for msg in messages[1:]:
        # Pass through human messages as-is
        if msg.type == "human":
            new_messages.append(msg)
            continue
        # Handle AI messages with tool calls
        elif msg.type == "ai":
            content = msg.content if isinstance(msg.content, str) else str(msg.content)
            
            # Convert tool calls to plaintext format
            if msg.tool_calls:
                for tool_call in msg.tool_calls:
                    fc = {
                        'name': tool_call['name'],
                        'arguments': tool_call['args']
                    }
                    fc_str = json.dumps(fc, ensure_ascii=False)
                    tool_call_text = f'<tool_call>\n{fc_str}\n</tool_call>'
                    
                    # Append to content
                    if content:
                        content += '\n' + tool_call_text
                    else:
                        content = tool_call_text
            
            # Merge consecutive AI messages
            if new_messages and new_messages[-1].type == "ai":
                prev_content = new_messages[-1].content
                if prev_content and not prev_content.endswith('\n'):
                    prev_content += '\n'
                new_messages[-1] = AIMessage(content=prev_content + content)
            else:
                new_messages.append(AIMessage(content=content))
            continue
        # Handle tool messages - convert to <tool_response> wrapped in HumanMessage
        elif msg.type == "tool":
            content = msg.content if isinstance(msg.content, str) else str(msg.content)
            response_text = f'<tool_response>\n{content}\n</tool_response>'
            if new_messages and new_messages[-1].type == "human":
                prev_content = new_messages[-1].content
                prev_content += '\n' + response_text
                new_messages[-1] = HumanMessage(content=prev_content)
            else:
                new_messages.append(HumanMessage(content=response_text))
            continue
        
    return new_messages


# SQLite Database

In [88]:
from typing import Any, Dict, Iterable, List, Literal, Sequence, Tuple, Union, Optional
from sqlalchemy import (
    MetaData,
    Table,
    create_engine,
    inspect,
    text,
    Column,
)
from sqlalchemy.engine import Engine, Result
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.types import NullType

In [122]:
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
    """Truncate a string to a certain number of words, based on the max string length."""
    if not isinstance(content, str) or length <= 0:
        return content
    if len(content) <= length:
        return content
    return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix


class SQLiteDatabase:
    """SQLAlchemy wrapper around a SQLite database with column comments support."""

    def __init__(
        self,
        engine: Engine,
        ignore_tables: Optional[List[str]] = None,
        include_tables: Optional[List[str]] = None,
        indexes_in_table_info: bool = False,
        max_string_length: int = 300,
        lazy_table_reflection: bool = False,
    ):
        """
        Create SQLite database wrapper.
        
        Args:
            engine: SQLAlchemy engine connected to SQLite database
            ignore_tables: List of table names to ignore
            include_tables: List of table names to include (mutually exclusive with ignore_tables)
            indexes_in_table_info: Whether to include index information in table info
            max_string_length: Maximum string length for truncating values
            lazy_table_reflection: Whether to lazily reflect tables
        """
        self._engine = engine
        if self._engine.dialect.name != "sqlite":
            raise ValueError("SQLiteDatabase only supports SQLite databases")
        
        if include_tables and ignore_tables:
            raise ValueError("Cannot specify both include_tables and ignore_tables")

        self._inspector = inspect(self._engine)
        self._all_tables = set(self._inspector.get_table_names())

        self._include_tables = set(include_tables) if include_tables else set()
        if self._include_tables:
            missing_tables = self._include_tables - self._all_tables
            if missing_tables:
                raise ValueError(f"include_tables {missing_tables} not found in database")
        
        self._ignore_tables = set(ignore_tables) if ignore_tables else set()
        if self._ignore_tables:
            missing_tables = self._ignore_tables - self._all_tables
            if missing_tables:
                raise ValueError(f"ignore_tables {missing_tables} not found in database")
        
        usable_tables = self.get_usable_table_names()
        self._usable_tables = set(usable_tables) if usable_tables else self._all_tables

        self._indexes_in_table_info = indexes_in_table_info
        self._max_string_length = max_string_length

        self._metadata = MetaData()
        if not lazy_table_reflection:
            self._metadata.reflect(
                bind=self._engine,
                only=list(self._usable_tables),
            )


    @classmethod
    def from_uri(
        cls,
        database_uri: str,
        engine_args: Optional[dict] = None,
        **kwargs: Any,
    ) -> "SQLiteDatabase":
        """Construct a SQLiteDatabase from URI."""
        _engine_args = engine_args or {}
        return cls(create_engine(database_uri, **_engine_args), **kwargs)


    @property
    def dialect(self) -> str:
        """Return string representation of dialect to use."""
        return "sqlite"


    def get_usable_table_names(self) -> Iterable[str]:
        """Get names of tables available."""
        if self._include_tables:
            base = set(self._include_tables)
        else:
            base = self._all_tables - self._ignore_tables

        # filter out metadata tables (companion EAV tables)
        base = {tbl for tbl in base if not tbl.endswith("__metadata")}
        return sorted(base)


    def get_table_info(
        self,
        table_name: str,
        get_col_comments: bool = False,
        allowed_col_names: Optional[List[str]] = None,
        sample_count: Optional[int] = None,
    ) -> str:
        """
        Get information about a specified table.

        Args:
            table_name: Name of the table to get info for
            get_col_comments: Whether to include column comments in the output
            allowed_col_names: If provided, only include these columns in the output.
                              If None, include all columns.
            sample_count: Number of distinct example values to include for each column.
                          If None, no example values are included.

        Returns:
            String containing table schema (CREATE TABLE statement) and optionally
            column comments and sample rows.
        """
        all_table_names = list(self.get_usable_table_names())
        if table_name not in all_table_names:
            raise ValueError(f"Table '{table_name}' not found in database. Available tables: {all_table_names}")

        # Ensure table is reflected
        metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
        if table_name not in metadata_table_names:
            self._metadata.reflect(
                bind=self._engine,
                only=[table_name],
            )

        # Find the table object
        table = None
        for tbl in self._metadata.sorted_tables:
            if tbl.name == table_name:
                table = tbl
                break

        if table is None:
            raise ValueError(f"Table '{table_name}' could not be reflected")

        # Remove NullType columns
        try:
            for _, v in table.columns.items():
                if type(v.type) is NullType:
                    table._columns.remove(v)
        except AttributeError:
            for _, v in dict(table.columns).items():
                if type(v.type) is NullType:
                    table._columns.remove(v)

        # Filter columns if allowed_col_names is specified
        display_columns = list(table.columns) if not allowed_col_names else [col for col in table.columns if col.name in allowed_col_names]
        if not display_columns:
            raise ValueError(f"No matching columns found. Requested: {allowed_col_names}")

        # Get sample values for columns if requested
        column_sample_values: Dict[str, List[str]] = {}
        if sample_count:
            column_sample_values = self._get_sample_values(
                table, display_columns, sample_count
            )

        # Build custom CREATE TABLE statement with filtered columns
        col_defs = []
        column_descriptions = (
            self._get_column_descriptions_from_metadata(table_name)
            if get_col_comments
            else {}
        )
        for col in display_columns:
            col_type = str(col.type) if not isinstance(col.type, NullType) else "TEXT"
            col_def = f'\t"{col.name}" {col_type}'
            
            # Build comment with description and example values
            comment_parts = []
            col_cmt = column_descriptions.get(col.name, "")
            if col_cmt:
                comment_parts.append(col_cmt)
            
            # Add sample values if available
            if col.name in column_sample_values and column_sample_values[col.name]:
                sample_values = column_sample_values[col.name]
                examples_str = ", ".join(str(v) for v in sample_values)
                comment_parts.append(f"ví dụ: {examples_str},...")
            
            if comment_parts:
                comment_text = " | ".join(comment_parts)
                col_def = f"{col_def}\t/* {comment_text} */"
            
            col_defs.append(col_def)

        col_defs.sort()        
        create_table = f'CREATE TABLE "{table_name}" (\n' + ", \n".join(col_defs) + "\n)"

        table_info = f"{create_table.rstrip()}"
            
        # Add indexes if needed
        if self._indexes_in_table_info:
            table_info += "\n\n/*"
            table_info += f"\n{self._get_table_indexes(table)}\n"
            table_info += "*/"

        return table_info


    def _get_column_descriptions_from_metadata(
        self, table_name: str
    ) -> Dict[str, str]:
        """
        Fetch column descriptions from the metadata EAV table created alongside the data table.

        Expects a companion table named "{table_name}__metadata" with rows:
            entity = column name
            attribute = "description"
            value = description text
        """
        metadata_table = f"{table_name}__metadata"
        if metadata_table not in self._all_tables:
            return {}

        try:
            query = text(
                f'SELECT entity, value FROM "{metadata_table}" WHERE attribute = :attr'
            )
            with self._engine.connect() as connection:
                result: Result = connection.execute(query, {"attr": "description"})
                return {row[0]: row[1] for row in result if row[1] is not None}
        except (ProgrammingError, SQLAlchemyError):
            return {}


    def get_column_groups(self, table_name: str) -> List[List[str]]:
        """
        Return column groups for a table based on its metadata companion table.

        Reads rows where attribute == "group" from "{table_name}__metadata" and
        builds a list of column-name lists, ordered by group id.
        """
        metadata_table = f"{table_name}__metadata"
        if metadata_table not in self._all_tables:
            return []

        groups: Dict[int, List[str]] = {}
        try:
            query = text(
                f'SELECT entity, value FROM "{metadata_table}" WHERE attribute = :attr'
            )
            with self._engine.connect() as connection:
                result: Result = connection.execute(query, {"attr": "group"})
                for entity, value in result:
                    if value is None:
                        continue
                    try:
                        group_id = int(value)
                    except (TypeError, ValueError):
                        continue
                    groups.setdefault(group_id, []).append(entity)
        except (ProgrammingError, SQLAlchemyError):
            return []

        if not groups:
            return []

        return [groups[idx] for idx in sorted(groups.keys())]


    def _get_table_indexes(self, table: Table) -> str:
        """Get formatted index information for a table."""
        indexes = self._inspector.get_indexes(table.name)
        indexes_formatted = "\n".join(
            f"Name: {idx['name']}, Unique: {idx['unique']}, Columns: {idx['column_names']}"
            for idx in indexes
        )
        return f"Table Indexes:\n{indexes_formatted}"


    def _get_sample_values(
        self,
        table: Table,
        columns: List[Column],
        sample_count: int,
    ) -> Dict[str, List[str]]:
        """
        Get up to sample_count distinct example values per column.

        Strings are quoted to reflect their type. Values longer than 100 chars are skipped.
        """
        if sample_count <= 0:
            return {}

        column_sample_values: Dict[str, List[str]] = {col.name: [] for col in columns}
        for col in columns:
            query = text(
                f'SELECT DISTINCT "{col.name}" '
                f'FROM "{table.name}" '
                f'WHERE "{col.name}" IS NOT NULL '
                f"LIMIT {sample_count}"
            )

            try:
                with self._engine.connect() as connection:
                    result = connection.execute(query)
                    remaining_length = 1000
                    for val, in result:
                        val_str = str(val)
                        # Represent type: quote strings, leave others as-is
                        display_val = f'"{val_str}"' if isinstance(val, str) else val_str
                        column_sample_values[col.name].append(display_val)
                        remaining_length -= len(display_val)
                        if remaining_length <= 0:
                            break

            except ProgrammingError:
                continue

        return column_sample_values


    def _execute(
        self,
        command: str,
        fetch: Literal["all", "one", "cursor"] = "all",
        *,
        parameters: Optional[Dict[str, Any]] = None,
        execution_options: Optional[Dict[str, Any]] = None,
    ) -> Union[Sequence[Dict[str, Any]], Result]:
        """Execute SQL command through underlying engine."""
        parameters = parameters or {}
        execution_options = execution_options or {}
        
        with self._engine.begin() as connection:
            cursor = connection.execute(
                text(command),
                parameters,
                execution_options=execution_options,
            )

            if cursor.returns_rows:
                if fetch == "all":
                    result = [x._asdict() for x in cursor.fetchall()]
                elif fetch == "one":
                    first_result = cursor.fetchone()
                    result = [] if first_result is None else [first_result._asdict()]
                elif fetch == "cursor":
                    return cursor
                else:
                    raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'")
                return result
        return []


    def run(
        self,
        command: str,
        fetch: Literal["all", "one", "cursor"] = "all",
        include_columns: bool = False,
        *,
        parameters: Optional[Dict[str, Any]] = None,
        execution_options: Optional[Dict[str, Any]] = None,
    ) -> Union[Sequence[Dict[str, Any]], Sequence[Tuple[Any, ...]], Result[Any]]:
        """Execute a SQL command and return a string representing the results."""
        result = self._execute(
            command, fetch, parameters=parameters, execution_options=execution_options
        )

        if fetch == "cursor":
            return result

        if include_columns:
            return [
                {
                    column: truncate_word(value, length=self._max_string_length)
                    for column, value in r.items()
                }
                for r in result
            ]
        else:
            return [
                tuple(
                    truncate_word(value, length=self._max_string_length)
                    for value in r.values()
                )
                for r in result
            ]


    def run_no_throw(
        self,
        command: str,
        fetch: Literal["all", "one"] = "all",
        include_columns: bool = False,
        *,
        parameters: Optional[Dict[str, Any]] = None,
        execution_options: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Execute a SQL command and return results or error message."""
        try:
            res = self.run(
                command,
                fetch,
                parameters=parameters,
                execution_options=execution_options,
                include_columns=include_columns,
            )
            return {
                "result": res,
                "error": None,
            }
        except SQLAlchemyError as e:
            return {
                "result": [],
                "error": f"Error: {e}",
            }


    def get_table_info_no_throw(
        self,
        table_name: str,
        get_col_comments: bool = False,
        allowed_col_names: Optional[List[str]] = None,
        sample_count: Optional[int] = None,
    ) -> str:
        """Get table info without throwing exceptions."""
        try:
            return self.get_table_info(
                table_name,
                get_col_comments=get_col_comments,
                allowed_col_names=allowed_col_names,
                sample_count=sample_count,
            )
        except ValueError as e:
            return f"Error: {e}"


    def get_context(self) -> Dict[str, Any]:
        """Return db context that you may want in agent prompt."""
        table_names = list(self.get_usable_table_names())
        # Get info for all tables
        table_infos = []
        for tbl in table_names:
            table_infos.append(self.get_table_info_no_throw(tbl))
        table_info = "\n\n".join(table_infos)
        return {"table_info": table_info, "table_names": ", ".join(table_names)}

In [123]:
@lru_cache()
def get_sqlite_db(business_name: str):
    return SQLiteDatabase.from_uri(
        f"sqlite:////Users/vinhnguyen/Projects/ext-chatbot/resources/database/{business_name}.db",
    )

In [124]:
db = get_sqlite_db("batdongsan")

In [125]:
# db.run("""
# select * from "BĐS Bán 500" limit 5;
# """, include_columns=True)

In [126]:
db.get_usable_table_names()

['BĐS Bán 500', 'BĐS Cho thuê 500']

In [127]:
# db.get_column_groups("BĐS Bán 500")

In [129]:
print(db.get_table_info_no_throw(
    "BĐS Bán 500",
    get_col_comments=True,
    allowed_col_names=["Bãi đỗ xe", "Chiều dài (m)"],
    sample_count=5
))

CREATE TABLE "BĐS Bán 500" (
	"Bãi đỗ xe" TEXT	/* Thông tin về khả năng đỗ xe (bao gồm số lượng và loại phương tiện). | ví dụ: "Có", "Không", "Nhiều xe máy", "1 ô tô", "2 ô tô",... */, 
	"Chiều dài (m)" REAL	/* Độ dài chiều dài của bất động sản tính theo mét. | ví dụ: 3.9, 8.6, 16.4, 7.7, 32.6,... */
)


# Chains

In [79]:
db.get_usable_table_names()

['BĐS Bán 500', 'BĐS Cho thuê 500']

In [80]:
import random

print(db.get_table_info_no_throw(
    "BĐS Cho thuê 500",
    get_col_comments=True,
    allowed_col_names=None,
    sample_row_ids=random.sample(range(1, 50), 5)
))

CREATE TABLE "BĐS Cho thuê 500" (
	"An ninh" TEXT	/* Loại hình an ninh tại bất động sản (ví dụ: camera, bảo vệ 24/7). */, 
	"Ban công/Sân thượng" TEXT	/* Loại không gian ngoại thất (ban công, sân thượng, cả hai). */, 
	"Bãi đỗ xe" TEXT	/* Thông tin về bãi đỗ xe (loại xe, phí, số lượng). */, 
	"Bếp" TEXT	/* Mô tả về tiện ích bếp (riêng/chung/khu bếp đầy đủ). */, 
	"Cho phép nuôi thú cưng" TEXT	/* Thông tin về việc cho phép nuôi thú cưng (Có/Không/Thỏa thuận). */, 
	"Diện tích (m²)" INTEGER	/* Diện tích của bất động sản được đo bằng mét vuông. */, 
	"Dự án" TEXT	/* Tên dự án bất động sản (nếu có), hoặc 'Không thuộc dự án' nếu không thuộc bất kỳ dự án nào. */, 
	"Email" TEXT	/* Địa chỉ email liên hệ. */, 
	"Ghi chú" TEXT	/* Thông tin bổ sung về yêu cầu thuê. */, 
	"Giá thuê (triệu/tháng)" REAL	/* Giá thuê hàng tháng tính theo triệu đồng. */, 
	"Giá/m²/tháng" TEXT	/* Giá thuê theo mét vuông mỗi tháng, bao gồm ký tự tiền tệ và dấu phẩy. */, 
	"Giá/m²/tháng_số" REAL	/* Giá thuê theo mét vuôn

In [None]:
import time
from pydantic import BaseModel, Field
import functools
# import asyncio
from tqdm.asyncio import tqdm_asyncio

from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import StateGraph, START, END, MessagesState
from langchain.tools import tool

In [15]:
llm = get_llm_model()

## SQL Chain

### Schema Linking Chain

In [None]:
# table selection (run async for multi table / run one time for all tables)
# column selection (run async for multi column group / run one time for all columns)
# query generation
# run query
# return message

In [57]:
SCHEMA_LINKING_TEMPLATE = """
You are an expert in SQL schema linking. 
Given a {dialect} table schema (DDL) and a user query, determine if the table is relevant to the query.

Your task:
1. Analyze the table columns and the user query to decide if they are related.
2. Answer "Y" (Yes) or "N" (No).
3. If the answer is "Y", list ALL columns that are semantically related to the query topics. 
   - You do NOT need to identify the exact columns for the final SQL query. 
   - You SHOULD include any columns that provide context, identifiers, or potential join keys related to the entities in the query.

Output must be a valid JSON object inside a ```json code block using this format:
```json
{{
    "think": "Brief reasoning step by step",
    "is_related": "Y or N",
    "columns": ["col_name1", "col_name2"]
}}
```

Table Schema (DDL):
{table_info}

User Query:
{query}
""".strip()

schema_linking_chain = (
    ChatPromptTemplate([("human", SCHEMA_LINKING_TEMPLATE)])
    | llm
    | JsonOutputParser()
)

async def _link_schema_one(
    query: str,
    table_name: str,
    allowed_col_names: Optional[List[str]] = None
) -> Dict[str, Any]:
    try:
        table_info = db.get_table_info_no_throw(
            table_name,
            get_col_comments=True,
            allowed_col_names=allowed_col_names,
        )
        result = await schema_linking_chain.ainvoke(
            {"table_info": table_info, "query": query, "dialect": db.dialect}
        )
        if "is_related" not in result:
            raise ValueError("Invalid response from schema linking chain")
        if result["is_related"] == "N":
            return {
                "input_item": {"table_name": table_name, "query": query, "allowed_col_names": allowed_col_names},
                "result": {},
                "error": None
            }
        else:
            return {
                "input_item": {"table_name": table_name, "query": query, "allowed_col_names": allowed_col_names},
                "result": {
                    "table_name": table_name,
                    "col_names": result["columns"]
                },
                "error": None
            }
    except Exception as e:
        return {
            "input_item": {"table_name": table_name, "query": query},
            "result": None,
            "error": str(e)
        }


async def link_schema(
    query: str,
    max_retries: int = 3
) -> Dict[str, str]:
    # queue = []
    # for table in  db.get_usable_table_names():
    #     for col_group in db.get_column_groups(table):
    #         queue.append({
    #             "table_name": table,
    #             "allowed_col_names": col_group,
    #             "query": query
    #         })
    queue = [{"table_name": table_name, "query": query} for table_name in db.get_usable_table_names()]
    successful_results = []
    for _ in range(max_retries):
        tasks = [_link_schema_one(**input_item) for input_item in queue]
        results = await tqdm_asyncio.gather(*tasks)
        successful_results.extend([
            res for res in results if res["error"] is None
        ])
        failed_items = [
            res["input_item"] for res in results if res["error"] is not None
        ]
        queue = failed_items
        if not queue:
            break
    
    return [result["result"] for result in successful_results if result["result"]]

In [60]:
query = "Tìm danh sách nhà cho thuê ở trên đường Láng"
results = await link_schema(query)
results

100%|██████████| 2/2 [00:05<00:00,  2.82s/it]


[{'table_name': 'BĐS Bán 500',
  'col_names': ['Địa chỉ_Tên đường',
   'Loại BĐS',
   'Địa chỉ',
   'Phường/Xã',
   'Quận/Huyện',
   'Tỉnh/TP',
   'Tiêu đề']},
 {'table_name': 'BĐS Cho thuê 500', 'col_names': ['Địa chỉ', 'Địa chỉ_đường']}]

### SQL Generation

In [35]:
import random

print(db.get_table_info_no_throw(
    results[0]["table_name"],
    allowed_col_names=results[0]["col_names"],
    get_col_comments=True,
    sample_row_ids=random.sample(range(1, 50), 5)
))

CREATE TABLE "BĐS Cho thuê 500" (
	"Ngày có thể chuyển vào" TEXT	/* Ngày mà người thuê có thể chuyển vào */, 
	"Tình trạng" TEXT	/* Trạng thái của bất động sản (ví dụ: trống, đang cho thuê, mới sửa chữa). */
)

/*
5 sample rows from BĐS Cho thuê 500 table (by ROWID):
Ngày có thể chuyển vào	Tình trạng
03/11/2025	Trống
12/12/2025	Đang có người thuê
27/11/2025	Trống
07/11/2025	Sắp trống
04/12/2025	Trống
*/


## Orchestrator

In [None]:
ORCHESTRATOR_PROMPT = """
Bạn là một trợ lý ảo chuyên trả lời các câu hỏi về thị trường bất động sản dựa trên một bộ dữ liệu nội bộ. Cơ sở dữ liệu này bao gồm danh sách các bất động sản bán và cho thuê. Khi người dùng hỏi về các chủ đề này, nhiệm vụ của bạn là gọi một trợ lý ảo khác (SQL Assistant) và đưa ra mệnh lệnh (bằng ngôn ngữ tự nhiên) để trợ lý ảo thực hiện viết câu truy vấn và thu thập thông tin liên quan. Đối với các câu hỏi nằm ngoài lĩnh vực bất động sản này, hãy sử dụng kiến thức của bạn để trả lời trực tiếp.
""".strip()


class SQLAssistantInput(BaseModel):
    command: str = Field(description="The command in natural language")


@tool("call_sql_assistant",args_schema=SQLAssistantInput)
def call_sql_assistant(command: str) -> str:
    """Call the SQL Assistant to get the information"""
    return "Hello, world!"


orchestrator_chain = RunnablePassthrough.assign(
    messages=functools.partial(
        preprocess_messages, system_prompt=ORCHESTRATOR_PROMPT)
    | get_llm_model().bind_tools([call_sql_assistant])
    | postprocess_ai_message
)

In [148]:
builder = StateGraph(MessagesState)
builder.add_node("orchestrator", orchestrator_chain)
builder.add_edge(START, "orchestrator")
builder.add_edge("orchestrator", END)

orchestrator = builder.compile()

In [149]:
start_time = time.time()
state = orchestrator.invoke({"messages": [
    {"role": "user", "content": "Hello, hôm qua ngày bao nhiêu nhể"},
    {"role": "assistant", "content": "Hôm qua là ngày 7 tháng 12 năm 2025"},
    {"role": "user", "content": "Có bao nhiêu nhà đang được cho thuê nhỉ"}
    # {"role": "user", "content": "1 + 2 + 3 + ... + 100 = ?"}
]})
for message in state["messages"]:
    print(message)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

content='Hello, hôm qua ngày bao nhiêu nhể' additional_kwargs={} response_metadata={} id='7f890688-38f2-4df2-b5f5-df069beacc23'
content='Hôm qua là ngày 7 tháng 12 năm 2025' additional_kwargs={} response_metadata={} id='cfbdbf8b-aaea-4b3b-bd37-335f02d81651'
content='Có bao nhiêu nhà đang được cho thuê nhỉ' additional_kwargs={} response_metadata={} id='3263ef3c-db6c-4c6d-bf56-b666c0e1d0a5'
content='' additional_kwargs={} response_metadata={} id='69cdb966-86fe-4c14-a3fc-5c7d989c8b9a' tool_calls=[{'name': 'call_sql_assistant', 'args': {'command': 'Hãy đếm số lượng nhà đang được cho thuê trong cơ sở dữ liệu.'}, 'id': '1', 'type': 'tool_call'}]
Time taken: 1.0824270248413086 seconds
