Skip to content

Commit

Permalink
chore: responded to review comments. adding FlowList.user_flows, upda…
Browse files Browse the repository at this point in the history
…ting FLowList.user_flow_ids, making FlowLists itterable.
  • Loading branch information
djcowley committed Oct 8, 2023
1 parent 8194f0b commit ff6de9f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 66 deletions.
2 changes: 1 addition & 1 deletion rasa/dialogue_understanding/commands/start_flow_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def run_command_on_tracker(
"command_executor.skip_command.already_started_flow", command=self
)
return []
elif self.flow not in all_flows.non_pattern_flows():
elif self.flow not in all_flows.user_flow_ids:
structlogger.debug(
"command_executor.skip_command.start_invalid_flow_id", command=self
)
Expand Down
121 changes: 66 additions & 55 deletions rasa/dialogue_understanding/generator/llm_command_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.shared.core.flows.flow import FlowStep, FlowsList, CollectInformationFlowStep
from rasa.shared.core.flows.flow import (
Flow,
FlowStep,
FlowsList,
CollectInformationFlowStep,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.slots import (
BooleanSlot,
Expand Down Expand Up @@ -67,7 +72,7 @@
is_trainable=True,
)
class LLMCommandGenerator(GraphComponent, CommandGenerator):
"""An LLM based command generator."""
"""An LLM-based command generator."""

@staticmethod
def get_default_config() -> Dict[str, Any]:
Expand Down Expand Up @@ -168,47 +173,29 @@ def render_template(
Returns:
The rendered prompt template.
"""
flows_without_patterns = FlowsList(
[f for f in flows.underlying_flows if not f.is_handling_pattern()]
)
top_relevant_frame = top_flow_frame(DialogueStack.from_tracker(tracker))
top_flow = top_relevant_frame.flow(flows) if top_relevant_frame else None
current_step = top_relevant_frame.step(flows) if top_relevant_frame else None
if top_flow is not None:
flow_slots = [
{
"name": info_step.collect,
"value": self.slot_value(tracker, info_step.collect),
"type": tracker.slots[info_step.collect].type_name,
"allowed_values": self.allowed_values_for_slot(
tracker.slots[info_step.collect]
),
"description": info_step.description,
}
for info_step in top_flow.get_collect_steps()
if self.is_extractable(info_step, tracker, current_step)
]
else:
flow_slots = []

collect_information, collect_information_description = (
(current_step.collect, current_step.description)
if isinstance(current_step, CollectInformationFlowStep)
else (None, None)
flow_slots = self.prepare_current_flow_slots_for_template(
top_flow, current_step, tracker
)
current_slot, current_slot_description = self.prepare_current_slot_for_template(
current_step
)
current_conversation = tracker_as_readable_transcript(tracker)
latest_user_message = sanitize_message_for_prompt(message.get(TEXT))
current_conversation += f"\nUSER: {latest_user_message}"

inputs = {
"available_flows": self.create_template_inputs(
flows_without_patterns, tracker
"available_flows": self.prepare_flows_for_template(
flows.user_flows, tracker
),
"current_conversation": current_conversation,
"flow_slots": flow_slots,
"current_flow": top_flow.id if top_flow is not None else None,
"collect_information": collect_information,
"collect_information_description": collect_information_description,
"collect_information": current_slot,
"collect_information_description": current_slot_description,
"user_message": latest_user_message,
}

Expand Down Expand Up @@ -307,7 +294,7 @@ def clean_extracted_value(value: str) -> str:
"""Clean up the extracted value from the llm."""
# replace any combination of single quotes, double quotes, and spaces
# from the beginning and end of the string
return re.sub(r"^['\"\s]+|['\"\s]+$", "", value)
return value.strip("'\" ")

@classmethod
def coerce_slot_value(
Expand Down Expand Up @@ -345,10 +332,10 @@ def coerce_slot_value(
return nullable_value

@classmethod
def create_template_inputs(
def prepare_flows_for_template(
cls, flows: FlowsList, tracker: DialogueStateTracker
) -> List[Dict[str, Any]]:
"""Create the template inputs for the flows.
"""Format data on available flows for insertion into the prompt template.
Args:
flows: The flows available to the user.
Expand All @@ -358,29 +345,24 @@ def create_template_inputs(
The inputs for the prompt template.
"""
result = []
for flow in flows.underlying_flows:
# TODO: check if we should filter more flows; e.g. flows that are
# linked to by other flows and that shouldn't be started directly.
# we might need a separate flag for that.
if not flow.is_rasa_default_flow():

slots_with_info = [
{"name": q.collect, "description": q.description}
for q in flow.get_collect_steps()
if cls.is_extractable(q, tracker)
]
result.append(
{
"name": flow.id,
"description": flow.description,
"slots": slots_with_info,
}
)
for flow in flows.user_flows:
slots_with_info = [
{"name": q.collect, "description": q.description}
for q in flow.get_collect_steps()
if cls.is_extractable(q, tracker)
]
result.append(
{
"name": flow.id,
"description": flow.description,
"slots": slots_with_info,
}
)
return result

@staticmethod
def is_extractable(
info_step: CollectInformationFlowStep,
collect_step: CollectInformationFlowStep,
tracker: DialogueStateTracker,
current_step: Optional[FlowStep] = None,
) -> bool:
Expand All @@ -391,27 +373,27 @@ def is_extractable(
slot has been filled already.
Args:
info_step: The collect_information step.
collect_step: The collect_information step.
tracker: The tracker containing the current state of the conversation.
current_step: The current step in the flow.
Returns:
`True` if the slot can be filled, `False` otherwise.
"""
slot = tracker.slots.get(info_step.collect)
slot = tracker.slots.get(collect_step.collect)
if slot is None:
return False

return (
# we can fill because this is a slot that can be filled ahead of time
not info_step.ask_before_filling
not collect_step.ask_before_filling
# we can fill because the slot has been filled already
or slot.has_been_set
# we can fill because the is currently getting asked
or (
current_step is not None
and isinstance(current_step, CollectInformationFlowStep)
and current_step.collect == info_step.collect
and current_step.collect == collect_step.collect
)
)

Expand Down Expand Up @@ -440,3 +422,32 @@ def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str:
return "undefined"
else:
return str(slot_value)

def prepare_current_flow_slots_for_template(
self, top_flow: Flow, current_step: FlowStep, tracker: DialogueStateTracker
) -> List[Dict[str, Any]]:
if top_flow is not None:
flow_slots = [
{
"name": collect_step.collect,
"value": self.slot_value(tracker, collect_step.collect),
"type": tracker.slots[collect_step.collect].type_name,
"allowed_values": self.allowed_values_for_slot(
tracker.slots[collect_step.collect]
),
"description": collect_step.description,
}
for collect_step in top_flow.get_collect_steps()
if self.is_extractable(collect_step, tracker, current_step)
]
else:
flow_slots = []
return flow_slots

def prepare_current_slot_for_template(self, current_step: FlowStep):
"""Prepare the current slot for the template."""
return (
(current_step.collect, current_step.description)
if isinstance(current_step, CollectInformationFlowStep)
else (None, None)
)
29 changes: 19 additions & 10 deletions rasa/shared/core/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def __init__(self, flows: List[Flow]) -> None:
"""
self.underlying_flows = flows

def __iter__(self) -> Generator[Flow, None, None]:
"""Iterates over the flows."""
yield from self.underlying_flows

def is_empty(self) -> bool:
"""Returns whether the flows list is empty."""
return len(self.underlying_flows) == 0
Expand Down Expand Up @@ -254,15 +258,23 @@ def validate(self) -> None:
for flow in self.underlying_flows:
flow.validate()

def non_pattern_flows(self) -> List[str]:
"""Get all flows that can be started.
@property
def user_flow_ids(self) -> List[str]:
"""Get all ids of flows that can be started by a user.
Args:
all_flows: All flows.
Returns:
The ids of all flows that can be started by a user."""
return [f.id for f in self.user_flows]

@property
def user_flows(self) -> FlowsList:
"""Get all flows that can be started by a user.
Returns:
All flows that can be started."""
return [f.id for f in self.underlying_flows if not f.is_handling_pattern()]
All flows that can be started by a user."""
return FlowsList(
[f for f in self.underlying_flows if not f.is_rasa_default_flow]
)


@dataclass
Expand Down Expand Up @@ -495,10 +507,6 @@ def _previously_asked_collect(

return _previously_asked_collect(step_id or START_STEP, set())

def is_handling_pattern(self) -> bool:
"""Returns whether the flow is handling a pattern."""
return self.id.startswith(RASA_DEFAULT_FLOW_PATTERN_PREFIX)

def get_trigger_intents(self) -> Set[str]:
"""Returns the trigger intents of the flow"""
results: Set[str] = set()
Expand All @@ -519,6 +527,7 @@ def is_user_triggerable(self) -> bool:
"""Test whether a user can trigger the flow with an intent."""
return len(self.get_trigger_intents()) > 0

@property
def is_rasa_default_flow(self) -> bool:
"""Test whether something is a rasa default flow."""
return self.id.startswith(RASA_DEFAULT_FLOW_PATTERN_PREFIX)
Expand Down

0 comments on commit ff6de9f

Please sign in to comment.