Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rasa role group #2

Merged
merged 19 commits into from
Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion chatette/adapters/rasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ def _write_batch(self, output_file_handle, batch):

def prepare_example(self, example):
def entity_to_rasa(entity):
return {
entity_dict = {
"entity": entity.slot_name,
"value": entity.value,
"start": entity._start_index,
"end": entity._start_index + entity._len,
}
if entity.role is not None:
entity_dict['role'] = entity.role
if entity.group is not None:
entity_dict['group'] = entity.group

return entity_dict

return {
"intent": example.intent_name,
Expand Down
15 changes: 11 additions & 4 deletions chatette/adapters/rasa_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,18 @@ def prepare_example(self, example):
)
result = example.text[:]
for entity in sorted_entities:
entity_annotation_text = ']{"entity": "' + entity.slot_name
entity_text = result[entity._start_index:entity._start_index + entity._len]
if entity_text != entity.value:
entity_annotation_text += '", "value": "{}'.format(entity.value)
if entity.role is not None:
entity_annotation_text += '", "role": "{}'.format(entity.role)
if entity.group is not None:
entity_annotation_text += '", "group": "{}'.format(entity.group)
result = \
result[:entity._start_index] + "[" + \
result[entity._start_index:entity._start_index + entity._len] + \
"](" + entity.slot_name + ")" + \
result[entity._start_index + entity._len:]
entity_text + entity_annotation_text + '"}' + \
result[entity._start_index + entity._len:] # New rasa entity format
return result


Expand All @@ -105,7 +112,7 @@ def _get_base_to_extend(self):
if self._base_file_contents is None:
if self._base_filepath is None:
return self._get_empty_base()
with io.open(self._base_filepath, 'r') as base_file:
with io.open(self._base_filepath, 'r', encoding='utf-8') as base_file:
self._base_file_contents = ''.join(base_file.readlines())
self.check_base_file_contents()
return self._base_file_contents
Expand Down
114 changes: 100 additions & 14 deletions chatette/adapters/rasa_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import io
from collections import OrderedDict
import ruamel.yaml as yaml
from ruamel.yaml.scalarstring import DoubleQuotedScalarString
from ruamel.yaml.error import YAMLError
from ruamel.yaml.constructor import DuplicateKeyError

from chatette.adapters._base import Adapter
from chatette.utils import append_to_list_in_dict, cast_to_unicode

YAML_VERSION = (1, 2)

def intent_dict_to_list_of_dict(data):
list_data = []
for key, values in data.items():
Expand All @@ -18,6 +23,18 @@ def intent_dict_to_list_of_dict(data):

return list_data

def fix_yaml_loader() -> None:
"""Ensure that any string read by yaml is represented as unicode."""
"""Code from Rasa yaml reader"""
def construct_yaml_str(self, node):
# Override the default string handling function
# to always return unicode objects
return self.construct_scalar(node)

yaml.Loader.add_constructor("tag:yaml.org,2002:str", construct_yaml_str)
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str", construct_yaml_str)


class RasaYMLAdapter(Adapter):
def __init__(self, base_filepath=None):
super(RasaYMLAdapter, self).__init__(base_filepath, None)
Expand Down Expand Up @@ -68,10 +85,17 @@ def prepare_example(self, example):
)
result = example.text[:]
for entity in sorted_entities:
entity_annotation_text = ']{"entity": "' + entity.slot_name
entity_text = result[entity._start_index:entity._start_index + entity._len]
if entity_text != entity.value:
entity_annotation_text += '", "value": "{}'.format(entity.value)
if entity.role is not None:
entity_annotation_text += '", "role": "{}'.format(entity.role)
if entity.group is not None:
entity_annotation_text += '", "group": "{}'.format(entity.group)
result = \
result[:entity._start_index] + "[" + \
result[entity._start_index:entity._start_index + entity._len] + \
']{"entity": "' + entity.slot_name + '"}' + \
entity_text + entity_annotation_text + '"}' + \
result[entity._start_index + entity._len:] # New rasa entity format
return result

Expand All @@ -87,18 +111,80 @@ def __format_synonyms(cls, synonyms):
if len(synonyms[slot_name]) > 1
]

def _read_yaml(self, content):
fix_yaml_loader()
yaml_parser = yaml.YAML(typ='safe')
yaml_parser.version = YAML_VERSION
yaml_parser.preserve_quotes = True
yaml.allow_duplicate_keys = False

return yaml_parser.load(content)

def _get_base_to_extend(self):
### TODO Implement later
return self._get_empty_base()
# if self._base_file_contents is None:
# if self._base_filepath is None:
# return self._get_empty_base()
# with io.open(self._base_filepath, 'r', encoding='utf-8') as base_file:
# self._base_file_contents = json.load(base_file)
# self.check_base_file_contents()
# return self._base_file_contents
if self._base_file_contents is None:
if self._base_filepath is None:
return self._get_empty_base()
with io.open(self._base_filepath, 'r', encoding='utf-8') as base_file:
try:
self._base_file_contents = self._read_yaml(base_file.read())
except (YAMLError, DuplicateKeyError) as e:
raise YamlSyntaxException(self._base_filepath, e)
self.check_base_file_contents()
return self._base_file_contents

def _get_empty_base(self):
return {
"nlu": list()
}
base = OrderedDict()
base['version'] = DoubleQuotedScalarString('2.0')
base['nlu'] = list()
return base

def check_base_file_contents(self):
"""
Checks that `self._base_file_contents` contains well formatted NLU dictionary.
Throws a `SyntaxError` if the data is incorrect.
"""
if self._base_file_contents is None:
return
if not isinstance(self._base_file_contents, dict):
self._base_file_contents = None
raise SyntaxError(
"Couldn't load valid data from base file '" + \
self._base_filepath + "'"
)
else:
if "nlu" not in self._base_file_contents:
self._base_file_contents = None
raise SyntaxError(
"Expected 'nlu' as a root of base file '" + \
self._base_filepath + "'")


class YamlSyntaxException(Exception):
"""Raised when a YAML file can not be parsed properly due to a syntax error."""
"""code from rasa.shared.exceptions.YamlSyntaxException"""

def __init__(self, filename, underlying_yaml_exception):
self.filename = filename
self.underlying_yaml_exception = underlying_yaml_exception

def __str__(self):
if self.filename:
exception_text = "Failed to read '{}'.".format(self.filename)
else:
exception_text = "Failed to read YAML."

if self.underlying_yaml_exception:
self.underlying_yaml_exception.warn = None
self.underlying_yaml_exception.note = None
exception_text += " {}".format(self.underlying_yaml_exception)

if self.filename:
exception_text = exception_text.replace(
'in "<unicode string>"', 'in "{}"'.format(self.filename)
)

exception_text += (
"\n\nYou can use https://yamlchecker.com/ to validate the "
"YAML syntax of your file."
)
return exception_text
9 changes: 8 additions & 1 deletion chatette/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from future.utils import with_metaclass

from chatette.units.modifiable.choice import Choice
from chatette.units.modifiable.unit_reference import UnitReference
from chatette.units.modifiable.unit_reference import UnitReference, SlotRoleGroupReference
from chatette.units.modifiable.definitions.alias import AliasDefinition
from chatette.units.modifiable.definitions.slot import SlotDefinition
from chatette.units.modifiable.definitions.intent import IntentDefinition
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(self):
self.identifier = None
self.variation = None
self.arg_value = None
self.slot_rolegroup = None

def _check_information(self):
super(UnitRefBuilder, self)._check_information()
Expand All @@ -108,6 +109,12 @@ def _build_modifiers_repr(self):

def create_concrete(self):
self._check_information()
if self.slot_rolegroup is not None:
return SlotRoleGroupReference(
self.identifier, self.type,
self.leading_space, self._build_modifiers_repr(),
self.slot_rolegroup
)
return UnitReference(
self.identifier, self.type,
self.leading_space, self._build_modifiers_repr()
Expand Down
15 changes: 14 additions & 1 deletion chatette/parsing/lexing/rule_unit_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
extract_identifier, \
CASE_GEN_SYM, UNIT_END_SYM

from chatette.parsing.lexing.rule_annotation import RuleAnnotation
from chatette.parsing.lexing.rule_unit_start import RuleUnitStart
from chatette.parsing.lexing.rule_variation import RuleVariation
from chatette.parsing.lexing.rule_rand_gen import RuleRandGen
Expand Down Expand Up @@ -55,11 +56,13 @@ def _apply_strategy(self, **kwargs):
"using character '" + UNIT_END_SYM + "')."
return False

is_slot = False
# TODO maybe making a function for this would be useful
if self._tokens[0].type == TerminalType.alias_ref_start:
unit_end_type = TerminalType.alias_ref_end
elif self._tokens[0].type == TerminalType.slot_ref_start:
unit_end_type = TerminalType.slot_ref_end
is_slot = True
elif self._tokens[0].type == TerminalType.intent_ref_start:
unit_end_type = TerminalType.intent_ref_end
else: # Should never happen
Expand All @@ -72,5 +75,15 @@ def _apply_strategy(self, **kwargs):
self._next_index += 1
self._update_furthest_matched_index()
self._tokens.append(LexicalToken(unit_end_type, UNIT_END_SYM))


# This is for adding new rasa training mode that has role and group entity
# Reference: https://rasa.com/docs/rasa/nlu-training-data/#entities-roles-and-groups
annotation_rule = RuleAnnotation(self._text, self._next_index)

# ? Should we raise error if RuleAnnotation doesn't match, i.e. wrong pattern
if is_slot and annotation_rule.matches():
self._next_index = annotation_rule.get_next_index_to_match()
self._update_furthest_matched_index()
self._tokens.extend(annotation_rule.get_lexical_tokens())

return True
35 changes: 33 additions & 2 deletions chatette/parsing/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from chatette.units.rule import Rule

from chatette.parsing import \
ChoiceBuilder, UnitRefBuilder, \
ChoiceBuilder, UnitRefBuilder,\
AliasDefBuilder, SlotDefBuilder, IntentDefBuilder


Expand Down Expand Up @@ -398,12 +398,18 @@ def _parse_rule(self, tokens):
elif (
token.type in \
(TerminalType.alias_ref_end,
TerminalType.slot_ref_end,
TerminalType.intent_ref_end)
):
rule_contents.append(current_builder.create_concrete())
current_builder = None
leading_space = False
elif token.type == TerminalType.slot_ref_end:
# checking annotation after slot reference
rolegroup_annotation, i = self._check_for_annotations(tokens, i)
current_builder.slot_rolegroup = rolegroup_annotation
rule_contents.append(current_builder.create_concrete())
current_builder = None
leading_space = False
elif token.type == TerminalType.unit_identifier:
current_builder.identifier = token.text
elif token.type == TerminalType.choice_start:
Expand Down Expand Up @@ -505,3 +511,28 @@ def _parse_choice(self, tokens):
)

return rules

def _check_for_annotations(self, tokens, i):
if (
i+1 == len(tokens)
or tokens[i+1].type != TerminalType.annotation_start
):
return None, i

annotation = {}
current_key = None
for j, token in enumerate(tokens[i+1:]):
if token.type == TerminalType.annotation_end:
i += j+1
break
elif token.type == TerminalType.key:
current_key = token.text
elif token.type == TerminalType.value:
if current_key in annotation:
self.input_file_manager.syntax_error(
"Annotation contained the key '" + current_key + \
"' twice."
)
annotation[current_key] = token.text

return annotation, i
16 changes: 14 additions & 2 deletions chatette/units/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ class Entity(object):
Represents an entity as it will be contained in examples
(instances of `Example`).
"""
def __init__(self, name, length, value=None, start_index=0):
def __init__(self, name, length, value=None, start_index=0, role=None, group=None):
self.slot_name = name # name of the entity (not the associated text)
self.value = value
self._len = length
self._start_index = start_index
self.role = role
self.group = group

def _remove_leading_space(self):
"""
Expand All @@ -146,17 +148,27 @@ def _remove_leading_space(self):
return True

def as_dict(self):
return {
entity_dict = {
"slot-name": self.slot_name,
"value": self.value,
"start-index": self._start_index,
"end-index": self._start_index + self._len
}
if self.role is not None:
entity_dict['role'] = self.role
if self.group is not None:
entity_dict['group'] = self.group
return entity_dict

def __repr__(self):
representation = "entity '" + self.slot_name + "'"
if self.value is not None:
representation += ":'" + self.value + "'"
# ? There might be better representation format?
if self.role is not None:
representation += ", 'role' :'" + self.role + "'"
if self.group is not None:
representation += ", 'group' :'" + self.group + "'"
return representation
def __str__(self):
return \
Expand Down
Loading