Skip to content

Commit

Permalink
implemented command improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Jul 25, 2023
1 parent 30ff38e commit b55f99a
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 25 deletions.
5 changes: 2 additions & 3 deletions rasa/cdu/command_generator/llm_command_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rasa.cdu.command_generator.base import CommandGenerator
from rasa.cdu.commands import (
Command,
ErrorCommand,
HandleInterruptionCommand,
SetSlotCommand,
CancelFlowCommand,
Expand Down Expand Up @@ -173,9 +174,7 @@ def parse_commands(
) -> List[Command]:
"""Parse the actions returned by the llm into intent and entities."""
if not actions:
# TODO: not quite sure yet how to handle this case - revisit!
# is predicting "no commands" an option?
return []
return [ErrorCommand()]

commands: List[Command] = []

Expand Down
85 changes: 70 additions & 15 deletions rasa/cdu/command_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from rasa.cdu.commands import (
CancelFlowCommand,
Command,
CorrectSlotCommand,
CorrectSlotsCommand,
CorrectedSlot,
ErrorCommand,
HandleInterruptionCommand,
ListenCommand,
SetSlotCommand,
Expand Down Expand Up @@ -32,6 +34,8 @@

FLOW_PATTERN_LISTEN_ID = RASA_DEFAULT_FLOW_PATTERN_PREFIX + "listen"

FLOW_PATTERN_INTERNAL_ERROR_ID = RASA_DEFAULT_FLOW_PATTERN_PREFIX + "internal_error"


def contains_command(commands: List[Command], typ: Type[Command]) -> bool:
"""Check if a list of commands contains a command of a given type.
Expand Down Expand Up @@ -89,6 +93,22 @@ def _get_commands_from_tracker(tracker: DialogueStateTracker) -> List[Command]:
return []


def validate_state_of_commands(commands: List[Command]) -> None:
"""Validates the state of the commands."""
# assert that that at max there is only one cancel flow command at
# the beginning of the list of commands
assert len([c for c in commands if isinstance(c, CancelFlowCommand)]) <= 1

# assert that interrupt commands are only at the beginning of the list
interrupt_commands = [
c for c in commands if isinstance(c, HandleInterruptionCommand)
]
assert interrupt_commands == commands[: len(interrupt_commands)]

# assert that there is at max only one correctslots command
assert len([c for c in commands if isinstance(c, CorrectSlotsCommand)]) <= 1


def execute_commands(
tracker: DialogueStateTracker, all_flows: FlowsList
) -> List[Event]:
Expand All @@ -111,28 +131,29 @@ def execute_commands(
commands = clean_up_commands(commands, tracker, all_flows)

events: List[Event] = []
collected_corrections = []
# TODO: should this really be reversed? 馃

# commands need to be reversed to make sure they end up in the right order
# on the stack. e.g. if there multiple start flow commands, the first one
# should be on top of the stack. this is achieved by reversing the list
# and then pushing the commands onto the stack in the reversed order.
reversed_commands = list(reversed(commands))

validate_state_of_commands(commands)

for i, command in enumerate(reversed_commands):
if isinstance(command, CorrectSlotCommand):
structlogger.debug("command_executor.correct_slot", command=command)
collected_corrections.append(command)
# pulling in all subsequent correction commands into a single correction
if i < (len(reversed_commands) - 1) and \
isinstance(reversed_commands[i+1], CorrectSlotCommand):
continue
for correction in collected_corrections:
if isinstance(command, CorrectSlotsCommand):
structlogger.debug("command_executor.correct_slots", command=command)
for correction in command.corrected_slots:
events.append(SlotSet(correction.name, correction.value))
events.append(SlotSet(CORRECTED_SLOTS_SLOT,
[s.name for s in collected_corrections]))
events.append(
SlotSet(CORRECTED_SLOTS_SLOT, [s.name for s in command.corrected_slots])
)
flow_stack.push(
FlowStackFrame(
flow_id=FLOW_PATTERN_CORRECTION_ID,
frame_type=StackFrameType.CORRECTION,
)
)
collected_corrections = []
elif isinstance(command, SetSlotCommand):
structlogger.debug("command_executor.set_slot", command=command)
events.append(SlotSet(command.name, command.value))
Expand All @@ -149,6 +170,10 @@ def execute_commands(
"command_executor.skip_cancel_flow.no_active_flow", command=command
)
continue
# in between the prediction and this canceling command, we might have
# added some stack frames. hence, we can't just cancle the current top frame
# but need to find the frame that was at the top before we started
# processing the commands.
for idx, frame in enumerate(flow_stack.frames):
if frame.flow_id == current_top_flow.id:
structlogger.debug("command_executor.cancel_flow", command=command)
Expand Down Expand Up @@ -178,6 +203,14 @@ def execute_commands(
frame_type=StackFrameType.DOCSEARCH,
)
)
elif isinstance(command, ErrorCommand):
structlogger.debug("command_executor.error", command=command)
flow_stack.push(
FlowStackFrame(
flow_id=FLOW_PATTERN_INTERNAL_ERROR_ID,
frame_type=StackFrameType.CORRECTION,
)
)

# if the flow stack has changed, persist it in a set slot event
if original_stack_dump != flow_stack.as_dict():
Expand Down Expand Up @@ -263,7 +296,29 @@ def clean_up_commands(
"command_executor.convert_command.correction", command=command
)

clean_commands.append(CorrectSlotCommand(command.name, command.value))
corrected_slot = CorrectedSlot(command.name, command.value)
for c in clean_commands:
if isinstance(c, CorrectSlotsCommand):
c.corrected_slots.append(corrected_slot)
break
else:
clean_commands.append(
CorrectSlotsCommand(corrected_slots=[corrected_slot])
)

elif isinstance(command, CancelFlowCommand) and contains_command(
clean_commands, CancelFlowCommand
):
structlogger.debug(
"command_executor.skip_command.already_cancelled_flow", command=command
)
continue
elif isinstance(command, HandleInterruptionCommand):
structlogger.debug(
"command_executor.prepend_command.handle_interruption", command=command
)
clean_commands.insert(0, command)
continue
else:
clean_commands.append(command)

Expand Down
26 changes: 23 additions & 3 deletions rasa/cdu/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ def command_from_json(data: Dict[str, Any]) -> "Command":
if data.get("command") == "set slot":
return SetSlotCommand(name=data["name"], value=data["value"])
elif data.get("command") == "correct slot":
return CorrectSlotCommand(name=data["name"], value=data["value"])
return CorrectSlotsCommand(
corrected_slots=[
CorrectedSlot(s["name"], value=s["value"])
for s in data["corrected_slots"]
]
)
elif data.get("command") == "start flow":
return StartFlowCommand(flow=data["flow"])
elif data.get("command") == "cancel flow":
Expand Down Expand Up @@ -53,11 +58,19 @@ class SetSlotCommand(Command):


@dataclass
class CorrectSlotCommand(Command):
"""A command to correct the value of a slot."""
class CorrectedSlot:
"""A slot that was corrected."""

name: str
value: Any


@dataclass
class CorrectSlotsCommand(Command):
"""A command to correct the value of a slot."""

corrected_slots: List[CorrectedSlot]

command: str = "correct slot"


Expand Down Expand Up @@ -110,3 +123,10 @@ class HumanHandoffCommand(Command):
"""A command to indicate that the bot should handoff to a human."""

command: str = "human handoff"


@dataclass
class ErrorCommand(Command):
"""A command to indicate that the bot failed to handle the dialogue."""

command: str = "error"
8 changes: 4 additions & 4 deletions rasa/core/policies/default_flows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ responses:
utter_too_complex_rasa:
- text: This is too complex for me, let's take it one step at a time.

utter_openai_error_rasa:
utter_internal_error_rasa:
- text: Sorry, I'm having trouble understanding you right now. Please try again later.

slots:
Expand Down Expand Up @@ -97,8 +97,8 @@ flows:
- id: "0"
action: action_listen

pattern_openai_error:
description: openai error
pattern_internal_error:
description: internal error
steps:
- id: "0"
action: utter_openai_error_rasa
action: utter_internal_error_rasa

0 comments on commit b55f99a

Please sign in to comment.