Skip to content

Commit

Permalink
Add support for predicate evaluation as a slot validation method (#12802
Browse files Browse the repository at this point in the history
)

* implement changes

* update default flow

* implement expected behaviour after invalid values given twice

* fix CI unit tests after adding new default slot

* revert default slot, implement review suggestion

* some fixes

* amend if check to step id comparison

* adapt to list format of checks, add new default action for running predicates

* make changes for re-ask continuous loop until valid outcome

* fix tests

* refactor slot rejection to a dataclass, rename default action

* refactor action_run_slot_rejections

* add tests for default action

* address review comments
  • Loading branch information
ancalita committed Sep 22, 2023
1 parent 28edf5c commit 618fc43
Show file tree
Hide file tree
Showing 15 changed files with 811 additions and 31 deletions.
4 changes: 4 additions & 0 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
from rasa.dialogue_understanding.patterns.correction import ActionCorrectFlowSlot
from rasa.dialogue_understanding.patterns.cancel import ActionCancelFlow
from rasa.dialogue_understanding.patterns.clarify import ActionClarifyFlows
from rasa.core.actions.action_run_slot_rejections import (
ActionRunSlotRejections,
)

return [
ActionListen(),
Expand All @@ -118,6 +121,7 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
ActionCancelFlow(),
ActionCorrectFlowSlot(),
ActionClarifyFlows(),
ActionRunSlotRejections(),
]


Expand Down
131 changes: 131 additions & 0 deletions rasa/core/actions/action_run_slot_rejections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Text

import structlog
from jinja2 import Template
from pypred import Predicate

from rasa.core.actions.action import Action, create_bot_utterance
from rasa.dialogue_understanding.patterns.collect_information import (
CollectInformationPatternFlowStackFrame,
)
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.shared.core.constants import ACTION_RUN_SLOT_REJECTIONS_NAME
from rasa.shared.core.events import Event, SlotSet

if TYPE_CHECKING:
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.channels.channel import OutputChannel
from rasa.shared.core.domain import Domain
from rasa.shared.core.trackers import DialogueStateTracker

structlogger = structlog.get_logger()


class ActionRunSlotRejections(Action):
"""Action which evaluates the predicate checks under rejections."""

def name(self) -> Text:
"""Return the name of the action."""
return ACTION_RUN_SLOT_REJECTIONS_NAME

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Run the predicate checks."""
events: List[Event] = []
violation = False
utterance = None
internal_error = False

dialogue_stack = DialogueStack.from_tracker(tracker)
top_frame = dialogue_stack.top()
if not isinstance(top_frame, CollectInformationPatternFlowStackFrame):
return []

if not top_frame.rejections:
return []

slot_name = top_frame.collect_information
slot_instance = tracker.slots.get(slot_name)
if slot_instance and not slot_instance.has_been_set:
# this is the first time the assistant asks for the slot value,
# therefore we skip the predicate validation because the slot
# value has not been provided
structlogger.debug(
"first.collect.slot.not.set",
slot_name=slot_name,
slot_value=slot_instance.value,
)
return []

slot_value = tracker.get_slot(slot_name)

current_context = dialogue_stack.current_context()
current_context[slot_name] = slot_value

structlogger.debug("run.predicate.context", context=current_context)
document = current_context.copy()

for rejection in top_frame.rejections:
condition = rejection.if_
utterance = rejection.utter

try:
rendered_template = Template(condition).render(current_context)
predicate = Predicate(rendered_template)
violation = predicate.evaluate(document)
structlogger.debug(
"run.predicate.result",
predicate=predicate.description(),
violation=violation,
)
except (TypeError, Exception) as e:
structlogger.error(
"run.predicate.error",
predicate=condition,
document=document,
error=str(e),
)
violation = True
internal_error = True

if violation:
break

if not violation:
return []

# reset slot value that was initially filled with an invalid value
events.append(SlotSet(top_frame.collect_information, None))

if internal_error:
utterance = "utter_internal_error_rasa"

if not isinstance(utterance, str):
structlogger.error(
"run.rejection.missing.utter",
utterance=utterance,
)
return events

message = await nlg.generate(
utterance,
tracker,
output_channel.name(),
)

if message is None:
structlogger.error(
"run.rejection.failed.finding.utter",
utterance=utterance,
)
else:
message["utter_action"] = utterance
events.append(create_bot_utterance(message))

return events
34 changes: 23 additions & 11 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
IfFlowLink,
EntryPromptFlowStep,
CollectInformationScope,
SlotRejection,
StepThatCanStartAFlow,
UserMessageStep,
LinkFlowStep,
Expand Down Expand Up @@ -171,6 +172,7 @@ def predict_action_probabilities(
domain: The model's domain.
rule_only_data: Slots and loops which are specific to rules and hence
should be ignored by this policy.
flows: The flows to use.
**kwargs: Depending on the specified `needs` section and the resulting
graph structure the policy can use different input to make predictions.
Expand Down Expand Up @@ -208,7 +210,7 @@ def _create_prediction_result(
domain: The model's domain.
score: The score of the predicted action.
Resturns:
Returns:
The prediction result where the score is used for one hot encoding.
"""
result = self._default_predictions(domain)
Expand Down Expand Up @@ -242,8 +244,9 @@ def __init__(
"""Initializes the `FlowExecutor`.
Args:
dialogue_stack_frame: State of the flow.
dialogue_stack: State of the flow.
all_flows: All flows.
domain: The domain.
"""
self.dialogue_stack = dialogue_stack
self.all_flows = all_flows
Expand All @@ -258,6 +261,7 @@ def from_tracker(
Args:
tracker: The tracker to create the `FlowExecutor` from.
flows: The flows to use.
domain: The domain to use.
Returns:
The created `FlowExecutor`.
Expand All @@ -270,7 +274,6 @@ def find_startable_flow(self, tracker: DialogueStateTracker) -> Optional[Flow]:
Args:
tracker: The tracker containing the conversation history up to now.
flows: The flows to use.
Returns:
The predicted action and the events to run.
Expand All @@ -296,7 +299,7 @@ def is_condition_satisfied(
) -> bool:
"""Evaluate a predicate condition."""

# attach context to the predicate evaluation to allow coditions using it
# attach context to the predicate evaluation to allow conditions using it
context = {"context": DialogueStack.from_tracker(tracker).current_context()}
document: Dict[str, Any] = context.copy()
for slot in self.domain.slots:
Expand Down Expand Up @@ -371,7 +374,7 @@ def render_template_variables(text: str, context: Dict[Text, Any]) -> str:
return Template(text).render(context)

def _slot_for_collect_information(self, collect_information: Text) -> Slot:
"""Find the slot for a collect information."""
"""Find the slot for the collect information step."""
for slot in self.domain.slots:
if slot.name == collect_information:
return slot
Expand Down Expand Up @@ -415,7 +418,6 @@ def advance_flows(self, tracker: DialogueStateTracker) -> ActionPrediction:
Args:
tracker: The tracker to get the next action for.
domain: The domain to get the next action for.
Returns:
The predicted action and the events to run.
Expand Down Expand Up @@ -456,7 +458,6 @@ def _select_next_action(
Args:
tracker: The tracker to get the next action for.
domain: The domain to get the next action for.
Returns:
The next action to execute, the events that should be applied to the
Expand Down Expand Up @@ -552,11 +553,14 @@ def _run_step(
"""
if isinstance(step, CollectInformationFlowStep):
structlogger.debug("flow.step.run.collect_information")
self.trigger_pattern_ask_collect_information(step.collect_information)
self.trigger_pattern_ask_collect_information(
step.collect_information, step.rejections
)

# reset the slot if its already filled and the collect infomation shouldn't
# reset the slot if its already filled and the collect information shouldn't
# be skipped
slot = tracker.slots.get(step.collect_information, None)

if slot and slot.has_been_set and step.ask_before_filling:
events = [SlotSet(step.collect_information, slot.initial_value)]
else:
Expand All @@ -567,8 +571,10 @@ def _run_step(
elif isinstance(step, ActionFlowStep):
if not step.action:
raise FlowException(f"Action not specified for step {step}")

context = {"context": self.dialogue_stack.current_context()}
action_name = self.render_template_variables(step.action, context)

if action_name in self.domain.action_names_or_texts:
structlogger.debug("flow.step.run.action", context=context)
return PauseFlowReturnPrediction(ActionPrediction(action_name, 1.0))
Expand Down Expand Up @@ -676,10 +682,16 @@ def trigger_pattern_completed(self, current_frame: DialogueStackFrame) -> None:
)
)

def trigger_pattern_ask_collect_information(self, collect_information: str) -> None:
def trigger_pattern_ask_collect_information(
self,
collect_information: str,
rejections: List[SlotRejection],
) -> None:
"""Trigger the pattern to ask for a slot value."""
self.dialogue_stack.push(
CollectInformationPatternFlowStackFrame(
collect_information=collect_information
collect_information=collect_information,
rejections=rejections,
)
)

Expand Down
15 changes: 14 additions & 1 deletion rasa/dialogue_understanding/patterns/collect_information.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStackFrame
from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
from rasa.shared.core.flows.flow import SlotRejection

FLOW_PATTERN_COLLECT_INFORMATION = (
RASA_DEFAULT_FLOW_PATTERN_PREFIX + "ask_collect_information"
Expand All @@ -20,6 +21,11 @@ class CollectInformationPatternFlowStackFrame(PatternFlowStackFrame):
collect_information: str = ""
"""The information that should be collected from the user.
this corresponds to the slot that will be filled."""
rejections: Optional[List[SlotRejection]] = None
"""The predicate check that should be applied to the collected information.
If a predicate check fails, its `utter` action indicated under rejections
will be executed.
"""

@classmethod
def type(cls) -> str:
Expand All @@ -36,10 +42,17 @@ def from_dict(data: Dict[str, Any]) -> CollectInformationPatternFlowStackFrame:
Returns:
The created `DialogueStackFrame`.
"""
rejections = data.get("rejections")
if rejections is not None:
rejections = [
SlotRejection.from_dict(rejection) for rejection in rejections
]

return CollectInformationPatternFlowStackFrame(
data["frame_id"],
step_id=data["step_id"],
collect_information=data["collect_information"],
rejections=rejections,
)

def context_as_dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ flows:
description: flow used to fill a slot
steps:
- id: "start"
action: action_extract_slots
action: action_run_slot_rejections
next: "validate"
- id: "validate"
action: validate_{{context.collect_information}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@ def as_dict(self) -> Dict[str, Any]:

def custom_asdict_factory(fields: List[Tuple[str, Any]]) -> Dict[str, Any]:
"""Converts enum values to their value."""

def rename_internal(field_name: str) -> str:
return field_name[:-1] if field_name.endswith("_") else field_name

return {
field: value.value if isinstance(value, Enum) else value
rename_internal(field): value.value
if isinstance(value, Enum)
else value
for field, value in fields
}

Expand Down
2 changes: 2 additions & 0 deletions rasa/shared/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ACTION_CANCEL_FLOW = "action_cancel_flow"
ACTION_CLARIFY_FLOWS = "action_clarify_flows"
ACTION_CORRECT_FLOW_SLOT = "action_correct_flow_slot"
ACTION_RUN_SLOT_REJECTIONS_NAME = "action_run_slot_rejections"


DEFAULT_ACTION_NAMES = [
Expand All @@ -60,6 +61,7 @@
ACTION_CANCEL_FLOW,
ACTION_CORRECT_FLOW_SLOT,
ACTION_CLARIFY_FLOWS,
ACTION_RUN_SLOT_REJECTIONS_NAME,
]

ACTION_SHOULD_SEND_DOMAIN = "send_domain"
Expand Down
11 changes: 7 additions & 4 deletions rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@
import rasa.shared.utils.common
import rasa.shared.core.slot_mappings
from rasa.shared.core.events import SlotSet, UserUttered
from rasa.shared.core.slots import Slot, CategoricalSlot, TextSlot, AnySlot, ListSlot
from rasa.shared.core.slots import (
Slot,
CategoricalSlot,
TextSlot,
AnySlot,
ListSlot,
)
from rasa.shared.utils.validation import KEY_TRAINING_DATA_FORMAT_VERSION
from rasa.shared.nlu.constants import (
ENTITY_ATTRIBUTE_TYPE,
Expand Down Expand Up @@ -969,9 +975,6 @@ def _add_categorical_slot_default_value(self) -> None:
def _add_flow_slots(self) -> None:
"""Adds the slots needed for the conversation flows.
Add a slot called `flow_step_slot` to the list of slots. The value of
this slot will hold the name of the id of the next step in the flow.
Add a slot called `dialogue_stack_slot` to the list of slots. The value of
this slot will be a call stack of the flow ids.
"""
Expand Down

0 comments on commit 618fc43

Please sign in to comment.