Skip to content

Commit

Permalink
removed branch flow step and fixed predicate validation, added test
Browse files Browse the repository at this point in the history
  • Loading branch information
twerkmeister committed Oct 26, 2023
1 parent c5d8e51 commit cb80889
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 72 deletions.
9 changes: 4 additions & 5 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
UserMessageStep,
)
from rasa.shared.core.flows.steps.link import LinkFlowStep
from rasa.shared.core.flows.steps.branch import BranchFlowStep
from rasa.shared.core.flows.steps.action import ActionFlowStep
from rasa.shared.core.flows.flow import Flow
from rasa.shared.core.flows.flows_list import FlowsList
Expand Down Expand Up @@ -678,10 +677,6 @@ def run_step(
structlogger.debug("flow.step.run.user_message")
return ContinueFlowWithNextStep()

elif isinstance(step, BranchFlowStep):
structlogger.debug("flow.step.run.branch")
return ContinueFlowWithNextStep()

elif isinstance(step, GenerateResponseFlowStep):
structlogger.debug("flow.step.run.generate_response")
generated = step.generate(tracker)
Expand All @@ -702,6 +697,10 @@ def run_step(
reset_events = self._reset_scoped_slots(flow, tracker)
return ContinueFlowWithNextStep(events=reset_events)

elif isinstance(step, FlowStep):
structlogger.debug("flow.step.run.base_flow_step")
return ContinueFlowWithNextStep()

else:
raise FlowException(f"Unknown flow step type {type(step)}")

Expand Down
7 changes: 3 additions & 4 deletions rasa/shared/core/flows/flow_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def step_from_json(data: Dict[Text, Any]) -> FlowStep:
LinkFlowStep,
SetSlotsFlowStep,
GenerateResponseFlowStep,
BranchFlowStep,
)

if "action" in data:
Expand All @@ -50,7 +49,7 @@ def step_from_json(data: Dict[Text, Any]) -> FlowStep:
if "generation_prompt" in data:
return GenerateResponseFlowStep.from_json(data)
else:
return BranchFlowStep.from_json(data)
return FlowStep.from_json(data)


@dataclass
Expand All @@ -69,7 +68,7 @@ class FlowStep:
"""The next steps of the flow step."""

@classmethod
def _from_json(cls, flow_step_config: Dict[Text, Any]) -> FlowStep:
def from_json(cls, flow_step_config: Dict[Text, Any]) -> FlowStep:
"""Used to read flow steps from parsed YAML.
Args:
Expand Down Expand Up @@ -122,7 +121,7 @@ def default_id(self) -> str:
@property
def default_id_postfix(self) -> str:
"""Returns the default id postfix of the flow step."""
raise NotImplementedError()
return "step"

@property
def utterances(self) -> Set[str]:
Expand Down
2 changes: 0 additions & 2 deletions rasa/shared/core/flows/steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .action import ActionFlowStep
from .branch import BranchFlowStep
from .collect import CollectInformationFlowStep
from .continuation import ContinueFlowStep
from .end import EndFlowStep
Expand All @@ -13,7 +12,6 @@
# to make ruff happy and use the imported names
all_steps = [
ActionFlowStep,
BranchFlowStep,
CollectInformationFlowStep,
ContinueFlowStep,
EndFlowStep,
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/core/flows/steps/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def from_json(cls, data: Dict[Text, Any]) -> ActionFlowStep:
Returns:
An ActionFlowStep object
"""
base = super()._from_json(data)
base = super().from_json(data)
return ActionFlowStep(
action=data["action"],
**base.__dict__,
Expand Down
40 changes: 0 additions & 40 deletions rasa/shared/core/flows/steps/branch.py

This file was deleted.

2 changes: 1 addition & 1 deletion rasa/shared/core/flows/steps/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def from_json(cls, data: Dict[Text, Any]) -> CollectInformationFlowStep:
Returns:
A CollectInformationFlowStep object
"""
base = super()._from_json(data)
base = super().from_json(data)
return CollectInformationFlowStep(
collect=data["collect"],
utter=data.get("utter", f"utter_ask_{data['collect']}"),
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/core/flows/steps/generate_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def from_json(cls, data: Dict[Text, Any]) -> GenerateResponseFlowStep:
Returns:
A GenerateResponseFlowStep object
"""
base = super()._from_json(data)
base = super().from_json(data)
return GenerateResponseFlowStep(
generation_prompt=data["generation_prompt"],
llm_config=data.get("llm"),
Expand Down
5 changes: 5 additions & 0 deletions rasa/shared/core/flows/steps/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ def as_json(self) -> Dict[Text, Any]:
"Internal flow steps are ephemeral and are not to be serialized "
"or de-serialized."
)

@property
def default_id_postfix(self) -> str:
"""Returns the default id postfix of the flow step."""
raise ValueError("Internal flow steps do not need a default id")
2 changes: 1 addition & 1 deletion rasa/shared/core/flows/steps/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def from_json(cls, data: Dict[Text, Any]) -> LinkFlowStep:
Returns:
a LinkFlowStep object
"""
base = super()._from_json(data)
base = super().from_json(data)
return LinkFlowStep(
link=data.get("link", ""),
**base.__dict__,
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/core/flows/steps/set_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def from_json(cls, data: Dict[Text, Any]) -> SetSlotsFlowStep:
Returns:
a SetSlotsFlowStep object
"""
base = super()._from_json(data)
base = super().from_json(data)
slots = [
{"key": k, "value": v}
for slot_sets in data["set_slots"]
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/core/flows/steps/user_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def from_json(cls, flow_step_config: Dict[Text, Any]) -> UserMessageStep:
Returns:
The parsed flow step.
"""
base = super()._from_json(flow_step_config)
base = super().from_json(flow_step_config)

trigger_conditions = []
if "intent" in flow_step_config:
Expand Down
29 changes: 14 additions & 15 deletions rasa/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rasa.shared.core.flows.flow_step_links import IfFlowStepLink
from rasa.shared.core.flows.steps.set_slots import SetSlotsFlowStep
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
from rasa.shared.core.flows.steps.branch import BranchFlowStep
from rasa.shared.core.flows.flow_step import FlowStep

Check failure on line 13 in rasa/validator.py

View workflow job for this annotation

GitHub Actions / Code Quality

F401 [*] `rasa.shared.core.flows.flow_step.FlowStep` imported but unused
from rasa.shared.core.flows.steps.action import ActionFlowStep
from rasa.shared.core.flows.flows_list import FlowsList
import rasa.shared.nlu.constants
Expand Down Expand Up @@ -631,24 +631,23 @@ def _construct_predicate(
return pred, all_good

def verify_predicates(self) -> bool:
"""Checks that predicates used in branch flow steps or `collect` steps are valid.""" # noqa: E501
"""Validate predicates used in flow step links and slot rejections."""
all_good = True
for flow in self.flows.underlying_flows:
for step in flow.steps:
if isinstance(step, BranchFlowStep):
for link in step.next.links:
if isinstance(link, IfFlowStepLink):
predicate, all_good = Validator._construct_predicate(
link.condition, step.id
for link in step.next.links:
if isinstance(link, IfFlowStepLink):
predicate, all_good = Validator._construct_predicate(
link.condition, step.id
)
if predicate and not predicate.is_valid():
logger.error(
f"Detected invalid condition '{link.condition}' "
f"at step '{step.id}' for flow id '{flow.id}'. "
f"Please make sure that all conditions are valid."
)
if predicate and not predicate.is_valid():
logger.error(
f"Detected invalid condition '{link.condition}' "
f"at step '{step.id}' for flow id '{flow.id}'. "
f"Please make sure that all conditions are valid."
)
all_good = False
elif isinstance(step, CollectInformationFlowStep):
all_good = False
if isinstance(step, CollectInformationFlowStep):
predicates = [predicate.if_ for predicate in step.rejections]
for predicate in predicates:
pred, all_good = Validator._construct_predicate(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import pytest
from _pytest.logging import LogCaptureFixture
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
from rasa.shared.core.domain import Domain
from rasa.shared.core.flows.yaml_flows_io import flows_from_str
from rasa.shared.core.training_data.structures import StoryGraph
from rasa.shared.nlu.training_data.training_data import TrainingData

from rasa.validator import Validator

Expand Down Expand Up @@ -1349,6 +1353,24 @@ def test_verify_predicates_invalid_rejection_if(
assert error_log in caplog.text


def test_flow_predicate_validation_fails_for_faulty_flow_link_predicates():
flows = flows_from_str(
"""
flows:
pattern_bar:
steps:
- id: first
action: action_listen
next:
- if: xxx !!!
then: END
- else: END
"""
)
validator = Validator(Domain.empty(), TrainingData(), StoryGraph([]), flows, None)
assert not validator.verify_predicates()


@pytest.fixture
def domain_file_name(tmp_path: Path) -> Path:
domain_file_name = tmp_path / "domain.yml"
Expand Down

0 comments on commit cb80889

Please sign in to comment.