## langcain agent demo for ReAct

In [None]:
!pip install langchain[all]
!pip install langchain-experimental
#!pip install sagemaker --upgrade
!pip install  boto3
!pip install requests_aws4auth
!pip install opensearch-py
!pip install pydantic==1.10.0
#!pip install PyAthena[SQLAlchemy]==1.0.0
#!pip install PyAthena[JDBC]==1.0.0
#!pip install openai
!pip install sqlalchemy-redshift
!pip install redshift_connector
!pip install SQLAlchemy

In [None]:
!pip install SQLAlchemy

## initial sagemaker env

In [None]:
import os
import sagemaker
import boto3
import json
from typing import Dict
from typing import Any, Dict, List, Optional

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
sm_client = boto3.client("sagemaker-runtime")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## intial lanchain lib

In [None]:
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory,ConversationBufferMemory
from langchain import LLMChain
from typing import Any, Dict, List, Union,Mapping, Optional, TypeVar, Union

#os.environ["OPENAI_API_KEY"]= "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"
#from langchain.embeddings.openai import OpenAIEmbeddings

aos_endpoint="vpc-llm-rag-aos-seg3mzhpp76ncpxezdqtcsoiga.us-west-2.es.amazonaws.com"
region='us-west-2'
username="admin"
passwd="(OL>0p;/"
index_name="metadata_index"
size=10

### for sqlcoder
class TextGenContentHandler2(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        #print(response_json)
        #sql_result=response_json["outputs"].split("```sql")[-1].split("```")[0].split(";")[0].strip().replace("\\n"," ") + ";"
        sql_result=response_json["outputs"]
        return sql_result


content_hander2=TextGenContentHandler2()

## for embedding
class EmbeddingContentHandler(EmbeddingsContentHandler):
    parameters = {
        "max_new_tokens": 50,
        "temperature": 0,
        "min_length": 10,
        "no_repeat_ngram_size": 2,
    }
    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["sentence_embeddings"]

embedding_content_handler=EmbeddingContentHandler()
    
sm_embeddings = SagemakerEndpointEmbeddings(
    # endpoint_name="endpoint-name", 
    # credentials_profile_name="credentials-profile-name", 
    #endpoint_name="huggingface-textembedding-bloom-7b1-fp1-2023-04-17-03-31-12-148", 
    endpoint_name="st-paraphrase-mpnet-base-v2-2023-04-17-10-05-10-718-endpoint",
    region_name="us-west-2", 
    content_handler=embedding_content_handler
)


parameters = {
  "max_new_tokens": 300,
  #"no_repeat_ngram_size": 2,
}
sm_sql_llm=SagemakerEndpoint(
        endpoint_name="sqlcoder-2023-09-24-06-31-09-959-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=content_hander2
)


## func for agent

In [None]:
import boto3
import json
import requests
import time
from collections import defaultdict
from requests_aws4auth import AWS4Auth
import os
from opensearchpy import OpenSearch, RequestsHttpConnection
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory
from langchain import LLMChain



def aos_knn_search(client, q_embedding, index, size=10):
    query = {
        "size": size,
        "query": {
            "knn": {
                "query_desc_embedding": {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }
    opensearch_knn_respose = []
    query_response = client.search(
        body=query,
        index=index
    )
    opensearch_knn_respose = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'query_desc_text':item['query_desc_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose

def aos_reverse_search(client, index_name, field, query_term, exactly_match=False, size=10):
    """
    search opensearch with query.
    :param host: AOS endpoint
    :param index_name: Target Index Name
    :param field: search field
    :param query_term: query term
    :return: aos response json
    """
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': client, 'port': 443}],
            http_auth = awsauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = None
    if exactly_match:
        query =  {
            "query" : {
                "match_phrase":{
                    "doc": {
                        "query": query_term,
                        "analyzer": "ik_smart"
                      }
                }
            }
        }
    else:
        query = {
            "size": size,
            "query": {
                    "bool": {
                            "must": [{
                                "term": {
                                         "doc_type": "Question"
                                        }
                                     },
                                    {
                                        "match": {
                                            "doc": query_term
                                        }
                                    }
                                ]
                    },
           "sort": [{
               "_score": {
                   "order": "desc"
               }
           }]
        }
    }        
    query_response = client.search(
        body=query,
        index=index_name
    )
    
    result_arr = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'query_desc_text':item['query_desc_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return result_arr




def get_vector_by_sm_endpoint(questions, sm_client, endpoint_name):
    parameters = {
    }

    response_model = sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(
            {
                "inputs": questions,
                "parameters": parameters,
                "is_query" : True,
                "instruction" :  "为这个句子生成表示以用于检索相关文章："
            }
        ),
        ContentType="application/json",
    )
    json_str = response_model['Body'].read().decode('utf8')
    json_obj = json.loads(json_str)
    embeddings = json_obj['sentence_embeddings']
    return embeddings




def get_topk_items(opensearch_query_response, topk=5):
    opensearch_knn_nodup = []
    unique_ids = set()
    for item in opensearch_query_response:
        if item['id'] not in unique_ids:
            opensearch_knn_nodup.append(item['score'], item['idx'],item['database_name'],item['table_name'],item['query_desc_text'])
            unique_ids.add(item['id'])
    return opensearch_knn_nodup



def k_nn_ingestion_by_aos(docs,index,hostname,username,passwd):
    auth = (username, passwd)
    search = OpenSearch(
        hosts = [{'host': aos_endpoint, 'port': 443}],
        ##http_auth = awsauth ,
        http_auth = auth ,
        use_ssl = True,
        verify_certs = True,
        connection_class = RequestsHttpConnection
    )
    for doc in docs:
        query_desc_embedding = doc['query_desc_embedding']
        database_name = doc['database_name']
        table_name = doc['table_name']
        query_desc_text = doc["query_desc_text"]
        document = { "query_desc_embedding": query_desc_embedding, 'database_name':database_name, "table_name": table_name,"query_desc_text":query_desc_text}
        search.index(index=index, body=document)

## major chain pipeline ################

### 0: index 创建

PUT metadata-index
{
    "settings" : {
        "index":{
            "number_of_shards" : 5,
            "number_of_replicas" : 0,
            "knn": "true",
            "knn.algo_param.ef_search": 32
        }
    },
    "mappings": {
        "properties": {
            "metadata_type" : {
                "type" : "keyword"
            },
            "database_name": {
                "type" : "keyword"
            },
            "table_name": {
               "type" : "keyword"
            },
            "query_desc_text": {
                "type": "text",
                "analyzer": "ik_max_word",
                "search_analyzer": "ik_smart"
            },
            "query_desc_embedding": {
                "type": "knn_vector",
                "dimension": 768,
                "method": {
                    "name": "hnsw",
                    "space_type": "l2",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 512,
                        "m": 32
                    }
                }            
            }
        }
    }
}

### 1: data process

In [None]:
all_querys = """最近一个月温度合格的派车单数量"""
querys = all_querys.split("\n")

all_tables = """ads_bi_quality_monitor_shipping_detail"""
tables=all_tables.split("\n")

all_dbs = """llm"""
dbs=all_dbs.split("\n")

In [None]:
#import sys
#sys.path.append("./code/")
#import func

endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint"
##########embedding by llm model##############
sentense_vectors = []
sentense_vectors=get_vector_by_sm_endpoint(querys,sm_client,endpoint_name)



In [None]:
docs=[]
for index, sentence_vector in enumerate(sentense_vectors):
    #print(index, sentence_vector)
    doc = {
        "metadata_type":"table",
        "database_name":dbs[index],
        "table_name": tables[index],
        "query_desc_text":querys[index],
        "query_desc_embedding": sentence_vector
          }
    docs.append(doc)

#print(docs)
23
#########ingestion into aos ###################
k_nn_ingestion_by_aos(docs,index_name,aos_endpoint,username,passwd)

### 2:自定义Agent ，定制context
1:自定义AOS倒排及knn检索tools    
2:自定义中文Sql Agent 的ReAct prompt 前缀   
3:使用customerSqlDatabaseChain+Sql Agent触发   

In [None]:
from langchain.tools.base import BaseTool, Tool, tool
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
import os
import time
import json
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
from typing import Optional, Type
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)

credentials = boto3.Session().get_credentials()
region = boto3.Session().region_name
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, 'es', session_token=credentials.token)

class CustomEmbeddingSearchTool(BaseTool):
    name = "custom_knn_search"
    aos_client = OpenSearch(
                hosts=[{'host': aos_endpoint, 'port': 443}],
                http_auth = awsauth,
                use_ssl=True,
                verify_certs=True,
                connection_class=RequestsHttpConnection)
    aos_index="metadata-index"
        
    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Opensearch 向量检索."""
        start = time.time()
        query_embedding = get_vector_by_sm_endpoint(query, sm_client, sm_embeddings)
        elpase_time = time.time() - start
        print(f'runing time of opensearch_knn : {elpase_time}s seconds')
        return get_topk_item(aos_knn_search(client, q_embedding, aos_index, size=10),2)
         
        
   

class CustomReverseIndexSearchTool(BaseTool):
    name = "custom_reverse_search"
    aos_client = OpenSearch(
                hosts=[{'host': aos_endpoint, 'port': 443}],
                http_auth = awsauth,
                use_ssl=True,
                verify_certs=True,
                connection_class=RequestsHttpConnection)
    aos_index="metadata_labels"
    
    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Opensearch 标签检索."""
        start = time.time()
        opensearch_query_response = aos_reverse_search(aos_client, aos_index, "doc", query_input)
        # logger.info(opensearch_query_response)
        elpase_time = time.time() - start
        logger.info(f'runing time of opensearch_query : {elpase_time}s seconds')
        return get_topk_item(opensearch_query_response,2)
        


custom_tool_list=[]
custom_tool_list.append(
    Tool(
        func=CustomReverseIndexSearchTool.run,
        name="reverse index search",
        description="用于向量检索找到具体的数据库和表名"
    )  
)

custom_tool_list.append(
    Tool(
        func=CustomEmbeddingSearchTool.run,
        name="embedding knn search",
        description="用于标签检索找到具体的数据库和表名"
    )    
)

db = SQLDatabase.from_uri(
    "mysql+pymysql://admin:admin12345678@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    sample_rows_in_table_info=0)


### for llama2
class TextGenContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["outputs"]


content_hander=TextGenContentHandler()

sm_llm=SagemakerEndpoint(
        endpoint_name="lmi-model-2023-09-25-09-56-14-425",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=content_hander
)



toolkit = SQLDatabaseToolkit(db=db, llm=sm_llm)

custom_suffix = """
我应该先利用标签检索，找到具体的数据库和表名，
如果找不到，则利用向量检索查找，
然后使用数据库工具查看我刚才找到的应该查询的库和表的详细schema元数据。
"""
agent = create_sql_agent(llm=llm,
                         toolkit=toolkit,
                         verbose=True,
                         agent_type=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
                         extra_tools=custom_tool_list,
                         suffix=custom_suffix
                        )

agent.run("我需要知道销售报表中，下单金额最大的客户id")

## llm_chain自定义 Agent test

* 使用llm chain做chat的planner

In [None]:
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.chains.base import Chain
from langchain.prompts.base import BasePromptTemplate
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner
from langchain.agents.tools import Tool


sql_prompt_template = """
You are a MySQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. 
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
{table_info}

Question: {question}
SQLQuery:
"""


memory = ConversationBufferMemory(memory_key="chat_history")
prompt = PromptTemplate(
        template=sql_prompt_template, input_variables=["table_info", "question","dialect"]
    )


def build_chat_chain():
    conversation_with_summary_chain = ConversationChain(
        llm=sm_llm, 
        memory=memory,
        verbose=True
    )
    return conversation_with_summary_chain

* 使用SqlDatabaseChain作为db tools做数据库交互
* 将db tools加入之前定义的元数据召回的tools列表

In [None]:
def run_query(query):
    PROMPT_sql = PromptTemplate(
        input_variables=["question", "table_info", "dialect"], template=sql_prompt_template
    )
    db_chain = SQLDatabaseChain.from_llm(sm_sql_llm, db, prompt=PROMPT_sql, verbose=True, return_intermediate_steps=False)
    response=db_chain.run(query)
    return response

custom_tool_list.append(
   Tool(
        name="Db Querying Tool",
        func=run_query,
        description="""用于连接数据库获取schema元数据信息"""
    ))

* 使用llm chain做chat的planner
* 使用langchain plannerAndExecutor

In [None]:
from langchain.prompts import (
    ChatPromptTemplate,
    PromptTemplate,
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

combo_template = """
让我们先了解问题并制定解决问题的计划。
请输出以标题“plan:”开头的计划，然后输出步骤的编号列表。在计划步骤时，不要使用过去的对话历史记录。
请将计划制定为准确完成任务所需的最少步骤数。

以下是关于使用工具解决任务的指导，请严格遵守：
先利用"reverse index search" 标签检索工具
如果结果为空，则利用"embedding knn search"向量检索工具查找
然后使用"db Querying Tool"数据库工具查看我刚才找到的应该查询的库和表的详细schema元数据。
"""
planner = load_chat_planner(sm_llm)

system_message_prompt = SystemMessagePromptTemplate.from_template(combo_template)
human_message_prompt = planner.llm_chain.prompt.messages[1]
planner.llm_chain.prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

executor = load_agent_executor(sm_llm, custom_tool_list, verbose=True)
agent = PlanAndExecute(planner=planner, executor=executor, verbose=True, max_iterations=2,memory=memory)


output = agent("最近一个月温度合格的派车单数量")