Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Sep 11, 2023
1 parent 5f3cd3c commit eb52875
Show file tree
Hide file tree
Showing 28 changed files with 1,625 additions and 166 deletions.
72 changes: 45 additions & 27 deletions rasa/cdu/commands/cancel_flow_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasa.cdu.stack.frames.flow_frame import BaseFlowStackFrame
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT
from rasa.shared.core.events import Event, SlotSet
from rasa.shared.core.flows.flow import FlowsList
from rasa.shared.core.flows.flow import Flow, FlowsList
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.cdu.stack.utils import top_user_flow_frame

Expand All @@ -36,6 +36,37 @@ def from_dict(cls, data: Dict[str, Any]) -> CancelFlowCommand:
"""
return CancelFlowCommand()

@staticmethod
def select_canceled_frames(stack: DialogueStack, current_flow: Flow) -> List[str]:
"""Selects the frames that were canceled.
Args:
dialogue_stack: The dialogue stack.
current_flow: The current flow.
Returns:
The frames that were canceled."""
canceled_frames = []
# we need to go through the original stack dump in reverse order
# to find the frames that were canceled. we cancel everthing from
# the top of the stack until we hit the user flow that was canceled.
# this will also cancel any patterns put ontop of that user flow,
# e.g. corrections.
for frame in reversed(stack.frames):
canceled_frames.append(frame.frame_id)
if (
isinstance(frame, BaseFlowStackFrame)
and frame.flow_id == current_flow.id
):
return canceled_frames
else:
# we should never get here as we should always find the user flow
# that was canceled.
raise ValueError(
f"Could not find the user flow '{current_flow.id}' "
f"on the stack. Current stack: {stack}."
)

def run_command_on_tracker(
self,
tracker: DialogueStateTracker,
Expand All @@ -52,40 +83,27 @@ def run_command_on_tracker(
Returns:
The events to apply to the tracker.
"""
original_dialogue_stack = DialogueStack.from_tracker(original_tracker)

dialogue_stack = DialogueStack.from_tracker(tracker)
current_user_frame = top_user_flow_frame(dialogue_stack)
current_top_flow = (
current_user_frame.flow(all_flows) if current_user_frame else None
)
if not current_top_flow:
stack = DialogueStack.from_tracker(tracker)
original_stack = DialogueStack.from_tracker(original_tracker)
user_frame = top_user_flow_frame(original_stack)
current_flow = user_frame.flow(all_flows) if user_frame else None

if not current_flow:
structlogger.debug(
"command_executor.skip_cancel_flow.no_active_flow", command=self
)
return []

canceled_frames = []
# we need to go through the original stack dump in reverse order
# to find the frames that were canceled. we cancel everthing from
# the top of the stack until we hit the user flow that was canceled.
# this will also cancel any patterns put ontop of that user flow,
# e.g. corrections.
for frame in reversed(original_dialogue_stack.frames):
canceled_frames.append(frame.frame_id)
if (
current_user_frame
and isinstance(frame, BaseFlowStackFrame)
and frame.flow_id == current_user_frame.flow_id
):
break
# we pass in the original dialogue stack (before any of the currently
# predicted commands were applied) to make sure we don't cancel any
# frames that were added by the currently predicted commands.
canceled_frames = self.select_canceled_frames(original_stack, current_flow)

dialogue_stack.push(
stack.push(
CancelPatternFlowStackFrame(
canceled_name=current_user_frame.flow(all_flows).readable_name()
if current_user_frame
else None,
canceled_name=current_flow.readable_name(),
canceled_frames=canceled_frames,
)
)
return [SlotSet(DIALOGUE_STACK_SLOT, dialogue_stack.as_dict())]
return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())]
8 changes: 4 additions & 4 deletions rasa/cdu/commands/chit_chat_answer_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List
from rasa.cdu.commands import FreeFormAnswerCommand
from rasa.cdu.stack.dialogue_stack import DialogueStack
from rasa.cdu.stack.frames.chitchat_frame import ChitChatStackFrame
from rasa.cdu.stack.frames.chit_chat_frame import ChitChatStackFrame
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT
from rasa.shared.core.events import Event, SlotSet
from rasa.shared.core.flows.flow import FlowsList
Expand Down Expand Up @@ -45,6 +45,6 @@ def run_command_on_tracker(
Returns:
The events to apply to the tracker.
"""
dialogue_stack = DialogueStack.from_tracker(tracker)
dialogue_stack.push(ChitChatStackFrame())
return [SlotSet(DIALOGUE_STACK_SLOT, dialogue_stack.as_dict())]
stack = DialogueStack.from_tracker(tracker)
stack.push(ChitChatStackFrame())
return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())]
13 changes: 9 additions & 4 deletions rasa/cdu/commands/clarify_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def from_dict(cls, data: Dict[str, Any]) -> ClarifyCommand:
Returns:
The converted dictionary.
"""
return ClarifyCommand(options=data["options"])
try:
return ClarifyCommand(options=data["options"])
except KeyError as e:
raise ValueError(
f"Missing parameter '{e}' while parsing ClarifyCommand."
) from e

def run_command_on_tracker(
self,
Expand Down Expand Up @@ -67,8 +72,8 @@ def run_command_on_tracker(
)
return []

dialogue_stack = DialogueStack.from_tracker(tracker)
stack = DialogueStack.from_tracker(tracker)
relevant_flows = [all_flows.flow_by_id(opt) for opt in clean_options]
names = [flow.readable_name() for flow in relevant_flows if flow is not None]
dialogue_stack.push(ClarifyPatternFlowStackFrame(names=names))
return [SlotSet(DIALOGUE_STACK_SLOT, dialogue_stack.as_dict())]
stack.push(ClarifyPatternFlowStackFrame(names=names))
return [SlotSet(DIALOGUE_STACK_SLOT, stack.as_dict())]

0 comments on commit eb52875

Please sign in to comment.