Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions openai_agents/customer_service/customer_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations as _annotations

from typing import Dict, Tuple

from agents import Agent, RunContextWrapper, function_tool, handoff
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
from pydantic import BaseModel
Expand All @@ -23,19 +25,20 @@ class AirlineAgentContext(BaseModel):
description_override="Lookup frequently asked questions.",
)
async def faq_lookup_tool(question: str) -> str:
if "bag" in question or "baggage" in question:
question_lower = question.lower()
if "bag" in question_lower or "baggage" in question_lower:
return (
"You are allowed to bring one bag on the plane. "
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
)
elif "seats" in question or "plane" in question:
elif "seats" in question_lower or "plane" in question_lower:
return (
"There are 120 seats on the plane. "
"There are 22 business class seats and 98 economy seats. "
"Exit rows are rows 4 and 16. "
"Rows 5-8 are Economy Plus, with extra legroom. "
)
elif "wifi" in question:
elif "wifi" in question_lower:
return "We have free wifi on the plane, join Airline-Wifi"
return "I'm sorry, I don't know the answer to that question."

Expand Down Expand Up @@ -74,7 +77,9 @@ async def on_seat_booking_handoff(
### AGENTS


def init_agents() -> Agent[AirlineAgentContext]:
def init_agents() -> Tuple[
Agent[AirlineAgentContext], Dict[str, Agent[AirlineAgentContext]]
]:
"""
Initialize the agents for the airline customer service workflow.
:return: triage agent
Expand Down Expand Up @@ -121,7 +126,9 @@ def init_agents() -> Agent[AirlineAgentContext]:

faq_agent.handoffs.append(triage_agent)
seat_booking_agent.handoffs.append(triage_agent)
return triage_agent
return triage_agent, {
agent.name: agent for agent in [faq_agent, seat_booking_agent, triage_agent]
}


class ProcessUserMessageInput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def main():
CustomerServiceWorkflow.process_user_message, message_input
)
history.extend(new_history)
print(*new_history, sep="\n")
print(*new_history[1:], sep="\n")
except WorkflowUpdateFailedError:
print("** Stale conversation. Reloading...")
length = len(history)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations as _annotations

from agents import (
Agent,
HandoffCallItem,
HandoffOutputItem,
ItemHelpers,
MessageOutputItem,
Expand All @@ -12,6 +12,7 @@
TResponseInputItem,
trace,
)
from pydantic import dataclasses
from temporalio import workflow

from openai_agents.customer_service.customer_service import (
Expand All @@ -21,32 +22,65 @@
)


@dataclasses.dataclass
class CustomerServiceWorkflowState:
printed_history: list[str]
current_agent_name: str
context: AirlineAgentContext
input_items: list[TResponseInputItem]


@workflow.defn
class CustomerServiceWorkflow:
@workflow.init
def __init__(self, input_items: list[TResponseInputItem] | None = None):
def __init__(
self, customer_service_state: CustomerServiceWorkflowState | None = None
):
self.run_config = RunConfig()
self.chat_history: list[str] = []
self.current_agent: Agent[AirlineAgentContext] = init_agents()
self.context = AirlineAgentContext()
self.input_items = [] if input_items is None else input_items

starting_agent, self.agent_map = init_agents()
self.current_agent = (
self.agent_map[customer_service_state.current_agent_name]
if customer_service_state
else starting_agent
)
self.context = (
customer_service_state.context
if customer_service_state
else AirlineAgentContext()
)
self.printed_history: list[str] = (
customer_service_state.printed_history if customer_service_state else []
)
self.input_items = (
customer_service_state.input_items if customer_service_state else []
)

@workflow.run
async def run(self, input_items: list[TResponseInputItem] | None = None):
async def run(
self, customer_service_state: CustomerServiceWorkflowState | None = None
):
await workflow.wait_condition(
lambda: workflow.info().is_continue_as_new_suggested()
and workflow.all_handlers_finished()
)
workflow.continue_as_new(self.input_items)
workflow.continue_as_new(
CustomerServiceWorkflowState(
printed_history=self.printed_history,
current_agent_name=self.current_agent.name,
context=self.context,
input_items=self.input_items,
)
)

@workflow.query
def get_chat_history(self) -> list[str]:
return self.chat_history
return self.printed_history

@workflow.update
async def process_user_message(self, input: ProcessUserMessageInput) -> list[str]:
length = len(self.chat_history)
self.chat_history.append(f"User: {input.user_input}")
length = len(self.printed_history)
self.printed_history.append(f"User: {input.user_input}")
with trace("Customer service", group_id=workflow.info().workflow_id):
self.input_items.append({"content": input.user_input, "role": "user"})
result = await Runner.run(
Expand All @@ -59,33 +93,38 @@ async def process_user_message(self, input: ProcessUserMessageInput) -> list[str
for new_item in result.new_items:
agent_name = new_item.agent.name
if isinstance(new_item, MessageOutputItem):
self.chat_history.append(
self.printed_history.append(
f"{agent_name}: {ItemHelpers.text_message_output(new_item)}"
)
elif isinstance(new_item, HandoffOutputItem):
self.chat_history.append(
self.printed_history.append(
f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}"
)
elif isinstance(new_item, HandoffCallItem):
self.printed_history.append(
f"{agent_name}: Handed off to tool {new_item.raw_item.name}"
)
elif isinstance(new_item, ToolCallItem):
self.chat_history.append(f"{agent_name}: Calling a tool")
self.printed_history.append(f"{agent_name}: Calling a tool")
elif isinstance(new_item, ToolCallOutputItem):
self.chat_history.append(
self.printed_history.append(
f"{agent_name}: Tool call output: {new_item.output}"
)
else:
self.chat_history.append(
self.printed_history.append(
f"{agent_name}: Skipping item: {new_item.__class__.__name__}"
)
self.input_items = result.to_input_list()
self.current_agent = result.last_agent
workflow.set_current_details("\n\n".join(self.chat_history))
return self.chat_history[length:]
workflow.set_current_details("\n\n".join(self.printed_history))

return self.printed_history[length:]

@process_user_message.validator
def validate_process_user_message(self, input: ProcessUserMessageInput) -> None:
if not input.user_input:
raise ValueError("User input cannot be empty.")
if len(input.user_input) > 1000:
raise ValueError("User input is too long. Please limit to 1000 characters.")
if input.chat_length != len(self.chat_history):
if input.chat_length != len(self.printed_history):
raise ValueError("Stale chat history. Please refresh the chat.")
Loading