## 依赖安装

In [None]:
###如果是国内，设置清华pip repo源
!pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

In [None]:
!pip install langchain[all]
!pip install requests_aws4auth
!pip install opensearch-py
!pip install pydantic==1.10.0
!pip install PyAthena[SQLAlchemy]==1.0.0
!pip install PyAthena[JDBC]==1.0.0
!pip install sqlalchemy-redshift
!pip install redshift_connector
!pip install pymysql
!pip install langchain_experimental

### 初始化rds元数据，aos index
 * 如果已经有RDS实例，修改setup.sh，skip create db instance 步骤
 * 确保你的aws configure正确设置aksk及region
 * 确保网络在同一VPC且入站规则互联互通

In [None]:
!chmod 777 ./setup.sh
!./setup.sh

* AOS domain creation 

In [36]:
#public公网aos创建
#!aws opensearch create-domain --domain-name llm-rag-aos --engine-version OpenSearch_2.3  --ebs-options EBSEnabled=true,VolumeType=gp2,VolumeSize=10

#vpc内aos创建
!aws opensearch create-domain   --domain-name my-domain --engine-version OpenSearch_2.3 --cluster-config InstanceType=r6g.xlarge.search,InstanceCount=2 --ebs-options EBSEnabled=true,VolumeType=gp2,VolumeSize=10 --access-policies '{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":["*"]},"Action":["es:*"],"Resource":"*"}]}' --vpc-options SubnetIds=subnet-0135e88c8e8da7369,SecurityGroupIds=sg-0aa3d61256d687a0c
!aws opensearch describe-domain --domain-name my-domain| jq -r '.DomainStatus.Endpoints.vpc'

vpc-my-domain-g63shn6r3volwzhs2gt7rzy7bq.us-west-2.es.amazonaws.com


* metadata ingestion
* 如用admin账户aos domain，修改上述脚本创建后的aos domain管理员账户为下述user password，并采用password auth认证方式创建index
* 如果是没有开admin账户的aos domain，用下面的AWSV4SignerAuth认证方式创建index
* 注意：如果开了aos的精细权限控制，则需要用账户密码或者AWSV4SignerAuth签名认证
* 确保你的vpc网络和安全组与notebook instance互联互通
* 修改aos_endpoint为上述创建的domain endpoint

## initial sagemaker env

In [38]:
import os
import sagemaker
import boto3
import json
from typing import Dict
from typing import Any, Dict, List, Optional

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
sm_client = boto3.client("sagemaker-runtime")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")



sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker role arn: arn:aws:iam::687912291502:role/service-role/AmazonSageMaker-ExecutionRole-20211013T113123
sagemaker bucket: sagemaker-us-west-2-687912291502
sagemaker session region: us-west-2


In [None]:
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, helpers
region='us-west-2'
username="admin"
passwd="(OL>0p;/"
size=10

auth = (username, passwd)

credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region)
index_name="prompt-optimal-index"
aos_endpoint="vpc-my-domain-g63shn6r3volwzhs2gt7rzy7bq.us-west-2.es.amazonaws.com"

schema={
    "settings" : {
        "index":{
            "number_of_shards" : 5,
            "number_of_replicas" : 0,
            "knn": "true",
            "knn.algo_param.ef_search": 32
        }
    },
    "mappings": {
        "properties": {
            "metadata_type" : {
                "type" : "keyword"
            },
            "database_name": {
                "type" : "keyword"
            },
            "table_name": {
               "type" : "keyword"
            },
            "exactly_query_text": {
                "type": "text",
                "analyzer": "ik_max_word",
                "search_analyzer": "ik_smart"
            },
            "exactly_query_embedding": {
                "type": "knn_vector",
                "dimension": 1024,
                "method": {
                    "name": "hnsw",
                    "space_type": "l2",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 512,
                        "m": 32
                    }
                }            
            }
        }
    }
}
search = OpenSearch(
    hosts = [{'host': aos_endpoint, 'port': 443}],
    ##http_auth = awsauth ,
    http_auth = auth ,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection
)
search.indices.create(index=index_name, body=schema)

## intial lanchain lib

In [None]:
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory,ConversationBufferMemory
from langchain import LLMChain
from typing import Dict


## for chatglm
class TextGenContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        #input_str = json.dumps({prompt: prompt, **model_kwargs})
        input_str = json.dumps({
                "ask": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["answer"]

### for vicuna/llama
class TextGenContentHandler2(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
                "inputs": prompt,
                "parameters": model_kwargs
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        #print(response_json)
        #sql_result=response_json["outputs"].split("```sql")[-1].split("```")[0].split(";")[0].strip().replace("\\n"," ") + ";"
        sql_result=response_json["outputs"]
        return sql_result


content_hander2=TextGenContentHandler2()

parameters = {
  #"early_stopping": True,
  #"length_penalty": 2.0,
  "max_new_tokens": 300,
  #"temperature": 0.6,
  #"max_tokens": 300,
  #"no_repeat_ngram_size": 2,
}
sm_llm=SagemakerEndpoint(
        #endpoint_name="chatglm-inference-0524-2023-06-01-07-11-27-379",
        endpoint_name="sqlcoder-2023-09-23-00-54-24-198-endpoint",
        #endpoint_name="codellame-2023-09-22-09-25-02-063-endpoint",
        region_name="us-west-2", 
        model_kwargs=parameters,
        content_handler=content_hander2,
        #endpoint_kwargs={'CustomAttributes':'accept_eula=true'}
)


## major chain pipeline ################

### Opt1:直接使用 langchain SQLDatabaseChain 
* text2sql prompt见如下langchain sqldatabasaseChain 所示  
* SqlDatabaseChain可以使用sagemaker endpoint llm，也可以使用BedRock
* 此处使用langchain BedRock

In [None]:
import re
text2sq_prompt_template="""
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:

CREATE TABLE ads_bi_quality_monitor_shipping_detail (
	shipping_order_code VARCHAR(100) COMMENT '派车单编码', 
	license_plate VARCHAR(100) COMMENT '车牌号', 
	truck_type VARCHAR(50) COMMENT '车辆类型', 
	tenant_id VARCHAR(100) COMMENT '租户编码', 
	tenant_name VARCHAR(200) COMMENT '租户名称', 
	father_company_code VARCHAR(100) COMMENT '分子公司编码', 
	father_company_name VARCHAR(200) COMMENT '分子公司名称', 
	father_company_short_name VARCHAR(200) COMMENT '分子公司简称', 
	start_transport_time VARCHAR(50) COMMENT '运输出发时间', 
	signing_time VARCHAR(50) COMMENT '派车单签收时间', 
	plan_start_time VARCHAR(50) COMMENT '派车单计划取货时间', 
	plan_end_time VARCHAR(50) COMMENT '派车单计划送达时间', 
	frist_fence_time VARCHAR(50) COMMENT '第一次碰撞装货地电子围栏时间', 
	leave_load_station_2km_time VARCHAR(50) COMMENT '离开最后一个装货地电子围栏2km的时间', 
	frist_arrive_unload_time VARCHAR(50) COMMENT '首次到达卸货点电子围栏时间', 
	transport_type VARCHAR(10) COMMENT '运输类型(干线/城配)', 
	plan_mileage VARCHAR(50) COMMENT '计划里程', 
	driver_accept_time VARCHAR(50) COMMENT '司机接单时间', 
	driver_name VARCHAR(50) COMMENT '司机姓名', 
	driver_phone VARCHAR(50) COMMENT '司机电话', 
	driver_type VARCHAR(10) COMMENT '司机类型', 
	transport_tenant_id VARCHAR(100) COMMENT '运力承运商id', 
	transport_tenant_name VARCHAR(200) COMMENT '运力承运商名称', 
	warm_area VARCHAR(50) COMMENT '温区信息', 
	waybill_count INTEGER COMMENT '运单数量', 
	gps_device_list VARCHAR(100) COMMENT 'gps设备', 
	gps_report_dot_num INTEGER COMMENT 'gps上报点数', 
	temp_substandard_min DECIMAL(20, 8) COMMENT '温度不达标时长(分钟)', 
	temp_substandard_min_n12 DECIMAL(20, 8) COMMENT '温度不达标时长(分钟)_-12', 
	temp_substandard_min_n16 DECIMAL(20, 8) COMMENT '温度不达标时长(分钟)_-16', 
	shipping_order_transport_min DECIMAL(20, 8) COMMENT '派车单运输总时长(分钟)', 
	shipping_order_temp_standard_rate DECIMAL(20, 8) COMMENT '派车单温度达标率', 
	shipping_order_temp_standard_rate_n12 DECIMAL(20, 8) COMMENT '派车单温度达标率_-12', 
	shipping_order_temp_standard_rate_n16 DECIMAL(20, 8) COMMENT '派车单温度达标率_-16', 
	dep_fence_match_num INTEGER COMMENT '出发地电子围栏匹配数量', 
	dep_total_num INTEGER COMMENT '出发地总数量', 
	dep_fence_match_rate DECIMAL(20, 8) COMMENT '出发地电子围栏匹配率', 
	des_fence_match_num INTEGER COMMENT '目的地电子围栏匹配数量', 
	des_total_num INTEGER COMMENT '目的地总数量', 
	des_fence_match_rate DECIMAL(20, 8) COMMENT '目的地电子围栏匹配率', 
	exception_shipping_order_type VARCHAR(10) COMMENT '异常派车单情况,0-非异常', 
	settlement_tenant_id VARCHAR(100) COMMENT '结算主体租户id', 
	settlement_tenant_name VARCHAR(100) COMMENT '结算主体租户名称', 
	truck_ownership VARCHAR(100) COMMENT '车辆所有权(自有/外请/临时)', 
	load_finish_time VARCHAR(50) COMMENT '派车单点击装货完成时间', 
	last_leave_unload_time VARCHAR(50) COMMENT '最后离开卸货点时间', 
	temp_right_tag VARCHAR(20) COMMENT '温度是否合格(合格/不合格/不参与评估)', 
	warm_area_type VARCHAR(20) COMMENT '温区类型。常温/冷链/自定义区间', 
	source_dt VARCHAR(20) COMMENT '派车单来源的增量表dt', 
	cur_shipping_odr_cust TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci COMMENT '当前派车单对应的客户', 
	is_app_operation VARCHAR(10) COMMENT '是否APP操作卡控', 
	leave_load_station_time_app VARCHAR(50) COMMENT '离开第一个装货地时间_app', 
	frist_arrive_unload_time_app VARCHAR(50) COMMENT '到达第一个卸货点时间_app', 
	last_arrive_unload_time_app VARCHAR(50) COMMENT '到达最后一个卸货点时间_app', 
	temp_calc_start_time VARCHAR(50) COMMENT '温度计算开始时间', 
	temp_calc_end_time VARCHAR(50) COMMENT '温度计算结束时间', 
	exception_shipping_order_type_desc VARCHAR(100) COMMENT '异常派车单情况描述', 
	root_shipping_odr_cust TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci COMMENT '根派车单对应的客户', 
	is_use_custom_temp_range VARCHAR(10) COMMENT '是否使用自定义温度范围', 
	temp_eval_lowest INTEGER COMMENT '温度考核最低值', 
	temp_eval_highest INTEGER COMMENT '温度考核最高值', 
	is_gps_cover VARCHAR(10) COMMENT 'GPS是否覆盖', 
	shipping_create_time VARCHAR(50) COMMENT '派车单创建时间', 
	gps_device_list_plan VARCHAR(100) COMMENT 'GPS设备_预估', 
	is_gps_cover_plan VARCHAR(10) COMMENT 'GPS是否覆盖_预估', 
	first_station_is_on_time VARCHAR(10) COMMENT '首店是否准时', 
	is_many_warm VARCHAR(10) COMMENT '是否多温区', 
	many_temp_standard_rate DECIMAL(20, 8) COMMENT '多温区温度达标率', 
	lc_standard_ratio DECIMAL(20, 8) COMMENT '冷藏温度达标率', 
	ld_standard_ratio DECIMAL(20, 8) COMMENT '冷冻温度达标率', 
	probe_warm_list TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci COMMENT '探头温区'
)ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='BI看板_品质监控_派车单明细'

Question: 最近一个月温度合格的派车单数量
SQLQuery: 
"""

* 定制langchain SqlDataBase
* 可以实现对sql生成更个性化控制（如，去掉冗余信息，sql改写/优化等）
* 此处示例对输出去掉SQLQuery的前缀，直接取sql语句

In [None]:
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations

import warnings
from typing import Any, Iterable, List, Optional
import sqlalchemy
import re
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chains import create_sql_query_chain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable



class CustomerizedSQLDatabaseChain(SQLDatabaseChain):
    
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        input_text = f"{inputs[self.input_key]}\nSQLQuery:"
        _run_manager.on_text(input_text, verbose=self.verbose)
        # If not present, then defaults to None which is all tables.
        table_names_to_use = inputs.get("table_names_to_use")
        table_info = self.database.get_table_info(table_names=table_names_to_use)
        llm_inputs = {
            "input": input_text,
            "top_k": str(self.top_k),
            "dialect": self.database.dialect,
            "table_info": table_info,
            #"stop": ["\nSQLResult:"],
        }
        intermediate_steps: List = []
        try:
            intermediate_steps.append(llm_inputs)  # input: sql generation
            sql_cmd = self.llm_chain.predict(
                callbacks=_run_manager.get_child(),
                **llm_inputs,
            ).strip()
            #print("orginal sql_cmd=="+sql_cmd)
            if self.return_sql:
                return {self.output_key: sql_cmd}
            if not self.use_query_checker:
                _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
                intermediate_steps.append(
                    sql_cmd
                )  # output: sql generation (no checker)
                #########定制bedRock模型输出##############
                pattern = r"SQL执行结果: (.*?)\n"
                matches = re.findall(pattern, sql_cmd)
                match = matches[1]
                sql_cmd = match
                #print("query sql=="+sql_cmd) 
                
                intermediate_steps.append({"sql_cmd": sql_cmd})  # input: sql exec
                result = self.database.run(sql_cmd)
                intermediate_steps.append(str(result))  # output: sql exec
            else:
                query_checker_prompt = self.query_checker_prompt or PromptTemplate(
                    template=QUERY_CHECKER, input_variables=["query", "dialect"]
                )
                query_checker_chain = LLMChain(
                    llm=self.llm_chain.llm, prompt=query_checker_prompt
                )
                query_checker_inputs = {
                    "query": sql_cmd,
                    "dialect": self.database.dialect,
                }
                checked_sql_command: str = query_checker_chain.predict(
                    callbacks=_run_manager.get_child(), **query_checker_inputs
                ).strip()
                intermediate_steps.append(
                    checked_sql_command
                )  # output: sql generation (checker)
                _run_manager.on_text(
                    checked_sql_command, color="green", verbose=self.verbose
                )
                intermediate_steps.append(
                    {"sql_cmd": checked_sql_command}
                )  # input: sql exec
                result = self.database.run(checked_sql_command)
                intermediate_steps.append(str(result))  # output: sql exec
                sql_cmd = checked_sql_command

            _run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
            _run_manager.on_text(result, color="yellow", verbose=self.verbose)
            # If return direct, we just set the final result equal to
            # the result of the sql query result, otherwise try to get a human readable
            # final answer
            if self.return_direct:
                final_result = result
            else:
                _run_manager.on_text("\nAnswer:", verbose=self.verbose)
                input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
                llm_inputs["input"] = input_text
                intermediate_steps.append(llm_inputs)  # input: final answer
                final_result = self.llm_chain.predict(
                    callbacks=_run_manager.get_child(),
                    **llm_inputs,
                ).strip()
                intermediate_steps.append(final_result)  # output: final answer
                _run_manager.on_text(final_result, color="green", verbose=self.verbose)
            chain_result: Dict[str, Any] = {self.output_key: final_result}
            if self.return_intermediate_steps:
                chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
            return chain_result
        except Exception as exc:
            # Append intermediate steps to exception, to aid in logging and later
            # improvement of few shot prompt seeds
            exc.intermediate_steps = intermediate_steps  # type: ignore
            raise exc

* 初始化Bedrock

In [None]:
import os
from typing import Optional

# External Dependencies:
import boto3
from botocore.config import Config
from langchain.llms.bedrock import Bedrock

def get_bedrock_client(
    assumed_role: Optional[str] = None,
    region: Optional[str] = None,
    runtime: Optional[bool] = True,
):
  
    if region is None:
        target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
    else:
        target_region = region

    print(f"Create new client\n  Using region: {target_region}")
    session_kwargs = {"region_name": target_region}
    client_kwargs = {**session_kwargs}

    profile_name = os.environ.get("AWS_PROFILE")
    if profile_name:
        print(f"  Using profile: {profile_name}")
        session_kwargs["profile_name"] = profile_name

    retry_config = Config(
        region_name=target_region,
        retries={
            "max_attempts": 10,
            "mode": "standard",
        },
    )
    session = boto3.Session(**session_kwargs)

    if assumed_role:
        print(f"  Using role: {assumed_role}", end='')
        sts = session.client("sts")
        response = sts.assume_role(
            RoleArn=str(assumed_role),
            RoleSessionName="langchain-llm-1"
        )
        print(" ... successful!")
        client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
        client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"]
        client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
        

    if runtime:
        service_name='bedrock-runtime'
    else:
        service_name='bedrock'

    client_kwargs["aws_access_key_id"] = os.environ.get("AWS_ACCESS_KEY_ID","")
    client_kwargs["aws_secret_access_key"] = os.environ.get("AWS_SECRET_ACCESS_KEY","")
    
    bedrock_client = session.client(
        service_name=service_name,
        config=retry_config,
        **client_kwargs
    )

    print("boto3 Bedrock client successfully created!")
    print(bedrock_client._endpoint)
    return bedrock_client



## for aksk bedrock
def get_bedrock_aksk(secret_name='chatbot_bedrock', region_name = "us-west-2"):
    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        # For a list of exceptions thrown, see
        # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
        raise e

    # Decrypts secret using the associated KMS key.
    secret = json.loads(get_secret_value_response['SecretString'])
    return secret['BEDROCK_ACCESS_KEY'],secret['BEDROCK_SECRET_KEY']

ACCESS_KEY, SECRET_KEY=get_bedrock_aksk()

#role based initial client#######
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"  # E.g. "us-east-1"
os.environ["AWS_PROFILE"] = "default"
#os.environ["BEDROCK_ASSUME_ROLE"] = "arn:aws:iam::687912291502:role/service-role/AmazonSageMaker-ExecutionRole-20211013T113123"  # E.g. "arn:aws:..."
os.environ["AWS_ACCESS_KEY_ID"]=ACCESS_KEY
os.environ["AWS_SECRET_ACCESS_KEY"]=SECRET_KEY


#新boto3 sdk只能session方式初始化bedrock
boto3_bedrock = get_bedrock_client(
    #assumed_role=os.environ.get("BEDROCK_ASSUME_ROLE", None),
    region=os.environ.get("AWS_DEFAULT_REGION", None)
)

parameters_bedrock = {
    "max_tokens_to_sample": 2048,
    #"temperature": 0.5,
    "temperature": 0,
    #"top_k": 250,
    #"top_p": 1,
    "stop_sequences": ["\n\nHuman"],
}


bedrock_llm = Bedrock(model_id="anthropic.claude-v2", 
                      client=boto3_bedrock, 
                      model_kwargs=parameters_bedrock)

In [16]:
bedrock_llm.predict("windows远程链接关闭后，如何不进入锁屏状态")

' 对于Windows远程桌面连接,当连接断开后系统默认会进入锁屏状态,这是为了保护远程桌面不被其他人访问。但是如果需要避免锁屏,可以通过以下几种方法:\n\n1. 在远程桌面连接窗口底部,点击"选项" - "远程" - 勾选"断开连接时不锁定计算机"。这样断开连接后就不会锁屏了。\n\n2. 在注册表中设置相关键值也可以避免锁屏:\n\n(1)打开注册表编辑器,定位到HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Winlogon 键。\n\n(2)右击新建一个DWORD值,命名为"DisableLockWorkstation",并设置值为1。\n\n(3)重新启动计算机后设置才会生效。\n\n3. 使用第三方工具,例如DisableLockWorkstation,可以设置热键来禁用或启用锁屏功能。\n\n4. 通过组策略也可以设置远程桌面断开后不锁定屏幕。\n\n需要注意的是,关闭锁屏功能会降低安全性,如果不是非常必要,不建议这样操作。正确的做法是在断开连接前保存工作并妥善退出。'

* Langchain SqlDatabase chain支持多种数据源
* 不同数据源的连接url及传参不一样
* 此处使用rds mysql 

In [None]:
## langchain agent demo test##########
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chains import create_sql_query_chain
import json
import os

os.environ["PGPASSWORD"] = "*******"
os.environ["LANGCHAIN_HANDLER"] = "langchain"
conn_str = "awsathena+rest://{aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/"\
           "{schema_name}?s3_staging_dir={s3_staging_dir}"
conn_str = conn_str.format(
    aws_access_key_id="**********",
    aws_secret_access_key="***********",
    region_name="cn-northwest-1",
    schema_name="specturmdb",
    s3_staging_dir="s3://tangqy-athenaoutput-us-west-2")

db = SQLDatabase.from_uri(
    #conn_str,
    #"redshift+redshift_connector://admin@redshift-cluster-1.cp1kgq7oikv3.ap-southeast-1.redshift.amazonaws.com:5439/dev",
    "mysql+pymysql://admin:admin12345678@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    #"jdbc:awsathena://athena.us-west-2.amazonaws.com:443/tpcds_bin_partitioned_orc_300?s3_staging_dir=s3://tangqy-athenaoutput/&aws_credentials_provider_class=com.amazonaws.auth.DefaultAWSCredentialsProviderChain",
    include_tables=['dws_truck_portrait_index_sum_da','dws_ots_waybill_info_da','ads_bi_quality_monitor_shipping_detail','dim_pub_truck_info'], # we include only one table to save tokens in the prompt :)
    #include_tables=["customer"],
    sample_rows_in_table_info=0)
#chain = create_sql_query_chain(sm_llm, db)
#response = chain.invoke({"question":"最近一个月温度合格的派车单数量"})
#response = chain.invoke({"question":"最近一个月下单最大的客户邮件地址"})

#print(response)
#db_chain = CustomerizedSQLDatabaseChain(llm=llm, database=db, verbose=True, top_k=3)
#db_chain.run("最近一个月温度合格的派车单数量")
#db_chain.run("I need to know the max sales customer's id in sales report")
### for bedrock#######
db_chain = CustomerizedSQLDatabaseChain.from_llm(llm=bedrock_llm, db=db, verbose=False, return_sql=True)
#db_chain.run("2023年7月派车单数量超过26次的4.2米车辆一共有多少辆?")
#db_chain.run("历史累计派车单数量、干线派车单数量、城配派车单数量")
db_chain.run("2022年的运输总量是多少吨?")
#db_chain.run("成都市的车辆资源累计有多少？")
#db_chain.run("I need to know the max sales customer's id in sales report")

### Opt2:元数据召回+langchain SQLDatabasechain

#### func for agent

In [None]:
import boto3
import json
import requests
import time
from collections import defaultdict
from requests_aws4auth import AWS4Auth
import os
from opensearchpy import OpenSearch, RequestsHttpConnection
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferWindowMemory
from langchain import LLMChain



def aos_knn_search(client, field,q_embedding, index, size=1):
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = {
        "size": size,
        "query": {
            "knn": {
                field: {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }
    opensearch_knn_respose = []
    query_response = client.search(
        body=query,
        index=index
    )
    opensearch_knn_respose = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'query_desc_text':item['_source']['query_desc_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose

def aos_knn_search_v2(client, field,q_embedding, index, size=1):
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = auth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = {
        "size": size,
        "query": {
            "knn": {
                field: {
                    "vector": q_embedding,
                    "k": size
                }
            }
        }
    }
    opensearch_knn_respose = []
    query_response = client.search(
        body=query,
        index=index
    )
    opensearch_knn_respose = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'exactly_query_text':item['_source']['exactly_query_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return opensearch_knn_respose


def aos_reverse_search(client, index_name, field, query_term, exactly_match=False, size=1):
    """
    search opensearch with query.
    :param host: AOS endpoint
    :param index_name: Target Index Name
    :param field: search field
    :param query_term: query term
    :return: aos response json
    """
    if not isinstance(client, OpenSearch):   
        client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
    query = None
    if exactly_match:
        query =  {
            "query" : {
                "match_phrase":{
                    field: {
                        "query": query_term,
                        "analyzer": "ik_smart"
                      }
                }
            }
        }
    else:
        query = {
            "size": size,
            "query": {
                "query_string": {
                "default_field": "exactly_query_text",  
                "query": query_term         
              }
            },
           "sort": [{
               "_score": {
                   "order": "desc"
               }
           }]
    }        
    query_response = client.search(
        body=query,
        index=index_name
    )
    result_arr = [{'idx':item['_source'].get('idx',1),'database_name':item['_source']['database_name'],'table_name':item['_source']['table_name'],'exactly_query_text':item['_source']['exactly_query_text'],"score":item["_score"]}  for item in query_response["hits"]["hits"]]
    return result_arr




def get_vector_by_sm_endpoint(questions, sm_client, endpoint_name):
    parameters = {
    }

    response_model = sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(
            {
                "inputs": questions,
                "parameters": parameters,
                "is_query" : True,
                "instruction" :  "为这个句子生成表示以用于检索相关文章："
            }
        ),
        ContentType="application/json",
    )
    json_str = response_model['Body'].read().decode('utf8')
    json_obj = json.loads(json_str)
    embeddings = json_obj['sentence_embeddings']
    return embeddings


def k_nn_ingestion_by_aos(docs,index,hostname,username,passwd):
    auth = (username, passwd)
    search = OpenSearch(
        hosts = [{'host': aos_endpoint, 'port': 443}],
        ##http_auth = awsauth ,
        http_auth = auth ,
        use_ssl = True,
        verify_certs = True,
        connection_class = RequestsHttpConnection
    )
    for doc in docs:
        query_desc_embedding = doc['query_desc_embedding']
        database_name = doc['database_name']
        table_name = doc['table_name']
        query_desc_text = doc["query_desc_text"]
        document = { "query_desc_embedding": query_desc_embedding, 'database_name':database_name, "table_name": table_name,"query_desc_text":query_desc_text}
        search.index(index=index, body=document)
        
def k_nn_ingestion_by_aos_v2(docs,index,hostname,username,passwd):
    auth = (username, passwd)
    search = OpenSearch(
        hosts = [{'host': aos_endpoint, 'port': 443}],
        ##http_auth = awsauth ,
        http_auth = auth ,
        use_ssl = True,
        verify_certs = True,
        connection_class = RequestsHttpConnection
    )
    for doc in docs:
        exactly_query_embedding = doc['exactly_query_embedding']
        database_name = doc['database_name']
        table_name = doc['table_name']
        exactly_query_text = doc["exactly_query_text"]
        document = { "exactly_query_embedding": exactly_query_embedding, 'database_name':database_name, "table_name": table_name,"exactly_query_text":exactly_query_text}
        search.index(index=index, body=document)

#### 元数据ingestion 入aos

In [None]:
## data process
all_querys = """2023年7月派车单数量超过26次的4.2米车辆一共有多少辆
请统计历史累计派车单数量、干线派车单数量、城配派车单数量。历史累计的意思是不限定时间范围
奶茶品牌的站点数量和运输货品数量统计
2022年的运输总量是多少吨？请注意：traff_weight的单位是千克，请把单位转换为吨
车牌为'黑RG6696'的车辆的GPS最近定位上传时间、GPS最近定位省份、GPS最近定位城市、GPS最近定位区县、APP最近定位上传时间、APP最近定位省份、APP最近定位城市、APP最近定位区县。给出sql中字段名不要带上库名
车牌归属城市为'成都'的车辆累计有多少？
取货地城市名称为'北京市'的历史累计不重复的车牌有多少？请注意车牌号有可能有重复
货主-行业列表相似于'西餐连锁'的一共有多少个品牌？多少个客户？
品牌名称为'星巴克'的不重复的站点一共有多少个？
查看租户简称为云南,车辆的车厢长为9.6米和15米的外廓车长、核定载重"""
querys = all_querys.split("\n")

all_tables = """ads_bi_quality_monitor_shipping_detail
dws_ots_waybill_info_da
dws_station_portrait_index_sum_da
dws_ots_waybill_info_da
dws_truck_portrait_index_sum_da
dws_truck_portrait_index_sum_da
dws_ots_waybill_info_da
ads_customer_portrait_index_sum_da
dim_customer_enterprise_station_base_info
dim_pub_truck_tenant,dim_pub_truck_info"""
tables=all_tables.split("\n")

all_dbs = """llm
llm
llm
llm
llm
llm
llm
llm
llm
llm"""
dbs=all_dbs.split("\n")

In [None]:
index_name="prompt-optimal-index"
embedding_endpoint_name="bge-zh-15-2023-09-25-07-02-01-080-endpoint"
##########embedding by llm model##############
sentense_vectors = []
sentense_vectors=get_vector_by_sm_endpoint(querys,sm_client,embedding_endpoint_name)

In [None]:
docs=[]
for index, sentence_vector in enumerate(sentense_vectors):
    doc = {
        "metadata_type":"table",
        "database_name":dbs[index],
        "table_name": tables[index],
        "exactly_query_text":querys[index],
        "exactly_query_embedding": sentence_vector
          }
    docs.append(doc)

#print((doc["database_name"]))
#########ingestion into aos ###################
k_nn_ingestion_by_aos_v2(docs,index_name,aos_endpoint,username,passwd)

In [None]:
client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = pwdauth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )

query="上个月温度合格的派车单数量"
query_embedding = get_vector_by_sm_endpoint(query, sm_client, embedding_endpoint_name)
rets=aos_knn_search_v2(client, "exactly_query_embedding",query_embedding[0],index_name,1)   
print(rets)

In [None]:


query="上个月温度合格的派车单数量"
query_embedding = get_vector_by_sm_endpoint(query, sm_client, embedding_endpoint_name)
rets=aos_knn_search_v2(client, "exactly_query_embedding",query_embedding[0],index_name,1)   
print(rets)

#### e2e pipeline
* 先使用reverse search召回表元数据
* 如果reverse search召回为空，使用embedding向量召回
* 将召回的表table name传入sqlDatabaseChain

In [None]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chains import create_sql_query_chain
import json
import os


table_name = None
query="2023年7月派车单数量超过26次的4.2米车辆一共有多少辆?"
aos_client = OpenSearch(
            hosts=[{'host': aos_endpoint, 'port': 443}],
            http_auth = auth,
            use_ssl=True,
            verify_certs=True,
            connection_class=RequestsHttpConnection
        )
#### reverse 倒排召回 ############
opensearch_query_response = aos_reverse_search(aos_client, index_name, "exactly_query_text", query)
try:
    table_name=opensearch_query_response[0]["table_name"].strip()
except Exception as e:
    print(e)
    table_name=None

#### reverse 向量召回 ############
if table_name is None:
    query_embedding = get_vector_by_sm_endpoint(query, sm_client, embedding_endpoint_name)
    opensearch_query_response = aos_knn_search_v2(aos_client, "exactly_query_embedding",query_embedding[0], index_name, size=10)
    try:
        table_name = opensearch_query_response[0]["table_name"].strip()
    except Exception as e:
        print(e)
        table_name = None

#####使用召回table name执行SqlDataBaseChain#######
db = SQLDatabase.from_uri(
    "mysql+pymysql://admin:admin12345678@database-us-west-2-demo.cluster-c1qvx9wzmmcz.us-west-2.rds.amazonaws.com/llm",
    include_tables=[table_name], # we include only one table to save tokens in the prompt :)
    sample_rows_in_table_info=0)

db_chain = CustomerizedSQLDatabaseChain.from_llm(llm=bedrock_llm, db=db, verbose=False, return_sql=True)
db_chain.run(query)
#db_chain.run("历史累计派车单数量、干线派车单数量、城配派车单数量")
