In [1]:
import re
from typing import Tuple, Any, List
from llama_index.core.workflow import Event
from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    Event,
    step,
)

## Note:

So far, the workflow works with the function like calculation. What about dynamic text inputs?

In [2]:
solution = """ 
First, compute the sum of 3 and 4, resulting in [FUNC add(3, 4) = y1]. 
Next, multiply the result by 4, resulting in [FUNC multiply(y1, 4) = y2]. 
Finally, subtract 8 from the result, resulting in [FUNC add(y2, -8) = y3].
"""

In [15]:
func_calls = re.findall(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution)

placeholders = set()
for match in re.finditer(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution):
    placeholders.add(match.group(3))

In [16]:
func_calls

[('add', '3, 4', 'y1'), ('multiply', 'y1, 4', 'y2'), ('add', 'y2, -8', 'y3')]

In [17]:
from calculator_tool_spec import CalculatorToolSpec
tools = CalculatorToolSpec().to_tool_list()
tools_by_name = {tool.metadata.name: tool for tool in tools}
tools_by_name

{'multiply': <llama_index.core.tools.function_tool.FunctionTool at 0x13faf0d90>,
 'add': <llama_index.core.tools.function_tool.FunctionTool at 0x13faf0450>,
 'subtract': <llama_index.core.tools.function_tool.FunctionTool at 0x13f49b4d0>,
 'divide': <llama_index.core.tools.function_tool.FunctionTool at 0x13fc6cc90>}

In [18]:
from arithmetic_validator_tool_spec import ArithmeticValidatorToolSpec
tools = ArithmeticValidatorToolSpec().to_tool_list()
validators = {tool.metadata.name: tool for tool in tools}
validators

{'eval_expression': <llama_index.core.tools.function_tool.FunctionTool at 0x13fc6e450>}

In [19]:
raw_tool_output = await tools_by_name["add"].acall(3, 4)
raw_tool_output

ToolOutput(content='7', tool_name='add', raw_input={'args': (3, 4), 'kwargs': {}}, raw_output=7, is_error=False)

In [20]:
results = {}

for func_call in func_calls:
    func_name, raw_inputs, output_placeholder = func_call
    input_data = []
    
    for raw_input in raw_inputs.split(","):
        raw_input = raw_input.strip()
        
        try:
            raw_input = results[raw_input] if raw_input in results else raw_input
            input_data.append(int(raw_input))
        
        except ValueError as e:
            print("Input string cannot be converted to integer")
            input_data.append(raw_input)
            
    tool_output = await tools_by_name[func_name].acall(*input_data)
    
    if output_placeholder not in results:
        results[output_placeholder] = str(tool_output)

results

{'y1': '7', 'y2': '28', 'y3': '20'}

In [21]:
# events.py
from llama_index.core.tools.types import AsyncBaseTool
from llama_index.core.tools.function_tool import FunctionTool
import json

class FunctionCallEvent(Event):
    # input_data: List[str]
    func_call: Tuple[str, str, str]
    
    
class ValidateFunctionCall(Event):
    validator: AsyncBaseTool
    input_data: List[Any]
    output_placeholder: str
    tool_output: Any


In [22]:
class COAWorkFlow(Workflow):

    @step(pass_context=True)
    async def main_step(self, ctx: Context, ev: StartEvent) -> FunctionCallEvent:
        """Initialize the context and start processing function calls."""
        ctx.data["results"] = {} 
        ctx.data["tools_by_name"] = ev.tools_by_name
        ctx.data["function_calls"] = ev.func_calls
        ctx.data["validators"] = ev.validators
        ctx.data["iteration"] = 0
        
        # Start with the first function call
        first_func_call = ctx.data["function_calls"][0]
        return FunctionCallEvent(func_call=first_func_call)

    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> ValidateFunctionCall | StopEvent:
        """Execute function call"""
        func_name, raw_inputs, output_placeholder = ev.func_call
        tools_by_name = ctx.data["tools_by_name"]
        results = ctx.data["results"]
        
        input_data = []
        for raw_input in raw_inputs.split(","):
            raw_input = raw_input.strip()
            
            try:
                raw_input = results[raw_input] if raw_input in results else raw_input
                input_data.append(int(raw_input))
            
            except ValueError:
                print("Input string cannot be converted to integer")
                input_data.append(raw_input)
        
        # Execute the tool function
        try:
            tool_output = await tools_by_name[func_name].acall(*input_data)
            print(f"Executed {func_name} with inputs {input_data} -> {output_placeholder}: {tool_output}")
        except Exception as e:
            return StopEvent(f"Error in {func_name} with inputs {input_data}: {e}")
        
        input_data.append(func_name)
        
        validator = ctx.data["validators"]["eval_expression"]
            
        if validator is not None:
            return ValidateFunctionCall(validator=validator, tool_output=tool_output, input_data=input_data, output_placeholder=output_placeholder)

        return StopEvent(f"This is the result of the final function {results}")
    
    @step(pass_context=True)
    async def validate_function_step(self, ctx: Context, ev: ValidateFunctionCall) -> FunctionCallEvent | StopEvent:
        """Validate the function output and proceed based on the validation."""
        input_data = ev.input_data
        tool_output = ev.tool_output
        validator = ev.validator
        output_placeholder = ev.output_placeholder
        results = ctx.data["results"]

        # Execute the validation
        try:
            validator_output = await validator.acall(*input_data)
            print(f"Validation result: {validator_output} | Tool output: {tool_output}")
        except Exception as e:
            return StopEvent(f"Validation error: {e}")

        # Compare the validator's output with the tool's output
        if str(validator_output) == str(tool_output):
            results[output_placeholder] = str(tool_output)  # Cache the result
            iteration = ctx.data["iteration"]
            function_calls = ctx.data["function_calls"]

            if iteration + 1 < len(function_calls):
                next_func_call = function_calls[iteration + 1]
                ctx.data["iteration"] += 1
                return FunctionCallEvent(func_call=next_func_call)

            return StopEvent(f"All function calls processed. Final results: {ctx.data['results']}")
        else:
            return StopEvent("Tool output does not match the expected validation result")

In [23]:
w = COAWorkFlow()
await w.run(tools_by_name=tools_by_name, func_calls=func_calls, validators=validators )

Executed add with inputs [3, 4] -> y1: 7
Validation result: 7 | Tool output: 7
Executed multiply with inputs [7, 4] -> y2: 28
Validation result: 28 | Tool output: 28
Executed add with inputs [28, -8] -> y3: 20
Validation result: 20 | Tool output: 20


"All function calls processed. Final results: {'y1': '7', 'y2': '28', 'y3': '20'}"