## langcain agent demo for ReAct

In [None]:
!pip install langchain[all]
!pip install langchain-experimental
#!pip install sagemaker --upgrade
!pip install  boto3
!pip install requests_aws4auth
!pip install opensearch-py
!pip install pydantic==1.10.0
#!pip install PyAthena[SQLAlchemy]==1.0.0
#!pip install PyAthena[JDBC]==1.0.0
#!pip install openai
!pip install sqlalchemy-redshift
!pip install redshift_connector
!pip install SQLAlchemy

In [None]:
!pip install SQLAlchemy

## initial sagemaker env

In [1]:
import os
import sagemaker
import boto3
import json
from typing import Dict
from typing import Any, Dict, List, Optional

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
sm_client = boto3.client("sagemaker-runtime")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker role arn: arn:aws:iam::687912291502:role/service-role/AmazonSageMaker-ExecutionRole-20211013T113123
sagemaker bucket: sagemaker-us-west-2-687912291502
sagemaker session region: us-west-2


## intial lanchain lib

In [66]:
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory,ConversationBufferMemory
from langchain import LLMChain
from typing import Any, Dict, List, Union,Mapping, Optional, TypeVar, Union
import os
import time
import json
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth

# set environment OPENAI_API_KEY
#from langchain.embeddings.openai import OpenAIEmbeddings

aos_endpoint="vpc-llm-rag-aos-seg3mzhpp76ncpxezdqtcsoiga.us-west-2.es.amazonaws.com"
region='us-west-2'
# replace to your real username and passsword
username=""
passwd=""
index_name="metadata-index"
size=10

credentials = boto3.Session().get_credentials()
region = boto3.Session().region_name
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, 'es', session_token=credentials.token)
pwdauth = (username, passwd)

### for sqlcoder
class TextGenContentHandler2(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        #print(response_json)
        #sql_result=response_json["outputs"].split("```sql")[-1].split("```")[0].split(";")[0].strip().replace("\\n"," ") + ";"
        sql_result=response_json["outputs"]
        return sql_result


content_hander2=TextGenContentHandler2()

## for embedding
class EmbeddingContentHandler(EmbeddingsContentHandler):
    parameters = {
        "max_new_tokens": 50,
        "temperature": 0,
        "min_length": 10,
        "no_repeat_ngram_size": 2,
    }
    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["sentence_embeddings"]

embedding_content_handler=EmbeddingContentHandler()
    
sm_embeddings = SagemakerEndpointEmbeddings(
    # endpoint_name="endpoint-name", 
    # credentials_profile_name="credentials-profile-name", 
    #endpoint_name="huggingface-textembedding-bloom-7b1-fp1-2023-04-17-03-31-12-148", 
    endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint",
    region_name="us-west-2", 
    content_handler=embedding_content_handler
)


parameters = {
  "max_new_tokens": 450,
  #"do_sample":False,
  #"temperatual" : 0
  #"no_repeat_ngram_size": 2,
}
sm_sql_llm=SagemakerEndpoint(
        endpoint_name="sqlcoder-2023-09-24-06-31-09-959-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=content_hander2
)


## func for agent

In [67]:
import boto3
import json
import requests
import time
from collections import defaultdict
from requests_aws4auth import AWS4Auth
import os
from opensearchpy import OpenSearch, RequestsHttpConnection
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory
from langchain import LLMChain



def aos_knn_search(client, field,q_embedding, index, size=1):
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = {
        "size": size,
        "query": {
            "knn": {
                field: {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }
    opensearch_knn_respose = []
    query_response = client.search(
        body=query,
        index=index
    )
    opensearch_knn_respose = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'query_desc_text':item['_source']['query_desc_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose

def aos_reverse_search(client, index_name, field, query_term, exactly_match=False, size=1):
    """
    search opensearch with query.
    :param host: AOS endpoint
    :param index_name: Target Index Name
    :param field: search field
    :param query_term: query term
    :return: aos response json
    """
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = None
    if exactly_match:
        query =  {
            "query" : {
                "match_phrase":{
                    "doc": {
                        "query": query_term,
                        "analyzer": "ik_smart"
                      }
                }
            }
        }
    else:
        query = {
            "size": size,
            "query": {
                "query_string": {
                "default_field": "query_desc_text",  
                "query": query_term         
              }
            },
           "sort": [{
               "_score": {
                   "order": "desc"
               }
           }]
    }        
    query_response = client.search(
        body=query,
        index=index_name
    )
    result_arr = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'query_desc_text':item['_source']['query_desc_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return result_arr




def get_vector_by_sm_endpoint(questions, sm_client, endpoint_name):
    parameters = {
    }

    response_model = sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(
            {
                "inputs": questions,
                "parameters": parameters,
                "is_query" : True,
                "instruction" :  "为这个句子生成表示以用于检索相关文章："
            }
        ),
        ContentType="application/json",
    )
    json_str = response_model['Body'].read().decode('utf8')
    json_obj = json.loads(json_str)
    embeddings = json_obj['sentence_embeddings']
    return embeddings




def get_topk_items(opensearch_query_response, topk=5):
    opensearch_knn_nodup = []
    unique_ids = set()
    for item in opensearch_query_response:
        if item['id'] not in unique_ids:
            opensearch_knn_nodup.append(item['score'], item['idx'],item['database_name'],item['table_name'],item['query_desc_text'])
            unique_ids.add(item['id'])
    return opensearch_knn_nodup



def k_nn_ingestion_by_aos(docs,index,hostname,username,passwd):
    auth = (username, passwd)
    search = OpenSearch(
        hosts = [{'host': aos_endpoint, 'port': 443}],
        ##http_auth = awsauth ,
        http_auth = auth ,
        use_ssl = True,
        verify_certs = True,
        connection_class = RequestsHttpConnection
    )
    for doc in docs:
        query_desc_embedding = doc['query_desc_embedding']
        database_name = doc['database_name']
        table_name = doc['table_name']
        query_desc_text = doc["query_desc_text"]
        document = { "query_desc_embedding": query_desc_embedding, 'database_name':database_name, "table_name": table_name,"query_desc_text":query_desc_text}
        search.index(index=index, body=document)

* for local test only

In [51]:
query="上个月温度合格的派车单数量"
embedding_endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint"
query_embedding = get_vector_by_sm_endpoint(query, sm_client, embedding_endpoint_name)

client=None
rets=aos_knn_search(client, "query_desc_embedding",query_embedding[0], index_name, size=1)
print(rets)

rets=opensearch_query_response = aos_reverse_search(client, index_name, "query_desc_text", query)   
print(rets)

[{'idx': 1, 'database_name': 'llm', 'table_name': 'ads_bi_quality_monitor_shipping_detail', 'query_desc_text': '最近一个月温度合格的派车单数量', 'score': 0.94651675}]
[{'idx': 1, 'database_name': 'llm', 'table_name': 'ads_bi_quality_monitor_shipping_detail', 'query_desc_text': '最近一个月温度合格的派车单数量', 'score': 1.4384104}]


## major chain pipeline ################

### 0: index 创建

PUT metadata-index
{
    "settings" : {
        "index":{
            "number_of_shards" : 5,
            "number_of_replicas" : 0,
            "knn": "true",
            "knn.algo_param.ef_search": 32
        }
    },
    "mappings": {
        "properties": {
            "metadata_type" : {
                "type" : "keyword"
            },
            "database_name": {
                "type" : "keyword"
            },
            "table_name": {
               "type" : "keyword"
            },
            "query_desc_text": {
                "type": "text",
                "analyzer": "ik_max_word",
                "search_analyzer": "ik_smart"
            },
            "query_desc_embedding": {
                "type": "knn_vector",
                "dimension": 768,
                "method": {
                    "name": "hnsw",
                    "space_type": "l2",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 512,
                        "m": 32
                    }
                }            
            }
        }
    }
}

### 1: data process

In [52]:
all_querys = """最近一个月温度合格的派车单数量"""
querys = all_querys.split("\n")

all_tables = """ads_bi_quality_monitor_shipping_detail"""
tables=all_tables.split("\n")

all_dbs = """llm"""
dbs=all_dbs.split("\n")

In [53]:
#import sys
#sys.path.append("./code/")
#import func

embedding_endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint"
##########embedding by llm model##############
sentense_vectors = []
sentense_vectors=get_vector_by_sm_endpoint(querys,sm_client,embedding_endpoint_name)

In [54]:
docs=[]
for index, sentence_vector in enumerate(sentense_vectors):
    #print(index, sentence_vector)
    doc = {
        "metadata_type":"table",
        "database_name":dbs[index],
        "table_name": tables[index],
        "query_desc_text":querys[index],
        "query_desc_embedding": sentence_vector
          }
    docs.append(doc)

#print((doc["database_name"]))
#########ingestion into aos ###################
k_nn_ingestion_by_aos(docs,index_name,aos_endpoint,username,passwd)

### 2:自定义Agent ，定制tools
* 自定义AOS倒排及knn检索tools    
* 自定义中文Sql Agent 的ReAct prompt 前缀   
* 定制CustomerizedSqlDatabaseChain作为db tools做数据库交互
* 将db tools加入之前定义的元数据召回的tools列表

In [68]:
from langchain.tools.base import BaseTool, Tool, tool
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain.agents.agent_types import AgentType
from typing import Optional, Type
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
    CallbackManagerForChainRun
)


name="embedding knn search"
description="用于标签检索找到具体的数据库和表名"
aos_client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection)
aos_index="metadata-index"

"""Opensearch 向量检索."""
def customEmbeddingSearch(query: str) -> str:
    start = time.time()
    query_embedding = get_vector_by_sm_endpoint(query, sm_client, embedding_endpoint_name)
    elpase_time = time.time() - start
    #print(f'runing time of opensearch_knn : {elpase_time}s seconds')
    responses = aos_knn_search(aos_client, "query_desc_embedding",query_embedding[0], aos_index, size=10)
    #return "{table name: "+responses[0]["table_name"] +"}" 
    return responses[0]["table_name"] 

"""Opensearch 标签检索."""
def customReverseIndexSearch(query: str) -> str:
    start = time.time()
    opensearch_query_response = aos_reverse_search(aos_client, aos_index, "query_desc_text", query)
    elpase_time = time.time() - start
    #print(f'runing time of opensearch_query : {elpase_time}s seconds')
    #return "{table name: "+opensearch_query_response[0]["table_name"] +"}" 
    return opensearch_query_response[0]["table_name"] 

 
db = SQLDatabase.from_uri(
    "mysql+pymysql://admin:admin12345678@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    sample_rows_in_table_info=0)    
    
class CustomerizedSQLDatabaseChain(SQLDatabaseChain):
    
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        input_text = f"{inputs[self.input_key]}\nSQLQuery:"
        _run_manager.on_text(input_text, verbose=self.verbose)
        # If not present, then defaults to None which is all tables.
        table_names_to_use = inputs.get("table_names_to_use")
        print("table name==")
        print(self.database._include_tables)
        table_info = self.database.get_table_info(table_names=table_names_to_use)
        llm_inputs = {
            "input": input_text,
            "top_k": str(self.top_k),
            "dialect": self.database.dialect,
            "table_info": table_info,
            #"stop": ["\nSQLResult:"],
        }
        intermediate_steps: List = []
        try:
            intermediate_steps.append(llm_inputs)  # input: sql generation
            sql_cmd = self.llm_chain.predict(
                callbacks=_run_manager.get_child(),
                **llm_inputs,
            ).strip()
            print("orginal sql_cmd=="+sql_cmd)
            if self.return_sql:
                return {self.output_key: sql_cmd}
            if not self.use_query_checker:
                _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
                intermediate_steps.append(
                    sql_cmd
                )  # output: sql generation (no checker)
                #########定制sqlcoder 模型输出##############
                pattern = r"SQLQuery: (.*?)\n"
                matches = re.findall(pattern, sql_cmd)
                match = matches[1]
                sql_cmd = match
                #print("query sql=="+sql_cmd) 
                
                intermediate_steps.append({"sql_cmd": sql_cmd})  # input: sql exec
                result = self.database.run(sql_cmd)
                intermediate_steps.append(str(result))  # output: sql exec
            else:
                query_checker_prompt = self.query_checker_prompt or PromptTemplate(
                    template=QUERY_CHECKER, input_variables=["query", "dialect"]
                )
                query_checker_chain = LLMChain(
                    llm=self.llm_chain.llm, prompt=query_checker_prompt
                )
                query_checker_inputs = {
                    "query": sql_cmd,
                    "dialect": self.database.dialect,
                }
                checked_sql_command: str = query_checker_chain.predict(
                    callbacks=_run_manager.get_child(), **query_checker_inputs
                ).strip()
                intermediate_steps.append(
                    checked_sql_command
                )  # output: sql generation (checker)
                _run_manager.on_text(
                    checked_sql_command, color="green", verbose=self.verbose
                )
                intermediate_steps.append(
                    {"sql_cmd": checked_sql_command}
                )  # input: sql exec
                result = self.database.run(checked_sql_command)
                intermediate_steps.append(str(result))  # output: sql exec
                sql_cmd = checked_sql_command

            _run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
            _run_manager.on_text(result, color="yellow", verbose=self.verbose)
            # If return direct, we just set the final result equal to
            # the result of the sql query result, otherwise try to get a human readable
            # final answer
            if self.return_direct:
                final_result = result
            else:
                _run_manager.on_text("\nAnswer:", verbose=self.verbose)
                input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
                llm_inputs["input"] = input_text
                intermediate_steps.append(llm_inputs)  # input: final answer
                final_result = self.llm_chain.predict(
                    callbacks=_run_manager.get_child(),
                    **llm_inputs,
                ).strip()
                intermediate_steps.append(final_result)  # output: final answer
                _run_manager.on_text(final_result, color="green", verbose=self.verbose)
            chain_result: Dict[str, Any] = {self.output_key: final_result}
            if self.return_intermediate_steps:
                chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
            return chain_result
        except Exception as exc:
            # Append intermediate steps to exception, to aid in logging and later
            # improvement of few shot prompt seeds
            exc.intermediate_steps = intermediate_steps  # type: ignore
            raise exc

def run_query(query):
    print("db tool input=="+query)
    #table_name=json.loads(query)["table_name"]
    table_name="ads_bi_quality_monitor_shipping_detail"
    #PROMPT_sql = PromptTemplate(
    #    input_variables=["question", "table_info"], template=sql_prompt_template
    #)
    db_chain = CustomerizedSQLDatabaseChain.from_llm(sm_sql_llm, db, verbose=True, return_intermediate_steps=False)
    db_chain.database._include_tables=[table_name]
    response=db_chain.run(query)
    return response


custom_tool_list=[
    Tool.from_function(
        func=customReverseIndexSearch,
        name="reverse index search",
        description="用于向量检索找到具体的数据库和表名"
    ),
    Tool.from_function(
        func=customEmbeddingSearch,
        name="embedding knn search",
        description="用于标签检索找到具体的数据库和表名"
    ),
    Tool.from_function(
        name="Db Querying",
        func=run_query,
        description="""用于数据库交互及生成sql"""
    )]

#reverseIndexSearchTool=CustomReverseIndexSearchTool()
#embeddingSearchTool=CustomEmbeddingSearchTool()

#custom_tool_list=[]
#custom_tool_list=[CustomReverseIndexSearchTool(),CustomEmbeddingSearchTool()]
#custom_tool_list.append(
#   Tool.from_function(
#        name="Db Querying Tool",
#        func=run_query,
#        description="""用于数据库交互及生成sql"""
#    ))


#print(custom_tool_list)


### for llama2
class TextGenContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["generated_text"]


content_hander=TextGenContentHandler()

sm_llm=SagemakerEndpoint(
        endpoint_name="tgi-llama2-2023-09-27-14-25-28-609-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=content_hander
)


## for bedrock
def get_bedrock_aksk(secret_name='chatbot_bedrock', region_name = "us-west-2"):
    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        # For a list of exceptions thrown, see
        # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
        raise e

    # Decrypts secret using the associated KMS key.
    secret = json.loads(get_secret_value_response['SecretString'])
    return secret['BEDROCK_ACCESS_KEY'],secret['BEDROCK_SECRET_KEY']

ACCESS_KEY, SECRET_KEY=get_bedrock_aksk()

boto3_bedrock = boto3.client(
    service_name="bedrock",
    region_name="us-east-1",
    endpoint_url="https://bedrock.us-west-2.amazonaws.com",
    aws_access_key_id=ACCESS_KEY,
    aws_secret_access_key=SECRET_KEY
)

parameters_bedrock = {
    "max_tokens_to_sample": 450,
    #"stop_sequences":STOP,
    #"temperature":0.5,
    # "top_p":0.9
}

sm_llm = Bedrock(model_id="anthropic.claude-v2", client=boto3_bedrock, model_kwargs=parameters)



####目前SqlAgent只支持OpenAI Function类型，不适合SM endpoint######
#toolkit = SQLDatabaseToolkit(db=db, llm=sm_sql_llm)

#custom_suffix = """
#我应该先利用标签检索，找到具体的数据库和表名，
#如果找不到，则利用向量检索查找，
#然后使用数据库工具查看我刚才找到的应该查询的库和表的详细schema元数据。
#"""
#agent = create_sql_agent(llm=sm_llm,
#                         toolkit=toolkit,
#                         verbose=True,
#                         agent_type=AgentType.SELF_ASK_WITH_SEARCH,
#                         extra_tools=custom_tool_list,
#                         suffix=custom_suffix
#                        )
#print(agent.agent.llm_chain.prompt.template)
#agent.run({"input":"我需要知道销售报表中，下单金额最大的客户id","agent_scratchpad":""})

## llm_chain自定义 Agent test

* 使用llm chain做chat的planner

In [69]:
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.chains.base import Chain
from langchain.prompts.base import BasePromptTemplate
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner
from langchain.agents.tools import Tool


sql_prompt_template = """
You are a MySQL expert. Given an input question, first create a syntactically correct MySql query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. 
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
{table_info}

Question: {question}
SQLQuery:
"""


memory = ConversationBufferMemory(memory_key="chat_history")

def build_chat_chain():
    conversation_with_summary_chain = ConversationChain(
        llm=sm_llm, 
        verbose=True
    )
    return conversation_with_summary_chain

* 使用llm chain做chat的planner
* 使用langchain plannerAndExecutor

In [None]:
from langchain.prompts import (
    ChatPromptTemplate,
    PromptTemplate,
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)



combo_template = """
先了解问题并制定解决问题的计划。
请输出以标题"plan:"开头的计划，然后输出步骤的编号列表。在计划步骤时，不要使用过去的对话历史记录。
请将计划制定为准确完成任务所需的最少步骤数。

以下是使用工具解决任务的步骤，请严格遵守：
先使用"reverse index search" 标签检索工具查找输入问题。
如果结果为空，则使用"embedding knn search" 向量检索工具查找输入问题。
然后使用"db Querying Tool" 数据库工具执行上一步骤结果输出。
最后返回结果。
"""

combo_template = """
First understand the problem and devise a plan to solve the problem. 
Please output the plan starting with the header 'Plan:' and then followed by a numbered list of steps. Do not use past conversation history when you are planning the steps.
Please make the plan the minimum number of steps required to accurately complete the task. 
    
These are guidance on when to use a tool to solve a task, follow them strictly: 

first use "reverse index search" tool to get table name 
if previously step not find anyting，then use "embedding knn search" tool to get table name 
if previously step still can't find the answer, say "I can't find the answer for this question." 
otherwise use "db Querying" tool to search table info using table name which previously step return

DO NOT CREATE STEPS THAT ARE NOT NEEDED TO SOLVE A TASK.     
Once you have answers for the question, stop and provide the final answers. The final answers should be a combination of the answers to all the questions, not just the last one.

Please make sure you have a plan to answer all the questions in the input, not just the last one. 
Please use these to construct an answer to the question , as though you were answering the question directly. Ensure that your answer is accurate and doesn’t contain any information not directly supported by the summary and quotes. 
If there are no data or information in this document that seem relevant to this question, please just say "I can’t find any relevant quotes".
"""


SYSTEM_PROMPT = (
    "First understand the problem and devise a plan to solve the problem."
    " Please output the plan starting with the header 'Plan:' "
    "and then followed by a numbered list of steps. "
    "Please make the plan the minimum number of steps required "
    "to accurately complete the task. If the task is a question, "
    "the final step should almost always be 'Given the above steps taken, "
    "please respond to the users original question'. "
    """These are guidance to use a tool to solve the task, follow them strictly: 
first use "reverse index search" tool to search  
if previously step not find anyting，then use "embedding knn search" tool to search   
if previously step still can't find the answer, say "I can't find the answer for this question." 
otherwise use "db Querying Tool" tool execute and return 
"""
    "At the end of your plan, say '<END_OF_PLAN>'"
)


combo_template = """
First understand the problem and devise a plan to solve the problem. 
Please output the plan starting with the header 'Plan:' and then followed by a numbered list of steps. Do not use past conversation history when you are planning the steps.
Please make the plan the minimum number of steps required to accurately complete the task. 
    
These are guidance on when to use a tool to solve a task, follow them strictly: 

first use "reverse index search" tool to get table name 
if previously step not find anyting，then use "embedding knn search" tool to get table name 
if previously step still can't find the answer, say "I can't find the answer for this question." 
otherwise use "db Querying Tool" tool to generate sql query using table name which previously step return

DO NOT CREATE STEPS THAT ARE NOT NEEDED TO SOLVE A TASK.     
Once you have answers for the question, stop and provide the final answers. The final answers should be a combination of the answers to all the questions, not just the last one.

Please make sure you have a plan to answer all the questions in the input, not just the last one. 
Please use these to construct an answer to the question , as though you were answering the question directly. Ensure that your answer is accurate and doesn’t contain any information not directly supported by the summary and quotes. 
If there are no data or information in this document that seem relevant to this question, please just say "I can’t find any relevant quotes". 
"""

planner = load_chat_planner(sm_llm)
system_message_prompt = SystemMessagePromptTemplate.from_template(combo_template)
human_message_prompt = planner.llm_chain.prompt.messages[1]
print(system_message_prompt)
print(human_message_prompt)
planner.llm_chain.prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

executor = load_agent_executor(sm_llm, custom_tool_list, verbose=True)
agent = PlanAndExecute(planner=planner, executor=executor, verbose=True, max_iterations=1,memory=memory)


#output = agent({"input":"查询‘最近一个月温度合格的派车单数量’这条语句的table name并生成sql"})
output = agent({"input":"最近一个月温度合格的派车单数量"})