<a href="https://colab.research.google.com/github/vatsalagarwal09/GenAI/blob/main/UsingReactAgentForReportGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [79]:
!pip install colorama



In [80]:
import google.generativeai as genai
from pprint import pprint
from IPython.display import display_markdown
from google.colab import userdata

In [81]:
from getpass import getpass

gemini_key = getpass("Enter Gemini Key")

Enter Gemini Key··········


In [82]:
import os
os.environ["GOOGLE_API_KEY"] = gemini_key

In [83]:
!pip install tavily-python



In [84]:
# @title
import json
import re
from dataclasses import dataclass
from typing import Callable


def get_fn_signature(fn: Callable) -> dict:
    """
    Generates the signature for a given function.

    Args:
        fn (Callable): The function whose signature needs to be extracted.

    Returns:
        dict: A dictionary containing the function's name, description,
              and parameter types.
    """
    fn_signature: dict = {
        "name": fn.__name__,
        "description": fn.__doc__,
        "parameters": {"properties": {}},
    }
    schema = {
        k: {"type": v.__name__} for k, v in fn.__annotations__.items() if k != "return"
    }
    fn_signature["parameters"]["properties"] = schema
    return fn_signature


def validate_arguments(tool_call: dict, tool_signature: dict) -> dict:
    """
    Validates and converts arguments in the input dictionary to match the expected types.

    Args:
        tool_call (dict): A dictionary containing the arguments passed to the tool.
        tool_signature (dict): The expected function signature and parameter types.

    Returns:
        dict: The tool call dictionary with the arguments converted to the correct types if necessary.
    """
    properties = tool_signature["parameters"]["properties"]

    # TODO: This is overly simplified but enough for simple Tools.
    type_mapping = {
        "int": int,
        "str": str,
        "bool": bool,
        "float": float,
        "list": list
    }

    for arg_name, arg_value in tool_call["arguments"].items():
        expected_type = properties[arg_name].get("type")

        if not isinstance(arg_value, type_mapping[expected_type]):
            tool_call["arguments"][arg_name] = type_mapping[expected_type](arg_value)

    return tool_call


class Tool:
    """
    A class representing a tool that wraps a callable and its signature.

    Attributes:
        name (str): The name of the tool (function).
        fn (Callable): The function that the tool represents.
        fn_signature (str): JSON string representation of the function's signature.
    """

    def __init__(self, name: str, fn: Callable, fn_signature: str):
        self.name = name
        self.fn = fn
        self.fn_signature = fn_signature

    def __str__(self):
        return self.fn_signature

    def run(self, **kwargs):
        """
        Executes the tool (function) with provided arguments.

        Args:
            **kwargs: Keyword arguments passed to the function.

        Returns:
            The result of the function call.
        """
        return self.fn(**kwargs)


def tool(fn: Callable):
    """
    A decorator that wraps a function into a Tool object.

    Args:
        fn (Callable): The function to be wrapped.

    Returns:
        Tool: A Tool object containing the function, its name, and its signature.
    """

    def wrapper():
        fn_signature = get_fn_signature(fn)
        return Tool(
            name=fn_signature.get("name"), fn=fn, fn_signature=json.dumps(fn_signature)
        )

    return wrapper()


@dataclass
class TagContentResult:
    """
    A data class to represent the result of extracting tag content.

    Attributes:
        content (List[str]): A list of strings containing the content found between the specified tags.
        found (bool): A flag indicating whether any content was found for the given tag.
    """

    content: list[str]
    found: bool


def extract_tag_content(text: str, tag: str) -> TagContentResult:
    """
    Extracts all content enclosed by specified tags (e.g., <thought>, <response>, etc.).

    Parameters:
        text (str): The input string containing multiple potential tags.
        tag (str): The name of the tag to search for (e.g., 'thought', 'response').

    Returns:
        dict: A dictionary with the following keys:
            - 'content' (list): A list of strings containing the content found between the specified tags.
            - 'found' (bool): A flag indicating whether any content was found for the given tag.
    """
    # Build the regex pattern dynamically to find multiple occurrences of the tag
    tag_pattern = rf"<{tag}>(.*?)</{tag}>"

    # Use findall to capture all content between the specified tag
    matched_contents = re.findall(tag_pattern, text, re.DOTALL)

    # Return the dataclass instance with the result
    return TagContentResult(
        content=[content.strip() for content in matched_contents],
        found=bool(matched_contents),
    )


In [85]:
import os
from tavily import TavilyClient

# IMPORTANT: Replace "YOUR_API_KEY" with your actual Tavily API key.
# It's best practice to use environment variables for sensitive information.
# If you are in Google Colab, you can also store it in a secrets manager
# or directly as a string for this example.
TAVILY_API_KEY = "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"

# Ensure the API key is set before proceeding
if not TAVILY_API_KEY or TAVILY_API_KEY == "YOUR_API_KEY":
    print("Error: Please replace 'YOUR_API_KEY' with your actual Tavily API key.")
else:
    # Initialize the Tavily client with your API key.
    # The client will use the key to authenticate your requests.
    try:
        tavily = TavilyClient(api_key=TAVILY_API_KEY)
    except ValueError as e:
        print(f"Error initializing Tavily client: {e}")
        tavily = None

@tool
def get_research_response(questions: list[str]) -> str:
    """
    Performs a web search for each question using the Tavily API
    and prints the title and content of the results.

    Args:
        questions (list): A list of strings, where each string is a search query.
    """
    if not tavily:
        print("Tavily client not initialized. Aborting search.")
        return {}

    all_results = {}

    # Loop through each question provided in the list.
    for question in questions:
        print(f"===========================================================")
        print(f"Searching for: '{question}'...")
        print(f"===========================================================")
        question_results = []

        try:
            # Perform the search. The `query` is the search string.
            # We set `max_results` to limit the number of results returned.
            response = tavily.search(query=question, max_results=1)

            # Check if the search response contains any results.
            if 'results' in response and response['results']:
                # Iterate through each result in the search response.
                for i, result in enumerate(response['results']):
                    # Extract the title and content (snippet) of the search result.
                    title = result.get('title', 'No Title Available')
                    content = result.get('content', 'No Content Available')

                    # Create a dictionary for the current result and add it to the list
                    question_results.append({
                        "title": title,
                        "content": content
                    })

                    print(f"Result {i+1}:")
                    print(f"  Title: {title}")
                    print(f"  Content: {content}\n")

                all_results[question] = question_results

            else:
                print("No results found for this query.")
        except Exception as e:
            print(f"An error occurred during the search for '{question}': {e}")

        print("\n")
    return all_results



In [86]:
import re
import time

from colorama import Fore
from colorama import Style

from dataclasses import dataclass


def completions_create(client, messages: list, model: str) -> str:
    """
    Sends a request to the client's `completions.create` method to interact with the language model.

    Args:
        client (OpenAI): The OpenAI client object
        messages (list[dict]): A list of message objects containing chat history for the model.
        model (str): The model to use for generating tool calls and responses.

    Returns:
        str: The content of the model's response.
    """
    response = model.generate_content(messages)
    if not response.candidates:
        return "Model did not return any candidates."
    output = response.candidates[0].content.parts[0].text
    return str(output)


def build_prompt_structure(prompt: str, role: str, tag: str = "") -> dict:
    """
    Builds a structured prompt that includes the role and content.

    Args:
        prompt (str): The actual content of the prompt.
        role (str): The role of the speaker (e.g., user, assistant).

    Returns:
        dict: A dictionary representing the structured prompt.
    """
    if tag:
        prompt = f"<{tag}>{prompt}</{tag}>"
    return {"role": role, "parts": [{"text": prompt}]}

def update_chat_history(history: list, msg: str, role: str):
    """
    Updates the chat history by appending the latest response.

    Args:
        history (list): The list representing the current chat history.
        msg (str): The message to append.
        role (str): The role type (e.g. 'user', 'assistant', 'system')
    """
    history.append(build_prompt_structure(prompt=msg, role=role))


class ChatHistory(list):
    def __init__(self, messages: list | None = None, total_length: int = -1):
        """Initialise the queue with a fixed total length.

        Args:
            messages (list | None): A list of initial messages
            total_length (int): The maximum number of messages the chat history can hold.
        """
        if messages is None:
            messages = []

        super().__init__(messages)
        self.total_length = total_length

    def append(self, msg: str):
        """Add a message to the queue.

        Args:
            msg (str): The message to be added to the queue
        """
        if len(self) == self.total_length:
            self.pop(0)
        super().append(msg)



class FixedFirstChatHistory(ChatHistory):
    def __init__(self, messages: list | None = None, total_length: int = -1):
        """Initialise the queue with a fixed total length.

        Args:
            messages (list | None): A list of initial messages
            total_length (int): The maximum number of messages the chat history can hold.
        """
        super().__init__(messages, total_length)

    def append(self, msg: str):
        """Add a message to the queue. The first messaage will always stay fixed.

        Args:
            msg (str): The message to be added to the queue
        """
        if len(self) == self.total_length:
            self.pop(1)
        super().append(msg)

def fancy_print(message: str) -> None:
    """
    Displays a fancy print message.

    Args:
        message (str): The message to display.
    """
    print(Style.BRIGHT + Fore.CYAN + f"\n{'=' * 50}")
    print(Fore.MAGENTA + f"{message}")
    print(Style.BRIGHT + Fore.CYAN + f"{'=' * 50}\n")
    time.sleep(0.5)


def fancy_step_tracker(step: int, total_steps: int) -> None:
    """
    Displays a fancy step tracker for each iteration of the generation-reflection loop.

    Args:
        step (int): The current step in the loop.
        total_steps (int): The total number of steps in the loop.
    """
    fancy_print(f"STEP {step + 1}/{total_steps}")


@dataclass
class TagContentResult:
    """
    A data class to represent the result of extracting tag content.

    Attributes:
        content (List[str]): A list of strings containing the content found between the specified tags.
        found (bool): A flag indicating whether any content was found for the given tag.
    """

    content: list[str]
    found: bool


def extract_tag_content(text: str, tag: str) -> TagContentResult:
    """
    Extracts all content enclosed by specified tags (e.g., <thought>, <response>, etc.).

    Parameters:
        text (str): The input string containing multiple potential tags.
        tag (str): The name of the tag to search for (e.g., 'thought', 'response').

    Returns:
        dict: A dictionary with the following keys:
            - 'content' (list): A list of strings containing the content found between the specified tags.
            - 'found' (bool): A flag indicating whether any content was found for the given tag.
    """
    # Build the regex pattern dynamically to find multiple occurrences of the tag
    tag_pattern = rf"<{tag}>(.*?)</{tag}>"

    # Use findall to capture all content between the specified tag
    matched_contents = re.findall(tag_pattern, text, re.DOTALL)

    # Return the dataclass instance with the result
    return TagContentResult(
        content=[content.strip() for content in matched_contents],
        found=bool(matched_contents),
    )

In [88]:
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field

class ToolResponse(BaseModel):
  Title: str = Field(description="Title from the tool execution result")
  Content: str = Field(description="Content from the tool execution result")


class SectionResponse(BaseModel):
  Question: str = Field(description="Each question generated during thinking stage for further deep dive")
  Tool: list[ToolResponse] = Field(description="Response from the tool")

# Define your desired data structure - like a python data class.
class ReviewAnalysisResponse(BaseModel):
    ProjectTitle: str = Field(description="Title for the Research Report")
    Introduction: str = Field(description="A brief introduction for the report - max 3 points")
    Section: list[SectionResponse] = Field(description="Response for each field")
    Conclusion: str = Field(description="Final conclusion of the complete report")

# Set up a parser + inject instructions into the prompt template.
parser = PydanticOutputParser(pydantic_object=ReviewAnalysisResponse)

In [113]:
def getReportResponse(response : str):
                try:
                    stripped_json_input = response.strip()

                    if stripped_json_input.startswith('```json'):
                        # Slice the string to remove the first 7 characters ('```json\n')
                        # The .lstrip() at the start handles any whitespace before the markers
                        stripped_json_input = stripped_json_input[7:].lstrip()

                    # Check for and remove the ending marker
                    if stripped_json_input.endswith('```'):
                        # Slice the string to remove the last 3 characters ('```')
                        stripped_json_input = stripped_json_input[:-3].rstrip()

                    if json.loads(stripped_json_input):
                        return {"stripped_json_input": stripped_json_input}
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")
                    raise e

In [115]:
import json
import re

from colorama import Fore


BASE_SYSTEM_PROMPT = ""


REACT_SYSTEM_PROMPT = """
You are provided with a topic to do research on in the following stages : Thinking, Action and Result Generation.
For the purpose of research you will start with the Thinking Stage. During thinking, you have to figure out 2 relevant questions for that topic, that you can dive deep further.

Once, you have generated 2 questions for the topic, you start with the Action Stage.
In this stage, you can use the tools present under the <tool></tool> tag and get the response from it. Pass the list of questions in the tool present as the parameter.

For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:

<tool_call>
{"name": <function-name>,"arguments": <args-dict>, "id": <monotonically-increasing-id>}
</tool_call>

Here are the available tools / actions:

<tools>
%s
</tools>

After this, start the Result Generation Stage. In this stage, return a final report in the following format : %s

Example session:

<question>Climate Change</question>
<thought>
1. What causes climate change?
2. What are recent climate trends?
3. What cause global warming?
4. What are government policies on climate change?
5. What are the impacts of climate change?
</thought>
<tool_call>{"name": "get_research_response","arguments": {"questionList": ["What causes climate change?", "What are recent climate trends?", "What cause global warming?",
"What are government policies on climate change?", "What are the impacts of climate change?"]}, "id": 0}</tool_call>

You will be called again with this:

<observation>{0: {"response": "Responses for the question"}}</observation>

You will then output:

<response>
"ProjectTitle": "Cybersecurity Research Report",
"Tntroduction": "This report provides an overview of current cybersecurity landscapes. It identifies prevalent cyber threats and their prevention methods. The report also explores the latest technological advancements and trends in cybersecurity.",
"Section": [
  {
    "Question": "What are the most common types of cybersecurity threats and how can they be prevented?",
    "Tool": [
      {
        "Title": "Top 10 Cyber Security Threats and How to Prevent Them",
        "Content": "Today's most common cyber attacks include phishing, social engineering, malware, ransomware, zero-day vulnerabilities, insider threats, supply chain attacks, denial of service, distributed denial of service, and system intrusion. These often exploit human psychology to trick individuals into revealing sensitive information. Prevention strategies involve educating employees on cybercrime risks, utilizing the latest software and technology for threat recognition, and implementing robust network security controls."
      }
    ]
  },
  {
    "Question": "What are the latest trends and technologies in cybersecurity?",
    "Tool": [
      {
        "Title": "Explore the emerging Cybersecurity Technologies and Trends",
        "Content": "The latest technologies and trends in cybersecurity include Artificial Intelligence (AI) and Machine Learning (ML), Behavioral Biometrics, Zero Trust Architecture, Blockchain, Quantum Computing, Cloud Security, and IoT Security."
      }
    ]
  }
],
"conclusion": "Cybersecurity remains a critical domain, continuously evolving to combat sophisticated threats. Common attacks like phishing and social engineering often leverage human vulnerabilities, emphasizing the importance of user education and strong security protocols. The integration of advanced technologies such as AI/ML, Zero Trust, and blockchain, alongside specialized areas like cloud and IoT security, signifies a proactive shift towards more resilient and intelligent defense mechanisms against emerging cyber risks."
</response>

Additional constraints:

- If the user asks you something unrelated to any of the tools above, answer freely enclosing your answer with <response></response> tags.
"""


class ReactAgent:
    """
    A class that represents an agent using the ReAct logic that interacts with tools to process
    user inputs, make decisions, and execute tool calls. The agent can run interactive sessions,
    collect tool signatures, and process multiple tool calls in a given round of interaction.

    Attributes:
        client (Gemini): The Gemini client used to handle model-based completions.
        model (str): The name of the model used for generating responses. Default is "gpt-4o".
        tools (list[Tool]): A list of Tool instances available for execution.
        tools_dict (dict): A dictionary mapping tool names to their corresponding Tool instances.
    """

    def __init__(
        self,
        tools: Tool | list[Tool],
        model: str = "gemini-2.5-flash",
        system_prompt: str = BASE_SYSTEM_PROMPT,
    ) -> None:
        self.client = genai.configure(api_key=gemini_key)

        self.model = genai.GenerativeModel(model)
        self.system_prompt = system_prompt
        self.tools = tools if isinstance(tools, list) else [tools]
        self.tools_dict = {tool.name: tool for tool in self.tools}

    def add_tool_signatures(self) -> str:
        """
        Collects the function signatures of all available tools.

        Returns:
            str: A concatenated string of all tool function signatures in JSON format.
        """
        return "".join([tool.fn_signature for tool in self.tools])

    def process_tool_calls(self, tool_calls_content: list) -> dict:
        """
        Processes each tool call, validates arguments, executes the tools, and collects results.

        Args:
            tool_calls_content (list): List of strings, each representing a tool call in JSON format.

        Returns:
            dict: A dictionary where the keys are tool call IDs and values are the results from the tools.
        """
        observations = {}
        for tool_call_str in tool_calls_content:
            tool_call = json.loads(tool_call_str)
            tool_name = tool_call["name"]
            tool = self.tools_dict[tool_name]

            print(Fore.GREEN + f"\nUsing Tool: {tool_name}")

            # Validate and execute the tool call
            validated_tool_call = validate_arguments(
                tool_call, json.loads(tool.fn_signature)
            )
            print(Fore.GREEN + f"\nTool call dict: \n{validated_tool_call}")

            result = tool.run(**validated_tool_call["arguments"])
            print(Fore.GREEN + f"\nTool result: \n{result}")

            # Store the result using the tool call ID
            observations[validated_tool_call["id"]] = result

        return observations

    def run(
        self,
        user_msg: str,
        max_rounds: int = 3,
    ) -> str:
        """
        Executes a user interaction session, where the agent processes user input, generates responses,
        handles tool calls, and updates chat history until a final response is ready or the maximum
        number of rounds is reached.

        Args:
            user_msg (str): The user's input message to start the interaction.
            max_rounds (int, optional): Maximum number of interaction rounds the agent should perform. Default is 10.

        Returns:
            str: The final response generated by the agent after processing user input and any tool calls.
        """
        user_prompt = build_prompt_structure(
            prompt=user_msg, role="user", tag="question"
        )

        if self.tools:
            self.system_prompt += (
                "\n" + REACT_SYSTEM_PROMPT % (self.add_tool_signatures(), parser.get_format_instructions())
            )

        print(Fore.GREEN + f"\nSystem prompt: \n{self.system_prompt}")

        chat_history = ChatHistory(
            [
                build_prompt_structure(
                    prompt=self.system_prompt,
                    role="model",
                ),
                user_prompt,
            ]
        )

        if self.tools:
            # Run the ReAct loop for max_rounds
            for _ in range(max_rounds):
                print(Fore.YELLOW + f"\nRoundNumber: {_}")
                completion = completions_create(self.client, chat_history, self.model)

                try:
                  finalResponse = getReportResponse(str(completion))
                  return finalResponse
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")

                thought = extract_tag_content(str(completion), "thought")
                tool_calls = extract_tag_content(str(completion), "tool_call")

                update_chat_history(chat_history, completion, "model")

                if thought.found:
                  print(Fore.MAGENTA + f"\nThought: {thought.content[0]}")

                if tool_calls.found:
                    observations = self.process_tool_calls(tool_calls.content)
                    print(Fore.BLUE + f"\nObservations: {observations}")
                    update_chat_history(chat_history, f"{observations}", "user")

        return completions_create(self.client, chat_history, self.model)


In [116]:
agent = ReactAgent(tools=[get_research_response])

In [117]:
agent.run(user_msg="CyberSecurity")

[32m
System prompt: 


You are provided with a topic to do research on in the following stages : Thinking, Action and Result Generation.
For the purpose of research you will start with the Thinking Stage. During thinking, you have to figure out 2 relevant questions for that topic, that you can dive deep further.

Once, you have generated 2 questions for the topic, you start with the Action Stage.
In this stage, you can use the tools present under the <tool></tool> tag and get the response from it. Pass the list of questions in the tool present as the parameter.

For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:

<tool_call>
{"name": <function-name>,"arguments": <args-dict>, "id": <monotonically-increasing-id>}
</tool_call>

Here are the available tools / actions:

<tools>
{"name": "get_research_response", "description": "\n    Performs a web search for each question using the Tavily API\n    and prints the 

{'stripped_json_input': '{"ProjectTitle": "Cybersecurity Research Report", "Introduction": "This report provides an overview of current cybersecurity landscapes. It identifies prevalent cyber threats and their prevention methods. The report also explores the latest technological advancements and trends in cybersecurity.", "Section": [{"Question": "What are the most common types of cybersecurity threats and how can they be prevented?", "Tool": [{"Title": "Top 10 Cyber Security Threats and How to Prevent Them", "Content": "The following are today’s top 10 most common and impactful cyber attacks: phishing, social engineering, malware, ransomware, zero-day vulnerabilities, insider threats, supply chain attacks, denial of service, distributed denial of service, and system intrusion. These attacks involve criminals exploitinghuman psychologyrather than technical vulnerabilities to trick people into providing them with sensitive information or access to data, networks, and systems.Social engi