In [1]:
import os
os.environ["OPENAI_API_KEY"] = ""

In [2]:
INITIAL_REASONING_PROMPT_TEMPLATE = """
Generate an initial reasoning step using placeholders for the specific values and function calls needed. 
The goal is to determine the first function to call based on the question provided and the available functions.

Use the placeholders labeled y1, y2, etc., to represent outputs if needed. 
The reasoning should lead to only one function call, represented as an inline string like [FUNC {{function_name}}({{input1}}, {{input2}}, ...) = {{output_placeholder}}].

You are not required to use all functions, but you must use at least one function that best matches the question's intent. 
If the question can be answered without any function, you can conclude so.

Assume someone will read the plan after this function has been executed to continue further steps.
{previous_steps} #e.g. {{filtered_retrieval(guidelines for management of frozen shoulder = y1)}}
{previous_outcomes} #eg. Evaluator output: retrieved sources are irrelevant and do not answer the question of management of frozen shoulder, the reasoner should execute a new retrieval step. 
Example: #---->The next step should do a broad article search for froznen shoulder managment using {structured_retrieval(frozen shoulder management=y1)}
-----------
Available functions:
```python
def add(a: int, b: int) -> int:
    \"\"\"Add two numbers together.\"\"\"
    ...

def multiply(a: int, b: int) -> int:
    \"\"\"Multiply two numbers together.\"\"\"
    ...
```   
    
Question:
Sally has 3 apples and buys 2 more. Then magically, a wizard casts a spell that multiplies the number of apples by 3. How many apples does Sally have now?

Abstract plan of reasoning:
After buying the apples, Sally has [FUNC add(3, 2) = y1] apples.

Your Turn:
-----------
Available functions:
```python
{functions}
```

Question:
{question}

Abstract plan of reasoning:
"""

NEXT_STEP_REASONING_PROMPT_TEMPLATE = """
Based on the outputs generated from the previous function calls, decide whether additional steps are necessary:
No further steps are necessary if the retrieved context completely answers the question
If no further steps are required, return "NO".
If the context does not answer the question, return XXXX
If further steps are needed, provide a clear reasoning step using placeholders (y1, y2, etc.) for specific values and function calls. The objective is to determine the next function to call, using available outputs and aligning with the intent of the question.
{question}

{function steps}

{outputs}
#evaluate how well outputs answer question, if more info needed, trigger another reasoning step with instructions of why more info is pygame.examples.headless_no_windows_needed.main(#e..g, fout, w, h)
#e.g. similarity threshold was not highenough so no nodes retrieved, nodes are irrelevant, no guidelines or drug monographs exist, filter too restricve
Use placeholders (y1, y2, etc.) to represent the outputs of previous function calls if needed.
The reasoning should lead to a single, well-justified function call, represented as: [FUNC {function_name}({input1}, {input2}, ...) = {output_placeholder}].
You are not required to use all functions, but you must select at least one function that best matches the question's intent or utilizes keywords from the previous output to refine the results.
If the question can be fully answered without invoking additional functions, conclude accordingly.
Assumptions:
Assume that someone will read this plan after executing the current function to determine the next steps.

Example:
-----------
Available functions:
```python
def add(a: int, b: int) -> int:
    \"\"\"Add two numbers together.\"\"\"
    ...

def multiply(a: int, b: int) -> int:
    \"\"\"Multiply two numbers together.\"\"\"
    ...
```   
Question:
Sally has 3 apples and buys 2 more. Then magically, a wizard casts a spell that multiplies the number of apples by 3. How many apples does Sally have now?

Previous Function Call Output:
After buying the apples, Sally has [FUNC add(3, 2) = y1] apples. After executing the first function, we have y1 = 5.

Abstract Plan of Reasoning for Next Step:
Since a wizard casts a spell that multiplies the number of apples by 3, use the output from the previous function call to determine the total apples. Thus, the next function call is [FUNC multiply(y1, 3)].



Your Turn:
-----------
Available functions:
```python
{functions}
```

Question:
{question}

Previous Function Call Output:
{function_call_output}

Abstract plan of reasoning for next step:
"""

In [3]:
from agents.coa_agent.prompts.evaluator import prometheus_relevancy_eval_prompt_template, prometheus_relevancy_refine_prompt_template
from agents.coa_agent.validator.relevancy_eval import GPT4RelevancyEvaluator
import nest_asyncio
nest_asyncio.apply()

relevancy_eval_prompt_template = """###Task Description: An instruction (might include an Input inside it), a query, context, and a score rubric representing evaluation criteria are given. 
       1. You are provided with evaluation task with the help of a query and context output by function tool.
       2. Write a detailed feedback based on evaluation task and the given score rubric, not evaluating in general. 
       3. After writing a feedback, write a score that is YES or NO. You should refer to the score rubric. 
       4. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (YES or NO)” 
       5. Please do not generate any other opening, closing, and explanations. 

        ###The instruction to evaluate: Your task is to evaluate if the source nodes for the query whether they are in line with the context information provided.

        ###Query: {query_str} 

        ###Context: {context_str}
            
        ###Score Rubrics: 
        Score YES: If the context information provided is sufficient and in line to answer the query.
        Score NO: If the context information provided is sufficient and in line to answer the query.
        
        ###Feedback: """

  from .autonotebook import tqdm as notebook_tqdm



In [4]:
from typing import List, Optional, Any
from llama_index.llms.openai import OpenAI
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core import Settings


class StepGenerator:
    def __init__(self, tools_strs: List[str],  llm: Optional[Any] = None, ):
        self.state = ""
        self.tools_strs = tools_strs  # Store available functions
        self.initial_reasoning_template = INITIAL_REASONING_PROMPT_TEMPLATE
        self.next_step_reasoning_template = NEXT_STEP_REASONING_PROMPT_TEMPLATE
        
        if llm is None:
            self.llm = OpenAI(temperature=0, model="gpt-4o")
        else:
            self.llm = llm

        
    def select_prompt(self, query: str) -> str:
        """
        Selects the appropriate prompt based on the current state.

        Args:
            query (str): The question or input for which reasoning is being generated.

        Returns:
            str: The formatted reasoning prompt.
        """
        if not self.state:
            # Use the initial reasoning template
            reasoning_prompt = self.initial_reasoning_template.format(
                functions="\n".join(self.tools_strs),
                question=query
            )
        else:
            # Use the next step reasoning template
            reasoning_prompt = self.next_step_reasoning_template.format(
                functions="\n".join(self.tools_strs),
                question=query,
                function_call_output=self.state
            )

        return reasoning_prompt
    
    async def generate_step(self, query: str):
        """
        Generates the reasoning step by selecting the prompt and executing the reasoning.

        Args:
            gpt4o (OpenAI): The OpenAI instance for processing the reasoning prompt.
            query (str): The question or input for which reasoning is being generated.

        Returns:
            str: The response content from the reasoning step.
        """
        # Select the appropriate reasoning prompt
        reasoning_prompt = self.select_prompt(query)
        reasoning_message = ChatMessage(role="user", content=reasoning_prompt)

        # Run the reasoning prompt
        response = await self.llm.achat([reasoning_message])
        
        solution = response.message.content
        
        import re
        func_calls = re.findall(r"\[FUNC (\w+)\((.*?)\) = (\w+)\]", solution)
        func_call = func_calls[-1]
        
        # Update state based on the response content
        self.state = solution
        return func_call


In [5]:
from typing import Any, List, Tuple
from llama_index.core.tools.types import AsyncBaseTool
from llama_index.core.workflow import Workflow, Context, Event, StartEvent, StopEvent, step
import os
from llama_index.core.tools import FunctionTool


class FunctionCallEvent(Event):
    func_call: Tuple[str, str, str]  # Function name, raw inputs, output placeholder

class ValidateFunctionCallEvent(Event):
    input_data: List[Any]
    output_placeholder: str
    tool_output: Any

class InitializeEvent(Event):
    """Event for initializing the workflow context."""
    pass




In [6]:
# workflow.py
from llama_index.core.workflow import Context, StartEvent, Workflow, StopEvent, step
from agents.coa_agent.validator.relevancy_eval import GPT4RelevancyEvaluator
from typing import List, Any
from agents.coa.tools_handler.tool_retriver import ToolRetriever


class ChainOfAbstractionSteps:
    def __init__(self):
        self.validator = GPT4RelevancyEvaluator(relevancy_eval_prompt_template, prometheus_relevancy_refine_prompt_template)
        self.tools_retriever = None
        
        
    @step(pass_context=True)
    async def initialize_step(self, ctx: Context, ev: StartEvent | InitializeEvent) -> InitializeEvent | FunctionCallEvent | StopEvent:
        """Initialize the workflow context."""
        if isinstance(ev, StartEvent):
            if ev.get("query") is None:
                return StopEvent(result={"message": "Please provide query"})
            
            ctx.data["query"] = ev.get("query")
            self.tools_retriever = ToolRetriever(ev.tools)
            # Perform initial setup, e.g., load tools, validators, etc.
            retrieved_tools = self.tools_retriever.prepare_tools(ctx.data["query"])

            ctx.data["results"] = {}
            
            self.step_generator = StepGenerator(tools_strs=retrieved_tools["tools_strs"] or [])
            ctx.data["tools_by_name"] = retrieved_tools["tools_by_name"]
            ctx.data["function_calls"] = []
            ctx.data["iteration"] = 0
            ctx.data["accumulated_sources"] = 0

            
            return InitializeEvent()

        # After initialization, check if there are function calls to process
        if ctx.data["function_calls"]:
            first_func_call = ctx.data["function_calls"][0]
            return FunctionCallEvent(func_call=first_func_call)
        else:
            # Generate the first reasoning step
            first_reasoning = await self.step_generator.generate_step(ctx.data["query"])
            if not first_reasoning:
                return StopEvent({"message": "No valid reasoning could be generated.", "results": ctx.data["results"]})

            return FunctionCallEvent(func_call=first_reasoning)
        
        
            # return StopEvent({"message": "No function calls to process.", "results": ctx.data["results"]})
            
    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> StopEvent:
        """Execute a function call and prepare for validation."""
        func_name, raw_inputs, output_placeholder = ev.func_call
        input_data = self._prepare_inputs(ctx, raw_inputs)

        try:
            print(f"Execute the function {ctx.data['tools_by_name'][func_name]} with inputs: {', '.join(map(str, input_data))}")
            tool_output = await ctx.data["tools_by_name"][func_name].acall(*input_data)
            
            # print(f"Executed {func_name} with inputs {input_data} -> {output_placeholder}: {tool_output} \n")
            
            
        except Exception as e:
            error_message = f"Error in {func_name} with inputs {input_data}: {str(e)}"
            print(error_message) 
            return self._handle_error(error_message)

        
        # if self.validator:
        #     return ValidateFunctionCallEvent(
        #         input_data=input_data + [tool_output.raw_output],
        #         output_placeholder=output_placeholder,
        #         tool_output=tool_output
        #     )


        ctx.data["results"][output_placeholder] = tool_output.content
        return StopEvent({"message": "No function calls to process.", "results": ctx.data["results"]})
        # return await self._move_to_next_function(ctx)

    def _prepare_inputs(self, ctx: Context, raw_inputs: str) -> List[Any]:
        """Parse and prepare function inputs."""
        results = ctx.data["results"]
        input_data = []
        for raw_input in raw_inputs.split(","):
            raw_input = raw_input.strip()
            try:
                input_data.append(int(results.get(raw_input, raw_input)))
            except ValueError:
                input_data.append(raw_input)  # Handle non-integer inputs
        return input_data
    

    def _handle_error(self, message: str) -> StopEvent:
        """Handle errors gracefully and provide feedback."""
        return StopEvent({"message": message, "results": {}})
            
            


In [7]:
class ChainOfAbstractionWorkflow(Workflow):

    steps = ChainOfAbstractionSteps()

    @step(pass_context=True)
    async def initialize_step(self, ctx: Context, ev: StartEvent | InitializeEvent) -> InitializeEvent | FunctionCallEvent | StopEvent:
        """Initialize the workflow context and decide initial reasoning."""
        return await self.steps.initialize_step(ctx, ev)
    
    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> StopEvent:
        return await self.steps.function_call_step(ctx, ev)

In [8]:
# main.py
import asyncio
from typing import Dict
from pydantic import BaseModel, Field
from agents.pdf_reader_agent.workflow import ConciergeWorkflow as DietConsultantAgent, ConciergeWorkflow as NutritionConsultantAgent


class DietQuery(BaseModel):
    query: str = Field(description="A question or query related to diet.")

async def consult_diet_async(query: str) -> Dict[str, str]:
    concierge = DietConsultantAgent(timeout=180, verbose=True)
    result = await concierge.run(query=query, collection_name="pdf-diet-docs")
    return result

def consult_diet(query: str) -> Dict[str, object]:
    input_data = DietQuery(query=query)
    # Run the asynchronous function
    return asyncio.run(consult_diet_async(input_data.query))


class NutritionQuery(BaseModel):
    query: str = Field(description="A question or query related to nutrition.")

async def consult_nutrition_async(query: str) -> Dict[str, str]:
    concierge = NutritionConsultantAgent(timeout=180, verbose=True)
    result = await concierge.run(query=query, collection_name="pdf-nutrition-docs")
    return result

def consult_nutrition(query: str) -> Dict[str, object]:
    # Validate the input using the updated Pydantic model
    input_data = NutritionQuery(query=query)
    # Run the asynchronous function
    return asyncio.run(consult_nutrition_async(input_data.query))


diet_tool = FunctionTool.from_defaults(
        fn=consult_diet,
        name="consult_diet",
        description="Consults on diet based on a query.",
    )

nutrition_tool = FunctionTool.from_defaults(
    fn=consult_nutrition,
    name="consult_nutrition",
    description="Provides nutritional information based on a query.",
)

tools = [diet_tool, nutrition_tool]

In [9]:
w = ChainOfAbstractionWorkflow(timeout=100)
await w.run(query="Tell me more about Indo-mediterranean diet", tools=tools)

('consult_diet', '"Indo-Mediterranean diet"', 'y1')
Execute the function <llama_index.core.tools.function_tool.FunctionTool object at 0x29e4508d0> with inputs: "Indo-Mediterranean diet"
Running step concierge
Step concierge produced event InitializeEvent
Running step initialize
Step initialize produced event ConciergeEvent
Running step concierge
Step concierge produced event OrchestratorEvent
Running step orchestrator
Step orchestrator produced event QueryEvent
Running step query_index
Step query_index produced event StopEvent


{'message': 'No function calls to process.',
 'results': {'y1': '{\'query_result\': \'The "Indo-Mediterranean diet" is a dietary pattern used in a randomized trial involving Indian patients with pre-existing coronary heart disease or high cardiovascular risk. This diet is rich in whole grains, fruits, vegetables, walnuts, almonds, and oils such as mustard or soybean oil, which are high in alpha-linolenic acid. Compared to a control group following a step I National Cholesterol Education Program (NCEP) diet, patients on the "Indo-Mediterranean diet" experienced approximately a 60% reduction in the rate of cardiovascular death and about a 50% reduction in the risk of non-fatal myocardial infarction.\', \'source_node\': [\'Int. J. Environ. Res. Public Health 2019, 16, 942 6 of 16 (the Mediterranean Diet Adherence Screener, or MeDiet score) and was found to be inversely associated with the rate of cardiovascular events [36]. Other analyses on the population of the PREDIMED study further sh

In [None]:


@step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> ValidateFunctionCallEvent | StopEvent:
        """Execute a function call and prepare for validation."""
        func_name, raw_inputs, output_placeholder = ev.func_call
        input_data = self._prepare_inputs(ctx, raw_inputs)

        try:
            tool_output = await ctx.data["tools_by_name"][func_name].acall(*input_data)
            
            # print(f"Executed {func_name} with inputs {input_data} -> {output_placeholder}: {tool_output} \n")
        except Exception as e:
            error_message = f"Error in {func_name} with inputs {input_data}: {str(e)}"
            print(error_message)  # Log error for debugging
            return self._handle_error(error_message)

        
        if self.validator:
            return ValidateFunctionCallEvent(
                input_data=input_data + [tool_output.raw_output],
                output_placeholder=output_placeholder,
                tool_output=tool_output
            )


        ctx.data["results"][output_placeholder] = tool_output.content
        return await self._move_to_next_function(ctx)

    def _prepare_inputs(self, ctx: Context, raw_inputs: str) -> List[Any]:
        """Parse and prepare function inputs."""
        results = ctx.data["results"]
        input_data = []
        for raw_input in raw_inputs.split(","):
            raw_input = raw_input.strip()
            try:
                input_data.append(int(results.get(raw_input, raw_input)))
            except ValueError:
                input_data.append(raw_input)  # Handle non-integer inputs gracefully
        return input_data

    @step(pass_context=True)
    async def validate_function_step(self, ctx: Context, ev: ValidateFunctionCallEvent) -> FunctionCallEvent | StopEvent:
        """Validate function output and decide on the next step."""
        try:
            validator_output = self.validator.evaluate_sources(*ev.input_data)
            # print(f"Validation result: {validator_output} | Tool output: {ev.tool_output}")
        except Exception as e:
            return self._handle_error(f"Validation error: {e}")

        if validator_output:
            
            source_nodes, length_nodes = validator_output
            
            if length_nodes > 0:
                ctx.data["results"][ev.output_placeholder] = str(source_nodes)
                # Accumulate source if validation is successful
                ctx.data["accumulated_sources"] += length_nodes

                # Check if the maximum number of sources has been accumulated
                if ctx.data["accumulated_sources"] >= self.max_num_sources:
                    return StopEvent({"message": "Maximum number of sources accumulated.", "results": ctx.data["results"]})

            return await self._move_to_next_function(ctx)

        return self._handle_error("Tool output does not match the expected validation result")

    async def _move_to_next_function(self, ctx: Context) -> FunctionCallEvent | StopEvent:
        """Move to the next function call if available."""
        iteration = ctx.data["iteration"]
        function_calls = ctx.data["function_calls"]

        if iteration + 1 < len(function_calls):
            ctx.data["iteration"] += 1
            next_func_call = function_calls[ctx.data["iteration"]]
            return FunctionCallEvent(func_call=next_func_call)

        # End of the workflow, produce a StopEvent
        return StopEvent({"message": "All function calls processed.", "results": ctx.data["results"]})

    def _handle_error(self, message: str) -> StopEvent:
        """Handle errors gracefully and provide feedback."""
        return StopEvent({"message": message, "results": {}})

In [None]:
class ChainOfAbstractionWorkflow(Workflow):

    steps = ChainOfAbstractionSteps()

    def __init__(self):
        super().__init__()
        self.step_generator = StepGenerator()

    @step(pass_context=True)
    async def initialize_step(self, ctx: Context, ev: StartEvent | InitializeEvent) -> InitializeEvent | FunctionCallEvent | StopEvent:
        """Initialize the workflow context and decide initial reasoning."""
        return await self.steps.initialize_step(ctx, ev)

    @step(pass_context=True)
    async def initial_reasoning_step(self, ctx: Context, ev: FunctionCallEvent) -> FunctionCallEvent | StopEvent:
        """Generate initial reasoning to decide the first function call."""
        return await self.steps.initial_reasoning_step(ctx, ev)

    @step(pass_context=True)
    async def function_call_step(self, ctx: Context, ev: FunctionCallEvent) -> ValidateFunctionCallEvent | StopEvent:
        """Execute the function based on the current plan."""
        return await self.steps.function_call_step(ctx, ev)

    @step(pass_context=True)
    async def validate_function_step(self, ctx: Context, ev: ValidateFunctionCallEvent) -> EvaluateFunctionEvent | StopEvent:
        """Validate the function output and prepare for evaluation."""
        return await self.steps.validate_function_step(ctx, ev)

    @step(pass_context=True)
    async def evaluate_function_step(self, ctx: Context, ev: EvaluateFunctionEvent) -> StepGeneratorEvent | StopEvent:
        """Evaluate the validated function output to decide the next step."""
        evaluation_result = await self.steps.evaluate_function_step(ctx, ev)
        return StepGeneratorEvent(evaluation_result=evaluation_result)

    @step(pass_context=True)
    async def step_generator(self, ctx: Context, ev: StepGeneratorEvent) -> FunctionCallEvent | StopEvent:
        """Generate the next step based on the evaluation result."""
        return await self.step_generator.generate_next_steps(ctx, ev)