In [1]:
import re
from typing import Tuple, Any, List
from llama_index.core.workflow import Event
from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    Event,
    step,
)

## Note:

- So far, the workflow works with the function like calculation. What about dynamic text inputs?

## Specification Question:
- Should each function have one validator or list of validators? (maybe abstract it by allowing just one validator)
- For the planning phase, should the agent still return the respond with all of the steps or one step at a time?

In [2]:
solution = """ 
First, compute the sum of 3 and 4, resulting in [FUNC add(3, 4) = y1]. 
Next, multiply the result by 4, resulting in [FUNC multiply(y1, 4) = y2]. 
Finally, subtract 8 from the result, resulting in [FUNC add(y2, -8) = y3].
"""

In [3]:
func_calls = re.findall(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution)

placeholders = set()
for match in re.finditer(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution):
    placeholders.add(match.group(3))

In [4]:
func_calls

[('add', '3, 4', 'y1'), ('multiply', 'y1, 4', 'y2'), ('add', 'y2, -8', 'y3')]

In [5]:
from calculator_tool_spec import CalculatorToolSpec
tools = CalculatorToolSpec().to_tool_list()
tools_by_name = {tool.metadata.name: tool for tool in tools}
tools_by_name

{'multiply': <llama_index.core.tools.function_tool.FunctionTool at 0x1115b3c90>,
 'add': <llama_index.core.tools.function_tool.FunctionTool at 0x166a8d7d0>,
 'subtract': <llama_index.core.tools.function_tool.FunctionTool at 0x166a8fad0>,
 'divide': <llama_index.core.tools.function_tool.FunctionTool at 0x166a98390>}

In [6]:
from arithmetic_validator_tool_spec import ArithmeticValidatorToolSpec
tools = ArithmeticValidatorToolSpec().to_tool_list()
validators = {tool.metadata.name: tool for tool in tools}
validators

{'eval_expression': <llama_index.core.tools.function_tool.FunctionTool at 0x166a8e390>}

In [7]:
raw_tool_output = await tools_by_name["add"].acall(3, 4)
raw_tool_output

ToolOutput(content='7', tool_name='add', raw_input={'args': (3, 4), 'kwargs': {}}, raw_output=7, is_error=False)

In [8]:
results = {}

for func_call in func_calls:
    func_name, raw_inputs, output_placeholder = func_call
    input_data = []
    
    for raw_input in raw_inputs.split(","):
        raw_input = raw_input.strip()
        
        try:
            raw_input = results[raw_input] if raw_input in results else raw_input
            input_data.append(int(raw_input))
        
        except ValueError as e:
            print("Input string cannot be converted to integer")
            input_data.append(raw_input)
            
    tool_output = await tools_by_name[func_name].acall(*input_data)
    
    if output_placeholder not in results:
        results[output_placeholder] = str(tool_output)

results

{'y1': '7', 'y2': '28', 'y3': '20'}

In [9]:
# events.py
from llama_index.core.tools.types import AsyncBaseTool
from llama_index.core.tools.function_tool import FunctionTool
import json

class FunctionCallEvent(Event):
    func_call: Tuple[str, str, str]    
    
class ValidateFunctionCallEvent(Event):
    validator: AsyncBaseTool
    input_data: List[Any]
    output_placeholder: str
    tool_output: Any


In [10]:
class COAWorkFlow(Workflow):
    @step(pass_context=True)
    async def main_step(self, ctx: Context, ev: StartEvent) -> FunctionCallEvent:
        """Initialize the context and start processing function calls."""
        ctx.data["results"] = {} 
        ctx.data["tools_by_name"] = ev.tools_by_name
        ctx.data["function_calls"] = ev.func_calls
        ctx.data["validators"] = ev.validators
        ctx.data["iteration"] = 0
        
        # Start with the first function call
        first_func_call = ctx.data["function_calls"][0]
        return FunctionCallEvent(func_call=first_func_call)

    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> ValidateFunctionCallEvent | StopEvent:
        """Execute function call"""
        func_name, raw_inputs, output_placeholder = ev.func_call
        tools_by_name = ctx.data["tools_by_name"]
        results = ctx.data["results"]
        
        input_data = []
        for raw_input in raw_inputs.split(","):
            raw_input = raw_input.strip()
            
            try:
                raw_input = results[raw_input] if raw_input in results else raw_input
                input_data.append(int(raw_input))
            
            except ValueError:
                print("Input string cannot be converted to integer")
                input_data.append(raw_input)
        
        # Execute the tool function
        try:
            tool_output = await tools_by_name[func_name].acall(*input_data)
            print(f"Executed {func_name} with inputs {input_data} -> {output_placeholder}: {tool_output}")
        except Exception as e:
            return StopEvent(f"Error in {func_name} with inputs {input_data}: {e}")
        
        input_data.append(func_name)
        
        validator = ctx.data["validators"]["eval_expression"]
            
        if validator is not None:
            return ValidateFunctionCallEvent(validator=validator, tool_output=tool_output, input_data=input_data, output_placeholder=output_placeholder)

        return StopEvent(f"This is the result of the final function {results}")
    
    @step(pass_context=True)
    async def validate_function_step(self, ctx: Context, ev: ValidateFunctionCallEvent) -> FunctionCallEvent | StopEvent:
        """Validate the function output and proceed based on the validation."""
        input_data = ev.input_data
        tool_output = ev.tool_output
        validator = ev.validator
        output_placeholder = ev.output_placeholder
        results = ctx.data["results"]

        # Execute the validation
        try:
            validator_output = await validator.acall(*input_data)
            print(f"Validation result: {validator_output} | Tool output: {tool_output}")
        except Exception as e:
            return StopEvent(f"Validation error: {e}")

        # Compare the validator's output with the tool's output
        if str(validator_output) == str(tool_output):
            results[output_placeholder] = str(tool_output)  # Cache the result
            return await self._move_to_next_function(ctx, results)
        else:
            return StopEvent("Tool output does not match the expected validation result")

    async def _move_to_next_function(self, ctx: Context, results: dict) -> FunctionCallEvent | StopEvent:
        """Helper function to move to the next function call."""
        iteration = ctx.data["iteration"]
        function_calls = ctx.data["function_calls"]

        if iteration + 1 < len(function_calls):
            next_func_call = function_calls[iteration + 1]
            ctx.data["iteration"] += 1
            return FunctionCallEvent(func_call=next_func_call, validator_name=next_func_call[0])

        return StopEvent(f"All function calls processed. Final results: {results}")

In [11]:
w = COAWorkFlow()
await w.run(tools_by_name=tools_by_name, func_calls=func_calls, validators=validators )

Executed add with inputs [3, 4] -> y1: 7


"Validation error: ArithmeticValidatorToolSpec.eval_expression() missing 1 required positional argument: 'tool_output'"

In [12]:
validators_mapper = {
    "add": validators["eval_expression"],
    "subtract": validators["eval_expression"],
    "multiply": validators["eval_expression"]
}

In [13]:
validators_mapper["add"]

<llama_index.core.tools.function_tool.FunctionTool at 0x166a8e390>

In [14]:
class COAWorkFlow(Workflow):
    @step(pass_context=True)
    async def main_step(self, ctx: Context, ev: StartEvent) -> FunctionCallEvent:
        """Initialize the context and start processing function calls."""
        ctx.data["results"] = {} 
        ctx.data["tools_by_name"] = ev.tools_by_name
        ctx.data["function_calls"] = ev.func_calls
        ctx.data["validators"] = ev.validators
        ctx.data["iteration"] = 0
        
        # Start with the first function call
        first_func_call = ctx.data["function_calls"][0]
        return FunctionCallEvent(func_call=first_func_call)

    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> ValidateFunctionCallEvent | StopEvent:
        """Execute function call"""
        func_name, raw_inputs, output_placeholder = ev.func_call
        tools_by_name = ctx.data["tools_by_name"]
        results = ctx.data["results"]
        
        input_data = []
        for raw_input in raw_inputs.split(","):
            raw_input = raw_input.strip()
            
            try:
                raw_input = results[raw_input] if raw_input in results else raw_input
                input_data.append(int(raw_input))
            
            except ValueError:
                print("Input string cannot be converted to integer")
                input_data.append(raw_input)
        
        # Execute the tool function
        try:
            tool_output = await tools_by_name[func_name].acall(*input_data)
            print(f"Executed {func_name} with inputs {input_data} -> {output_placeholder}: {tool_output}")
        except Exception as e:
            return StopEvent(f"Error in {func_name} with inputs {input_data}: {e}")
        
        input_data.append(func_name)
        input_data.append(tool_output.content)
        
        validator = ctx.data["validators"][func_name]
            
        if validator is not None:
            return ValidateFunctionCallEvent(validator=validator, tool_output=tool_output, input_data=input_data, output_placeholder=output_placeholder)

        return StopEvent(f"This is the result of the final function {results}")
    
    @step(pass_context=True)
    async def validate_function_step(self, ctx: Context, ev: ValidateFunctionCallEvent) -> FunctionCallEvent | StopEvent:
        """Validate the function output and proceed based on the validation."""
        input_data = ev.input_data
        tool_output = ev.tool_output
        validator = ev.validator
        output_placeholder = ev.output_placeholder
        results = ctx.data["results"]

        # Execute the validation
        try:
            validator_output = await validator.acall(*input_data)
            print(f"Validation result: {validator_output} | Tool output: {tool_output}")
        except Exception as e:
            return StopEvent(f"Validation error: {e}")

        # Compare the validator's output with the tool's output
        if validator_output:
            results[output_placeholder] = str(tool_output)  # Cache the result
            return await self._move_to_next_function(ctx, results)
        else:
            return StopEvent("Tool output does not match the expected validation result")

    async def _move_to_next_function(self, ctx: Context, results: dict) -> FunctionCallEvent | StopEvent:
        """Helper function to move to the next function call."""
        iteration = ctx.data["iteration"]
        function_calls = ctx.data["function_calls"]

        if iteration + 1 < len(function_calls):
            next_func_call = function_calls[iteration + 1]
            ctx.data["iteration"] += 1
            return FunctionCallEvent(func_call=next_func_call, validator_name=next_func_call[0])

        return StopEvent(f"All function calls processed. Final results: {results}")

In [15]:
w = COAWorkFlow()
await w.run(tools_by_name=tools_by_name, func_calls=func_calls, validators=validators_mapper)

Executed add with inputs [3, 4] -> y1: 7
Validation result: False | Tool output: 7
Executed multiply with inputs [7, 4] -> y2: 28
Validation result: False | Tool output: 28
Executed add with inputs [28, -8] -> y3: 20
Validation result: False | Tool output: 20


"All function calls processed. Final results: {'y1': '7', 'y2': '28', 'y3': '20'}"

In [26]:
from typing import Any, List, Tuple
from llama_index.core.tools.types import AsyncBaseTool
from llama_index.core.workflow import Workflow, Context, Event, StartEvent, StopEvent, step

class FunctionCallEvent(Event):
    func_call: Tuple[str, str, str]  # Function name, raw inputs, output placeholder

class ValidateFunctionCallEvent(Event):
    validator: AsyncBaseTool
    input_data: List[Any]
    output_placeholder: str
    tool_output: Any

class InitializeEvent(Event):
    """Event for initializing the workflow context."""
    pass

class COAWorkFlow(Workflow):
    @step(pass_context=True)
    async def initialize_step(self, ctx: Context, ev: StartEvent | InitializeEvent) -> InitializeEvent | FunctionCallEvent:
        """Initialize the workflow context."""
        if isinstance(ev, StartEvent):
            # Perform initial setup, e.g., load tools, validators, etc.
            ctx.data["results"] = {}
            ctx.data["tools_by_name"] = ev.tools_by_name
            ctx.data["function_calls"] = ev.func_calls
            ctx.data["validators"] = ev.validators
            ctx.data["iteration"] = 0
            
            # Produce an InitializeEvent to start initialization
            return InitializeEvent()

        # After initialization, start processing the first function call
        first_func_call = ctx.data["function_calls"][0]
        return FunctionCallEvent(func_call=first_func_call)

    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> ValidateFunctionCallEvent | StopEvent:
        """Execute a function call and prepare for validation."""
        func_name, raw_inputs, output_placeholder = ev.func_call
        input_data = self._prepare_inputs(ctx, raw_inputs)

        try:
            tool_output = await ctx.data["tools_by_name"][func_name].acall(*input_data)
            print(f"Executed {func_name} with inputs {input_data} -> {output_placeholder}: {tool_output}")
        except Exception as e:
            return self._handle_error(f"Error in {func_name} with inputs {input_data}: {e}")

        validator = ctx.data["validators"].get(func_name)
        if validator:
            return ValidateFunctionCallEvent(
                validator=validator,
                input_data=input_data + [func_name, tool_output.content],
                output_placeholder=output_placeholder,
                tool_output=tool_output
            )

        ctx.data["results"][output_placeholder] = tool_output.content
        return await self._move_to_next_function(ctx)

    def _prepare_inputs(self, ctx: Context, raw_inputs: str) -> List[Any]:
        """Parse and prepare function inputs."""
        results = ctx.data["results"]
        input_data = []
        for raw_input in raw_inputs.split(","):
            raw_input = raw_input.strip()
            try:
                input_data.append(int(results.get(raw_input, raw_input)))
            except ValueError:
                input_data.append(raw_input)  # Handle non-integer inputs gracefully
        return input_data

    @step(pass_context=True)
    async def validate_function_step(self, ctx: Context, ev: ValidateFunctionCallEvent) -> FunctionCallEvent | StopEvent:
        """Validate function output and decide on the next step."""
        try:
            validator_output = await ev.validator.acall(*ev.input_data)
            print(f"Validation result: {validator_output} | Tool output: {ev.tool_output}")
        except Exception as e:
            return self._handle_error(f"Validation error: {e}")

        if validator_output:
            ctx.data["results"][ev.output_placeholder] = str(ev.tool_output)
            return await self._move_to_next_function(ctx)

        return self._handle_error("Tool output does not match the expected validation result")

    async def _move_to_next_function(self, ctx: Context) -> FunctionCallEvent | StopEvent:
        """Move to the next function call if available."""
        iteration = ctx.data["iteration"]
        function_calls = ctx.data["function_calls"]

        if iteration + 1 < len(function_calls):
            ctx.data["iteration"] += 1
            next_func_call = function_calls[ctx.data["iteration"]]
            return FunctionCallEvent(func_call=next_func_call)

        # End of the workflow, produce a StopEvent
        return StopEvent(f"All function calls processed. Final results: {ctx.data['results']}")

    def _handle_error(self, message: str) -> StopEvent:
        """Handle errors gracefully and provide feedback."""
        print(message)
        return StopEvent(message)


In [27]:
w = COAWorkFlow()
await w.run(tools_by_name=tools_by_name, func_calls=func_calls, validators=validators_mapper)

Executed add with inputs [3, 4] -> y1: 7
Validation result: False | Tool output: 7
Executed multiply with inputs [7, 4] -> y2: 28
Validation result: False | Tool output: 28
Executed add with inputs [28, -8] -> y3: 20
Validation result: False | Tool output: 20


"All function calls processed. Final results: {'y1': '7', 'y2': '28', 'y3': '20'}"