Skip to content

Commit

Permalink
Merge pull request #2 from tomgun132/rasa_role_group
Browse files Browse the repository at this point in the history
Rasa role group
  • Loading branch information
tomgun132 committed Jun 27, 2021
2 parents d4fb7b2 + 2dfdd0e commit 9b814ad
Show file tree
Hide file tree
Showing 17 changed files with 319 additions and 32 deletions.
8 changes: 7 additions & 1 deletion chatette/adapters/rasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ def _write_batch(self, output_file_handle, batch):

def prepare_example(self, example):
def entity_to_rasa(entity):
return {
entity_dict = {
"entity": entity.slot_name,
"value": entity.value,
"start": entity._start_index,
"end": entity._start_index + entity._len,
}
if entity.role is not None:
entity_dict['role'] = entity.role
if entity.group is not None:
entity_dict['group'] = entity.group

return entity_dict

return {
"intent": example.intent_name,
Expand Down
15 changes: 11 additions & 4 deletions chatette/adapters/rasa_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,18 @@ def prepare_example(self, example):
)
result = example.text[:]
for entity in sorted_entities:
entity_annotation_text = ']{"entity": "' + entity.slot_name
entity_text = result[entity._start_index:entity._start_index + entity._len]
if entity_text != entity.value:
entity_annotation_text += '", "value": "{}'.format(entity.value)
if entity.role is not None:
entity_annotation_text += '", "role": "{}'.format(entity.role)
if entity.group is not None:
entity_annotation_text += '", "group": "{}'.format(entity.group)
result = \
result[:entity._start_index] + "[" + \
result[entity._start_index:entity._start_index + entity._len] + \
"](" + entity.slot_name + ")" + \
result[entity._start_index + entity._len:]
entity_text + entity_annotation_text + '"}' + \
result[entity._start_index + entity._len:] # New rasa entity format
return result


Expand All @@ -105,7 +112,7 @@ def _get_base_to_extend(self):
if self._base_file_contents is None:
if self._base_filepath is None:
return self._get_empty_base()
with io.open(self._base_filepath, 'r') as base_file:
with io.open(self._base_filepath, 'r', encoding='utf-8') as base_file:
self._base_file_contents = ''.join(base_file.readlines())
self.check_base_file_contents()
return self._base_file_contents
Expand Down
114 changes: 100 additions & 14 deletions chatette/adapters/rasa_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import io
from collections import OrderedDict
import ruamel.yaml as yaml
from ruamel.yaml.scalarstring import DoubleQuotedScalarString
from ruamel.yaml.error import YAMLError
from ruamel.yaml.constructor import DuplicateKeyError

from chatette.adapters._base import Adapter
from chatette.utils import append_to_list_in_dict, cast_to_unicode

YAML_VERSION = (1, 2)

def intent_dict_to_list_of_dict(data):
list_data = []
for key, values in data.items():
Expand All @@ -18,6 +23,18 @@ def intent_dict_to_list_of_dict(data):

return list_data

def fix_yaml_loader() -> None:
"""Ensure that any string read by yaml is represented as unicode."""
"""Code from Rasa yaml reader"""
def construct_yaml_str(self, node):
# Override the default string handling function
# to always return unicode objects
return self.construct_scalar(node)

yaml.Loader.add_constructor("tag:yaml.org,2002:str", construct_yaml_str)
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str", construct_yaml_str)


class RasaYMLAdapter(Adapter):
def __init__(self, base_filepath=None):
super(RasaYMLAdapter, self).__init__(base_filepath, None)
Expand Down Expand Up @@ -68,10 +85,17 @@ def prepare_example(self, example):
)
result = example.text[:]
for entity in sorted_entities:
entity_annotation_text = ']{"entity": "' + entity.slot_name
entity_text = result[entity._start_index:entity._start_index + entity._len]
if entity_text != entity.value:
entity_annotation_text += '", "value": "{}'.format(entity.value)
if entity.role is not None:
entity_annotation_text += '", "role": "{}'.format(entity.role)
if entity.group is not None:
entity_annotation_text += '", "group": "{}'.format(entity.group)
result = \
result[:entity._start_index] + "[" + \
result[entity._start_index:entity._start_index + entity._len] + \
']{"entity": "' + entity.slot_name + '"}' + \
entity_text + entity_annotation_text + '"}' + \
result[entity._start_index + entity._len:] # New rasa entity format
return result

Expand All @@ -87,18 +111,80 @@ def __format_synonyms(cls, synonyms):
if len(synonyms[slot_name]) > 1
]

def _read_yaml(self, content):
fix_yaml_loader()
yaml_parser = yaml.YAML(typ='safe')
yaml_parser.version = YAML_VERSION
yaml_parser.preserve_quotes = True
yaml.allow_duplicate_keys = False

return yaml_parser.load(content)

def _get_base_to_extend(self):
### TODO Implement later
return self._get_empty_base()
# if self._base_file_contents is None:
# if self._base_filepath is None:
# return self._get_empty_base()
# with io.open(self._base_filepath, 'r', encoding='utf-8') as base_file:
# self._base_file_contents = json.load(base_file)
# self.check_base_file_contents()
# return self._base_file_contents
if self._base_file_contents is None:
if self._base_filepath is None:
return self._get_empty_base()
with io.open(self._base_filepath, 'r', encoding='utf-8') as base_file:
try:
self._base_file_contents = self._read_yaml(base_file.read())
except (YAMLError, DuplicateKeyError) as e:
raise YamlSyntaxException(self._base_filepath, e)
self.check_base_file_contents()
return self._base_file_contents

def _get_empty_base(self):
return {
"nlu": list()
}
base = OrderedDict()
base['version'] = DoubleQuotedScalarString('2.0')
base['nlu'] = list()
return base

def check_base_file_contents(self):
"""
Checks that `self._base_file_contents` contains well formatted NLU dictionary.
Throws a `SyntaxError` if the data is incorrect.
"""
if self._base_file_contents is None:
return
if not isinstance(self._base_file_contents, dict):
self._base_file_contents = None
raise SyntaxError(
"Couldn't load valid data from base file '" + \
self._base_filepath + "'"
)
else:
if "nlu" not in self._base_file_contents:
self._base_file_contents = None
raise SyntaxError(
"Expected 'nlu' as a root of base file '" + \
self._base_filepath + "'")


class YamlSyntaxException(Exception):
"""Raised when a YAML file can not be parsed properly due to a syntax error."""
"""code from rasa.shared.exceptions.YamlSyntaxException"""

def __init__(self, filename, underlying_yaml_exception):
self.filename = filename
self.underlying_yaml_exception = underlying_yaml_exception

def __str__(self):
if self.filename:
exception_text = "Failed to read '{}'.".format(self.filename)
else:
exception_text = "Failed to read YAML."

if self.underlying_yaml_exception:
self.underlying_yaml_exception.warn = None
self.underlying_yaml_exception.note = None
exception_text += " {}".format(self.underlying_yaml_exception)

if self.filename:
exception_text = exception_text.replace(
'in "<unicode string>"', 'in "{}"'.format(self.filename)
)

exception_text += (
"\n\nYou can use https://yamlchecker.com/ to validate the "
"YAML syntax of your file."
)
return exception_text
9 changes: 8 additions & 1 deletion chatette/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from future.utils import with_metaclass

from chatette.units.modifiable.choice import Choice
from chatette.units.modifiable.unit_reference import UnitReference
from chatette.units.modifiable.unit_reference import UnitReference, SlotRoleGroupReference
from chatette.units.modifiable.definitions.alias import AliasDefinition
from chatette.units.modifiable.definitions.slot import SlotDefinition
from chatette.units.modifiable.definitions.intent import IntentDefinition
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(self):
self.identifier = None
self.variation = None
self.arg_value = None
self.slot_rolegroup = None

def _check_information(self):
super(UnitRefBuilder, self)._check_information()
Expand All @@ -108,6 +109,12 @@ def _build_modifiers_repr(self):

def create_concrete(self):
self._check_information()
if self.slot_rolegroup is not None:
return SlotRoleGroupReference(
self.identifier, self.type,
self.leading_space, self._build_modifiers_repr(),
self.slot_rolegroup
)
return UnitReference(
self.identifier, self.type,
self.leading_space, self._build_modifiers_repr()
Expand Down
15 changes: 14 additions & 1 deletion chatette/parsing/lexing/rule_unit_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
extract_identifier, \
CASE_GEN_SYM, UNIT_END_SYM

from chatette.parsing.lexing.rule_annotation import RuleAnnotation
from chatette.parsing.lexing.rule_unit_start import RuleUnitStart
from chatette.parsing.lexing.rule_variation import RuleVariation
from chatette.parsing.lexing.rule_rand_gen import RuleRandGen
Expand Down Expand Up @@ -55,11 +56,13 @@ def _apply_strategy(self, **kwargs):
"using character '" + UNIT_END_SYM + "')."
return False

is_slot = False
# TODO maybe making a function for this would be useful
if self._tokens[0].type == TerminalType.alias_ref_start:
unit_end_type = TerminalType.alias_ref_end
elif self._tokens[0].type == TerminalType.slot_ref_start:
unit_end_type = TerminalType.slot_ref_end
is_slot = True
elif self._tokens[0].type == TerminalType.intent_ref_start:
unit_end_type = TerminalType.intent_ref_end
else: # Should never happen
Expand All @@ -72,5 +75,15 @@ def _apply_strategy(self, **kwargs):
self._next_index += 1
self._update_furthest_matched_index()
self._tokens.append(LexicalToken(unit_end_type, UNIT_END_SYM))


# This is for adding new rasa training mode that has role and group entity
# Reference: https://rasa.com/docs/rasa/nlu-training-data/#entities-roles-and-groups
annotation_rule = RuleAnnotation(self._text, self._next_index)

# ? Should we raise error if RuleAnnotation doesn't match, i.e. wrong pattern
if is_slot and annotation_rule.matches():
self._next_index = annotation_rule.get_next_index_to_match()
self._update_furthest_matched_index()
self._tokens.extend(annotation_rule.get_lexical_tokens())

return True
35 changes: 33 additions & 2 deletions chatette/parsing/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from chatette.units.rule import Rule

from chatette.parsing import \
ChoiceBuilder, UnitRefBuilder, \
ChoiceBuilder, UnitRefBuilder,\
AliasDefBuilder, SlotDefBuilder, IntentDefBuilder


Expand Down Expand Up @@ -398,12 +398,18 @@ def _parse_rule(self, tokens):
elif (
token.type in \
(TerminalType.alias_ref_end,
TerminalType.slot_ref_end,
TerminalType.intent_ref_end)
):
rule_contents.append(current_builder.create_concrete())
current_builder = None
leading_space = False
elif token.type == TerminalType.slot_ref_end:
# checking annotation after slot reference
rolegroup_annotation, i = self._check_for_annotations(tokens, i)
current_builder.slot_rolegroup = rolegroup_annotation
rule_contents.append(current_builder.create_concrete())
current_builder = None
leading_space = False
elif token.type == TerminalType.unit_identifier:
current_builder.identifier = token.text
elif token.type == TerminalType.choice_start:
Expand Down Expand Up @@ -505,3 +511,28 @@ def _parse_choice(self, tokens):
)

return rules

def _check_for_annotations(self, tokens, i):
if (
i+1 == len(tokens)
or tokens[i+1].type != TerminalType.annotation_start
):
return None, i

annotation = {}
current_key = None
for j, token in enumerate(tokens[i+1:]):
if token.type == TerminalType.annotation_end:
i += j+1
break
elif token.type == TerminalType.key:
current_key = token.text
elif token.type == TerminalType.value:
if current_key in annotation:
self.input_file_manager.syntax_error(
"Annotation contained the key '" + current_key + \
"' twice."
)
annotation[current_key] = token.text

return annotation, i
16 changes: 14 additions & 2 deletions chatette/units/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ class Entity(object):
Represents an entity as it will be contained in examples
(instances of `Example`).
"""
def __init__(self, name, length, value=None, start_index=0):
def __init__(self, name, length, value=None, start_index=0, role=None, group=None):
self.slot_name = name # name of the entity (not the associated text)
self.value = value
self._len = length
self._start_index = start_index
self.role = role
self.group = group

def _remove_leading_space(self):
"""
Expand All @@ -146,17 +148,27 @@ def _remove_leading_space(self):
return True

def as_dict(self):
return {
entity_dict = {
"slot-name": self.slot_name,
"value": self.value,
"start-index": self._start_index,
"end-index": self._start_index + self._len
}
if self.role is not None:
entity_dict['role'] = self.role
if self.group is not None:
entity_dict['group'] = self.group
return entity_dict

def __repr__(self):
representation = "entity '" + self.slot_name + "'"
if self.value is not None:
representation += ":'" + self.value + "'"
# ? There might be better representation format?
if self.role is not None:
representation += ", 'role' :'" + self.role + "'"
if self.group is not None:
representation += ", 'group' :'" + self.group + "'"
return representation
def __str__(self):
return \
Expand Down
Loading

0 comments on commit 9b814ad

Please sign in to comment.