In [70]:
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import AzureChatOpenAI

In [71]:
# Database configurations (3 databases)
DATABASES = {
    "MySQL": "mysql+pymysql://sql12755292:uiIV5jjJhR@sql12.freesqldatabase.com:3306/sql12755292",
    "PostgreSQL": "postgresql://costdb_owner:5mneLU6dIRVf@ep-fancy-feather-a1l3r3st.ap-southeast-1.aws.neon.tech/costdb?sslmode=require"
    # "SQLServer": "mssql+pyodbc://@Admin/Cost_Central_Monitor?driver=ODBC+Driver+17+for+SQL+Server&Trusted_Connection=yes"
}

In [72]:
db_instances = {name: SQLDatabase.from_uri(uri) for name, uri in DATABASES.items()}

In [73]:
# Get LLM model
def get_llm():
    return AzureChatOpenAI(
        azure_endpoint="https://ifd-copilot-internship-program.openai.azure.com",
        azure_deployment="gpt-4o-mini",
        api_key="",  # Remember to replace with your actual API key
        api_version="2024-08-01-preview",
        temperature=0,
        max_tokens=None,
        timeout=None,
        max_retries=2
    )

In [74]:
# Define Semantic Kernel function to select relevant databases
from semantic_kernel import Kernel
from semantic_kernel.functions.kernel_function_decorator import kernel_function
from semantic_kernel.functions.kernel_arguments import KernelArguments

In [84]:
class DataPlugin:
    @kernel_function(name="select_databases", description="Choose relevant databases for the query")
    def select_databases(self, user_query, db_schemas):
        """
        This function uses the Kernel to select relevant databases for the user query.
        The result will be a list of database names.
        """
        llm = get_llm()
        schemas_context = "\n".join([f"{name}: {schema}" for name, schema in db_schemas.items()])
        
        # Cập nhật prompt để yêu cầu danh sách cơ sở dữ liệu phù hợp
        prompt = f"""
        Here are the schemas of the available databases:
        {schemas_context}

        User query: {user_query}

        Please return a comma-separated list of database names that are most relevant for the question. 
        Do not include any explanations or extra text, just the names of the databases.
        """
        
        # Truyền chuỗi vào invoke() và lấy kết quả
        result = llm.invoke(prompt)
        
        # Trả về danh sách cơ sở dữ liệu sau khi tách chuỗi
        return [db.strip() for db in result.content.strip().split(',')]

    def get_sql_chain(self, db):
        template = """
        Based on the schema, write a SQL query to answer the user's question.
        <SCHEMA>{schema}</SCHEMA>

        Question: {question}

        Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.
        """
        prompt = ChatPromptTemplate.from_template(template)
        llm = get_llm()

        def get_schema(_):
            return db.get_table_info()

        return (
            RunnablePassthrough.assign(schema=get_schema)
            | prompt
            | llm
            | StrOutputParser()
        )

    def process_results_with_ai(self, raw_results, user_query):
        llm = get_llm()
        prompt = f"Here are the results from the query: {user_query}. Format the results nicely:\n{raw_results}"

        result = llm.invoke(prompt)  
        return result.content.strip()

    def get_sql_response(self, db, user_query):
        sql_query = self.get_sql_chain(db).invoke({"question": user_query})
        raw_results = db.run(sql_query)
        return raw_results

    @kernel_function(name="query_multiple_databases")
    def query_multiple_databases(self, user_query, db_instances):
        # Step 1: Get schemas from all databases
        db_schemas = {name: db.get_table_info() for name, db in db_instances.items()}

        # Step 2: Use LLM to select the most relevant databases
        selected_db_names = self.select_databases(user_query, db_schemas)
        
        # Step 3: Generate SQL and fetch results for each selected database
        results = []
        for db_name in selected_db_names:
            db_name = db_name.strip()
            selected_db = db_instances.get(db_name)
            if selected_db:
                raw_result = self.get_sql_response(selected_db, user_query)
                processed_results = self.process_results_with_ai(raw_result, user_query)
                results.append((db_name, processed_results))
        
        return results

In [85]:
kernel = Kernel()
data_plugin = kernel.add_plugin(DataPlugin(),'DataPlugin')

In [86]:
user_query = "How many project are there?"
argument = KernelArguments(user_query = user_query, db_instances = db_instances)
result = await kernel.invoke(data_plugin['query_multiple_databases'], argument)

In [87]:
result

FunctionResult(function=KernelFunctionMetadata(name='query_multiple_databases', plugin_name='DataPlugin', description=None, parameters=[KernelParameterMetadata(name='user_query', description=None, default_value=None, type_='Any', is_required=True, type_object=<class 'inspect._empty'>, schema_data={'type': 'object', 'properties': {}}, include_in_function_choices=True), KernelParameterMetadata(name='db_instances', description=None, default_value=None, type_='Any', is_required=True, type_object=<class 'inspect._empty'>, schema_data={'type': 'object', 'properties': {}}, include_in_function_choices=True)], is_prompt=False, is_asynchronous=False, return_parameter=KernelParameterMetadata(name='return', description='', default_value=None, type_='Any', is_required=True, type_object=None, schema_data={'type': 'object'}, include_in_function_choices=True), additional_properties={}), value=[('MySQL', 'The query results indicate that there are a total of **100 projects**.'), ('PostgreSQL', 'The quer

In [None]:
dt = DataPlugin()
db_schemas = {name: db.get_table_info() for name, db in db_instances.items()}

['MySQL', 'PostgreSQL']

In [80]:
dt.get_sql_chain(db_instances['PostgreSQL']).invoke({"question": user_query})

'SELECT COUNT(*) FROM tbl_projects;'

In [81]:
data = dt.get_sql_response(db_instances['MySQL'], user_query)

In [82]:
data

'The query results indicate that there is a total of **100 projects**.'

In [83]:
dt.process_results_with_ai(data,user_query)

'The query results indicate that there is a total of **100 projects**.'

In [88]:
dt.query_multiple_databases(user_query,db_instances)

[('MySQL',
  'The query results indicate that there are a total of **100 projects**.'),
 ('PostgreSQL',
  'The query results indicate that there are a total of **5 projects**.')]