In [1]:

from typing import Any, List, Tuple
from llama_index.core.tools.types import AsyncBaseTool
from llama_index.core.workflow import Workflow, Context, Event, StartEvent, StopEvent, step
import os
from llama_index.core.tools import QueryEngineTool


os.environ["OPENAI_API_KEY"] = ""

class FunctionCallEvent(Event):
    func_call: Tuple[str, str, str]  # Function name, raw inputs, output placeholder

class ValidateFunctionCallEvent(Event):
    validator: AsyncBaseTool
    input_data: List[Any]
    output_placeholder: str
    tool_output: Any

class InitializeEvent(Event):
    """Event for initializing the workflow context."""
    pass

class COAWorkFlow(Workflow):
    @step(pass_context=True)
    async def initialize_step(self, ctx: Context, ev: StartEvent | InitializeEvent) -> InitializeEvent | FunctionCallEvent:
        """Initialize the workflow context."""
        if isinstance(ev, StartEvent):
            # Perform initial setup, e.g., load tools, validators, etc.
            ctx.data["results"] = {}
            ctx.data["tools_by_name"] = ev.tools_by_name
            ctx.data["function_calls"] = ev.func_calls
            ctx.data["validators"] = ev.validators
            ctx.data["iteration"] = 0
            
            # Produce an InitializeEvent to start initialization
            return InitializeEvent()

        # After initialization, start processing the first function call
        first_func_call = ctx.data["function_calls"][0]
        return FunctionCallEvent(func_call=first_func_call)

    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> ValidateFunctionCallEvent | StopEvent:
        """Execute a function call and prepare for validation."""
        func_name, raw_inputs, output_placeholder = ev.func_call
        input_data = self._prepare_inputs(ctx, raw_inputs)

        try:
            tool_output = await ctx.data["tools_by_name"][func_name].acall(*input_data)
            
            print(f"Executed {func_name} with inputs {input_data} -> {output_placeholder}: {tool_output}")
        except Exception as e:
            return self._handle_error(f"Error in {func_name} with inputs {input_data}: {e}")

        validator = ctx.data["validators"].get(func_name)
        if validator:
            return ValidateFunctionCallEvent(
                validator=validator,
                input_data=input_data + [func_name, tool_output.content],
                output_placeholder=output_placeholder,
                tool_output=tool_output
            )

        ctx.data["results"][output_placeholder] = tool_output.content
        return await self._move_to_next_function(ctx)

    def _prepare_inputs(self, ctx: Context, raw_inputs: str) -> List[Any]:
        """Parse and prepare function inputs."""
        results = ctx.data["results"]
        input_data = []
        for raw_input in raw_inputs.split(","):
            raw_input = raw_input.strip()
            try:
                input_data.append(int(results.get(raw_input, raw_input)))
            except ValueError:
                input_data.append(raw_input)  # Handle non-integer inputs gracefully
        return input_data

    @step(pass_context=True)
    async def validate_function_step(self, ctx: Context, ev: ValidateFunctionCallEvent) -> FunctionCallEvent | StopEvent:
        """Validate function output and decide on the next step."""
        try:
            validator_output = await ev.validator.acall(*ev.input_data)
            print(f"Validation result: {validator_output} | Tool output: {ev.tool_output}")
        except Exception as e:
            return self._handle_error(f"Validation error: {e}")

        if validator_output:
            ctx.data["results"][ev.output_placeholder] = str(ev.tool_output)
            return await self._move_to_next_function(ctx)

        return self._handle_error("Tool output does not match the expected validation result")

    async def _move_to_next_function(self, ctx: Context) -> FunctionCallEvent | StopEvent:
        """Move to the next function call if available."""
        iteration = ctx.data["iteration"]
        function_calls = ctx.data["function_calls"]

        if iteration + 1 < len(function_calls):
            ctx.data["iteration"] += 1
            next_func_call = function_calls[ctx.data["iteration"]]
            return FunctionCallEvent(func_call=next_func_call)

        # End of the workflow, produce a StopEvent
        return StopEvent({"message": "All function calls processed.", "results": ctx.data["results"]})

    def _handle_error(self, message: str) -> StopEvent:
        """Handle errors gracefully and provide feedback."""
        print(message)
        return StopEvent({"message": message, "results": {}})


In [2]:
import re
from calculator_tool_spec import CalculatorToolSpec
from arithmetic_validator_tool_spec import ArithmeticValidatorToolSpec


tools = CalculatorToolSpec().to_tool_list()
tools_by_name = {tool.metadata.name: tool for tool in tools}

tools = ArithmeticValidatorToolSpec().to_tool_list()
validators = {tool.metadata.name: tool for tool in tools}


validators_mapper = {
    "add": validators["eval_expression"],
    "subtract": validators["eval_expression"],
    "multiply": validators["eval_expression"]
}

solution = """ 
First, compute the sum of 3 and 4, resulting in [FUNC add(3, 4) = y1]. 
Next, multiply the result by 4, resulting in [FUNC multiply(y1, 4) = y2]. 
Finally, subtract 8 from the result, resulting in [FUNC add(y2, -8) = y3].
"""

func_calls = re.findall(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution)

placeholders = set()
for match in re.finditer(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution):
    placeholders.add(match.group(3))

In [27]:

w = COAWorkFlow()
await w.run(tools_by_name=tools_by_name, func_calls=func_calls, validators=validators_mapper)

Executed add with inputs [3, 4] -> y1: 7
Validation result: False | Tool output: 7
Executed multiply with inputs [7, 4] -> y2: 28
Validation result: False | Tool output: 28
Executed add with inputs [28, -8] -> y3: 20
Validation result: False | Tool output: 20


{'message': 'All function calls processed.',
 'results': {'y1': '7', 'y2': '28', 'y3': '20'}}

In [28]:
# Setup your tools and validators for Lyft and Uber
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core import Settings
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI



from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.retrievers import BaseRetriever
from llama_index.core import get_response_synthesizer
from llama_index.core.response_synthesizers import BaseSynthesizer
    
Settings.embed_model = OpenAIEmbedding(
    model="text-embedding-3-small", embed_batch_size=256
)
Settings.llm = OpenAI(model="gpt-4o", temperature=0.1)
storage_context = StorageContext.from_defaults(persist_dir="./storage/lyft")
lyft_index = load_index_from_storage(storage_context)

storage_context = StorageContext.from_defaults(persist_dir="./storage/uber")
uber_index = load_index_from_storage(storage_context)

lyft_engine = lyft_index.as_query_engine(similarity_top_k=2)
uber_engine = uber_index.as_query_engine(similarity_top_k=2)

query_engine_tools = [
    QueryEngineTool.from_defaults(
        query_engine=lyft_engine,
        name="lyft_10k",
        description="Provides information about Lyft financials for year 2021.",
    ),
    QueryEngineTool.from_defaults(
        query_engine=uber_engine,
        name="uber_10k",
        description="Provides information about Uber financials for year 2021.",
    ),
]

tools_by_name = {tool.metadata.name: tool for tool in query_engine_tools}

# Example function calls and validators
func_calls = [
    ("lyft_10k", "What was Lyft's revenue growth in 2021?", "result1"),
    ("uber_10k", "What was Uber's revenue growth in 2021?", "result2"),
]

validators = {
    "lyft_10k": None,  # Example, you might not need specific validators for simple queries
    "uber_10k": None,
}

# Run the workflow
w = COAWorkFlow(timeout=60)
await w.run(tools_by_name=tools_by_name, func_calls=func_calls, validators=validators)

Executed lyft_10k with inputs ["What was Lyft's revenue growth in 2021?"] -> result1: Lyft's revenue in 2021 was $3,208,323,000, compared to $2,364,681,000 in 2020. This represents a revenue growth of approximately 35.7% in 2021.
Executed uber_10k with inputs ["What was Uber's revenue growth in 2021?"] -> result2: Uber's revenue growth in 2021 was 57%.


{'message': 'All function calls processed.',
 'results': {'result1': "Lyft's revenue in 2021 was $3,208,323,000, compared to $2,364,681,000 in 2020. This represents a revenue growth of approximately 35.7% in 2021.",
  'result2': "Uber's revenue growth in 2021 was 57%."}}

In [29]:
import nest_asyncio
import asyncio
import re
from typing import Dict, Tuple, List

from llama_index.core.tools import AsyncBaseTool, ToolOutput
from llama_index.core.types import BaseOutputParser
from llama_index.core.workflow import Context

# Apply nest_asyncio to allow nested event loops
nest_asyncio.apply()

class ChainOfAbstractionParser(BaseOutputParser):
    """Chain of abstraction output parser."""

    def __init__(self, verbose: bool = False):
        """Initialize the parser with verbosity and workflow setup."""
        self._verbose = verbose

    def parse(
        self, solution: str, tools_by_name: Dict[str, AsyncBaseTool]
    ) -> Tuple[str, List[ToolOutput]]:
        """Run the async parse method, handling running event loops."""
        if asyncio.get_event_loop().is_running():
            # Use `await` if inside a running event loop (Jupyter Notebook case)
            return asyncio.get_event_loop().run_until_complete(self.aparse(solution, tools_by_name))
        else:
            # Normal use case outside of Jupyter, use asyncio.run
            return asyncio.run(self.aparse(solution, tools_by_name))

    async def aparse(
        self, solution: str, tools_by_name: Dict[str, AsyncBaseTool]
    ) -> Tuple[str, List[ToolOutput]]:
        """Asynchronously parse the solution and execute the workflow."""
        # Extract function calls
        func_calls = re.findall(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution)

        # Initialize the workflow
        workflow = COAWorkFlow(timeout=60)

        validators = {}
        # Run the workflow
        response = await workflow.run(timeout=60,tools_by_name=tools_by_name, func_calls=func_calls, validators=validators)
        
        results = response["results"]
        # print(context.data)

        # Collect results from the workflow
        tool_outputs = []
        for func_name, raw_inputs, output_placeholder in func_calls:
            if output_placeholder in results:
                tool_outputs.append(
                    ToolOutput(
                        content=results[output_placeholder],
                        tool_name=func_name,
                        raw_output=results[output_placeholder],
                        raw_input={"args": raw_inputs},
                        is_error=False,
                    )
                )
                print(tool_outputs)
            else:
                tool_outputs.append(
                    ToolOutput(
                        content="Error: No output generated",
                        tool_name=func_name,
                        raw_output=None,
                        raw_input={"args": raw_inputs},
                        is_error=True,
                    )
                )


        # Replace placeholders in the solution text
        for placeholder, value in results.items():
            solution = solution.replace(f"{placeholder}", '"' + str(value) + '"')

        print(solution)
        return solution, tool_outputs




In [30]:
query_engine_tools = [
    QueryEngineTool.from_defaults(
        query_engine=lyft_engine,
        name="lyft_10k",
        description="Provides information about Lyft financials for year 2021.",
    ),
    QueryEngineTool.from_defaults(
        query_engine=uber_engine,
        name="uber_10k",
        description="Provides information about Uber financials for year 2021.",
    ),
]

tools_by_name = {tool.metadata.name: tool for tool in query_engine_tools}


# Example solution text
solution = """1. Retrieve Uber's revenue growth for 2021 by querying the Uber financial tool with a specific question about revenue growth:
   - [FUNC uber_10k("What was Uber's revenue growth in 2021?") = y1]

2. Retrieve Lyft's revenue growth for 2021 by querying the Lyft financial tool with a similar question about revenue growth:
   - [FUNC lyft_10k("What was Lyft's revenue growth in 2021?") = y2]
"""

# Initialize parser and run the workflow
parser = ChainOfAbstractionParser(verbose=True)

solution, tool_outputs = await parser.aparse(solution, tools_by_name)

for output in tool_outputs:
    print(output.content)
# solution, tool_outputs = parser.parse(solution, tools_by_name)

# for output in tool_outputs:
#     print(output.content)

Executed uber_10k with inputs ['"What was Uber\'s revenue growth in 2021?"'] -> y1: Uber's revenue growth in 2021 was 57%.
Executed lyft_10k with inputs ['"What was Lyft\'s revenue growth in 2021?"'] -> y2: Lyft's revenue in 2021 was $3,208,323,000, compared to $2,364,681,000 in 2020. This represents a revenue growth of approximately 35.7% from 2020 to 2021.
[ToolOutput(content="Uber's revenue growth in 2021 was 57%.", tool_name='uber_10k', raw_input={'args': '"What was Uber\'s revenue growth in 2021?"'}, raw_output="Uber's revenue growth in 2021 was 57%.", is_error=False)]
[ToolOutput(content="Uber's revenue growth in 2021 was 57%.", tool_name='uber_10k', raw_input={'args': '"What was Uber\'s revenue growth in 2021?"'}, raw_output="Uber's revenue growth in 2021 was 57%.", is_error=False), ToolOutput(content="Lyft's revenue in 2021 was $3,208,323,000, compared to $2,364,681,000 in 2020. This represents a revenue growth of approximately 35.7% from 2020 to 2021.", tool_name='lyft_10k', r