Skip to content

Commit

Permalink
Update the constituency conversion tool to take into account the 2022…
Browse files Browse the repository at this point in the history
… updates to the VI dataset. This includes outputing a bunch more diagnostic information for broken trees - the number is now low enough that it is realistic to print them all out
  • Loading branch information
AngledLuffa committed Oct 4, 2022
1 parent 1c768a7 commit 500435d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
25 changes: 24 additions & 1 deletion stanza/utils/datasets/constituency/prepare_con_dataset.py
Expand Up @@ -146,7 +146,28 @@ def process_vlsp21(paths):
raise FileNotFoundError("Could not find the 2021 dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(vlsp_file, paths["CONSTITUENCY_BASE"]))
with tempfile.TemporaryDirectory() as tmp_output_path:
vtb_convert.convert_files([vlsp_file], tmp_output_path)
# This produces a tiny test set, just as a placeholder until the actual test set is released
# This produces a 0 length test set, just as a placeholder until the actual test set is released
vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], short_name, train_size=0.9, dev_size=0.1)
_, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], short_name)
with open(test_file, "w"):
# create an empty test file - currently we don't have actual test data for VLSP 21
pass


def process_vlsp22(paths):
"""
Processes the VLSP 2022 dataset, which is four separate files for some reason
"""
short_name = "vi_vlsp22"
vlsp_dir = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", "VLSP_2022")
if not os.path.exists(vlsp_dir):
raise FileNotFoundError("Could not find the 2022 dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(vlsp_dir, paths["CONSTITUENCY_BASE"]))
vlsp_files = os.listdir(vlsp_dir)
vlsp_files = [os.path.join(vlsp_dir, x) for x in vlsp_files]
print("Procesing {}".format(vlsp_files))
with tempfile.TemporaryDirectory() as tmp_output_path:
vtb_convert.convert_files(vlsp_files, tmp_output_path, verbose=True)
# This produces a 0 length test set, just as a placeholder until the actual test set is released
vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], short_name, train_size=0.9, dev_size=0.1)
_, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], short_name)
with open(test_file, "w"):
Expand Down Expand Up @@ -232,6 +253,8 @@ def main(dataset_name):
process_vlsp09(paths)
elif dataset_name == 'vi_vlsp21':
process_vlsp21(paths)
elif dataset_name == 'vi_vlsp22':
process_vlsp22(paths)
elif dataset_name == 'da_arboretum':
process_arboretum(paths, dataset_name)
elif dataset_name == 'tr_starlang':
Expand Down
38 changes: 26 additions & 12 deletions stanza/utils/datasets/constituency/vtb_convert.py
Expand Up @@ -13,7 +13,7 @@

from collections import Counter

from stanza.models.constituency.tree_reader import read_trees, MixedTreeError
from stanza.models.constituency.tree_reader import read_trees, MixedTreeError, UnlabeledTreeError

REMAPPING = {
'(MPD': '(MDP',
Expand Down Expand Up @@ -73,7 +73,7 @@ def unify_label(tree):
return tree


def is_closed_tree(tree):
def count_paren_parity(tree):
"""
Checks if the tree is properly closed
:param tree: tree as a string
Expand All @@ -85,7 +85,7 @@ def is_closed_tree(tree):
count += 1
elif char == ')':
count -= 1
return count == 0
return count


def is_valid_line(line):
Expand All @@ -106,7 +106,7 @@ def is_valid_line(line):
# not clear if TP is supposed to be NP or PP - needs a native speaker to decode
WEIRD_LABELS = ["WP", "YP", "SNP", "STC", "UPC", "(TP"]

def convert_file(orig_file, new_file):
def convert_file(orig_file, new_file, verbose=False):
"""
:param orig_file: original directory storing original trees
:param new_file: new directory storing formatted constituency trees
Expand All @@ -120,7 +120,7 @@ def convert_file(orig_file, new_file):
# does not have a '(' that signifies the presence of constituents
tree = ""
reading_tree = False
for line in content:
for line_idx, line in enumerate(content):
line = ' '.join(line.split())
if line == '':
continue
Expand All @@ -136,11 +136,19 @@ def convert_file(orig_file, new_file):
errors["empty"] += 1
continue
tree += ')\n'
if not is_closed_tree(tree):
#print("Rejecting the following tree from {} for being unclosed: |{}|".format(orig_file, tree))
parity = count_paren_parity(tree)
if parity > 0:
if verbose:
print("Rejecting the following tree from {} line {} for being unclosed: |{}|".format(orig_file, line_idx, tree))
tree = ""
errors["unclosed"] += 1
continue
if parity < 0:
if verbose:
print("Rejecting the following tree from {} line {} for having extra parens: {}".format(orig_file, line_idx, tree))
tree = ""
errors["extra_parens"] += 1
continue
# TODO: these blocks eliminate 11 trees
# maybe those trees can be salvaged?
bad_label = False
Expand All @@ -160,27 +168,33 @@ def convert_file(orig_file, new_file):
reading_tree = False
tree = ""
except MixedTreeError:
#print("Skipping an illegal tree: {}".format(tree))
errors["illegal"] += 1
if verbose:
print("Skipping a tree with mixed leaves and constituents from {} line {}: {}".format(orig_file, line_idx, tree))
errors["mixed"] += 1
except UnlabeledTreeError:
if verbose:
print("Skipping a tree with unlabeled nodes from {} line {}: {}".format(orig_file, line_idx, tree))
errors["unlabeled"] += 1
else: # content line
if is_valid_line(line) and reading_tree:
tree += line
elif reading_tree:
errors["invalid"] += 1
#print("Invalid tree error in {}: |{}|, rejected because of line |{}|".format(orig_file, tree, line))
if verbose:
print("Invalid tree error in {} line {}: |{}|, rejected because of line |{}|".format(orig_file, line_idx, tree, line))
tree = ""
reading_tree = False

return errors

def convert_files(file_list, new_dir):
def convert_files(file_list, new_dir, verbose=False):
errors = Counter()
for filename in file_list:
base_name, _ = os.path.splitext(os.path.split(filename)[-1])
new_path = os.path.join(new_dir, base_name)
new_file_path = f'{new_path}.mrg'
# Convert the tree and write to new_file_path
errors += convert_file(filename, new_file_path)
errors += convert_file(filename, new_file_path, verbose)

errors = "\n ".join(sorted(["%s: %s" % x for x in errors.items()]))
print("Found the following error counts:\n {}".format(errors))
Expand Down

0 comments on commit 500435d

Please sign in to comment.