Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
varunshankar committed Aug 30, 2023
1 parent c25f53f commit e62077f
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 70 deletions.
8 changes: 4 additions & 4 deletions rasa/cli/initial_project_dm2/data/flows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ flows:
description: greet the user and ask how they are doing. cheer them up if needed.
steps:
- id: "0"
question: good_mood
collect_information: good_mood
description: "can be true or false"
next:
- if: good_mood
Expand All @@ -30,13 +30,13 @@ flows:
description: This flow recommends a restaurant
steps:
- id: "0"
question: cuisine
collect_information: cuisine
next: "1"
- id: "1"
question: price_range
collect_information: price_range
next: "2"
- id: "2"
question: city
collect_information: city
next: "3"
- id: "3"
action: utter_recommend_restaurant
Expand Down
10 changes: 5 additions & 5 deletions rasa/core/channels/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ <h2>Happy chatting!</h2>
title: ${flow?.id ?? "No Flow"}
---
flowchart TD
classDef question stroke-width:1px
classDef collect_information stroke-width:1px
classDef action fill:#FBFCFD,stroke:#A0B8CF
classDef link fill:#f43
classDef slot fill:#aaa
Expand All @@ -347,11 +347,11 @@ <h2>Happy chatting!</h2>

if (flow) {
flow.steps.forEach((step) => {
if (step.question) {
var slotValue = slots[step.question]
? `'${slots[step.question]}'`
if (step.collect_information) {
var slotValue = slots[step.collect_information]
? `'${slots[step.collect_information]}'`
: "\uD83D\uDCAC";
mermaidText += `${step.id}["${toHtmlEntities(inject(keepShort(step.question), currentContext))}\n${keepShort(slotValue)}"]:::question\n`;
mermaidText += `${step.id}["${toHtmlEntities(inject(keepShort(step.collect_information), currentContext))}\n${keepShort(slotValue)}"]:::collect_information\n`;
}
if (step.action) {
mermaidText += `${step.id}["${toHtmlEntities(inject(keepShort(step.action), currentContext))}"]:::action\n`;
Expand Down
39 changes: 21 additions & 18 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from jinja2 import Template
from rasa.cdu.conversation_patterns import (
FLOW_PATTERN_ASK_QUESTION,
FLOW_PATTERN_COLLECT_INFORMATION,
FLOW_PATTERN_CLARIFICATION,
FLOW_PATTERN_COMPLETED,
FLOW_PATTERN_CONTINUE_INTERRUPTED,
Expand Down Expand Up @@ -46,12 +46,12 @@
GenerateResponseFlowStep,
IfFlowLink,
EntryPromptFlowStep,
QuestionScope,
CollectInformationScope,
StepThatCanStartAFlow,
UserMessageStep,
LinkFlowStep,
SetSlotsFlowStep,
QuestionFlowStep,
CollectInformationFlowStep,
StaticFlowLink,
)
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
Expand Down Expand Up @@ -382,22 +382,22 @@ def _select_next_step(
)
return step

def _slot_for_question(self, question: Text) -> Slot:
"""Find the slot for a question."""
def _slot_for_collect_information(self, collect_information: Text) -> Slot:
"""Find the slot for a collect information."""
for slot in self.domain.slots:
if slot.name == question:
if slot.name == collect_information:
return slot
else:
raise FlowException(
f"Question '{question}' does not map to an existing slot."
f"Collect Information '{collect_information}' does not map to an existing slot."
)

def _is_step_completed(
self, step: FlowStep, tracker: "DialogueStateTracker"
) -> bool:
"""Check if a step is completed."""
if isinstance(step, QuestionFlowStep):
return tracker.get_slot(step.question) is not None
if isinstance(step, CollectInformationFlowStep):
return tracker.get_slot(step.collect_information) is not None
else:
return True

Expand Down Expand Up @@ -536,10 +536,13 @@ def _reset_scoped_slots(
events: List[Event] = []
for step in current_flow.steps:
# reset all slots scoped to the flow
if isinstance(step, QuestionFlowStep) and step.scope == QuestionScope.FLOW:
slot = tracker.slots.get(step.question, None)
if (
isinstance(step, CollectInformationFlowStep)
and step.scope == CollectInformationScope.FLOW
):
slot = tracker.slots.get(step.collect_information, None)
initial_value = slot.initial_value if slot else None
events.append(SlotSet(step.question, initial_value))
events.append(SlotSet(step.collect_information, initial_value))
return events

def _run_step(
Expand All @@ -565,9 +568,9 @@ def _run_step(
Returns:
A result of running the step describing where to transition to.
"""
if isinstance(step, QuestionFlowStep):
structlogger.debug("flow.step.run.question")
self.trigger_pattern_ask_question(step.question)
if isinstance(step, CollectInformationFlowStep):
structlogger.debug("flow.step.run.collect information")
self.trigger_pattern_ask_collect_information(step.collect_information)
return ContinueFlowWithNextStep()

elif isinstance(step, ActionFlowStep):
Expand Down Expand Up @@ -679,13 +682,13 @@ def trigger_pattern_completed(self, current_frame: FlowStackFrame) -> None:
)
)

def trigger_pattern_ask_question(self, question: str) -> None:
def trigger_pattern_ask_collect_information(self, collect_information: str) -> None:
context = self.flow_stack.current_context().copy()
context["question"] = question
context["collect information"] = collect_information

self.flow_stack.push(
FlowStackFrame(
flow_id=FLOW_PATTERN_ASK_QUESTION,
flow_id=FLOW_PATTERN_COLLECT_INFORMATION,
frame_type=StackFrameType.REMARK,
context=context,
)
Expand Down
88 changes: 45 additions & 43 deletions rasa/shared/core/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,32 +291,32 @@ def first_step_in_flow(self) -> Optional[FlowStep]:
return None
return self.steps[0]

def previously_asked_questions(
def previously_asked_collect_information(
self, step_id: Optional[str]
) -> List[QuestionFlowStep]:
"""Returns the questions asked before the given step.
) -> List[CollectInformationFlowStep]:
"""Returns the collect informations asked before the given step.
Questions are returned roughly in reverse order, i.e. the first
question in the list is the one asked last. But due to circles
CollectInformations are returned roughly in reverse order, i.e. the first
collect information in the list is the one asked last. But due to circles
in the flow the order is not guaranteed to be exactly reverse.
"""

def _previously_asked_questions(
def _previously_asked_collect_information(
current_step_id: str, visited_steps: Set[str]
) -> List[QuestionFlowStep]:
"""Returns the questions asked before the given step.
) -> List[CollectInformationFlowStep]:
"""Returns the collect informations asked before the given step.
Keeps track of the steps that have been visited to avoid circles.
"""
current_step = self.step_by_id(current_step_id)

questions: List[QuestionFlowStep] = []
collect_informations: List[CollectInformationFlowStep] = []

if not current_step:
return questions
return collect_informations

if isinstance(current_step, QuestionFlowStep):
questions.append(current_step)
if isinstance(current_step, CollectInformationFlowStep):
collect_informations.append(current_step)

visited_steps.add(current_step.id)

Expand All @@ -326,12 +326,14 @@ def _previously_asked_questions(
continue
if previous_step.id in visited_steps:
continue
questions.extend(
_previously_asked_questions(previous_step.id, visited_steps)
collect_informations.extend(
_previously_asked_collect_information(
previous_step.id, visited_steps
)
)
return questions
return collect_informations

return _previously_asked_questions(step_id or START_STEP, set())
return _previously_asked_collect_information(step_id or START_STEP, set())

def is_handling_pattern(self) -> bool:
"""Returns whether the flow is handling a pattern."""
Expand Down Expand Up @@ -361,13 +363,13 @@ def is_rasa_default_flow(self) -> bool:
"""Test whether something is a rasa default flow."""
return self.id.startswith(RASA_DEFAULT_FLOW_PATTERN_PREFIX)

def get_question_steps(self) -> List[QuestionFlowStep]:
"""Return the question steps of the flow."""
question_steps = []
def get_collect_information_steps(self) -> List[CollectInformationFlowStep]:
"""Return the collect information steps of the flow."""
collect_information_steps = []
for step in self.steps:
if isinstance(step, QuestionFlowStep):
question_steps.append(step)
return question_steps
if isinstance(step, CollectInformationFlowStep):
collect_information_steps.append(step)
return collect_information_steps


def step_from_json(flow_step_config: Dict[Text, Any]) -> FlowStep:
Expand All @@ -383,8 +385,8 @@ def step_from_json(flow_step_config: Dict[Text, Any]) -> FlowStep:
return ActionFlowStep.from_json(flow_step_config)
if "intent" in flow_step_config:
return UserMessageStep.from_json(flow_step_config)
if "question" in flow_step_config:
return QuestionFlowStep.from_json(flow_step_config)
if "collect_information" in flow_step_config:
return CollectInformationFlowStep.from_json(flow_step_config)
if "link" in flow_step_config:
return LinkFlowStep.from_json(flow_step_config)
if "set_slots" in flow_step_config:
Expand Down Expand Up @@ -951,37 +953,37 @@ def is_triggered(self, tracker: DialogueStateTracker) -> bool:
return False


# enumeration of question scopes. scope can either be flow or global
class QuestionScope(str, Enum):
# enumeration of collect information scopes. scope can either be flow or global
class CollectInformationScope(str, Enum):
FLOW = "flow"
GLOBAL = "global"

@staticmethod
def from_str(label: Optional[Text]) -> "QuestionScope":
"""Converts a string to a QuestionScope."""
def from_str(label: Optional[Text]) -> "CollectInformationScope":
"""Converts a string to a CollectInformationScope."""
if label is None:
return QuestionScope.FLOW
return CollectInformationScope.FLOW
elif label.lower() == "flow":
return QuestionScope.FLOW
return CollectInformationScope.FLOW
elif label.lower() == "global":
return QuestionScope.GLOBAL
return CollectInformationScope.GLOBAL
else:
raise NotImplementedError


@dataclass
class QuestionFlowStep(FlowStep):
"""Represents the configuration of a question flow step."""
class CollectInformationFlowStep(FlowStep):
"""Represents the configuration of a collect information flow step."""

question: Text
"""The question of the flow step."""
collect_information: Text
"""The collect information of the flow step."""
skip_if_filled: bool = True
"""Whether to skip the question if the slot is already filled."""
scope: QuestionScope = QuestionScope.FLOW
"""how the question is scoped, determins when to reset its value."""
"""Whether to skip the collect information if the slot is already filled."""
scope: CollectInformationScope = CollectInformationScope.FLOW
"""how the collect information is scoped, determins when to reset its value."""

@classmethod
def from_json(cls, flow_step_config: Dict[Text, Any]) -> QuestionFlowStep:
def from_json(cls, flow_step_config: Dict[Text, Any]) -> CollectInformationFlowStep:
"""Used to read flow steps from parsed YAML.
Args:
Expand All @@ -991,10 +993,10 @@ def from_json(cls, flow_step_config: Dict[Text, Any]) -> QuestionFlowStep:
The parsed flow step.
"""
base = super()._from_json(flow_step_config)
return QuestionFlowStep(
question=flow_step_config.get("question", ""),
return CollectInformationFlowStep(
collect_information=flow_step_config.get("collect_information", ""),
skip_if_filled=flow_step_config.get("skip_if_filled", True),
scope=QuestionScope.from_str(flow_step_config.get("scope")),
scope=CollectInformationScope.from_str(flow_step_config.get("scope")),
**base.__dict__,
)

Expand All @@ -1005,7 +1007,7 @@ def as_json(self) -> Dict[Text, Any]:
The flow step as a dictionary.
"""
dump = super().as_json()
dump["question"] = self.question
dump["collect_information"] = self.collect_information
dump["skip_if_filled"] = self.skip_if_filled
dump["scope"] = self.scope.value

Expand Down

0 comments on commit e62077f

Please sign in to comment.