## langcain agent demo for ReAct

In [None]:
!pip config unset global.index-url

In [None]:
!pip install langchain[all]
!pip install langchain-experimental
!pip install requests_aws4auth
!pip install opensearch-py
!pip install pydantic==1.10.0
!pip install sqlalchemy-redshift
!pip install redshift_connector
!pip install SQLAlchemy
!pip install pymysql
!pip install langchainhub

## initial sagemaker env

In [None]:
import os
import sagemaker
import boto3
import json
from typing import Dict
from typing import Any, Dict, List, Optional

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
sm_client = boto3.client("sagemaker-runtime")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## intial lanchain lib

In [None]:
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory,ConversationBufferMemory
from langchain import LLMChain
from typing import Any, Dict, List, Union,Mapping, Optional, TypeVar, Union
import os
import time
import json
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth

# set environment OPENAI_API_KEY
#from langchain.embeddings.openai import OpenAIEmbeddings

aos_endpoint="vpc-llm-rag-aos-seg3mzhpp76ncpxezdqtcsoiga.us-west-2.es.amazonaws.com"
embedding_endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint"
region='us-west-2'
# replace to your real username and passsword
username=""
passwd=""
#index_name="prompt-optimal-index"
index_name="metadata-index"
size=10

credentials = boto3.Session().get_credentials()
region = boto3.Session().region_name
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, 'es', session_token=credentials.token)
pwdauth = (username, passwd)

### for sqlcoder
class TextGenContentHandler2(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        #print(response_json)
        #sql_result=response_json["outputs"].split("```sql")[-1].split("```")[0].split(";")[0].strip().replace("\\n"," ") + ";"
        sql_result=response_json["outputs"]
        return sql_result


content_hander2=TextGenContentHandler2()

## for embedding
class EmbeddingContentHandler(EmbeddingsContentHandler):
    parameters = {
        "max_new_tokens": 50,
        "temperature": 0,
        "min_length": 10,
        "no_repeat_ngram_size": 2,
    }
    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["sentence_embeddings"]

embedding_content_handler=EmbeddingContentHandler()
    
sm_embeddings = SagemakerEndpointEmbeddings(
    # endpoint_name="endpoint-name", 
    # credentials_profile_name="credentials-profile-name", 
    #endpoint_name="huggingface-textembedding-bloom-7b1-fp1-2023-04-17-03-31-12-148", 
    endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint",
    region_name="us-west-2", 
    content_handler=embedding_content_handler
)


parameters = {
  "max_new_tokens": 350,
  #"do_sample":False,
  #"temperatual" : 0
  #"no_repeat_ngram_size": 2,
}
sm_sql_llm=SagemakerEndpoint(
        endpoint_name="sqlcoder-2023-10-07-01-50-46-950-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=content_hander2
)


## func for agent

In [None]:
import boto3
import json
import requests
import time
from collections import defaultdict
from requests_aws4auth import AWS4Auth
import os
from opensearchpy import OpenSearch, RequestsHttpConnection
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory
from langchain import LLMChain



def aos_knn_search(client, field,q_embedding, index, size=1):
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = {
        "size": size,
        "query": {
            "knn": {
                field: {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }
    opensearch_knn_respose = []
    query_response = client.search(
        body=query,
        index=index
    )
    opensearch_knn_respose = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'query_desc_text':item['_source']['query_desc_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose

def aos_knn_search_v2(client, field,q_embedding, index, size=1):
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = {
        "size": size,
        "query": {
            "knn": {
                field: {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }
    opensearch_knn_respose = []
    query_response = client.search(
        body=query,
        index=index
    )
    opensearch_knn_respose = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'exactly_query_text':item['_source']['exactly_query_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose


def aos_reverse_search(client, index_name, field, query_term, exactly_match=False, size=1):
    """
    search opensearch with query.
    :param host: AOS endpoint
    :param index_name: Target Index Name
    :param field: search field
    :param query_term: query term
    :return: aos response json
    """
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = None
    if exactly_match:
        query =  {
            "query" : {
                "match_phrase":{
                    "doc": {
                        "query": query_term,
                        "analyzer": "ik_smart"
                      }
                }
            }
        }
    else:
        query = {
            "size": size,
            "query": {
                "query_string": {
                "default_field": "query_desc_text",  
                "query": query_term         
              }
            },
           "sort": [{
               "_score": {
                   "order": "desc"
               }
           }]
    }        
    query_response = client.search(
        body=query,
        index=index_name
    )
    result_arr = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'query_desc_text':item['_source']['query_desc_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return result_arr




def get_vector_by_sm_endpoint(questions, sm_client, endpoint_name):
    parameters = {
    }

    response_model = sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(
            {
                "inputs": questions,
                "parameters": parameters,
                "is_query" : True,
                "instruction" :  "为这个句子生成表示以用于检索相关文章："
            }
        ),
        ContentType="application/json",
    )
    json_str = response_model['Body'].read().decode('utf8')
    json_obj = json.loads(json_str)
    embeddings = json_obj['sentence_embeddings']
    return embeddings




def get_topk_items(opensearch_query_response, topk=5):
    opensearch_knn_nodup = []
    unique_ids = set()
    for item in opensearch_query_response:
        if item['id'] not in unique_ids:
            opensearch_knn_nodup.append(item['score'], item['idx'],item['database_name'],item['table_name'],item['query_desc_text'])
            unique_ids.add(item['id'])
    return opensearch_knn_nodup



def k_nn_ingestion_by_aos(docs,index,hostname,username,passwd):
    auth = (username, passwd)
    search = OpenSearch(
        hosts = [{'host': aos_endpoint, 'port': 443}],
        ##http_auth = awsauth ,
        http_auth = auth ,
        use_ssl = True,
        verify_certs = True,
        connection_class = RequestsHttpConnection
    )
    for doc in docs:
        query_desc_embedding = doc['query_desc_embedding']
        database_name = doc['database_name']
        table_name = doc['table_name']
        query_desc_text = doc["query_desc_text"]
        document = { "query_desc_embedding": query_desc_embedding, 'database_name':database_name, "table_name": table_name,"query_desc_text":query_desc_text}
        search.index(index=index, body=document)
        
def k_nn_ingestion_by_aos_v2(docs,index,hostname,username,passwd):
    auth = (username, passwd)
    search = OpenSearch(
        hosts = [{'host': aos_endpoint, 'port': 443}],
        ##http_auth = awsauth ,
        http_auth = auth ,
        use_ssl = True,
        verify_certs = True,
        connection_class = RequestsHttpConnection
    )
    for doc in docs:
        exactly_query_embedding = doc['exactly_query_embedding']
        database_name = doc['database_name']
        table_name = doc['table_name']
        exactly_query_text = doc["exactly_query_text"]
        document = { "exactly_query_embedding": exactly_query_embedding, 'database_name':database_name, "table_name": table_name,"exactly_query_text":exactly_query_text}
        search.index(index=index, body=document)

* for local test only

In [None]:
query="上个月温度合格的派车单数量"
query_embedding = get_vector_by_sm_endpoint(query, sm_client, embedding_endpoint_name)

client=None
rets=aos_knn_search(client, "query_desc_embedding",query_embedding[0], index_name, size=1)
print(rets)

rets=opensearch_query_response = aos_reverse_search(client, index_name, "query_desc_text", query)   
print(rets)

rets=aos_knn_search_v2(client, "exactly_query_embedding",query_embedding[0],index_name,1)   
print(rets)


## major chain pipeline ################

### 0: index 创建

PUT prompt-optimal-index
{
    "settings" : {
        "index":{
            "number_of_shards" : 5,
            "number_of_replicas" : 0,
            "knn": "true",
            "knn.algo_param.ef_search": 32
        }
    },
    "mappings": {
        "properties": {
            "metadata_type" : {
                "type" : "keyword"
            },
            "database_name": {
                "type" : "keyword"
            },
            "table_name": {
               "type" : "keyword"
            },
            "exactly_query_text": {
                "type": "text",
                "analyzer": "ik_max_word",
                "search_analyzer": "ik_smart"
            },
            "exactly_query_embedding": {
                "type": "knn_vector",
                "dimension": 1024,
                "method": {
                    "name": "hnsw",
                    "space_type": "l2",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 512,
                        "m": 32
                    }
                }            
            }
        }
    }
}

### 1: data process

In [None]:
all_querys = """2023年7月派车单数量超过26次的4.2米车辆一共有多少辆
请统计历史累计派车单数量、干线派车单数量、城配派车单数量。历史累计的意思是不限定时间范围
奶茶品牌的站点数量和运输货品数量统计
2022年的运输总量是多少吨？请注意：traff_weight的单位是千克，请把单位转换为吨
车牌为'黑RG6696'的车辆的GPS最近定位上传时间、GPS最近定位省份、GPS最近定位城市、GPS最近定位区县、APP最近定位上传时间、APP最近定位省份、APP最近定位城市、APP最近定位区县。给出sql中字段名不要带上库名
车牌归属城市为'成都'的车辆累计有多少？
取货地城市名称为'北京市'的历史累计不重复的车牌有多少？请注意车牌号有可能有重复
货主-行业列表相似于'西餐连锁'的一共有多少个品牌？多少个客户？
品牌名称为'星巴克'的不重复的站点一共有多少个？
查看租户简称为云南,车辆的车厢长为9.6米和15米的外廓车长、核定载重"""
querys = all_querys.split("\n")

all_tables = """ads_bi_quality_monitor_shipping_detail
dws_ots_waybill_info_da
dws_station_portrait_index_sum_da
dws_ots_waybill_info_da
dws_truck_portrait_index_sum_da
dws_truck_portrait_index_sum_da
dws_ots_waybill_info_da
ads_customer_portrait_index_sum_da
dim_customer_enterprise_station_base_info
dim_pub_truck_tenant,dim_pub_truck_info"""
tables=all_tables.split("\n")

all_dbs = """llm
llm
llm
llm
llm
llm
llm
llm
llm
llm"""
dbs=all_dbs.split("\n")

In [None]:
#import sys
#sys.path.append("./code/")
#import func
index_name="prompt-optimal-index"
embedding_endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint"
##########embedding by llm model##############
sentense_vectors = []
sentense_vectors=get_vector_by_sm_endpoint(querys,sm_client,embedding_endpoint_name)

In [None]:
docs=[]
for index, sentence_vector in enumerate(sentense_vectors):
    #print(index, sentence_vector)
    #doc = {
    #    "metadata_type":"table",
    #    "database_name":dbs[index],
    #    "table_name": tables[index],
    #    "query_desc_text":querys[index],
    #    "query_desc_embedding": sentence_vector
    #      }
    doc = {
        "metadata_type":"table",
        "database_name":dbs[index],
        "table_name": tables[index],
        "exactly_query_text":querys[index],
        "exactly_query_embedding": sentence_vector
          }
    docs.append(doc)

#print((doc["database_name"]))
#########ingestion into aos ###################
k_nn_ingestion_by_aos_v2(docs,index_name,aos_endpoint,username,passwd)

### 2:自定义Agent ，定制tools

* langchain sm endpoint设置

In [None]:
from langchain.tools.base import BaseTool, Tool, tool
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain.agents.agent_types import AgentType
from typing import Optional, Type
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
    CallbackManagerForChainRun
)
from langchain.llms.bedrock import Bedrock

### for llama2
class TextGenContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["generated_text"]


content_hander=TextGenContentHandler()

sm_llm=SagemakerEndpoint(
        endpoint_name="tgi-llama2-2023-09-27-14-25-28-609-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=content_hander
)


* initial bedrock

In [None]:
import os
from typing import Optional

# External Dependencies:
import boto3
from botocore.config import Config


def get_bedrock_client(
    assumed_role: Optional[str] = None,
    region: Optional[str] = None,
    runtime: Optional[bool] = True,
):
  
    if region is None:
        target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
    else:
        target_region = region

    print(f"Create new client\n  Using region: {target_region}")
    session_kwargs = {"region_name": target_region}
    client_kwargs = {**session_kwargs}

    profile_name = os.environ.get("AWS_PROFILE")
    if profile_name:
        print(f"  Using profile: {profile_name}")
        session_kwargs["profile_name"] = profile_name

    retry_config = Config(
        region_name=target_region,
        retries={
            "max_attempts": 10,
            "mode": "standard",
        },
    )
    session = boto3.Session(**session_kwargs)

    if assumed_role:
        print(f"  Using role: {assumed_role}", end='')
        sts = session.client("sts")
        response = sts.assume_role(
            RoleArn=str(assumed_role),
            RoleSessionName="langchain-llm-1"
        )
        print(" ... successful!")
        client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
        client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"]
        client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
        

    if runtime:
        service_name='bedrock-runtime'
    else:
        service_name='bedrock'

    client_kwargs["aws_access_key_id"] = os.environ.get("AWS_ACCESS_KEY_ID","")
    client_kwargs["aws_secret_access_key"] = os.environ.get("AWS_SECRET_ACCESS_KEY","")
    
    bedrock_client = session.client(
        service_name=service_name,
        config=retry_config,
        **client_kwargs
    )

    print("boto3 Bedrock client successfully created!")
    print(bedrock_client._endpoint)
    return bedrock_client



## for aksk bedrock
def get_bedrock_aksk(secret_name='chatbot_bedrock', region_name = "us-west-2"):
    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        # For a list of exceptions thrown, see
        # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
        raise e

    # Decrypts secret using the associated KMS key.
    secret = json.loads(get_secret_value_response['SecretString'])
    return secret['BEDROCK_ACCESS_KEY'],secret['BEDROCK_SECRET_KEY']

ACCESS_KEY, SECRET_KEY=get_bedrock_aksk()

###aksk intial client#######
#boto3_bedrock = boto3.client(
#    service_name="bedrock",
#    region_name="us-west-2",
#    endpoint_url="https://bedrock.us-west-2.amazonaws.com",
#    aws_access_key_id="",
#    aws_secret_access_key="*******"
#)

#role based initial client#######
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"  # E.g. "us-east-1"
os.environ["AWS_PROFILE"] = "default"
#os.environ["BEDROCK_ASSUME_ROLE"] = "arn:aws:iam::687912291502:role/service-role/AmazonSageMaker-ExecutionRole-20211013T113123"  # E.g. "arn:aws:..."
os.environ["AWS_ACCESS_KEY_ID"]=ACCESS_KEY
os.environ["AWS_SECRET_ACCESS_KEY"]=SECRET_KEY


#新boto3 sdk只能session方式初始化bedrock
boto3_bedrock = get_bedrock_client(
    #assumed_role=os.environ.get("BEDROCK_ASSUME_ROLE", None),
    region=os.environ.get("AWS_DEFAULT_REGION", None)
)

parameters_bedrock = {
    "max_tokens_to_sample": 2048,
    #"temperature": 0.5,
    "temperature": 0,
    #"top_k": 250,
    #"top_p": 1,
    "stop_sequences": ["\n\nHuman"],
}

bedrock_llm = Bedrock(model_id="anthropic.claude-v2", client=boto3_bedrock, model_kwargs=parameters_bedrock)
###test the bedrock langchain integration###
#bedrock_llm.predict("Human:how do you describe LLM?\n"+
#           "Assistant:")

* 自定义AOS倒排及knn检索tools    
* 自定义中文Sql Agent 的ReAct prompt 前缀   
* 定制CustomerizedSqlDatabaseChain作为db tools做数据库交互
* 将db tools加入之前定义的元数据召回的tools列表

In [None]:
aos_client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection)
aos_index="metadata-index"

"""Opensearch 向量检索."""
def customEmbeddingSearch(query: str) -> str:
    start = time.time()
    query_embedding = get_vector_by_sm_endpoint(query, sm_client, embedding_endpoint_name)
    elpase_time = time.time() - start
    #print(f'runing time of opensearch_knn : {elpase_time}s seconds')
    responses = aos_knn_search(aos_client, "query_desc_embedding",query_embedding[0], aos_index, size=10)
    return "The database table is "+responses[0]["table_name"].strip()+"\n"
    #return responses[0]["table_name"] 
    #return "表名是 "+responses[0]["table_name"].strip()

"""Opensearch 标签检索."""
def customReverseIndexSearch(query: str) -> str:
    start = time.time()
    opensearch_query_response = aos_reverse_search(aos_client, aos_index, "query_desc_text", query)
    elpase_time = time.time() - start
    #print(f'runing time of opensearch_query : {elpase_time}s seconds')
    return "The database table is "+opensearch_query_response[0]["table_name"].strip()+"\n"
    #return opensearch_query_response[0]["table_name"] 
    #return "表名是 "+opensearch_query_response[0]["table_name"].strip()

 
db = SQLDatabase.from_uri(
    "mysql+pymysql://admin:admin12345678@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    sample_rows_in_table_info=0)    
    
class CustomerizedSQLDatabaseChain(SQLDatabaseChain):
    
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        input_text = f"{inputs[self.input_key]}\nSQLQuery:"
        _run_manager.on_text(input_text, verbose=self.verbose)
        # If not present, then defaults to None which is all tables.
        table_names_to_use = inputs.get("table_names_to_use")
        if table_names_to_use is None:
            table_names_to_use=self.database._include_tables
            print("table_names_to_use==")
            print(table_names_to_use)
        table_info = self.database.get_table_info(table_names=table_names_to_use)
        
        llm_inputs = {
            "input": input_text,
            "top_k": str(self.top_k),
            "dialect": self.database.dialect,
            "table_info": table_info,
            #"stop": ["\nSQLResult:"],
        }
        intermediate_steps: List = []
        try:
            intermediate_steps.append(llm_inputs)  # input: sql generation
            sql_cmd = self.llm_chain.predict(
                callbacks=_run_manager.get_child(),
                **llm_inputs,
            ).strip()
            print("orginal sql_cmd=="+sql_cmd)
            if self.return_sql:
                return {self.output_key: sql_cmd}
            if not self.use_query_checker:
                _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
                intermediate_steps.append(
                    sql_cmd
                )  # output: sql generation (no checker)
                #########定制sqlcoder 模型输出##############
                pattern = r"SQLQuery: (.*?)\n"
                matches = re.findall(pattern, sql_cmd)
                match = matches[1]
                sql_cmd = match
                #print("query sql=="+sql_cmd) 
                
                intermediate_steps.append({"sql_cmd": sql_cmd})  # input: sql exec
                result = self.database.run(sql_cmd)
                intermediate_steps.append(str(result))  # output: sql exec
            else:
                query_checker_prompt = self.query_checker_prompt or PromptTemplate(
                    template=QUERY_CHECKER, input_variables=["query", "dialect"]
                )
                query_checker_chain = LLMChain(
                    llm=self.llm_chain.llm, prompt=query_checker_prompt
                )
                query_checker_inputs = {
                    "query": sql_cmd,
                    "dialect": self.database.dialect,
                }
                checked_sql_command: str = query_checker_chain.predict(
                    callbacks=_run_manager.get_child(), **query_checker_inputs
                ).strip()
                intermediate_steps.append(
                    checked_sql_command
                )  # output: sql generation (checker)
                _run_manager.on_text(
                    checked_sql_command, color="green", verbose=self.verbose
                )
                intermediate_steps.append(
                    {"sql_cmd": checked_sql_command}
                )  # input: sql exec
                result = self.database.run(checked_sql_command)
                intermediate_steps.append(str(result))  # output: sql exec
                sql_cmd = checked_sql_command

            _run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
            _run_manager.on_text(result, color="yellow", verbose=self.verbose)
            # If return direct, we just set the final result equal to
            # the result of the sql query result, otherwise try to get a human readable
            # final answer
            if self.return_direct:
                final_result = result
            else:
                _run_manager.on_text("\nAnswer:", verbose=self.verbose)
                input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
                llm_inputs["input"] = input_text
                intermediate_steps.append(llm_inputs)  # input: final answer
                final_result = self.llm_chain.predict(
                    callbacks=_run_manager.get_child(),
                    **llm_inputs,
                ).strip()
                intermediate_steps.append(final_result)  # output: final answer
                _run_manager.on_text(final_result, color="green", verbose=self.verbose)
            chain_result: Dict[str, Any] = {self.output_key: final_result}
            if self.return_intermediate_steps:
                chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
            return chain_result
        except Exception as exc:
            # Append intermediate steps to exception, to aid in logging and later
            # improvement of few shot prompt seeds
            exc.intermediate_steps = intermediate_steps  # type: ignore
            raise exc

def run_query(query):
    infos=""
    table_name=""
    question=""
    ####post process agent action output######
    if "\n" in query:
        infos = query.split("\n")
    elif "," in query:
        infos = query.split(",") 

    if ":" in infos[0]:
        table_name=infos[0].split(":")[1]
    else:
        table_name=infos[0]
    
    if ":" in infos[1]:
        question=infos[1].split(":")[1]
    else:
        question=infos[1]
    
    table_name=table_name.strip()
    question=question.strip()
    
    #db_chain = CustomerizedSQLDatabaseChain.from_llm(sm_sql_llm, db, verbose=True, return_sql=True,return_intermediate_steps=False)
    db_chain = CustomerizedSQLDatabaseChain.from_llm(bedrock_llm, db, verbose=True, return_sql=True,return_intermediate_steps=False)

    if table_name is not None:
        db_chain.database._include_tables=[table_name]
    response=db_chain.run(question)
    return response


custom_tool_list=[
    Tool.from_function(
        func=customReverseIndexSearch,
        name="reverse index search",
        description="use for keywords search to get the database table name"
    ),
    Tool.from_function(
        func=customEmbeddingSearch,
        name="embedding knn search",
        description="use for semantic level search to get the database table name"
    ),
    Tool.from_function(
        name="db utility",
        func=run_query,
        description="""use for generate sql statement"""
    )]

#reverseIndexSearchTool=CustomReverseIndexSearchTool()
#embeddingSearchTool=CustomEmbeddingSearchTool()

#custom_tool_list=[]
#custom_tool_list=[CustomReverseIndexSearchTool(),CustomEmbeddingSearchTool()]
#custom_tool_list.append(
#   Tool.from_function(
#        name="Db Querying Tool",
#        func=run_query,
#        description="""用于数据库交互及生成sql"""
#    ))


#print(custom_tool_list)




### 2.0 sqlAgent 测试

* 使用langchain sqlAgent
* 初始化agent并执行

In [None]:
####目前SqlAgent只支持OpenAI Function类型，不适合SM endpoint######
#toolkit = SQLDatabaseToolkit(db=db, llm=sm_sql_llm)

#custom_suffix = """
#我应该先利用标签检索工具，找到具体的数据库和表名，
#如果找不到，则利用向量检索工具查找，
#然后使用数据库工具查看我刚才找到的应该查询的库和表的详细schema元数据。
#"""
#agent = create_sql_agent(llm=sm_llm,
#                         toolkit=toolkit,
#                         verbose=True,
#                         agent_type=AgentType.SELF_ASK_WITH_SEARCH,
#                         extra_tools=custom_tool_list,
#                         suffix=custom_suffix
#                        )
#print(agent.agent.llm_chain.prompt.template)
#agent.run({"input":"我需要知道销售报表中，下单金额最大的客户id","agent_scratchpad":""})

### 2.1 PlannerAndExecutor agent type

* 使用llm chain做chat的planner

In [None]:
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.chains.base import Chain
from langchain.prompts.base import BasePromptTemplate
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner
from langchain.agents.tools import Tool
from typing import Optional, List, Any
from langchain.callbacks.manager import CallbackManagerForLLMRun

class BedrockModelWrapper(Bedrock):
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        prompt = "\nHuman: \n" + prompt + "\nAssistant:"   ## Satisfy Bedrock-Claude prompt requirements
        return super()._call(prompt, stop, run_manager, **kwargs)


sql_prompt_template = """
You are a MySQL expert. Given an input question, first create a syntactically correct MySql query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. 
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
{table_info}

Question: {question}
SQLQuery:
"""
bedrock_llm_additional = BedrockModelWrapper(model_id="anthropic.claude-v2", 
                                          client=boto3_bedrock, 
                                          model_kwargs=parameters_bedrock)


memory = ConversationBufferWindowMemory(k=2)


def build_chat_chain():
    conversation_with_summary_chain = ConversationChain(
        llm=bedrock_llm_additional, 
        verbose=True
    )
    return conversation_with_summary_chain

* 使用llm chain做chat的planner
* 使用langchain plannerAndExecutor

In [49]:
from langchain.prompts import (
    ChatPromptTemplate,
    PromptTemplate,
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)



combo_template = """
先了解问题并制定解决问题的计划。
请输出以标题"plan:"开头的计划，然后输出步骤的编号列表。在计划步骤时，不要使用过去的对话历史记录。
请将计划制定为准确完成任务所需的最少步骤数。

以下是使用工具解决任务的步骤，请严格遵守：
先使用"reverse index search" 标签检索工具查找输入问题。
如果结果为空，则使用"embedding knn search" 向量检索工具查找输入问题。
然后使用"db Querying Tool" 数据库工具执行上一步骤结果输出。
最后返回结果。
"""



combo_template = """
First understand the problem and devise a plan to solve the problem. 
Please output the plan starting with the header 'Plan:' and then followed by a numbered list of steps. Do not use past conversation history when you are planning the steps.
Please make the plan the minimum number of steps required to accurately complete the task. 


These are guidance on when to use a tool to solve a task, follow them strictly:

first use "reverse index search" tool to get database table name base on user's original question
if previously step not return anyting，then use "embedding knn search" tool to get database table name base on user's original question
then use the "db utility" tool to generate the sql statement, use that database table name which previously step returns and user's original question

DO NOT GENERATE SQL STATEMENT YOURSELF,ONLY USE TOOL TO DO.
DO NOT CREATE STEPS THAT ARE NOT NEEDED TO SOLVE A TASK.     
Once you have answers for the question, stop and provide the final answers. The final answers should be a combination of the answers to all the questions, not just the last one.

Please make sure you have a plan to answer all the questions in the input, not just the last one. 
Please use these to construct an answer to the question , as though you were answering the question directly. Ensure that your answer is accurate and doesn’t contain any information not directly supported by the summary and quotes. 
If there are no data or information in this document that seem relevant to this question, please just say "I can’t find any relevant quotes". 
"""

planner = load_chat_planner(bedrock_llm)
system_message_prompt = SystemMessagePromptTemplate.from_template(combo_template)
human_message_prompt = planner.llm_chain.prompt.messages[1]

#print(planner.llm_chain.prompt)
#planner.llm_chain.prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
#print(planner.llm_chain.prompt.messages[1].prompt.template)
executor = load_agent_executor(bedrock_llm, custom_tool_list, verbose=True)
#print(executor.chain.agent.llm_chain.prompt.messages)
agent = PlanAndExecute(planner=planner, executor=executor, verbose=True, max_iterations=2,memory=memory)

#output = agent({"input":"我想去米亚罗"})
output = agent({"input":"最近一个月温度合格的派车单数量"})
#output = agent({"input":"what's the number of qualified temperature delivery orders in the last month" })



[1m> Entering new PlanAndExecute chain...[0m
steps=[Step(value='从派车单数据库中提取最近一个月内的所有派车单记录。'), Step(value='对每个派车单记录检查温度监控数据,判断该单温度是否在合格范围内。'), Step(value='统计满足温度合格的派车单总数量。\n\n')]

[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m 您的目标是从派车单数据库中提取最近一个月内的所有派车单记录。

为了实现这个目标,我们可以按以下思路操作:

首先需要确定派车单数据库中的表名,可以使用反向索引搜索工具来搜索关键词获取表名:

Action:
```
{
  "action": "reverse index search",
  "action_input": "派车单 数据库 表名"
}
```

[0m
Observation: [36;1m[1;3mThe database table is ads_bi_quality_monitor_shipping_detail
[0m
Thought:[32;1m[1;3m Here is an example response that follows the requested format:

Question: How can I extract all shipping order records from the past month?
Thought: I need to identify the relevant database table. 
Action:
```
{
  "action": "reverse index",
  "action_input": "shipping database table name"  
}
```
Observation: The shipping order table is ads_bi_quality_monitor_shipping_detail.

Thought: I need to write a SQL query to retrieve records from the past mon

* clearup agent and executor

In [28]:
del executor
del agent
del planner

### 2.2 Conversationval Zero Shot Agent type

* Set up the base template

In [44]:
from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from langchain.utilities import SerpAPIWrapper
from langchain.chains import LLMChain
from typing import List, Union
from langchain.schema import AgentAction, AgentFinish, OutputParserException
import re
flexible_tools_guide="""
 - use "reverse index search" tool to get table name if need to search for table name by keywords
 - use "embedding knn search" tool to get table name if need to search for table name by semantic level
 - use "db utility" tool to generate sql query based on orignal question
 - use tool to generate sql query, Do Not generate sql query yourself
"""


strict_tools_guide="""
 - first use "reverse index search" tool to get table name 
 - if previously step does not return anyting，then use "embedding knn search" tool to get table name 
 - if previously step still can't return anyting, say "I can't find the answer for this question." 
 - otherwise use "db utility" tool to generate database sql using table name which previously step return, and original question
"""

# Set up the base template
template = """Answer the following questions as best you can.
You have access to the following tools:

{tools}


These are guidance on when to use a tool to solve a task, follow them strictly:
 - use "reverse index search" tool to get database table name if need to search for table name by keywords
 - use "embedding knn search" tool to get database table name if need to search for table name by semantic level
 - use "db utility" tool to generate database sql statement
 - use tool to generate sql statement, DO NOT GENERATE SQL STATEMENT YOURSELF

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation:
the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question


Begin! 

Previous conversation history:
{chat_history}

Question: {input}
{agent_scratchpad}
"""

simple_react_prompt="""Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do, Also try to follow steps mentioned above
Action: the action to take, should be one of ["reverse index search", "embedding knn search","db utility"]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

These are guidance on when to use a tool to solve a task, follow them strictly:
 - use "reverse index search" tool to get database table name if need to search for table name by keywords
 - use "embedding knn search" tool to get database table name if need to search for table name by semantic level
 - use "db utility" tool to generate database sql statement
 - use tool to generate sql statement, Do Not generate sql statement yourself

Previous conversation history:
{chat_history}

Question: {input}
{agent_scratchpad}"""

* Set up customerized templateFormat
* Set up customerized outputParse to get individual tool

In [45]:
class CustomPromptTemplate(StringPromptTemplate):
    # The template to use
    template: str
    # The list of tools available
    tools: List[Tool]

    def format(self, **kwargs) -> str:
        # Get the intermediate steps (AgentAction, Observation tuples)
        # Format them in a particular way
        intermediate_steps = kwargs.pop("intermediate_steps")
        thoughts = ""
        for action, observation in intermediate_steps:
            thoughts += action.log
            thoughts += f"\nObservation: {observation}\nThought: "
        # Set the agent_scratchpad variable to that value
        kwargs["agent_scratchpad"] = thoughts
        # Create a tools variable from the list of tools provided
        kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
        # Create a list of tool names for the tools provided
        kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
        final_prompt = "Human:"+self.template.format(**kwargs)+"\nAssistant:"
        return final_prompt
        #return self.template.format(**kwargs)
    
prompt = CustomPromptTemplate(
    template=template,
    tools=custom_tool_list,
    input_variables=["input", "intermediate_steps","chat_history"]
)

class CustomOutputParser(AgentOutputParser):

    def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
        #print("cur step's llm_output ==="+llm_output)
        # Check if agent should finish
        if "Final Answer:" in llm_output:
            return AgentFinish(
                # Return values is generally always a dictionary with a single `output` key
                # It is not recommended to try anything else at the moment :)
                return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
                log=llm_output,
            )
        # Parse out the action and action input
        regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
        match = re.search(regex, llm_output, re.DOTALL)
        if not match:
            raise OutputParserException(f"Could not parse LLM output: `{llm_output}`")
        action = match.group(1).strip()
        action_input = match.group(2)
        # Return the action and action input
        return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)


* add chat memeory
* Set up the agent
* Run the agent

In [48]:
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents import initialize_agent

memory=ConversationBufferWindowMemory(memory_key="chat_history",return_messages=True,k=3)
tool_names = [tool.name for tool in custom_tool_list]


###using LCEL (TBD)####
#from langchain.tools.render import render_text_description
#from langchain.agents.output_parsers import ReActSingleInputOutputParser
#from langchain.agents.format_scratchpad import format_log_to_str
#from langchain import hub
#prompt = hub.pull("hwchase17/react-chat")
#chat_prompt = prompt.partial(
#    tools=render_text_description(custom_tool_list),
#    tool_names=", ".join([t.name for t in custom_tool_list]),
#)
#llm_with_stop = bedrock_llm.bind(stop=["\nObservation"])
#agent = {
#    "input": lambda x: x["input"],
#    "agent_scratchpad": lambda x: format_log_to_str(x['intermediate_steps']),
#    "chat_history": lambda x: x["chat_history"]
#} | chat_prompt | llm_with_stop | ReActSingleInputOutputParser()
#agent_executor = AgentExecutor(agent=agent, tools=custom_tool_list,handle_parsing_errors=True, verbose=True, memory=memory)
#query="\n\nHuman:检索'最近一个月温度合格的派车单数量'这个问题涉及的表名\n"
#agent_executor.invoke({"input":query})
#query="\n\nHuman:根据刚才得到的表名和原始问题生成sql语句'\n"
#agent_executor.invoke({"input":query})


####simple react agent####################
agent_executor = initialize_agent(custom_tool_list, bedrock_llm, agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION, 
                                  verbose=True,max_iterations=3,handle_parsing_errors=True,memory=memory)
agent_executor.agent.llm_chain.prompt.template=simple_react_prompt
agent_executor.run("成都市的车辆资源累计有多少？")
#agent_executor.run("\n\nHuman:检索'最近一个月温度合格的派车单数量'这个问题涉及的表名，并根据这个表名和'最近一个月温度合格的派车单数量'的问题，使用数据库工具生成sql语句\n")
#query="\n\nHuman:检索'最近一个月温度合格的派车单数量'这个问题涉及的表名\n"
#agent_executor.run(query)
#query="\n\nHuman:根据刚才得到的表名和原始问题生成sql语句'\n"
#agent_executor.run(query)

####more flexible customerized agent#####
#output_parser = CustomOutputParser()
#llm_chain = LLMChain(llm=bedrock_llm, prompt=prompt)
#agent = LLMSingleActionAgent(
#    llm_chain=llm_chain,
#    output_parser=output_parser,
#    stop=["\nObservation:"],
#    allowed_tools=tool_names
#)
#agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=custom_tool_list,max_iterations=6,verbose=True,handle_parsing_errors=True, memory=memory)


#output=agent_executor.run("'最近一个月温度合格的派车单数量'这个问题涉及的表名是什么？")
#print(output)
#output=agent_executor.run("根据刚才的表名和‘最近一个月温度合格的派车单数量’这个问题，使用数据库工具生成sql")
#print(output)
#agent_executor.run("检索'最近一个月温度合格的派车单数量'这个问题的数据库表名字，并根据这个数据库表名字和刚才的原始问题，使用数据库工具生成sql语句")

## prompt in english
#query="""search the database table name related to the question of 'Number of qualified temperature delivery orders in the last month', and generate sql statement use the 'db utility' tool based on that table name and the question of 'Number of qualified temperature delivery orders in the last month'"""
#agent_executor.run(query)






[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m Here is my attempt to answer your question following the provided format:

Thought: The question is asking for a statistic about the total number of vehicles in Chengdu city. I should follow the steps to determine which tool to use.

Action: reverse index search
Action Input: 成都市 车辆资源 累计 数量[0m
Observation: [36;1m[1;3mThe database table is ads_bi_quality_monitor_shipping_detail
[0m
Thought:[32;1m[1;3m Here is my response following the provided format:

Thought: The question is asking for a statistic about the total number of vehicles in Chengdu city. Based on the guidance, I should use the "db utility" tool to generate the SQL statement to get this data.

Action: db utility  
Action Input: Generate SQL statement to get total number of vehicles in Chengdu city from ads_bi_quality_monitor_shipping_detail table[0m

IndexError: string index out of range

### 2.3 agent只处理意图识别，直接调用SqlDatabaseChain执行sql生成

In [None]:
import json
query="最近一个月温度合格的派车单数量"
output=agent_executor("\n\nHuman:'"+query+"'这个问题涉及的数据库表名是什么？")
#output=agent_executor.run("\n\nHuman:what's the database table related to the question of 'Number of qualified temperature delivery orders in the last month'?")
print(output)
###process the agent output , retrave the table name####
table_name=output["output"].split(" ")[-1].strip()
table_name=table_name.split("数据库表名是")[-1].strip()

db = SQLDatabase.from_uri(
    "mysql+pymysql://admin:admin12345678@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    sample_rows_in_table_info=0,
    include_tables=[table_name])

db_chain = CustomerizedSQLDatabaseChain.from_llm(sm_sql_llm, db, verbose=True, return_sql=True,return_intermediate_steps=False)
response=db_chain.run(query)
response

* destroy agent and exectutor

In [None]:
del agent_executor
del db_chain
del agent

In [None]:
!pip install streamlit

In [None]:
!streamlit run --server.maxUploadSize=1024  --server.maxMessageSize 2048 --server.port 8501 ./text2sql_gui.py