In [3]:
import os
from dotenv import load_dotenv
import requests

# 加载环境变量
load_dotenv()
# 从环境变量中读取api_key
api_key = os.getenv('ZISHU_API_KEY')
base_url = "http://43.200.7.56:8008/v1"
chat_model = "glm-4-flash"
emb_model = "embedding-3"

model_name = "Qwen2.5-32B-Instruct-AWQ"
openai_api_base = "http://192.168.12.10:8000/v1"  # 本地服务地址


# chat_model = model_name
# base_url=openai_api_base

In [4]:
from openai import OpenAI
from pydantic import Field  # 导入Field，用于Pydantic模型中定义字段的元数据
from llama_index.core.llms import (
    CustomLLM,
    CompletionResponse,
    LLMMetadata,
)
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.llms.callbacks import llm_completion_callback
from typing import List, Any, Generator

from llama_index.embeddings.openai import OpenAIEmbedding

# 定义OurLLM类，继承自CustomLLM基类
class OurLLM(CustomLLM):
    api_key: str = Field(default=api_key)
    base_url: str = Field(default=base_url)
    model_name: str = Field(default=chat_model)
    client: OpenAI = Field(default=None, exclude=True)  # 显式声明 client 字段

    def __init__(self, api_key: str, base_url: str, model_name: str = chat_model, **data: Any):
        super().__init__(**data)
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)  # 使用传入的api_key和base_url初始化 client 实例

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            model_name=self.model_name,
        )

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        response = self.client.chat.completions.create(model=self.model_name, messages=[{"role": "user", "content": prompt}])
        if hasattr(response, 'choices') and len(response.choices) > 0:
            response_text = response.choices[0].message.content
            return CompletionResponse(text=response_text)
        else:
            raise Exception(f"Unexpected response format: {response}")

    @llm_completion_callback()
    def stream_complete(
        self, prompt: str, **kwargs: Any
    ) -> Generator[CompletionResponse, None, None]:
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": prompt}],
            stream=True
        )

        try:
            for chunk in response:
                chunk_message = chunk.choices[0].delta
                if not chunk_message.content:
                    continue
                content = chunk_message.content
                yield CompletionResponse(text=content, delta=content)

        except Exception as e:
            raise Exception(f"Unexpected response format: {e}")

llm = OurLLM(api_key=api_key, base_url=base_url, model_name=chat_model)
llmlocal = OurLLM(api_key=api_key, base_url=openai_api_base, model_name=model_name)



In [None]:
import sqlite3
# 创建数据库
sqllite_path = 'llmdb.db'
con = sqlite3.connect(sqllite_path)

# 创建表
sql = """
CREATE TABLE `section_stats` (
  `部门` varchar(100) DEFAULT NULL,
  `人数` int(11) DEFAULT NULL
);
"""
c = con.cursor()
cursor = c.execute(sql)
c.close()
con.close()

In [None]:
con = sqlite3.connect(sqllite_path)
c = con.cursor()
data = [
    ["专利部",22],
    ["商标部",25],
]
for item in data:
    sql = """
    INSERT INTO section_stats (部门,人数) 
    values('%s','%d')
    """%(item[0],item[1])
    c.execute(sql)
    con.commit()
c.close()
con.close()

In [None]:
response = llm.stream_complete("你是谁？")
for chunk in response:
    print(chunk, end="", flush=True)

In [None]:
response = llmlocal.stream_complete("你是谁？")
for chunk in response:
    print(chunk, end="", flush=True)

In [10]:
from llama_index.core.agent import ReActAgent  
from llama_index.core.tools import FunctionTool  
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Settings  
from llama_index.core.tools import QueryEngineTool   
from llama_index.core import SQLDatabase  
from llama_index.core.query_engine import NLSQLTableQueryEngine  
from sqlalchemy import create_engine, select  


# 配置本地大模型  
Settings.llm = llm

In [None]:
from openai import OpenAI
from typing import Any, List
from llama_index.core.embeddings import BaseEmbedding
from pydantic import Field

# 配置 Xinference 的 API 信息
api_key = "your_api_key"  # 如果没有 API 密钥，可以留空
base_url = "http://192.168.12.10:9997/v1"  # Xinference 服务的地址
model_uid = "bge-m3"  # 替换为实际的 model_uid

class XinferenceEmbeddings(BaseEmbedding):
    api_key: str = Field(default=api_key)
    base_url: str = Field(default=base_url)
    model_uid: str = Field(default=model_uid)
    client: OpenAI = Field(default=None, exclude=True)  # 显式声明 client 字段

    def __init__(
        self,
        api_key: str = api_key, 
        base_url: str = base_url,
        model_uid: str = model_uid,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.api_key = api_key
        self.base_url = base_url
        self.model_uid = model_uid
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)

    def invoke_embedding(self, query: str) -> List[float]:
        try:
            response = self.client.embeddings.create(model=self.model_uid, input=[query])
            if response.data and len(response.data) > 0:
                return response.data[0].embedding
            else:
                raise ValueError("Failed to get embedding from Xinference API")
        except Exception as e:
            print(f"Error: {e}")
            return []

    def _get_query_embedding(self, query: str) -> List[float]:
        return self.invoke_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        return self.invoke_embedding(text)

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        return [self._get_text_embedding(text) for text in texts]

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        return self._get_text_embeddings(texts)

# 测试代码
embedding = XinferenceEmbeddings(api_key=api_key, base_url=base_url, model_uid=model_uid)
emb = embedding.get_text_embedding("你好呀呀")
print(len(emb), type(emb))

Settings.embed_model = embedding

In [None]:
## 创建数据库查询引擎  
# engine = create_engine("sqlite:///llmdb.db")  
# # prepare data  
# sql_database = SQLDatabase(engine, include_tables=["section_stats"])  
# query_engine = NLSQLTableQueryEngine(  
#     sql_database=sql_database,   
#     tables=["section_stats"],   
#     llm=llmlocal ,
#     embed_model="local"   
# )

# 创建 SQL Server 连接
connection_string = (
    "mssql+pymssql://sa:1@127.0.0.1:1433/甘肃省_CZBZB_2023_202403"
)
engine = create_engine(connection_string)

# prepare data  
sql_database = SQLDatabase(engine, include_tables=["SJ17国库集中支付凭证表"])  
query_engine = NLSQLTableQueryEngine(  
    sql_database=sql_database,   
    tables=["SJ17国库集中支付凭证表"],   
    llm=llmlocal  ,
    #embed_model="local"
    embed_model = Settings.embed_model
)
# 创建工具函数  
def multiply(a: float, b: float) -> float:  
    """将两个数字相乘并返回乘积。"""  
    return a * b  

multiply_tool = FunctionTool.from_defaults(fn=multiply)  

def add(a: float, b: float) -> float:  
    """将两个数字相加并返回它们的和。"""  
    return a + b

add_tool = FunctionTool.from_defaults(fn=add)

def sum_payments(payments: List[float]) -> float:
    """汇总支付金额。"""
    return sum(payments)
sum_tool = FunctionTool.from_defaults(fn=sum_payments)

# 把数据库查询引擎封装到工具函数对象中  
staff_tool = QueryEngineTool.from_defaults(
    query_engine,
    name="SJ17国库集中支付凭证表",
    description="查询预算单位的支付明细"      
)

# 构建ReActAgent
agent = ReActAgent.from_tools([add_tool,staff_tool,sum_tool], verbose=True)  
# 通过agent给出指令
# 通过 Agent 执行查询
try:
    response = agent.chat("请从数据库表中获取预算单位`省人社厅`和`省财政厅`的支付信息，并将查询出来的支付金额（XPAY_AMT）分别汇总！")
    print(response)

    # 输出生成的 SQL 语句
    if hasattr(response, "extra_info") and "sql_query" in response.extra_info:
        sql_query = response.extra_info["sql_query"]
        print("\nGenerated SQL Query:")
        print(sql_query)
    else:
        print("\nNo SQL query found in response.")
except Exception as e:
    print(f"Query failed: {e}")



> Running step b01e71c1-3761-4d8f-bb85-3fec40a5338b. Step input: 请从数据库表中获取预算单位`省人力资源和社会保障厅`和`省财政厅`的支付信息，并将查询出来的支付金额（XPAY_AMT）分别汇总！
[1;3;38;5;200mThought: The user wants to get the payment information of two budget units from a database table and summarize the payment amounts. I need to use the SJ17国库集中支付凭证表 tool to query the payment information.
Action: SJ17国库集中支付凭证表
Action Input: {'input': '省人力资源和社会保障厅'}
[0m

In [13]:
## 创建数据库查询引擎  
engine = create_engine("sqlite:///llmdb.db")  

# prepare data  
sql_database = SQLDatabase(engine, include_tables=["section_stats"])  
query_engine = NLSQLTableQueryEngine(  
    sql_database=sql_database,   
    tables=["section_stats"],   
    llm=llmlocal ,
    #embed_model="local"
    embed_model = Settings.embed_model
)

# 创建工具函数  
def multiply(a: float, b: float) -> float:  
    """将两个数字相乘并返回乘积。"""  
    return a * b  

multiply_tool = FunctionTool.from_defaults(fn=multiply)  

def add(a: float, b: float) -> float:  
    """将两个数字相加并返回它们的和。"""  
    return a + b

add_tool = FunctionTool.from_defaults(fn=add)

# 把数据库查询引擎封装到工具函数对象中  
staff_tool = QueryEngineTool.from_defaults(
    query_engine,
    name="section_staff",
    description="查询部门的人数。"  
)

# 构建ReActAgent
agent = ReActAgent.from_tools([multiply_tool, add_tool, staff_tool], verbose=True)  
# 通过agent给出指令
response = agent.chat("请从数据库表中获取`专利部`和`商标部`的人数，并将这两个部门的人数相加！")  


> Running step a1d90491-85c5-423e-b465-55d3dd2db03a. Step input: 请从数据库表中获取`专利部`和`商标部`的人数，并将这两个部门的人数相加！
[1;3;38;5;200mThought: The current language of the user is: Chinese. I need to use a tool to help me answer the question.
Action: section_staff
Action Input: {'input': '专利部'}
[0m[1;3;34mObservation: 根据查询结果，专利部共有22人。
[0m> Running step 7c4cafe8-70dc-4570-a28e-c7bb9c3c6c06. Step input: None
[1;3;38;5;200mThought: I need to use the section_staff tool again to get the number of staff in the Trademark Department.
Action: section_staff
Action Input: {'input': '商标部'}
[0m[1;3;34mObservation: 商标部的当前人数是25人。
[0m> Running step 4ff981e3-ad67-4e86-99ab-3a79f7831c6d. Step input: None
[1;3;38;5;200mThought: Now that I have the number of staff in the Patent Department (22) and the Trademark Department (25), I can use the add tool to calculate the total number of staff in both departments.
Action: add
Action Input: {'a': 22.0, 'b': 25.0}
[0m[1;3;34mObservation: 47.0
[0m> Running step 6f0bede