## langcain agent demo for ReAct

In [None]:
#!pip install langchain[all]
#!pip install sagemaker --upgrade
#!pip install  boto3
#!pip install requests_aws4auth
#!pip install opensearch-py
#!pip install pydantic==1.10.0
#!pip install PyAthena[SQLAlchemy]==1.0.0
#!pip install PyAthena[JDBC]==1.0.0
#!pip install openai
!pip install sqlalchemy-redshift
!pip install redshift_connector

## initial sagemaker env

In [29]:
import os
import sagemaker
import boto3
import json
from typing import Dict
from typing import Any, Dict, List, Optional

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
sm_client = boto3.client("sagemaker-runtime")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")



sagemaker role arn: arn:aws:iam::687912291502:role/service-role/AmazonSageMaker-ExecutionRole-20211013T113123
sagemaker bucket: sagemaker-us-west-2-687912291502
sagemaker session region: us-west-2


## intial lanchain lib

In [130]:
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory,ConversationBufferMemory
from langchain import LLMChain
from typing import Dict

#os.environ["OPENAI_API_KEY"]= "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"
#from langchain.embeddings.openai import OpenAIEmbeddings

aos_endpoint="vpc-llm-rag-aos-seg3mzhpp76ncpxezdqtcsoiga.us-west-2.es.amazonaws.com"
region='us-west-2'
username="admin"
passwd="(OL>0p;/"
index_name="qa_index"
size=10

## for chatglm
class TextGenContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        #input_str = json.dumps({prompt: prompt, **model_kwargs})
        input_str = json.dumps({
                "ask": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["answer"]

### for vicuna/llama
class TextGenContentHandler2(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["data"][0]["generated_text"]   



parameters = {
  "early_stopping": False,
  #"length_penalty": 2.0,
  #"max_new_tokens": 500,
  "temperature": 0.6,
  "max_tokens": 300,
  #"no_repeat_ngram_size": 2,
}
sm_llm=SagemakerEndpoint(
        #endpoint_name="chatglm-inference-0524-2023-06-01-07-11-27-379",
        endpoint_name="vicuna-7B-2023-06-04-13-07-39-746-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=TextGenContentHandler2
)


## major chain pipeline ################

### 直接使用 langchain SQLDatabaseChain 
定制SqlDataBase对接其他数据源(e.g StarRocks)   
SqlDatabaseChain使用sagemaker endpoint llm

In [140]:
sql_cmd="""
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 3 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:

Question: 我需要知道销售报表中，下单金额最大的客户id
SQLQuery: SELECT c_customer_id FROM web_sales ORDER BY ws_sold_price DESC LIMIT 1
SQLResult
SQL Query to run"""
pattern = r"SQLQuery: (.*?)\nSQLResult"
matches = re.findall(pattern, sql_cmd)
match = matches[1]
sql_cmd = match
sql_cmd=sql_cmd.replace("SQLQuery:","")
sql_cmd=sql_cmd.replace("SQLResult","")
sql_cmd=sql_cmd.replace("\\","")
print(sql_cmd) 

SELECT c_customer_id FROM web_sales ORDER BY ws_sold_price DESC LIMIT 1


In [136]:
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations

import warnings
from typing import Any, Iterable, List, Optional
import sqlalchemy
import re
from langchain.chains.base import Chain
from langchain import OpenAI, SQLDatabase,SQLDatabaseChain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable

from langchain import utils


class CustomerizedSQLDatabase(SQLDatabase):
    
    def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
        """Get information about specified tables.

        Follows best practices as specified in: Rajkumar et al, 2022
        (https://arxiv.org/abs/2204.00498)

        If `sample_rows_in_table_info`, the specified number of sample rows will be
        appended to each table description. This can increase performance as
        demonstrated in the paper.
        """
        all_table_names = self.get_usable_table_names()
        if table_names is not None:
            missing_tables = set(table_names).difference(all_table_names)
            if missing_tables:
                raise ValueError(f"table_names {missing_tables} not found in database")
            all_table_names = table_names

        meta_tables = [
            tbl
            for tbl in self._metadata.sorted_tables
            if tbl.name in set(all_table_names)
            and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
        ]

        tables = []
        for table in meta_tables:
            if self._custom_table_info and table.name in self._custom_table_info:
                tables.append(self._custom_table_info[table.name])
                continue

            # add create table command
            create_table = str(CreateTable(table).compile(self._engine))
            table_info = f"{create_table.rstrip()}"
            has_extra_info = (
                self._indexes_in_table_info or self._sample_rows_in_table_info
            )
            if has_extra_info:
                table_info += "\n\n/*"
            if self._indexes_in_table_info:
                table_info += f"\n{self._get_table_indexes(table)}\n"
            if self._sample_rows_in_table_info:
                table_info += f"\n{self._get_sample_rows(table)}\n"
            if has_extra_info:
                table_info += "*/"
            tables.append(table_info)
        final_str = "\n\n".join(tables)
        return final_str




class CustomerizedSQLDatabaseChain(SQLDatabaseChain):
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        input_text = f"{inputs[self.input_key]}\nSQLQuery:"
        _run_manager.on_text(input_text, verbose=self.verbose)
        # If not present, then defaults to None which is all tables.
        table_names_to_use = inputs.get("table_names_to_use")
        if table_names_to_use is None:
            table_names_to_use = self.database._include_tables
        table_info = self.database.get_table_info(table_names=list(table_names_to_use))
        llm_inputs = {
            "input": input_text,
            "top_k": str(self.top_k),
            "dialect": self.database.dialect,
            "table_info": table_info,
            #"stop": ["\nSQLResult:"],
        }
        intermediate_steps: List = []
        try:
            intermediate_steps.append(llm_inputs)  # input: sql generation
            sql_cmd = self.llm_chain.predict(
                callbacks=_run_manager.get_child(),
                **llm_inputs,
            ).strip()
            if not self.use_query_checker:
                _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
                intermediate_steps.append(
                    sql_cmd
                )  # output: sql generation (no checker)
                intermediate_steps.append({"sql_cmd": sql_cmd})  # input: sql exec
                ###定制 解析chatglm的输出sql，尝试抽取原始'''内的sql#######
                pattern = r"```([^.]*)```"
                matches = re.findall(pattern, sql_cmd)
                for match in matches:
                   sql_cmd = match
                   sql_cmd=sql_cmd.replace("```","")
                   sql_cmd=sql_cmd.replace("sql","")
                #########################################################
                
                ###定制 解析vicuna的输出sql，尝试抽取原始SQLQuery:和Answer: 内的sql#######
                #pattern = r"SQLQuery: (.*?)\nAnswer"
                #matches = re.findall(pattern, sql_cmd)
                #match = matches[0]
                #sql_cmd = match
                #sql_cmd=sql_cmd.replace("SQLQuery:","")
                #sql_cmd=sql_cmd.replace("Answer","")
                #sql_cmd=sql_cmd.replace("\\","")
                #print(sql_cmd) 
                #########################################################
                
                
                 ###定制 解析bloomz的输出sql，尝试抽取原始SQLQuery:和SQLResult: 内的sql#######
                pattern = r"SQLQuery: (.*?)\nSQLResult"
                print("sql_cmd original=="+sql_cmd)
                matches = re.findall(pattern, sql_cmd)
                match = matches[1]
                sql_cmd = match
                sql_cmd=sql_cmd.replace("SQLQuery:","")
                sql_cmd=sql_cmd.replace("SQLResult","")
                sql_cmd=sql_cmd.replace("\\","")
                print(sql_cmd) 
                #########################################################
                
                result = self.database.run(sql_cmd)
                intermediate_steps.append(str(result))  # output: sql exec
            else:
                query_checker_prompt = self.query_checker_prompt or PromptTemplate(
                    template=QUERY_CHECKER, input_variables=["query", "dialect"]
                )
                query_checker_chain = LLMChain(
                    llm=self.llm_chain.llm, prompt=query_checker_prompt
                )
                print("here2==="+sql_cmd)
                query_checker_inputs = {
                    "query": sql_cmd,
                    "dialect": self.database.dialect,
                }
                checked_sql_command: str = query_checker_chain.predict(
                    callbacks=_run_manager.get_child(), **query_checker_inputs
                ).strip()
                intermediate_steps.append(
                    checked_sql_command
                )  # output: sql generation (checker)
                _run_manager.on_text(
                    checked_sql_command, color="green", verbose=self.verbose
                )
                intermediate_steps.append(
                    {"sql_cmd": checked_sql_command}
                )  # input: sql exec
                
                result = self.database.run(checked_sql_command)
                intermediate_steps.append(str(result))  # output: sql exec
                sql_cmd = checked_sql_command

            _run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
            _run_manager.on_text(result, color="yellow", verbose=self.verbose)
            # If return direct, we just set the final result equal to
            # the result of the sql query result, otherwise try to get a human readable
            # final answer

            if self.return_direct:
                final_result = result
            else:
                print("here8888===")
                print(llm_inputs)
                _run_manager.on_text("\nAnswer:", verbose=self.verbose)
                input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
                llm_inputs["input"] = input_text
                intermediate_steps.append(llm_inputs)  # input: final answer
                final_result = self.llm_chain.predict(
                    callbacks=_run_manager.get_child(),
                    **llm_inputs,
                ).strip()
                intermediate_steps.append(final_result)  # output: final answer
                _run_manager.on_text(final_result, color="green", verbose=self.verbose)
            chain_result: Dict[str, Any] = {self.output_key: final_result}
            if self.return_intermediate_steps:
                chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
            return chain_result
        except Exception as exc:
            # Append intermediate steps to exception, to aid in logging and later
            # improvement of few shot prompt seeds
            exc.intermediate_steps = intermediate_steps  # type: ignore
            raise exc

In [122]:
## langchain agent demo test##########
from langchain import OpenAI, SQLDatabase,SQLDatabaseChain
import json
import os

os.environ["PGPASSWORD"] = "*******"
os.environ["LANGCHAIN_HANDLER"] = "langchain"
os.environ["OPENAI_API_KEY"] = "sk-3lTxHcynfeJ4*******lbkFJz4k968DnZNADqgT583TF"
conn_str = "awsathena+rest://{aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/"\
           "{schema_name}?s3_staging_dir={s3_staging_dir}"
conn_str = conn_str.format(
    aws_access_key_id="**********",
    aws_secret_access_key="***********",
    region_name="us-west-2",
    schema_name="specturmdb",
    s3_staging_dir="s3://tangqy-athenaoutput-us-west-2")

db = SQLDatabase.from_uri(
    #conn_str,
    #"redshift+redshift_connector://admin@redshift-cluster-1.cp1kgq7oikv3.ap-southeast-1.redshift.amazonaws.com:5439/dev",
    "mysql+pymysql://admin:******@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    #"jdbc:awsathena://athena.us-west-2.amazonaws.com:443/tpcds_bin_partitioned_orc_300?s3_staging_dir=s3://tangqy-athenaoutput/&aws_credentials_provider_class=com.amazonaws.auth.DefaultAWSCredentialsProviderChain",
    include_tables=['web_sales','customer'], # we include only one table to save tokens in the prompt :)
    #include_tables=["tpcds_text_1000.web_sales","tpcds_text_1000.customer"],
    sample_rows_in_table_info=0)
llm = OpenAI(temperature=0, verbose=True)
#db_chain = CustomerizedSQLDatabaseChain(llm=llm, database=db, verbose=True, top_k=3)
db_chain = SQLDatabaseChain(llm=sm_llm, database=db, verbose=True, top_k=3)

db_chain.run("我需要知道销售报表中，下单金额的平均数最大的客户id")
#db_chain.run("I need to know the max sales customer's id in sales report")





[1m> Entering new SQLDatabaseChain chain...[0m
我需要知道销售报表中，下单金额的平均数最大的客户id
SQLQuery:[32;1m[1;3mSELECT c_customer_id, AVG(ws_net_paid) AS avg_net_paid 
FROM customer 
INNER JOIN web_sales ON c_customer_sk = ws_bill_customer_sk 
GROUP BY c_customer_id 
ORDER BY avg_net_paid DESC 
LIMIT 3;[0m
SQLResult: [33;1m[1;3m[('AAAAAAAAOCIDIAAA', Decimal('7751.353750')), ('AAAAAAAAGFKBHAAA', Decimal('7730.946250')), ('AAAAAAAAOJKHGAAA', Decimal('7603.959000'))][0m
Answer:



[32;1m[1;3m最大的客户id是AAAAAAAAOCIDIAAA[0m
[1m> Finished chain.[0m


'最大的客户id是AAAAAAAAOCIDIAAA'

In [None]:
db_chain = CustomerizedSQLDatabaseChain(llm=sm_llm, database=db, verbose=True, top_k=3)
db_chain.run("我需要知道销售报表中，下单金额最大的客户id")
#db_chain.run("I need to know the max sales customer's id in web_sales report")

In [None]:
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain.chat_models import ChatOpenAI
from langchain.utilities import BingSearchAPIWrapper
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)

credentials = boto3.Session().get_credentials()
region = boto3.Session().region_name
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, 'es', session_token=credentials.token)

class CustomEmbeddingSearchTool(BaseTool):
    name = "custom_knn_search"
    aos_client = OpenSearch(
                hosts=[{'host': aos_endpoint, 'port': 443}],
                http_auth = awsauth,
                use_ssl=True,
                verify_certs=True,
                connection_class=RequestsHttpConnection)
    aos_index="metadata_labels"
        
    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Opensearch 向量检索."""
        start = time.time()
        query_embedding = get_vector_by_sm_endpoint(query, sm_client, sm_embeddings)
        elpase_time = time.time() - start
        print(f'runing time of opensearch_knn : {elpase_time}s seconds')
        return get_topk_item(aos_knn_search(client, q_embedding, aos_index, size=10),2)
         
        
   

class CustomReverseIndexSearchTool(BaseTool):
    name = "custom_reverse_search"
    aos_client = OpenSearch(
                hosts=[{'host': aos_endpoint, 'port': 443}],
                http_auth = awsauth,
                use_ssl=True,
                verify_certs=True,
                connection_class=RequestsHttpConnection)
    aos_index="metadata_labels"
    
    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Opensearch 标签检索."""
        start = time.time()
        opensearch_query_response = aos_reverse_search(aos_client, aos_index, "doc", query_input)
        # logger.info(opensearch_query_response)
        elpase_time = time.time() - start
        logger.info(f'runing time of opensearch_query : {elpase_time}s seconds')
        return get_topk_item(opensearch_query_response,2)
        


custom_tool_list=[]
custom_tool_list.append(
    Tool(
        func=CustomReverseIndexSearchTool.run,
        name="reverse index search",
        description="用于向量检索找到具体的数据库和表名"
    )  
)

custom_tool_list.append(
    Tool(
        func=CustomEmbeddingSearchTool.run,
        name="embedding knn search",
        description="用于标签检索找到具体的数据库和表名"
    )    
)

db = SQLDatabase.from_uri(
    "mysql+pymysql://admin:******@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    sample_rows_in_table_info=0)

sm_llm=SagemakerEndpoint(
        #endpoint_name="chatglm-inference-0524-2023-06-01-07-11-27-379",
        endpoint_name="vicuna-7B-2023-06-04-13-07-39-746-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=text_gen_content_handler3
)


toolkit = SQLDatabaseToolkit(db=db, llm=sm_llm)

custom_suffix = """
我应该先利用标签检索，找到具体的数据库和表名，
如果找不到，则利用向量检索查找，
然后使用数据库工具查看我刚才找到的应该查询的库和表的详细schema元数据。
"""
agent = create_sql_agent(llm=llm,
                         toolkit=toolkit,
                         verbose=True,
                         agent_type=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
                         extra_tools=custom_tool_list,
                         suffix=custom_suffix
                        )

agent.run("我需要知道销售报表中，下单金额最大的客户id")


* zero shot Agent test

In [None]:
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from langchain import OpenAI, LLMChain
from pydantic import BaseModel, Field
from langchain import PromptTemplate



chatglm_db_prompt_template = """你是MySQL的专家。给定一个输入问题，创建一个语法正确的MySQL查询语句。
除非用户在问题中指定了要获得的特定数量的示例，否则使用LIMIT子句查询最多{top_k}个结果。您可以对结果进行排序，以返回数据库中信息量最大的数据。您必须仅查询回答问题所需的列。将每个列名用反引号（`）括起来，表示为分隔的标识符。
请注意，仅可以使用在{table_info}这些表中看到的列名，不要查询不存在的列。此外，还要注意哪个列在哪个表中。如果问题涉及”今天”，请注意使用CURDATE()函数获取当前日期.
使用以下格式：
Question:此处为问题
SQLQuery:要运行的SQL查询
SQLResult:SQLQuery的结果
Answer:此处为最终答案

"""


PROMPT_SUFFIX = """Question:{input}"""

chatglm_db_prompt = PromptTemplate(
    input_variables=["input", "table_info", "top_k"],
    template=chatglm_db_prompt_template+PROMPT_SUFFIX,
)


#from langchain.chains.question_answering import load_qa_chain
#chain = LLMChain(llm=sm_llm,prompt=chatglm_db_prompt)
#chain.run({"input":"我需要知道销售报表中，下单金额最大的客户id","table_info":"'web_sales','customer'","top_k":"3"})


#db_chain = CustomerizedSQLDatabaseChain.from_llm(llm=sm_llm, prompt=chatglm_db_prompt,db=db, verbose=True, top_k=3)
db_chain = CustomerizedSQLDatabaseChain.from_llm(llm=sm_llm, db=db, verbose=True, top_k=3)
db_chain.run("我需要知道销售报表中，下单金额最大的客户id")
#db_chain.run("帮我查下销售报表中，最大的销售金额")

#tools = []
#prefix = """尽你所能回答以下问题。您可以访问以下工具"""
#suffix = """开始! 
#
#问题: {input}
#{agent_scratchpad}"""
#
#prompt = ZeroShotAgent.create_prompt(
#    tools, 
#    prefix=prefix, 
#    suffix=suffix, 
#    input_variables=["input", "agent_scratchpad"]
#)
#
##class AnalyzeInput(BaseModel):
##    query: str = Field()
#
#print(prompt.template)
#llm_chain = LLMChain(llm=llm, prompt=prompt)
#agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
#agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
#agent_executor.run("我需要知道llm数据库的web_sales销售报表中，ws_quantity和ws_ext_sales_price乘积最大的客户的c_customer_id")