# memstack-agent 自定义工具实现

本 notebook 演示如何实现复杂的自定义工具。

## 概述

memstack-agent 支持多种工具实现方式：
- 函数装饰器（简单场景）
- SimpleTool 类继承（中等复杂度）
- 实现 Tool Protocol（完全控制）

## 1. 导入必要模块

In [None]:
import sys
sys.path.insert(0, '/Users/tiejunsun/github/agi-demos/src')

import asyncio
import json
from typing import Optional, List, Dict, Any
from dataclasses import dataclass

from memstack_agent import (
    function_to_tool,
    ToolDefinition,
    Tool,
    ToolMetadata,
)
from memstack_agent.tools.protocol import SimpleTool

## 2. 使用 dataclass 定义参数

可以使用 dataclass 来定义复杂的工具参数，框架会自动推断 schema。

In [None]:
@dataclass
class SearchOptions:
    """搜索选项配置。"""
    max_results: int = 10
    include_images: bool = False
    language: str = "zh"


@function_to_tool
async def advanced_search(query: str, options: SearchOptions) -> str:
    """高级搜索功能。
    
    Args:
        query: 搜索查询词
        options: 搜索选项配置
    """
    return json.dumps({
        "query": query,
        "max_results": options.max_results,
        "include_images": options.include_images,
        "language": options.language
    })


print("Advanced Search Tool:")
print(json.dumps(advanced_search.parameters, indent=2, ensure_ascii=False))

## 3. 实现 Tool Protocol

对于需要完全控制的场景，可以直接实现 Tool Protocol。

In [None]:
class DatabaseQueryTool:
    """数据库查询工具 - 实现完整的 Tool Protocol。"""
    
    def __init__(self, db_connection_string: str):
        self._connection_string = db_connection_string
        self._query_count = 0
    
    @property
    def name(self) -> str:
        return "database_query"
    
    @property
    def description(self) -> str:
        return "执行 SQL 查询并返回结果"
    
    @property
    def permission(self) -> Optional[str]:
        # 需要数据库访问权限
        return "database:read"
    
    @property
    def metadata(self) -> ToolMetadata:
        return ToolMetadata(
            tags=["database", "sql", "query"],
            timeout_seconds=30,
            ui_category="Database Tools"
        )
    
    async def execute(self, query: str, limit: int = 100) -> str:
        """执行 SQL 查询。
        
        Args:
            query: SQL 查询语句
            limit: 返回结果数量限制
        """
        # 模拟数据库查询
        self._query_count += 1
        
        # 安全检查
        if "DROP" in query.upper() or "DELETE" in query.upper():
            return "Error: DDL/DML operations not allowed"
        
        return json.dumps({
            "query": query,
            "results": [{"id": 1, "name": "example"}],
            "row_count": 1,
            "execution_time_ms": 15
        })
    
    def get_parameters_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "SQL 查询语句 (SELECT only)"
                },
                "limit": {
                    "type": "integer",
                    "description": "返回结果数量限制",
                    "default": 100,
                    "minimum": 1,
                    "maximum": 1000
                }
            },
            "required": ["query"]
        }


# 创建工具实例并包装
db_tool = DatabaseQueryTool("postgresql://localhost/mydb")
db_tool_def = function_to_tool(
    db_tool.execute,
    name=db_tool.name,
    description=db_tool.description,
    metadata=db_tool.metadata,
)

print(f"Tool: {db_tool_def.name}")
print(f"Permission: {db_tool_def.permission}")
print(f"Parameters: {json.dumps(db_tool_def.parameters, indent=2, ensure_ascii=False)}")

In [None]:
# 测试执行
result = await db_tool_def.execute(query="SELECT * FROM users", limit=10)
print(f"Result: {result}")

## 4. 带状态的工具

某些工具需要维护内部状态。

In [None]:
class SessionManagerTool(SimpleTool):
    """会话管理工具 - 带状态的工具示例。"""
    
    name = "session_manager"
    description = "管理用户会话的创建、查询和销毁"
    
    def __init__(self):
        self._sessions: Dict[str, Dict[str, Any]] = {}
    
    @property
    def metadata(self) -> ToolMetadata:
        return ToolMetadata(
            tags=["session", "auth"],
            ui_category="Auth Tools"
        )
    
    async def execute(self, action: str, session_id: Optional[str] = None, data: Optional[Dict] = None) -> str:
        """管理会话。
        
        Args:
            action: 操作类型 (create/get/delete/list)
            session_id: 会话 ID
            data: 会话数据
        """
        import uuid
        import time
        
        if action == "create":
            new_id = str(uuid.uuid4())[:8]
            self._sessions[new_id] = {
                "created_at": time.time(),
                "data": data or {}
            }
            return json.dumps({"session_id": new_id, "status": "created"})
        
        elif action == "get":
            if session_id not in self._sessions:
                return json.dumps({"error": "Session not found"})
            return json.dumps(self._sessions[session_id])
        
        elif action == "delete":
            if session_id in self._sessions:
                del self._sessions[session_id]
                return json.dumps({"status": "deleted"})
            return json.dumps({"error": "Session not found"})
        
        elif action == "list":
            return json.dumps({"sessions": list(self._sessions.keys())})
        
        return json.dumps({"error": f"Unknown action: {action}"})
    
    def get_parameters_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "action": {
                    "type": "string",
                    "enum": ["create", "get", "delete", "list"],
                    "description": "操作类型"
                },
                "session_id": {
                    "type": "string",
                    "description": "会话 ID"
                },
                "data": {
                    "type": "object",
                    "description": "会话数据"
                }
            },
            "required": ["action"]
        }


# 创建并测试
session_tool = SessionManagerTool()
session_def = function_to_tool(
    session_tool.execute,
    name=session_tool.name,
    description=session_tool.description,
)

# 创建会话
result1 = await session_def.execute(action="create", data={"user": "alice"})
print(f"Create: {result1}")

# 列出会话
result2 = await session_def.execute(action="list")
print(f"List: {result2}")

## 5. 异步批量操作工具

In [None]:
class BatchProcessorTool(SimpleTool):
    """批量处理工具 - 演示异步并发。"""
    
    name = "batch_processor"
    description = "批量处理多个任务"
    
    async def execute(self, tasks: List[Dict[str, Any]], parallel: bool = True) -> str:
        """批量处理任务。
        
        Args:
            tasks: 任务列表
            parallel: 是否并行执行
        """
        async def process_single(task: Dict[str, Any]) -> Dict[str, Any]:
            """处理单个任务。"""
            await asyncio.sleep(0.1)  # 模拟 I/O
            return {
                "task_id": task.get("id"),
                "status": "completed",
                "result": f"Processed: {task.get('name', 'unknown')}"
            }
        
        import time
        start = time.time()
        
        if parallel:
            results = await asyncio.gather(*[process_single(t) for t in tasks])
        else:
            results = []
            for t in tasks:
                results.append(await process_single(t))
        
        elapsed = time.time() - start
        
        return json.dumps({
            "total_tasks": len(tasks),
            "parallel": parallel,
            "elapsed_seconds": round(elapsed, 3),
            "results": results
        })
    
    def get_parameters_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "tasks": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "id": {"type": "string"},
                            "name": {"type": "string"}
                        }
                    },
                    "description": "任务列表"
                },
                "parallel": {
                    "type": "boolean",
                    "description": "是否并行执行",
                    "default": True
                }
            },
            "required": ["tasks"]
        }


# 测试批量处理
batch_tool = BatchProcessorTool()
batch_def = function_to_tool(
    batch_tool.execute,
    name=batch_tool.name,
    description=batch_tool.description,
)

# 并行执行
tasks = [{"id": f"task-{i}", "name": f"Task {i}"} for i in range(5)]

print("Parallel execution:")
result = await batch_def.execute(tasks=tasks, parallel=True)
print(json.dumps(json.loads(result), indent=2, ensure_ascii=False))

## 6. 带错误处理和重试的工具

In [None]:
class ResilientApiTool(SimpleTool):
    """弹性 API 调用工具 - 带重试机制。"""
    
    name = "resilient_api"
    description = "调用外部 API，支持自动重试"
    
    def __init__(self, max_retries: int = 3, base_delay: float = 1.0):
        self._max_retries = max_retries
        self._base_delay = base_delay
    
    @property
    def metadata(self) -> ToolMetadata:
        return ToolMetadata(
            tags=["api", "http"],
            timeout_seconds=60
        )
    
    async def execute(self, url: str, method: str = "GET", retries: Optional[int] = None) -> str:
        """调用 API。
        
        Args:
            url: API URL
            method: HTTP 方法
            retries: 重试次数（覆盖默认值）
        """
        import random
        
        max_attempts = retries if retries is not None else self._max_retries
        
        for attempt in range(max_attempts + 1):
            try:
                # 模拟 API 调用（随机失败）
                if random.random() < 0.5 and attempt < max_attempts:
                    raise ConnectionError("Network error")
                
                return json.dumps({
                    "url": url,
                    "method": method,
                    "status": 200,
                    "data": {"result": "success"},
                    "attempts": attempt + 1
                })
            
            except ConnectionError as e:
                if attempt < max_attempts:
                    delay = self._base_delay * (2 ** attempt)  # 指数退避
                    print(f"Attempt {attempt + 1} failed, retrying in {delay}s...")
                    await asyncio.sleep(delay)
                else:
                    return json.dumps({
                        "error": str(e),
                        "attempts": attempt + 1,
                        "status": "failed"
                    })
        
        return json.dumps({"error": "Unexpected state"})
    
    def get_parameters_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "url": {
                    "type": "string",
                    "format": "uri",
                    "description": "API URL"
                },
                "method": {
                    "type": "string",
                    "enum": ["GET", "POST", "PUT", "DELETE"],
                    "default": "GET"
                },
                "retries": {
                    "type": "integer",
                    "minimum": 0,
                    "maximum": 5
                }
            },
            "required": ["url"]
        }


# 测试
api_tool = ResilientApiTool(max_retries=3)
api_def = function_to_tool(
    api_tool.execute,
    name=api_tool.name,
    description=api_tool.description,
)

print("Testing resilient API tool:")
result = await api_def.execute(url="https://api.example.com/data")
print(json.dumps(json.loads(result), indent=2, ensure_ascii=False))

## 7. 工具注册表模式

In [None]:
class ToolRegistry:
    """工具注册表 - 管理所有可用工具。"""
    
    def __init__(self):
        self._tools: Dict[str, ToolDefinition] = {}
    
    def register(self, tool: ToolDefinition) -> None:
        """注册工具。"""
        self._tools[tool.name] = tool
    
    def get(self, name: str) -> Optional[ToolDefinition]:
        """获取工具。"""
        return self._tools.get(name)
    
    def list_tools(self) -> List[str]:
        """列出所有工具名称。"""
        return list(self._tools.keys())
    
    def get_openai_tools(self) -> List[Dict[str, Any]]:
        """获取所有工具的 OpenAI 格式。"""
        return [t.to_openai_format() for t in self._tools.values()]
    
    def get_by_tag(self, tag: str) -> List[ToolDefinition]:
        """按标签获取工具。"""
        return [
            t for t in self._tools.values()
            if tag in t.metadata.tags
        ]


# 创建注册表并注册工具
registry = ToolRegistry()

# 注册之前创建的工具
registry.register(db_tool_def)
registry.register(session_def)
registry.register(batch_def)
registry.register(api_def)

print("Registered tools:")
for name in registry.list_tools():
    tool = registry.get(name)
    print(f"  - {name}: {tool.description[:40]}...")

In [None]:
# 按标签筛选
print("Tools with 'database' tag:")
for tool in registry.get_by_tag("database"):
    print(f"  - {tool.name}")

print("\nTools with 'api' tag:")
for tool in registry.get_by_tag("api"):
    print(f"  - {tool.name}")

In [None]:
# 导出为 OpenAI 格式
openai_tools = registry.get_openai_tools()
print(f"OpenAI tools format ({len(openai_tools)} tools):")
print(json.dumps(openai_tools[:2], indent=2, ensure_ascii=False))

## 总结

本 notebook 演示了多种自定义工具实现方式：

1. **dataclass 参数** - 使用 dataclass 定义复杂参数
2. **Tool Protocol** - 完全控制工具行为
3. **带状态工具** - 维护内部状态
4. **批量操作** - 异步并发处理
5. **弹性工具** - 重试机制和错误处理
6. **工具注册表** - 集中管理工具

下一步：查看 `04-agent-context.ipynb` 了解 Agent 上下文和配置。