Skip to content

Commit

Permalink
Added tests for flows io, fixed bug that made serialization and deser…
Browse files Browse the repository at this point in the history
…ialization unequal
  • Loading branch information
twerkmeister committed Oct 27, 2023
1 parent 8c0260b commit 48c0718
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 26 deletions.
10 changes: 10 additions & 0 deletions data/test_flows/basic_flows.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
flows:
foo:
description: "A test flow"
steps:
- action: "utter_test"
bar:
description: "Another test flow"
steps:
- action: "utter_greet"
- collect: "important_info"
5 changes: 3 additions & 2 deletions rasa/shared/core/flows/flow_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def as_json(self) -> Dict[Text, Any]:
Returns:
The FlowStep as serialized data.
"""
data: Dict[Text, Any] = {"next": self.next.as_json(), "id": self.id}

data: Dict[Text, Any] = {"next": self.next.as_json()}
if self.custom_id:
data["id"] = self.custom_id
if self.description:
data["description"] = self.description
if self.metadata:
Expand Down
21 changes: 7 additions & 14 deletions rasa/shared/core/flows/flows_list.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Generator, Any, Optional, Dict, Text, Set

import rasa.shared.utils.io
from rasa.shared.core.flows.flow import Flow
from rasa.shared.core.flows.validation import validate_flow


@dataclass
class FlowsList:
"""A collection of flows.
Expand All @@ -15,24 +16,16 @@ class FlowsList:
specific attributes or collecting all utterances across all flows.
"""

def __init__(self, flows: List[Flow]) -> None:
"""Initializes the FlowsList object.
Args:
flows: The flows for this collection.
"""
self.underlying_flows = flows
underlying_flows: List[Flow]
"""The flows contained in this FlowsList."""

def __iter__(self) -> Generator[Flow, None, None]:
"""Iterates over the flows."""
yield from self.underlying_flows

def __eq__(self, other: Any) -> bool:
"""Compares this FlowsList to another one."""
return (
isinstance(other, FlowsList)
and self.underlying_flows == other.underlying_flows
)
def __len__(self):
"""Return the length of this FlowsList."""
return len(self.underlying_flows)

def is_empty(self) -> bool:
"""Returns whether the flows list is empty."""
Expand Down
4 changes: 2 additions & 2 deletions rasa/shared/importers/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_flows(self) -> FlowsList:
Returns:
`FlowsList` containing all loaded flows.
"""
return FlowsList(flows=[])
return FlowsList([])

def get_conversation_tests(self) -> StoryGraph:
"""Retrieves end-to-end conversation stories for testing.
Expand Down Expand Up @@ -310,7 +310,7 @@ def get_flows(self) -> FlowsList:
flow_lists = [importer.get_flows() for importer in self._importers]

return reduce(
lambda merged, other: merged.merge(other), flow_lists, FlowsList(flows=[])
lambda merged, other: merged.merge(other), flow_lists, FlowsList([])
)

@rasa.shared.utils.common.cached_method
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/importers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def flows_from_paths(files: List[Text]) -> FlowsList:
"""Returns the flows from paths."""
from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader

flows = FlowsList(flows=[])
flows = FlowsList([])
for file in files:
flows = flows.merge(YAMLFlowsReader.read_from_file(file))
return flows
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,17 @@ def tracker_with_restarted_event(
events = initial_events_including_restart + events_after_restart

return DialogueStateTracker.from_events(sender_id=sender_id, evts=events)


@pytest.fixture(scope="session")
def tests_folder() -> str:
tests_folder = os.path.dirname(os.path.abspath(__file__))
assert os.path.isdir(tests_folder)
return tests_folder


@pytest.fixture(scope="session")
def tests_data_folder(tests_folder: str) -> str:
tests_data_folder = os.path.join(os.path.split(tests_folder)[0], "data")
assert os.path.isdir(tests_data_folder)
return tests_data_folder
2 changes: 1 addition & 1 deletion tests/core/flows/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def only_patterns() -> FlowsList:

@pytest.fixture
def empty_flowlist() -> FlowsList:
return FlowsList(flows=[])
return FlowsList([])


def test_user_flow_ids(user_flows_and_patterns: FlowsList):
Expand Down
41 changes: 41 additions & 0 deletions tests/core/flows/test_flows_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import pytest
import tempfile
from rasa.shared.core.flows.yaml_flows_io import (
is_flows_file,
YAMLFlowsReader,
YamlFlowsWriter,
)


@pytest.fixture(scope="module")
def basic_flows_file(tests_data_folder: str) -> str:
return os.path.join(tests_data_folder, "test_flows", "basic_flows.yml")


@pytest.mark.parametrize(
"path, expected_result",
[
(os.path.join("test_flows", "basic_flows.yml"), True),
(os.path.join("test_moodbot", "domain.yml"), False),
],
)
def test_is_flows_file(tests_data_folder: str, path: str, expected_result: bool):
full_path = os.path.join(tests_data_folder, path)
assert is_flows_file(full_path) == expected_result


def test_flow_reading(basic_flows_file: str):
flows_list = YAMLFlowsReader.read_from_file(basic_flows_file)
assert len(flows_list) == 2
assert flows_list.flow_by_id("foo") is not None
assert flows_list.flow_by_id("bar") is not None


def test_flow_writing(basic_flows_file: str):
flows_list = YAMLFlowsReader.read_from_file(basic_flows_file)
tmp_file_descriptor, tmp_file_name = tempfile.mkstemp()
YamlFlowsWriter.dump(flows_list.underlying_flows, tmp_file_name)

re_read_flows_list = YAMLFlowsReader.read_from_file(tmp_file_name)
assert re_read_flows_list == flows_list
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_run_command_skips_if_slot_is_set_to_same_value():
tracker = DialogueStateTracker.from_events("test", evts=[SlotSet("foo", "bar")])
command = SetSlotCommand(name="foo", value="bar")

assert command.run_command_on_tracker(tracker, FlowsList(flows=[]), tracker) == []
assert command.run_command_on_tracker(tracker, FlowsList([]), tracker) == []


def test_run_command_sets_slot_if_asked_for():
Expand Down
10 changes: 5 additions & 5 deletions tests/dialogue_understanding/stack/frames/test_flow_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def test_flow_get_flow():
name="foo flow",
description="foo flow description",
)
all_flows = FlowsList(flows=[flow])
all_flows = FlowsList([flow])
assert frame.flow(all_flows) == flow


def test_flow_get_flow_non_existant_id():
frame = UserFlowStackFrame(frame_id="test", flow_id="unknown", step_id="bar")
all_flows = FlowsList(
flows=[
[
Flow(
id="foo",
step_sequence=FlowStepSequence(child_steps=[]),
Expand All @@ -90,7 +90,7 @@ def test_flow_get_step():
next=FlowStepLinks(links=[]),
)
all_flows = FlowsList(
flows=[
[
Flow(
id="foo",
step_sequence=FlowStepSequence(child_steps=[step]),
Expand All @@ -105,7 +105,7 @@ def test_flow_get_step():
def test_flow_get_step_non_existant_id():
frame = UserFlowStackFrame(frame_id="test", flow_id="foo", step_id="unknown")
all_flows = FlowsList(
flows=[
[
Flow(
id="foo",
step_sequence=FlowStepSequence(child_steps=[]),
Expand All @@ -121,7 +121,7 @@ def test_flow_get_step_non_existant_id():
def test_flow_get_step_non_existant_flow_id():
frame = UserFlowStackFrame(frame_id="test", flow_id="unknown", step_id="unknown")
all_flows = FlowsList(
flows=[
[
Flow(
id="foo",
step_sequence=FlowStepSequence(child_steps=[]),
Expand Down

0 comments on commit 48c0718

Please sign in to comment.