In [5]:
import json
import traceback
import asyncio
from enum import Enum
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Callable, Tuple
from client import APIClient


class AgentTaskException(Exception):
    pass


class AgentState(Enum):
    IDLE = "idle"
    INTERACTIVE = "interactive"
    IN_TASK = "in_task"
    ERROR = "error"


@dataclass
class AgentConfiguration:
    max_recursion_depth: int = 50
    enable_streaming: bool = True
    fallback_to_non_streaming: bool = True
    auto_compression: bool = True


class SimpleMessage:
    def __init__(self, content):
        self.content = content
        self.role = "assistant"
        self.tool_calls = None


class ErrorMessage:
    def __init__(self, error_msg):
        self.content = f"Sorry, I encountered a technical problem: {error_msg}"
        self.role = "assistant"
        self.tool_calls = None


class Agent:
    def __init__(self, 
                 tools: Optional[Dict[str, Callable]] = None,
                 system_prompt: str = "You are a helpful assistant.",
                 config: Optional[AgentConfiguration] = None):
        """
        Initialize Agent with direct parameters
        
        Args:
            tools: Dict of tool_name -> callable function
            system_prompt: System prompt string
            config: Agent configuration
        """
        self._config = config or AgentConfiguration()
        self._api_client = APIClient()
        self._system_prompt = system_prompt
        self._tools = tools or {}
        self._tool_descriptions = self._build_tool_descriptions()
        
        # State management
        self._state = AgentState.IDLE
        self._state_lock = asyncio.Lock()
        self._recursion_depth = 0
        self._messages = []
        self._token_usage = {"input_tokens": 0, "output_tokens": 0}

    def _build_tool_descriptions(self) -> List[Dict]:
        """Build tool descriptions from provided tools"""
        descriptions = []
        for name, func in self._tools.items():
            # Extract description from docstring or function name
            description = func.__doc__ or f"Execute {name} tool"
            descriptions.append({
                "type": "function",
                "function": {
                    "name": name,
                    "description": description.strip(),
                    "parameters": getattr(func, '_schema', {"type": "object", "properties": {}})
                }
            })
        return descriptions

    @property
    def messages(self):
        return self._messages

    def add_message(self, message):
        self._messages.append(message)

    async def _set_state(self, new_state: AgentState):
        async with self._state_lock:
            self._state = new_state

    async def _get_state(self) -> AgentState:
        async with self._state_lock:
            return self._state

    async def start_interactive(self) -> None:
        """Start interactive conversation mode"""
        await self._set_state(AgentState.INTERACTIVE)
        
        # Add system message
        system_message = self._create_message("system", self._system_prompt)
        self.add_message(system_message)
        
        # Get initial user input
        user_input = input("👤 You: ")
        user_message = self._create_message("user", user_input)
        self.add_message(user_message)

        try:
            await self._recursive_message_handling()
        except Exception as e:
            await self._set_state(AgentState.ERROR)
            print(f"❌ System error: {e}")
            traceback.print_exc()
        finally:
            await self._set_state(AgentState.IDLE)

    async def run_task(self, user_input: str) -> str:
        """Run a single task and return the result"""
        await self._set_state(AgentState.IN_TASK)
        self._recursion_depth = 0
        self._messages = []  # Fresh start
        
        # Add system and user messages
        system_message = self._create_message("system", self._system_prompt)
        self.add_message(system_message)
        
        user_message = self._create_message("user", user_input)
        self.add_message(user_message)

        try:
            await self._recursive_message_handling()
        except Exception as e:
            await self._set_state(AgentState.ERROR)
            print(f"❌ Task error: {e}")
            traceback.print_exc()
            raise AgentTaskException(f"Task failed: {e}")
        finally:
            # Get last assistant message as result
            result = ""
            for msg in reversed(self._messages):
                if msg.get("role") == "assistant" and msg.get("content"):
                    if isinstance(msg["content"], list):
                        result = msg["content"][0].get("text", "")
                    else:
                        result = msg["content"]
                    break
            
            await self._set_state(AgentState.IDLE)
            return result

    async def _recursive_message_handling(self):
        if self._recursion_depth >= self._config.max_recursion_depth:
            print("❌ Maximum conversation depth reached")
            return
        
        self._recursion_depth += 1

        request = {
            "messages": self._get_messages_with_cache_mark(),
            "tools": self._tool_descriptions,
        }
        
        try:
            response_message, token_usage = await self._process_api_response(request)
        except Exception as e:
            await self._set_state(AgentState.ERROR)
            response_message = ErrorMessage(str(e))
            print(f"🤖 Assistant: {response_message.content}")
            return
            
        if token_usage:
            self._update_token_usage(token_usage)
        
        assistant_message = {
            "role": "assistant",
            "content": response_message.content,
            "tool_calls": response_message.tool_calls if hasattr(response_message, 'tool_calls') and response_message.tool_calls else None
        }
        self.add_message(assistant_message)

        if hasattr(response_message, 'tool_calls') and response_message.tool_calls:
            await self._handle_tool_calls(response_message.tool_calls)
            await self._recursive_message_handling()
        else:
            current_state = await self._get_state()
            if current_state == AgentState.IN_TASK:
                return
            
            user_input = input("👤 You: ")
            user_message = self._create_message("user", user_input)
            self.add_message(user_message)
            await self._recursive_message_handling()

    def _update_token_usage(self, usage):
        if hasattr(usage, 'input_tokens'):
            self._token_usage["input_tokens"] += usage.input_tokens
        if hasattr(usage, 'output_tokens'):
            self._token_usage["output_tokens"] += usage.output_tokens

    async def _process_api_response(self, request: Dict[str, Any]) -> Tuple[Any, Optional[Any]]:
        if self._config.enable_streaming:
            try:
                return await self._process_streaming_response(request)
            except Exception as e:
                if not self._config.fallback_to_non_streaming:
                    raise
                print(f"❌ Streaming error: {e}")
                print("💭 Trying non-streaming mode...")
        
        return await self._process_non_streaming_response(request)

    async def _process_streaming_response(self, request: Dict[str, Any]) -> Tuple[Any, Optional[Any]]:
        stream_generator = self._api_client.get_completion_stream(request)
        
        if stream_generator is None:
            raise Exception("Stream generator is None")
        
        response_message = None
        full_content = ""
        token_usage = None
        
        print("🤖 Assistant: ", end="", flush=True)
        
        try:
            for chunk in stream_generator:
                if isinstance(chunk, str):
                    full_content += chunk
                    print(chunk, end="", flush=True)
                elif hasattr(chunk, 'role') and chunk.role == 'assistant':
                    response_message = chunk
                    if hasattr(chunk, 'usage'):
                        token_usage = chunk.usage
                    break
                elif hasattr(chunk, 'usage'):
                    token_usage = chunk.usage
        finally:
            print()  # New line after streaming
        
        if response_message is None:
            response_message = SimpleMessage(full_content)
            
        return response_message, token_usage

    async def _process_non_streaming_response(self, request: Dict[str, Any]) -> Tuple[Any, Optional[Any]]:
        response_message, token_usage = self._api_client.get_completion(request)
        print(f"🤖 Assistant: {response_message.content}")
        return response_message, token_usage

    def _get_messages_with_cache_mark(self):
        if not self._messages:
            return self._messages
        
        messages = self._messages.copy()
        last_message = messages[-1]
        if "content" in last_message and last_message["content"]:
            last_message["content"][-1]["cache_control"] = {"type": "ephemeral"}
        return messages

    async def _handle_tool_calls(self, tool_calls):
        """Handle tool calls with user approval when needed."""
        for i, tool_call in enumerate(tool_calls):
            is_last_tool = (i == len(tool_calls) - 1)
            try:
                args = json.loads(tool_call.function.arguments)
            except json.JSONDecodeError as e:
                print(f"❌ Tool parameter parsing failed: {e}")
                self._add_tool_response(tool_call, "tool call failed due to JSONDecodeError", is_last_tool)
                continue
            
            # Check if user approval is needed
            need_approval = args.get('need_user_approve', False)
            if need_approval:
                approval_content = f"Tool: {tool_call.function.name}, args: {args}"
                print(f"⚠️  Tool approval needed: {approval_content}")
                response = input("Approve? (y/n): ").lower().strip()
                if not response.startswith('y'):
                    reason = input("Reason for denial: ")
                    self._add_tool_response(tool_call, f"user denied tool execution: {reason}", is_last_tool)
                    continue

            await self._execute_tool(tool_call, args, is_last_tool)

    async def _execute_tool(self, tool_call, args, is_last_tool=False):
        """Execute a tool call and handle the response."""
        tool_name = tool_call.function.name
        tool_args = {k: v for k, v in args.items() if k != 'need_user_approve'}
        
        self._ui_manager.show_preparing_tool(tool_name, tool_args)
        
        try:
            if tool_name not in self._tools:
                raise Exception(f"Tool '{tool_name}' not found")
            
            tool_func = self._tools[tool_name]
            tool_response = await self._run_tool_safely(tool_func, tool_args)
            
            self._ui_manager.show_tool_execution(tool_name, tool_args, success=True, result=str(tool_response))
            self._add_tool_response(tool_call, json.dumps(tool_response), is_last_tool)
        except Exception as e:
            self._ui_manager.show_tool_execution(tool_name, tool_args, success=False, result=str(e))
            self._add_tool_response(tool_call, f"tool call failed: {str(e)}", is_last_tool)

    async def _run_tool_safely(self, tool_func, args):
        """Safely run a tool function (sync or async)"""
        try:
            if asyncio.iscoroutinefunction(tool_func):
                return await tool_func(**args)
            else:
                return tool_func(**args)
        except Exception as e:
            raise Exception(f"Tool execution error: {str(e)}")

    def _add_tool_response(self, tool_call, content, is_last_tool=False):
        tool_content = [{"type": "text", "text": content}]
        
        tool_message = {
            "role": "tool",
            "tool_call_id": tool_call.id,
            "name": tool_call.function.name,
            "content": tool_content
        }
        self.add_message(tool_message)

    def _create_message(self, role: str, text: str) -> dict:
        return {
            "role": role,
            "content": [{"type": "text", "text": text}]
        }


# Helper functions for easy usage
def create_agent(tools=None, system_prompt="You are a helpful assistant.", **config_kwargs):
    """Create an agent with simple parameters"""
    config = AgentConfiguration(**config_kwargs)
    return Agent(tools=tools, system_prompt=system_prompt, config=config)

async def run_agent_task(user_input: str, tools=None, system_prompt="You are a helpful assistant.", **config_kwargs):
    """Run a single task with an agent"""
    agent = create_agent(tools=tools, system_prompt=system_prompt, **config_kwargs)
    return await agent.run_task(user_input)

async def start_interactive_agent(tools=None, system_prompt="You are a helpful assistant.", **config_kwargs):
    """Start an interactive agent session"""
    agent = create_agent(tools=tools, system_prompt=system_prompt, **config_kwargs)
    await agent.start_interactive()


# Example usage:
"""
# Define your tools
def file_reader(filename: str):
    with open(filename, 'r') as f:
        return f.read()

async def web_search(query: str):
    # Your web search implementation
    return f"Search results for: {query}"

# Create tools dict
tools = {
    "file_reader": file_reader,
    "web_search": web_search
}

# Simple task execution
result = await run_agent_task(
    "Read file.txt and summarize it",
    tools=tools,
    system_prompt="You are a file analysis assistant."
)

# Or create agent instance
agent = create_agent(
    tools=tools,
    system_prompt="You are a helpful coding assistant.",
    max_recursion_depth=30,
    enable_streaming=True
)

# Run task
result = await agent.run_task("Help me debug this code")

# Or start interactive mode
await agent.start_interactive()
"""

'\n# Define your tools\ndef file_reader(filename: str):\n    with open(filename, \'r\') as f:\n        return f.read()\n\nasync def web_search(query: str):\n    # Your web search implementation\n    return f"Search results for: {query}"\n\n# Create tools dict\ntools = {\n    "file_reader": file_reader,\n    "web_search": web_search\n}\n\n# Simple task execution\nresult = await run_agent_task(\n    "Read file.txt and summarize it",\n    tools=tools,\n    system_prompt="You are a file analysis assistant."\n)\n\n# Or create agent instance\nagent = create_agent(\n    tools=tools,\n    system_prompt="You are a helpful coding assistant.",\n    max_recursion_depth=30,\n    enable_streaming=True\n)\n\n# Run task\nresult = await agent.run_task("Help me debug this code")\n\n# Or start interactive mode\nawait agent.start_interactive()\n'

In [None]:
import asyncio
from typing import Literal

# Your tool functions
def internet_search(
    query: str,
    max_results: int = 5,
    topic: Literal["general", "news", "technology"] = "technology",
    include_raw_content: bool = False,
):
    # Mock implementation for testing - replace with actual tavily_client
    print(f"🔍 Searching: '{query}' (topic: {topic}, max: {max_results})")
    return {
        "results": [
            {"title": f"Result 1 for {query}", "url": "https://example1.com", "snippet": "Mock result 1"},
            {"title": f"Result 2 for {query}", "url": "https://example2.com", "snippet": "Mock result 2"}
        ]
    }

def calculate_math(expression: str):
    try:
        allowed_chars = set('0123456789+-*/.() ')
        if all(c in allowed_chars for c in expression):
            result = eval(expression)
            return {"result": result, "expression": expression}
        else:
            return {"error": "Invalid characters in expression"}
    except Exception as e:
        return {"error": str(e)}

def write_note(content: str, filename: str = "note.txt"):
    try:
        with open(filename, 'w') as f:
            f.write(content)
        return {"success": True, "message": f"Note written to {filename}"}
    except Exception as e:
        return {"error": str(e)}

def read_note(filename: str = "note.txt"):
    try:
        with open(filename, 'r') as f:
            content = f.read()
        return {"content": content, "filename": filename}
    except Exception as e:
        return {"error": str(e)}

# Tool mapping
TOOL_MAPPING = {
    "internet_search": internet_search,
    "calculate_math": calculate_math,
    "write_note": write_note,
    "read_note": read_note
}

# System prompt
SYSTEM_PROMPT = """You are a helpful AI assistant that can use various tools to help users.
You have access to these tools:
- internet_search: Search the web for current information
- calculate_math: Perform mathematical calculations  
- write_note: Write content to files
- read_note: Read content from files
Use these tools strategically to complete user requests. Think step by step and explain your reasoning."""

# Mock UI Manager for testing
class MockUI:
    def __init__(self):
        self.messages = []
    
    async def get_user_input(self):
        return input("\n👤 You: ")
    
    def print_assistant_message(self, content):
        print(f"\n🤖 Assistant: {content}")
        self.messages.append(("assistant", content))
    
    def print_error(self, msg):
        print(f"\n❌ Error: {msg}")
    
    def print_info(self, msg):
        print(f"\n💭 Info: {msg}")
    
    def start_stream_display(self):
        print("\n🤖 Assistant: ", end="", flush=True)
    
    def print_streaming_content(self, content):
        print(content, end="", flush=True)
    
    def stop_stream_display(self):
        print()  # New line after streaming
    
    def show_preparing_tool(self, tool_name, args):
        print(f"\n🔧 Preparing tool: {tool_name} with args: {args}")
    
    def show_tool_execution(self, tool_name, args, success, result):
        status = "✅" if success else "❌"
        print(f"{status} Tool {tool_name}: {result[:100]}...")
    
    async def wait_for_user_approval(self, approval_content):
        print(f"\n⚠️  Tool approval needed: {approval_content}")
        response = input("Approve? (y/n): ").lower().strip()
        if response.startswith('y'):
            return True, "approved"
        else:
            reason = input("Reason for denial: ")
            return False, reason

# Mock API Client for testing
class MockAPIClient:
    def get_completion(self, request):
        # Mock response - in real usage, this would call your actual API
        content = "I'll help you with that request using the available tools."
        
        class MockResponse:
            def __init__(self):
                self.content = content
                self.tool_calls = None
                self.role = "assistant"
        
        class MockUsage:
            def __init__(self):
                self.input_tokens = 100
                self.output_tokens = 50
        
        return MockResponse(), MockUsage()
    
    def get_completion_stream(self, request):
        # Mock streaming response
        content = "I'll help you with that request using the available tools."
        for char in content:
            yield char
        
        class MockResponse:
            def __init__(self):
                self.content = content
                self.role = "assistant"
                self.tool_calls = None
                self.usage = MockUsage()
        
        yield MockResponse()

# Import your agent (assuming it's in a file called 'agent.py')
# from agent import Agent, AgentConfiguration

# For testing, let's create a simple test function
async def test_agent():
    """Test the agent with your tools"""
    print("🚀 Testing Agent with Custom Tools")
    print("=" * 50)
    
    # Create agent with your tools and system prompt
    from agent import create_agent  # Import your agent module
    
    agent = create_agent(
        tools=TOOL_MAPPING,
        system_prompt=SYSTEM_PROMPT,
        max_recursion_depth=10,
        enable_streaming=False  # Disable for easier testing
    )
    
    # Test different scenarios
    test_cases = [
        "Calculate 15 + 27 * 3",
        "Search for latest AI news",
        "Write a note about my daily tasks",
        "Calculate the area of a circle with radius 5 (use 3.14159 for pi)"
    ]
    
    for i, test_input in enumerate(test_cases, 1):
        print(f"\n📝 Test Case {i}: {test_input}")
        print("-" * 40)
        
        try:
            result = await agent.run_task(test_input)
            print(f"✅ Result: {result}")
        except Exception as e:
            print(f"❌ Error: {e}")
        
        print()

# Interactive test function
async def test_interactive():
    """Test interactive mode with your tools"""
    print("🎯 Starting Interactive Agent Test")
    print("Type 'quit' to exit")
    print("=" * 50)
    
    from agent import create_agent
    
    # Override API client for testing
    agent = create_agent(
        tools=TOOL_MAPPING,
        system_prompt=SYSTEM_PROMPT,
        ui_manager=MockUI()
    )
    
    # Replace API client with mock for testing
    agent._api_client = MockAPIClient()
    
    await agent.start_interactive()

# Test with tool approval
async def test_with_approval():
    """Test agent with user approval for tools"""
    print("🛡️  Testing Agent with Tool Approval")
    print("=" * 50)
    
    # Add approval flag to a tool for testing
    def safe_internet_search(query: str, need_user_approve: bool = True, **kwargs):
        return internet_search(query, **kwargs)
    
    approval_tools = TOOL_MAPPING.copy()
    approval_tools["internet_search"] = safe_internet_search
    
    from agent import create_agent
    
    agent = create_agent(
        tools=approval_tools,
        system_prompt=SYSTEM_PROMPT + "\nSome tools may require user approval.",
        ui_manager=MockUI()
    )
    
    agent._api_client = MockAPIClient()
    
    test_input = "Search for Python programming tutorials"
    result = await agent.run_task(test_input)
    print(f"Final result: {result}")

if __name__ == "__main__":
    print("🔧 Agent Testing Suite")
    print("Choose test mode:")
    print("1. Automated tests")
    print("2. Interactive mode") 
    print("3. Tool approval test")
    
    choice = input("\nEnter choice (1-3): ").strip()
    
    if choice == "1":
        asyncio.run(test_agent())
    elif choice == "2":
        asyncio.run(test_interactive())
    elif choice == "3":
        asyncio.run(test_with_approval())
    else:
        print("Invalid choice!")

# Example of adding tool schemas to functions (optional enhancement)
def add_tool_schema(schema):
    """Decorator to add schema to tool functions"""
    def decorator(func):
        func._schema = schema
        return func
    return decorator

# Enhanced tool with schema
@add_tool_schema({
    "type": "object",
    "properties": {
        "query": {"type": "string", "description": "Search query"},
        "max_results": {"type": "integer", "description": "Max results", "default": 5}
    },
    "required": ["query"]
})
def enhanced_search(query: str, max_results: int = 5):
    return internet_search(query, max_results=max_results)

🔧 Agent Testing Suite
Choose test mode:
1. Automated tests
2. Interactive mode
3. Tool approval test
