-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
30 changed files
with
1,060 additions
and
583 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from rasa.cdu.commands.command import Command | ||
from rasa.cdu.commands.free_form_answer_command import FreeFormAnswerCommand | ||
from rasa.cdu.commands.cancel_flow_command import CancelFlowCommand | ||
from rasa.cdu.commands.knowledge_answer_command import KnowledgeAnswerCommand | ||
from rasa.cdu.commands.chit_chat_answer_command import ChitChatAnswerCommand | ||
from rasa.cdu.commands.cant_handle_command import CantHandleCommand | ||
from rasa.cdu.commands.clarify_command import ClarifyCommand | ||
from rasa.cdu.commands.error_command import ErrorCommand | ||
from rasa.cdu.commands.set_slot_command import SetSlotCommand | ||
from rasa.cdu.commands.start_flow_command import StartFlowCommand | ||
from rasa.cdu.commands.human_handoff_command import HumanHandoffCommand | ||
from rasa.cdu.commands.correct_slots_command import CorrectSlotsCommand, CorrectedSlot | ||
|
||
|
||
__all__ = [ | ||
"Command", | ||
"FreeFormAnswerCommand", | ||
"CancelFlowCommand", | ||
"KnowledgeAnswerCommand", | ||
"ChitChatAnswerCommand", | ||
"CantHandleCommand", | ||
"ClarifyCommand", | ||
"ErrorCommand", | ||
"SetSlotCommand", | ||
"StartFlowCommand", | ||
"HumanHandoffCommand", | ||
"CorrectSlotsCommand", | ||
"CorrectedSlot", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
|
||
import structlog | ||
|
||
from rasa.cdu.commands import Command | ||
from rasa.cdu.patterns.cancel import CancelPatternFlowStackFrame | ||
from rasa.cdu.stack.dialogue_stack import DialogueStack | ||
from rasa.cdu.stack.frames.flow_frame import BaseFlowStackFrame | ||
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT | ||
from rasa.shared.core.events import Event, SlotSet | ||
from rasa.shared.core.flows.flow import FlowsList | ||
from rasa.shared.core.trackers import DialogueStateTracker | ||
from rasa.cdu.stack.utils import top_user_flow_frame | ||
|
||
structlogger = structlog.get_logger() | ||
|
||
|
||
@dataclass | ||
class CancelFlowCommand(Command): | ||
"""A command to cancel the current flow.""" | ||
|
||
@classmethod | ||
def command(cls) -> str: | ||
"""Returns the command type.""" | ||
return "cancel flow" | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> CancelFlowCommand: | ||
"""Converts the dictionary to a command. | ||
Returns: | ||
The converted dictionary. | ||
""" | ||
return CancelFlowCommand() | ||
|
||
def run_command_on_tracker( | ||
self, | ||
tracker: DialogueStateTracker, | ||
all_flows: FlowsList, | ||
original_tracker: DialogueStateTracker, | ||
) -> List[Event]: | ||
"""Runs the command on the tracker. | ||
Args: | ||
tracker: The tracker to run the command on. | ||
all_flows: All flows in the assistant. | ||
original_tracker: The tracker before any command was executed. | ||
Returns: | ||
The events to apply to the tracker. | ||
""" | ||
original_dialogue_stack = DialogueStack.from_tracker(original_tracker) | ||
original_stack_dump = original_dialogue_stack.as_dict() | ||
|
||
dialogue_stack = DialogueStack.from_tracker(tracker) | ||
current_user_frame = top_user_flow_frame(dialogue_stack) | ||
current_top_flow = ( | ||
current_user_frame.flow(all_flows) if current_user_frame else None | ||
) | ||
if not current_top_flow: | ||
structlogger.debug( | ||
"command_executor.skip_cancel_flow.no_active_flow", command=self | ||
) | ||
return [] | ||
|
||
canceled_frames = [] | ||
original_frames = DialogueStack.from_dict(original_stack_dump).frames | ||
# we need to go through the original stack dump in reverse order | ||
# to find the frames that were canceled. we cancel everthing from | ||
# the top of the stack until we hit the user flow that was canceled. | ||
# this will also cancel any patterns put ontop of that user flow, | ||
# e.g. corrections. | ||
for frame in reversed(original_frames): | ||
canceled_frames.append(frame.frame_id) | ||
if ( | ||
current_user_frame | ||
and isinstance(frame, BaseFlowStackFrame) | ||
and frame.flow_id == current_user_frame.flow_id | ||
): | ||
break | ||
|
||
dialogue_stack.push( | ||
CancelPatternFlowStackFrame( | ||
canceled_name=current_user_frame.flow(all_flows).readable_name() | ||
if current_user_frame | ||
else None, | ||
canceled_frames=canceled_frames, | ||
) | ||
) | ||
return [SlotSet(DIALOGUE_STACK_SLOT, dialogue_stack.as_dict())] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
from rasa.cdu.commands import Command | ||
from rasa.shared.core.events import Event | ||
from rasa.shared.core.flows.flow import FlowsList | ||
from rasa.shared.core.trackers import DialogueStateTracker | ||
|
||
|
||
@dataclass | ||
class CantHandleCommand(Command): | ||
"""A command to indicate that the bot can't handle the user's input.""" | ||
|
||
@classmethod | ||
def command(cls) -> str: | ||
"""Returns the command type.""" | ||
return "cant handle" | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> CantHandleCommand: | ||
"""Converts the dictionary to a command. | ||
Returns: | ||
The converted dictionary. | ||
""" | ||
return CantHandleCommand() | ||
|
||
def run_command_on_tracker( | ||
self, | ||
tracker: DialogueStateTracker, | ||
all_flows: FlowsList, | ||
original_tracker: DialogueStateTracker, | ||
) -> List[Event]: | ||
"""Runs the command on the tracker. | ||
Args: | ||
tracker: The tracker to run the command on. | ||
all_flows: All flows in the assistant. | ||
original_tracker: The tracker before any command was executed. | ||
Returns: | ||
The events to apply to the tracker. | ||
""" | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
from rasa.cdu.commands import FreeFormAnswerCommand | ||
from rasa.cdu.stack.dialogue_stack import DialogueStack | ||
from rasa.cdu.stack.frames.chitchat_frame import ChitChatStackFrame | ||
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT | ||
from rasa.shared.core.events import Event, SlotSet | ||
from rasa.shared.core.flows.flow import FlowsList | ||
from rasa.shared.core.trackers import DialogueStateTracker | ||
|
||
|
||
@dataclass | ||
class ChitChatAnswerCommand(FreeFormAnswerCommand): | ||
"""A command to indicate a chitchat style free-form answer by the bot.""" | ||
|
||
@classmethod | ||
def command(cls) -> str: | ||
"""Returns the command type.""" | ||
return "chitchat" | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> ChitChatAnswerCommand: | ||
"""Converts the dictionary to a command. | ||
Returns: | ||
The converted dictionary. | ||
""" | ||
return ChitChatAnswerCommand() | ||
|
||
def run_command_on_tracker( | ||
self, | ||
tracker: DialogueStateTracker, | ||
all_flows: FlowsList, | ||
original_tracker: DialogueStateTracker, | ||
) -> List[Event]: | ||
"""Runs the command on the tracker. | ||
Args: | ||
tracker: The tracker to run the command on. | ||
all_flows: All flows in the assistant. | ||
original_tracker: The tracker before any command was executed. | ||
Returns: | ||
The events to apply to the tracker. | ||
""" | ||
dialogue_stack = DialogueStack.from_tracker(tracker) | ||
dialogue_stack.push(ChitChatStackFrame()) | ||
return [SlotSet(DIALOGUE_STACK_SLOT, dialogue_stack.as_dict())] |
Oops, something went wrong.