Skip to content

Commit

Permalink
command improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Aug 20, 2023
1 parent 9ab70f1 commit fc63ed9
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 182 deletions.
120 changes: 72 additions & 48 deletions rasa/cdu/command_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Set, Type
from typing import List, Optional, Set, Type

import structlog
from rasa.cdu.commands import (
Expand Down Expand Up @@ -28,12 +28,16 @@
StackFrameType,
)
from rasa.shared.core.constants import (
CANCELLED_FLOW_SLOT,
CORRECTED_SLOTS_SLOT,
FLOW_STACK_SLOT,
)
from rasa.shared.core.events import Event, SlotSet
from rasa.shared.core.flows.flow import END_STEP, FlowsList, QuestionFlowStep
from rasa.shared.core.flows.flow import (
END_STEP,
Flow,
FlowStep,
FlowsList,
QuestionFlowStep,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.nlu.constants import COMMANDS

Expand Down Expand Up @@ -77,8 +81,8 @@ def _get_commands_from_tracker(tracker: DialogueStateTracker) -> List[Command]:

def validate_state_of_commands(commands: List[Command]) -> None:
"""Validates the state of the commands."""
# assert that only the last command can be a cancel flow command
assert not any(isinstance(c, CancelFlowCommand) for c in commands[:-1])
# assert that there is only at max one cancel flow command
assert sum(isinstance(c, CancelFlowCommand) for c in commands) <= 1

# assert that free form answer commands are only at the beginning of the list
free_form_answer_commands = [
Expand All @@ -87,7 +91,7 @@ def validate_state_of_commands(commands: List[Command]) -> None:
assert free_form_answer_commands == commands[: len(free_form_answer_commands)]

# assert that there is at max only one correctslots command
assert len([c for c in commands if isinstance(c, CorrectSlotsCommand)]) <= 1
assert sum(isinstance(c, CorrectSlotsCommand) for c in commands) <= 1


def execute_commands(
Expand All @@ -107,6 +111,8 @@ def execute_commands(
flow_stack = FlowStack.from_tracker(tracker)
original_stack_dump = flow_stack.as_dict()

user_step, user_flow = flow_stack.topmost_user_frame(all_flows)

current_top_flow = flow_stack.top_flow(all_flows)

commands = clean_up_commands(commands, tracker, all_flows)
Expand All @@ -124,30 +130,34 @@ def execute_commands(
for command in reversed_commands:
if isinstance(command, CorrectSlotsCommand):
structlogger.debug("command_executor.correct_slots", command=command)
for correction in command.corrected_slots:
events.append(SlotSet(correction.name, correction.value))
events.append(
SlotSet(CORRECTED_SLOTS_SLOT, [s.name for s in command.corrected_slots])
proposed_slots = {c.name: c.value for c in command.corrected_slots}

reset_step = _find_earliest_updated_question(
user_step, user_flow, proposed_slots
)
context = {
"corrected_slots": proposed_slots,
"corrected_reset_point": {
"id": user_flow.id,
"step_id": reset_step.id if reset_step else None,
},
}
correction_frame = FlowStackFrame(
flow_id=FLOW_PATTERN_CORRECTION_ID,
frame_type=StackFrameType.CORRECTION,
context=context,
)
if (
not current_top_flow
or current_top_flow.id != FLOW_PATTERN_CORRECTION_ID
):
flow_stack.push(
FlowStackFrame(
flow_id=FLOW_PATTERN_CORRECTION_ID,
frame_type=StackFrameType.CORRECTION,
)
)
flow_stack.push(correction_frame)
else:
# wrap up the previous correction flow
flow_stack.frames[-1].step_id = END_STEP
# push a new correction flow
flow_stack.push(
FlowStackFrame(
flow_id=FLOW_PATTERN_CORRECTION_ID,
frame_type=StackFrameType.CORRECTION,
),
correction_frame,
# we allow the previous correction to finish first before
# starting the new one
index=-1,
Expand All @@ -168,30 +178,26 @@ def execute_commands(
"command_executor.skip_cancel_flow.no_active_flow", command=command
)
continue
# this logic only works if the cancel flow command is the first
# command in the list of commands. there is validation that
# ensures this. if the command is not the last command, another
# command might have added stackframes that unintentionally will
# be canceled now
canceled_user_flow = None
for frame in reversed(flow_stack.frames):
structlogger.debug(
"command_executor.cancel_flow", command=command, frame=frame
)

canceled_frames = []
original_frames = FlowStack.from_dict(original_stack_dump).frames
for i, frame in enumerate(reversed(original_frames)):
# Setting the stack frame to the end step so it is properly
# wrapped up by the flow policy
frame.step_id = END_STEP
if frame.frame_type in STACK_FRAME_TYPES_WITH_USER_FLOWS:
# as soon as we hit the first stack frame that is a "normal"
# user defined flow we cancel that and leave the remaining
# stack untouched
canceled_user_flow = frame.flow_id
canceled_frames.append(len(original_frames) - i - 1)
if frame.flow_id == user_flow.id:
break
events.append(SlotSet(CANCELLED_FLOW_SLOT, canceled_user_flow))

flow_stack.push(
FlowStackFrame(
flow_id=FLOW_PATTERN_CANCEl_ID,
frame_type=StackFrameType.REMARK,
context={
"canceled_name": user_flow.readable_name()
if user_flow
else None,
"canceled_frames": canceled_frames,
},
)
)
elif isinstance(command, ListenCommand):
Expand Down Expand Up @@ -232,6 +238,20 @@ def execute_commands(
return events


def _find_earliest_updated_question(
current_step: Optional[FlowStep], flow: Optional[Flow], updated_slots: List[str]
) -> Optional[FlowStep]:
"""Find the question that was updated."""
if not flow or not current_step:
return None
asked_question_steps = flow.previously_asked_questions(current_step.id)

for question_step in reversed(asked_question_steps):
if question_step.question in updated_slots:
return question_step
return None


def filled_slots_for_active_flow(
tracker: DialogueStateTracker, all_flows: FlowsList
) -> Set[str]:
Expand Down Expand Up @@ -328,6 +348,20 @@ def clean_up_commands(
structlogger.debug(
"command_executor.convert_command.correction", command=command
)
if (top := flow_stack.top()) and top.context:
already_corrected_slots = top.context.get("corrected_slots")
else:
already_corrected_slots = {}

if (
command.name in already_corrected_slots
and already_corrected_slots[command.name] == command.value
):
structlogger.debug(
"command_executor.skip_command.slot_already_corrected",
command=command,
)
continue

corrected_slot = CorrectedSlot(command.name, command.value)
for c in clean_commands:
Expand All @@ -351,14 +385,4 @@ def clean_up_commands(
clean_commands.insert(0, command)
else:
clean_commands.append(command)

# check if there is a cancel flow command in the list of commands
# if so, push it to the end of the list
for clean_command in clean_commands:
if isinstance(clean_command, CancelFlowCommand):
# push the command to the end
clean_commands.remove(clean_command)
clean_commands.append(clean_command)
break

return clean_commands
24 changes: 22 additions & 2 deletions rasa/cdu/flow_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Text, List, Optional
from typing import Any, Dict, Text, List, Optional, Tuple

from rasa.shared.core.constants import (
FLOW_STACK_SLOT,
Expand Down Expand Up @@ -113,6 +113,21 @@ def top_flow(self, flows: FlowsList) -> Optional[Flow]:

return flows.flow_by_id(top.flow_id)

def topmost_user_frame(
self, flows: FlowsList
) -> Tuple[Optional[FlowStep], Optional[Flow]]:
"""Returns the topmost user frame from the stack.
Returns:
The topmost user frame.
"""
for frame in reversed(self.frames):
if frame.frame_type in STACK_FRAME_TYPES_WITH_USER_FLOWS:
flow = flows.flow_by_id(frame.flow_id)
return flow.step_by_id(frame.step_id), flow

return None, None

def top_flow_step(self, flows: FlowsList) -> Optional[FlowStep]:
"""Get the current flow step.
Expand Down Expand Up @@ -219,6 +234,8 @@ class FlowStackFrame:
"""The ID of the current step."""
frame_type: StackFrameType = StackFrameType.REGULAR
"""The type of the frame. Defaults to `StackFrameType.REGULAR`."""
context: Optional[Dict[Text, Any]] = None
"""The context of the frame. Defaults to `None`."""

@staticmethod
def from_dict(data: Dict[Text, Any]) -> FlowStackFrame:
Expand All @@ -234,6 +251,7 @@ def from_dict(data: Dict[Text, Any]) -> FlowStackFrame:
data["flow_id"],
data["step_id"],
StackFrameType.from_str(data.get("frame_type")),
data["context"],
)

def as_dict(self) -> Dict[Text, Any]:
Expand All @@ -246,6 +264,7 @@ def as_dict(self) -> Dict[Text, Any]:
"flow_id": self.flow_id,
"step_id": self.step_id,
"frame_type": self.frame_type.value,
"context": self.context,
}

def with_updated_id(self, step_id: Text) -> FlowStackFrame:
Expand All @@ -263,5 +282,6 @@ def __repr__(self) -> Text:
return (
f"FlowState(flow_id: {self.flow_id}, "
f"step_id: {self.step_id}, "
f"frame_type: {self.frame_type.value})"
f"frame_type: {self.frame_type.value}, "
f"context: {self.context})"
)
3 changes: 3 additions & 0 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["Action"]:
"""List default actions."""
from rasa.core.actions.two_stage_fallback import TwoStageFallbackAction
from rasa.core.actions.flows import ActionCancelFlow, ActionCorrectFlowSlot

return [
ActionListen(),
Expand All @@ -112,6 +113,8 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
ActionSendText(),
ActionBack(),
ActionExtractSlots(action_endpoint),
ActionCancelFlow(),
ActionCorrectFlowSlot(),
]


Expand Down

0 comments on commit fc63ed9

Please sign in to comment.