<a href="https://colab.research.google.com/github/zahere-dev/openai-agents-sdk-tutorial/blob/main/openai_agents_sdk_guardrails_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Imports**

In [None]:
! pip install openai-agents nest_asyncio colorama


In [2]:
import os
from google.colab import userdata
import asyncio
from agents import (
    Agent,
    function_tool,
    set_default_openai_key,
    set_tracing_disabled,
    GuardrailFunctionOutput,
    InputGuardrailTripwireTriggered,
    RunContextWrapper,
    Runner,
    TResponseInputItem,
    input_guardrail,
    output_guardrail,
    OutputGuardrailTripwireTriggered,
)
from colorama import Fore, Style
from pydantic import BaseModel, Field


set_tracing_disabled(True)
set_default_openai_key(userdata.get("OPENAI_API_KEY"))


## **Output Guardrail**

In [3]:
# The agent's output type
class MessageOutput(BaseModel):
    reasoning: str = Field(description="Thoughts on how to respond to the user's message")
    response: str = Field(description="The response to the user's message")



@output_guardrail
async def sensitive_data_check(
    context: RunContextWrapper, agent: Agent, output: MessageOutput
) -> GuardrailFunctionOutput:
    card_present_in_response = "card" in output.response
    card_present_in_reasoning = "card" in output.reasoning

    return GuardrailFunctionOutput(
        output_info={
            "card_present_in_response": card_present_in_response,
            "card_present_in_reasoning": card_present_in_reasoning,
        },
        tripwire_triggered=card_present_in_response or card_present_in_reasoning,
    )

## **Order Agent**


In [4]:
import nest_asyncio # required for notebooks
nest_asyncio.apply()

@function_tool
def check_order_status(order_id: str):
    """Check the status of an order with the given order ID."""
    order_statuses = {
        "12345": "Your order 12345 is being prepared and will be delivered in 20 minutes.",
        "67890": "Your order 67890 has been dispatched and will arrive in 10 minutes.",
        "11121": "Your order 11121 is still being processed. Please wait a little longer.",
        "12346": "Your order 12346 was processed with card number 1122-1122-3334-5555"
    }
    return order_statuses.get(order_id, "Order ID not found. Please check and try again.")

order_agent = Agent(
    name="OrderAgent",
    instructions=f"Help customers with their order status. If they provide an order ID, fetch the status. Return the output in the specified format {MessageOutput}",
    tools=[check_order_status],
    output_type=MessageOutput,
    output_guardrails=[sensitive_data_check]
)



## **FAQ Agent**


In [5]:
@function_tool
def answer_faq(question: str):
    """Ensure the input is either hours, menu, location, contact, reservation, delivery or allergies """

    faq_responses = {
        "hours": "We are open from 10 AM to 11 PM every day.",
        "menu": "You can find our menu at restaurant.com/menu.",
        "location": "We are located at 123 Main Street, Cityville.",
        "contact": "You can reach us at 555-1234 or email support@restaurant.com.",
        "reservation": "We accept reservations online at restaurant.com/reservations or by calling 555-1234.",
        "delivery": "We offer delivery through our website and on major food delivery platforms like Uber Eats and DoorDash.",
        "allergies": "We accommodate allergies! Please let us know your dietary restrictions when placing an order."
    }
    return faq_responses.get(question.lower(), "I'm not sure, but you can call our helpline at 555-1234.")

faq_agent = Agent(
    name="FAQAgent",
    instructions="Answer common customer questions about hours, menu, and location.\
     Augment the answer based on the tone and details requested in the query \
     Pick up the relevant keyword from the user's query and pass that as input. \
     Example: If user is asking about time then the input keyword is hours.",
    tools=[answer_faq]
)

## **Complaint Handler Agent**


In [6]:
@function_tool
def handle_complaint(complaint: str):
    """Handle customer complaints and ensure respectful communication."""
    return "Thank you for your feedback. We take complaints seriously and will address your concern as soon as possible."

complaint_agent = Agent(
    name="ComplaintAgent",
    instructions="Handle customer complaints and ensure respectful communication.",
    tools=[handle_complaint]
)



## **Reservation Agent**

In [7]:
@function_tool
def handle_reservation(request: str):
    """Ensure the input is either make, modify, cancel or availability """
    reservation_responses = {
        "make": "Your reservation request has been received. Please check your email for confirmation.",
        "modify": "Your reservation modification request has been received. Please check your email for updates.",
        "cancel": "Your reservation has been canceled. We hope to see you another time!",
        "availability": "We have availability for dinner slots from 6 PM to 9 PM. Please book online or call us."
    }
    return reservation_responses.get(request.lower(), "I'm not sure about that request. Please call us at 555-1234 for assistance.")

reservation_agent = Agent(
    name="ReservationAgent",
     instructions="Assist customers with making, modifying, or canceling reservations.\
     Pick up the relevant keyword from the user's query and pass that as input. \
     Example: If user is asking making a reservation then input keyword is make.",
    tools=[handle_reservation]
)

## **Input Guardrail**

In [8]:
class ImpoliteOutput(BaseModel):
    reasoning: str
    is_impolite: bool

guardrail_agent = Agent(
    name="Guardrail check",
    instructions="Check if the user is polite and not using any harsh or abusive words.",
    output_type=ImpoliteOutput,
)

@input_guardrail
async def impolite_guardrail(
    context: RunContextWrapper[None], agent: Agent, input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
    """This is an input guardrail function, which happens to call an agent to check if the input
    is impolite.
    """
    result = await Runner.run(guardrail_agent, input, context=context.context)
    final_output = result.final_output_as(ImpoliteOutput)
    if final_output.is_impolite:
      print(f"Reasoning {final_output.reasoning}")

    return GuardrailFunctionOutput(
        output_info=final_output,
        tripwire_triggered=final_output.is_impolite,
    )


In [9]:
classifier_agent = Agent(
    name="User Interface Agent",
    model="gpt-4o-mini",
    instructions="Handoff to appropriate agent based on user query",
    input_guardrails=[impolite_guardrail],
    handoffs=[order_agent,faq_agent,complaint_agent,reservation_agent]
)


# **Chat**

In [None]:
async def chat():
    print("Welcome to the Restaurant Customer Support chat! Type 'exit' to end the chat.")
    response = ""
    while True:
        user_input = input(Fore.GREEN + "You: " + Style.RESET_ALL)
        if user_input.lower() == "exit":
            print(Fore.RED + "Goodbye!" + Style.RESET_ALL)
            break

        if response:
          input_with_context = response.to_input_list() + [
          {"role": "user", "content": user_input}
          ]
        else:
          input_with_context = [{"role": "user", "content": user_input}]

        try:
          response = await Runner.run(classifier_agent, input=input_with_context)
          print(Fore.BLUE + f"Support Agent: {response.final_output}" + Style.RESET_ALL)
        except InputGuardrailTripwireTriggered as e:
          message = "Sorry, I can't help you. Please be polite."
          print(Fore.RED +f"Input Guardrail tripped. Info: {e.guardrail_result.output.output_info}")
          print(Fore.RED +f"Input Guardrail tripped. Info: {message}")
        except OutputGuardrailTripwireTriggered as e:
          print(Fore.RED +f"Output Guardrail tripped. Info: {e.guardrail_result.output.output_info}")





if __name__ == "__main__":
    asyncio.run(chat())


Welcome to the Restaurant Customer Support chat! Type 'exit' to end the chat.
[32mYou: [0mWhy is my order delayed? You guys are pathetic
Reasoning The user is expressing frustration and has used the term "pathetic" which can be considered impolite or insulting when addressing a customer service issue.
[31mInput Guardrail tripped. Info: reasoning='The user is expressing frustration and has used the term "pathetic" which can be considered impolite or insulting when addressing a customer service issue.' is_impolite=True
[31mInput Guardrail tripped. Info: Sorry, I can't help you. Please be polite.
[32mYou: [0mWhat is the status of order 12346
[31mOutput Guardrail tripped. Info: {'card_present_in_response': False, 'card_present_in_reasoning': True}
