diff --git a/stanza/models/identity_lemmatizer.py b/stanza/models/identity_lemmatizer.py index 745fb25f1..f0b7a0df7 100644 --- a/stanza/models/identity_lemmatizer.py +++ b/stanza/models/identity_lemmatizer.py @@ -55,12 +55,17 @@ def main(args=None): # write to file and score batch.doc.set([LEMMA], preds) - CoNLL.write_doc2conll(batch.doc, system_pred_file) + if system_pred_file is not None: + CoNLL.write_doc2conll(batch.doc, system_pred_file) if gold_file is not None: + system_pred_file = "{:C}\n\n".format(batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, gold_file) logger.info("Lemma score:") logger.info("{} {:.2f}".format(args['shorthand'], score*100)) + return None, batch.doc + if __name__ == '__main__': main() diff --git a/stanza/models/lemmatizer.py b/stanza/models/lemmatizer.py index c3f0a821a..6b3e1bfec 100644 --- a/stanza/models/lemmatizer.py +++ b/stanza/models/lemmatizer.py @@ -145,8 +145,7 @@ def train(args): model_file = build_model_filename(args) logger.info("Using full savename: %s", model_file) - # pred and gold path - system_pred_file = args['output_file'] + # gold path gold_file = args['eval_file'] utils.print_config(args) @@ -169,7 +168,8 @@ def train(args): logger.info("Evaluating on dev set...") dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS])) dev_batch.doc.set([LEMMA], dev_preds) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_f = scorer.score(system_pred_file, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) @@ -223,7 +223,8 @@ def train(args): logger.info("[Ensembling dict with seq2seq model...]") dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds) dev_batch.doc.set([LEMMA], dev_preds) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch @@ -302,8 +303,11 @@ def evaluate(args): # write to file and score batch.doc.set([LEMMA], preds) - CoNLL.write_doc2conll(batch.doc, system_pred_file) + if system_pred_file: + CoNLL.write_doc2conll(batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, args['eval_file']) logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100)) diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index 29a67ae33..b9316cd44 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -9,6 +9,7 @@ composing the MWT, a classifier over the characters is used instead of the seq2seq """ +import io import sys import os import shutil @@ -104,9 +105,9 @@ def main(args=None): logger.info("Running MWT expander in {} mode".format(args['mode'])) if args['mode'] == 'train': - train(args) + return train(args) else: - evaluate(args) + return evaluate(args) def train(args): # load data @@ -129,7 +130,6 @@ def train(args): save_each_name = utils.build_save_each_filename(save_each_name) # pred and gold path - system_pred_file = args['output_file'] gold_file = args['gold_file'] # skip training if the language does not have training or dev data @@ -174,8 +174,9 @@ def train(args): dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True)) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) - _, _, dev_f = scorer.score(system_pred_file, gold_file) + system_preds = "{:C}\n\n".format(doc) + system_preds = io.StringIO(system_preds) + _, _, dev_f = scorer.score(system_preds, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) if args.get('dict_only', False): @@ -228,8 +229,9 @@ def train(args): dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) - _, _, dev_score = scorer.score(system_pred_file, gold_file) + system_preds = "{:C}\n\n".format(doc) + system_preds = io.StringIO(system_preds) + _, _, dev_score = scorer.score(system_preds, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score)) @@ -263,11 +265,14 @@ def train(args): dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) - _, _, dev_score = scorer.score(system_pred_file, gold_file) + system_preds = "{:C}\n\n".format(doc) + system_preds = io.StringIO(system_preds) + _, _, dev_score = scorer.score(system_preds, gold_file) logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100)) best_f = max(best_f, dev_score) + return trainer, _ + def evaluate(args): # file paths system_pred_file = args['output_file'] @@ -310,13 +315,18 @@ def evaluate(args): # write to file and score doc = copy.deepcopy(batch.doc) doc.set_mwt_expansions(preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) + if system_pred_file is not None: + CoNLL.write_doc2conll(doc, system_pred_file) + else: + system_pred_file = "{:C}\n\n".format(doc) + system_pred_file = io.StringIO(system_pred_file) if gold_file is not None: _, _, score = scorer.score(system_pred_file, gold_file) logger.info("MWT expansion score: {} {:.2f}".format(args['shorthand'], score*100)) + return trainer, doc if __name__ == '__main__': main() diff --git a/stanza/models/parser.py b/stanza/models/parser.py index 67da9d305..651867819 100644 --- a/stanza/models/parser.py +++ b/stanza/models/parser.py @@ -9,6 +9,7 @@ Training and evaluation for the parser. """ +import io import sys import os import copy @@ -153,7 +154,7 @@ def main(args=None): if args['mode'] == 'train': return train(args) else: - evaluate(args) + return evaluate(args) def model_file_name(args): return utils.standard_model_file_name(args, "parser") @@ -233,9 +234,6 @@ def train(args): dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) - # pred path - system_pred_file = args['output_file'] - # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: logger.info("Skip training because no data available...") @@ -296,7 +294,9 @@ def train(args): dev_preds = predict_dataset(trainer, dev_batch) dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x]) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file']) train_loss = train_loss / args['eval_interval'] # avg loss per batch @@ -333,7 +333,8 @@ def train(args): dev_preds = predict_dataset(trainer, dev_batch) dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x]) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file']) logger.info("Reloaded model with dev score %.4f", dev_score) @@ -379,7 +380,7 @@ def train(args): logger.info("Dev set never evaluated. Saving final model.") trainer.save(model_file) - return trainer + return trainer, _ def evaluate(args): model_file = model_file_name(args) @@ -392,7 +393,7 @@ def evaluate(args): # load model logger.info("Loading model from: {}".format(model_file)) trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args) - return evaluate_trainer(args, trainer, pretrain) + return trainer, evaluate_trainer(args, trainer, pretrain) def evaluate_trainer(args, trainer, pretrain): system_pred_file = args['output_file'] @@ -412,7 +413,8 @@ def evaluate_trainer(args, trainer, pretrain): # write to file and score batch.doc.set([HEAD, DEPREL], [y for x in preds for y in x]) - CoNLL.write_doc2conll(batch.doc, system_pred_file) + if system_pred_file: + CoNLL.write_doc2conll(batch.doc, system_pred_file) if args['gold_labels']: gold_doc = CoNLL.conll2doc(input_file=args['eval_file']) @@ -424,10 +426,14 @@ def evaluate_trainer(args, trainer, pretrain): raise ValueError("Gold document {} has a None at sentence {} word {}\n{:C}".format(args['eval_file'], sent_idx, word_idx, sentence)) scorer.score_named_dependencies(batch.doc, gold_doc, args['output_latex']) + system_pred_file = "{:C}\n\n".format(batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, args['eval_file']) logger.info("Parser score:") logger.info("{} {:.2f}".format(args['shorthand'], score*100)) + return batch.doc + if __name__ == '__main__': main() diff --git a/stanza/models/tagger.py b/stanza/models/tagger.py index 46d42fe8b..e905dda39 100644 --- a/stanza/models/tagger.py +++ b/stanza/models/tagger.py @@ -8,6 +8,7 @@ import argparse import logging +import io import os import time import zipfile @@ -143,9 +144,9 @@ def main(args=None): logger.info("Running tagger in {} mode".format(args['mode'])) if args['mode'] == 'train': - train(args) + return train(args) else: - evaluate(args) + return evaluate(args) def model_file_name(args): return utils.standard_model_file_name(args, "tagger") @@ -282,14 +283,11 @@ def train(args): eval_type = get_eval_type(dev_data) - # pred and gold path - system_pred_file = args['output_file'] - # skip training if the language does not have training or dev data # sum(...) to check if all of the training files are empty if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0: logger.info("Skip training because no data available...") - return + return None, None if args['wandb']: import wandb @@ -350,7 +348,9 @@ def train(args): indices.extend(batch[-1]) dev_preds = utils.unsort(dev_preds, indices) dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x]) - CoNLL.write_doc2conll(dev_data.doc, system_pred_file) + + system_pred_file = "{:C}\n\n".format(dev_data.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type) @@ -410,6 +410,7 @@ def train(args): logger.info("Dev set never evaluated. Saving final model.") trainer.save(model_file) + return trainer, _ def evaluate(args): # file paths @@ -423,7 +424,8 @@ def evaluate(args): # load model logger.info("Loading model from: {}".format(model_file)) trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args) - evaluate_trainer(args, trainer, pretrain) + result_doc = evaluate_trainer(args, trainer, pretrain) + return trainer, result_doc def evaluate_trainer(args, trainer, pretrain): system_pred_file = args['output_file'] @@ -455,12 +457,18 @@ def evaluate_trainer(args, trainer, pretrain): # write to file and score dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x]) - CoNLL.write_doc2conll(dev_data.doc, system_pred_file) + if system_pred_file: + CoNLL.write_doc2conll(dev_data.doc, system_pred_file) if args['gold_labels']: + system_pred_file = "{:C}\n\n".format(dev_data.doc) + system_pred_file = io.StringIO(system_pred_file) + _, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type) logger.info("POS Tagger score: %s %.2f", args['shorthand'], score*100) + return dev_data.doc + if __name__ == '__main__': main() diff --git a/stanza/tests/depparse/test_parser.py b/stanza/tests/depparse/test_parser.py index a9299bf92..27fd81363 100644 --- a/stanza/tests/depparse/test_parser.py +++ b/stanza/tests/depparse/test_parser.py @@ -118,7 +118,7 @@ def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, au args.extend(["--augment_nopunct", "0.0"]) if extra_args is not None: args = args + extra_args - trainer = parser.main(args) + trainer, _ = parser.main(args) assert os.path.exists(save_file) pt = pretrain.Pretrain(wordvec_pretrain_file) diff --git a/stanza/utils/datasets/prepare_depparse_treebank.py b/stanza/utils/datasets/prepare_depparse_treebank.py index 2fd8ab394..6319268cd 100644 --- a/stanza/utils/datasets/prepare_depparse_treebank.py +++ b/stanza/utils/datasets/prepare_depparse_treebank.py @@ -20,7 +20,7 @@ from stanza.resources.default_packages import default_charlms, pos_charlms import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank -from stanza.utils.training.run_pos import wordvec_args +from stanza.utils.training.common import build_pos_wordvec_args from stanza.utils.training.common import add_charlm_args, build_charlm_args, choose_charlm logger = logging.getLogger('stanza') @@ -110,7 +110,7 @@ def process_treebank(treebank, model_type, paths, args) -> None: if args.wordvec_pretrain_file: base_args += ["--wordvec_pretrain_file", args.wordvec_pretrain_file] else: - base_args = base_args + wordvec_args(short_language, dataset, []) + base_args = base_args + build_pos_wordvec_args(short_language, dataset, []) # charlm for POS diff --git a/stanza/utils/datasets/prepare_lemma_treebank.py b/stanza/utils/datasets/prepare_lemma_treebank.py index 16a88c5f8..e0205d542 100644 --- a/stanza/utils/datasets/prepare_lemma_treebank.py +++ b/stanza/utils/datasets/prepare_lemma_treebank.py @@ -31,7 +31,7 @@ def check_lemmas(train_file): # but what if a later dataset includes lemmas? #if short_language in ('vi', 'fro', 'th'): # return False - with open(train_file) as fin: + with open(train_file, encoding="utf-8") as fin: for line in fin: line = line.strip() if not line or line.startswith("#"): diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index b61add1b1..8032c47a3 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -27,6 +27,7 @@ import os import random import re +import sys import tempfile import zipfile @@ -501,7 +502,11 @@ def augment_quotes(sents, ratio=0.15): new_sents.append(new_sent) - print("Augmented {} quotes: {}".format(sum(counts.values()), counts)) + # we go through this to make it simpler to execute on Windows + # rather than nagging the user to set utf-8 + out = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) + print("Augmented {} quotes: {}".format(sum(counts.values()), counts), file=out) + out.detach() return new_sents def find_text_idx(sentence): diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index 875b11d98..727e4a7c0 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -5,11 +5,11 @@ import pathlib import random import sys -import tempfile from enum import Enum from stanza.resources.default_packages import default_charlms, lemma_charlms, tokenizer_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS +from stanza.resources.default_packages import no_pretrain_languages, pos_pretrains, depparse_pretrains, default_pretrains from stanza.models.common.constant import treebank_to_short_name from stanza.models.common.utils import ud_scores from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError @@ -195,13 +195,7 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) - if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer': - with tempfile.NamedTemporaryFile() as temp_output_file: - run_treebank(mode, paths, treebank, short_name, - temp_output_file.name, command_args, extra_args + save_name_args) - else: - run_treebank(mode, paths, treebank, short_name, - None, command_args, extra_args + save_name_args) + run_treebank(mode, paths, treebank, short_name, command_args, extra_args + save_name_args) def run_eval_script(gold_conllu_file, system_conllu_file, evals=None): """ Wrapper for lemma scorer. """ @@ -421,3 +415,28 @@ def build_tokenizer_charlm_args(short_language, dataset, charlm, base_args=True, charlm = choose_tokenizer_charlm(short_language, dataset, charlm) charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir, use_backward_model=False) return charlm_args + + +def build_wordvec_args(short_language, dataset, extra_args, task_pretrains): + if '--wordvec_pretrain_file' in extra_args or '--no_pretrain' in extra_args: + return [] + + if short_language in no_pretrain_languages: + # we couldn't find word vectors for a few languages...: + # coptic, naija, old russian, turkish german, swedish sign language + logger.warning("No known word vectors for language {} If those vectors can be found, please update the training scripts.".format(short_language)) + return ["--no_pretrain"] + else: + if short_language in task_pretrains and dataset in task_pretrains[short_language]: + dataset_pretrains = task_pretrains + else: + dataset_pretrains = {} + wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, dataset_pretrains, dataset) + return ["--wordvec_pretrain_file", wordvec_pretrain] + +def build_pos_wordvec_args(short_language, dataset, extra_args): + return build_wordvec_args(short_language, dataset, extra_args, pos_pretrains) + +def build_depparse_wordvec_args(short_language, dataset, extra_args): + return build_wordvec_args(short_language, dataset, extra_args, depparse_pretrains) + diff --git a/stanza/utils/training/run_charlm.py b/stanza/utils/training/run_charlm.py index 2aa251382..2369dcfc4 100644 --- a/stanza/utils/training/run_charlm.py +++ b/stanza/utils/training/run_charlm.py @@ -22,7 +22,7 @@ def add_charlm_args(parser): def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): + command_args, extra_args): short_language, dataset_name = short_name.split("_", 1) train_dir = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "train") diff --git a/stanza/utils/training/run_constituency.py b/stanza/utils/training/run_constituency.py index 4e993f56a..9322866ba 100644 --- a/stanza/utils/training/run_constituency.py +++ b/stanza/utils/training/run_constituency.py @@ -72,7 +72,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): return save_name -def run_treebank(mode, paths, treebank, short_name, temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): constituency_dir = paths["CONSTITUENCY_DATA_DIR"] short_language, dataset = short_name.split("_") diff --git a/stanza/utils/training/run_depparse.py b/stanza/utils/training/run_depparse.py index daa9ff3da..404ea2183 100644 --- a/stanza/utils/training/run_depparse.py +++ b/stanza/utils/training/run_depparse.py @@ -1,3 +1,4 @@ +import io import logging import os @@ -5,7 +6,7 @@ from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_depparse_charlm_args, choose_depparse_charlm, choose_transformer -from stanza.utils.training.run_pos import wordvec_args +from stanza.utils.training.common import build_depparse_wordvec_args from stanza.resources.default_packages import default_charlms, depparse_charlms @@ -29,7 +30,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): train_args = ["--shorthand", short_name, "--mode", "train"] # TODO: also, this downloads the wordvec, which we might not want to do yet - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args + train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: @@ -39,17 +40,16 @@ def build_model_filename(paths, short_name, command_args, extra_args): return save_name -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) # TODO: refactor these blocks? depparse_dir = paths["DEPPARSE_DATA_DIR"] train_file = f"{depparse_dir}/{short_name}.train.in.conllu" dev_in_file = f"{depparse_dir}/{short_name}.dev.in.conllu" - dev_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.dev.pred.conllu" + dev_pred_file = f"{depparse_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{depparse_dir}/{short_name}.test.in.conllu" - test_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.test.pred.conllu" + test_pred_file = f"{depparse_dir}/{short_name}.test.pred.conllu" eval_file = None if '--eval_file' in extra_args: @@ -84,12 +84,11 @@ def run_treebank(mode, paths, treebank, short_name, train_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--train_file", train_file, "--eval_file", eval_file if eval_file else dev_in_file, - "--output_file", dev_pred_file, "--batch_size", batch_size, "--lang", short_language, "--shorthand", short_name, "--mode", "train"] - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args train_args = train_args + extra_args logger.info("Running train depparse for {} with args {}".format(treebank, train_args)) parser.main(train_args) @@ -97,37 +96,45 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--eval_file", eval_file if eval_file else dev_in_file, - "--output_file", dev_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] - dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + if command_args.save_output: + dev_args.extend(["--output_file", dev_pred_file]) + dev_args = dev_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev depparse for {} with args {}".format(treebank, dev_args)) - parser.main(dev_args) + _, dev_doc = parser.main(dev_args) if '--no_gold_labels' not in extra_args: + if not command_args.save_output: + dev_pred_file = "{:C}\n\n".format(dev_doc) + dev_pred_file = io.StringIO(dev_pred_file) results = common.run_eval_script_depparse(eval_file if eval_file else dev_in_file, dev_pred_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", dev_pred_file) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--eval_file", eval_file if eval_file else test_in_file, - "--output_file", test_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] - test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + if command_args.save_output: + test_args.extend(["--output_file", test_pred_file]) + test_args = test_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test depparse for {} with args {}".format(treebank, test_args)) - parser.main(test_args) + _, test_doc = parser.main(test_args) if '--no_gold_labels' not in extra_args: + if not command_args.save_output: + test_pred_file = "{:C}\n\n".format(test_doc) + test_pred_file = io.StringIO(test_pred_file) results = common.run_eval_script_depparse(eval_file if eval_file else test_in_file, test_pred_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", test_pred_file) diff --git a/stanza/utils/training/run_ete.py b/stanza/utils/training/run_ete.py index bd4ef1f00..67f3000f1 100644 --- a/stanza/utils/training/run_ete.py +++ b/stanza/utils/training/run_ete.py @@ -168,8 +168,7 @@ def run_ete(paths, dataset, short_name, command_args, extra_args): results = common.run_eval_script(gold_file, ete_file) logger.info("{} {} models on {} {} data:\n{}".format(RESULTS_STRING, short_name, test_short_name, dataset, results)) -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): if mode == Mode.TRAIN: dataset = 'train' elif mode == Mode.SCORE_DEV: diff --git a/stanza/utils/training/run_lemma.py b/stanza/utils/training/run_lemma.py index a2668ba09..f3d7331c9 100644 --- a/stanza/utils/training/run_lemma.py +++ b/stanza/utils/training/run_lemma.py @@ -70,16 +70,15 @@ def build_model_filename(paths, short_name, command_args, extra_args): save_name = lemmatizer.build_model_filename(args) return save_name -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) lemma_dir = paths["LEMMA_DATA_DIR"] train_file = f"{lemma_dir}/{short_name}.train.in.conllu" dev_in_file = f"{lemma_dir}/{short_name}.dev.in.conllu" - dev_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.dev.pred.conllu" + dev_pred_file = f"{lemma_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{lemma_dir}/{short_name}.test.in.conllu" - test_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.test.pred.conllu" + test_pred_file = f"{lemma_dir}/{short_name}.test.pred.conllu" charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) @@ -94,15 +93,19 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.TRAIN or mode == Mode.SCORE_DEV: train_args = ["--train_file", train_file, "--eval_file", dev_in_file, - "--output_file", dev_pred_file, + "--gold_file", dev_in_file, "--shorthand", short_name] + if command_args.save_output: + train_args.extend(["--output_file", dev_pred_file]) logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) identity_lemmatizer.main(train_args) elif mode == Mode.SCORE_TEST: train_args = ["--train_file", train_file, "--eval_file", test_in_file, - "--output_file", test_pred_file, + "--gold_file", test_in_file, "--shorthand", short_name] + if command_args.save_output: + train_args.extend(["--output_file", test_pred_file]) logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) identity_lemmatizer.main(train_args) else: @@ -115,7 +118,6 @@ def run_treebank(mode, paths, treebank, short_name, train_args = ["--train_file", train_file, "--eval_file", dev_in_file, - "--output_file", dev_pred_file, "--shorthand", short_name, "--num_epoch", num_epochs, "--mode", "train"] @@ -125,18 +127,20 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--eval_file", dev_in_file, - "--output_file", dev_pred_file, "--shorthand", short_name, "--mode", "predict"] + if command_args.save_output: + train_args.extend(["--output_file", dev_pred_file]) dev_args = dev_args + charlm_args + extra_args logger.info("Running dev lemmatizer for {} with args {}".format(treebank, dev_args)) lemmatizer.main(dev_args) if mode == Mode.SCORE_TEST: test_args = ["--eval_file", test_in_file, - "--output_file", test_pred_file, "--shorthand", short_name, "--mode", "predict"] + if command_args.save_output: + train_args.extend(["--output_file", test_pred_file]) test_args = test_args + charlm_args + extra_args logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args)) lemmatizer.main(test_args) diff --git a/stanza/utils/training/run_lemma_classifier.py b/stanza/utils/training/run_lemma_classifier.py index 73669e79b..68ef748ea 100644 --- a/stanza/utils/training/run_lemma_classifier.py +++ b/stanza/utils/training/run_lemma_classifier.py @@ -18,8 +18,7 @@ def add_lemma_args(parser): def build_model_filename(paths, short_name, command_args, extra_args): return os.path.join("saved_models", "lemma_classifier", short_name + "_lemma_classifier.pt") -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) base_args = [] diff --git a/stanza/utils/training/run_mwt.py b/stanza/utils/training/run_mwt.py index af3171d71..74d4ba199 100644 --- a/stanza/utils/training/run_mwt.py +++ b/stanza/utils/training/run_mwt.py @@ -16,6 +16,7 @@ """ +import io import logging import math @@ -37,8 +38,7 @@ def check_mwt(filename): data = doc.get_mwt_expansions(False) return len(data) > 0 -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language = short_name.split("_")[0] mwt_dir = paths["MWT_DATA_DIR"] @@ -46,10 +46,10 @@ def run_treebank(mode, paths, treebank, short_name, train_file = f"{mwt_dir}/{short_name}.train.in.conllu" dev_in_file = f"{mwt_dir}/{short_name}.dev.in.conllu" dev_gold_file = f"{mwt_dir}/{short_name}.dev.gold.conllu" - dev_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.dev.pred.conllu" + dev_output_file = f"{mwt_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{mwt_dir}/{short_name}.test.in.conllu" test_gold_file = f"{mwt_dir}/{short_name}.test.gold.conllu" - test_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.test.pred.conllu" + test_output_file = f"{mwt_dir}/{short_name}.test.pred.conllu" train_json = f"{mwt_dir}/{short_name}-ud-train-mwt.json" dev_json = f"{mwt_dir}/{short_name}-ud-dev-mwt.json" @@ -76,7 +76,6 @@ def run_treebank(mode, paths, treebank, short_name, logger.info("Max len: %f" % max_mwt_len) train_args = ['--train_file', train_file, '--eval_file', eval_file if eval_file else dev_in_file, - '--output_file', dev_output_file, '--gold_file', gold_file if gold_file else dev_gold_file, '--lang', short_language, '--shorthand', short_name, @@ -88,28 +87,36 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ['--eval_file', eval_file if eval_file else dev_in_file, - '--output_file', dev_output_file, '--gold_file', gold_file if gold_file else dev_gold_file, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict'] + if command_args.save_output: + dev_args.extend(['--output_file', dev_output_file]) dev_args = dev_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) - mwt_expander.main(dev_args) + _, dev_doc = mwt_expander.main(dev_args) + if not command_args.save_output: + dev_output_file = "{:C}\n\n".format(dev_doc) + dev_output_file = io.StringIO(dev_output_file) results = common.run_eval_script_mwt(gold_file if gold_file else dev_gold_file, dev_output_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) if mode == Mode.SCORE_TEST: test_args = ['--eval_file', eval_file if eval_file else test_in_file, - '--output_file', test_output_file, '--gold_file', gold_file if gold_file else test_gold_file, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict'] + if command_args.save_output: + test_args.extend(['--output_file', test_output_file]) test_args = test_args + extra_args logger.info("Running test step with args: {}".format(test_args)) - mwt_expander.main(test_args) + _, test_doc = mwt_expander.main(test_args) + if not command_args.save_output: + test_output_file = "{:C}\n\n".format(test_doc) + test_output_file = io.StringIO(test_output_file) results = common.run_eval_script_mwt(gold_file if gold_file else test_gold_file, test_output_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) diff --git a/stanza/utils/training/run_ner.py b/stanza/utils/training/run_ner.py index 2651df58f..10fdabc18 100644 --- a/stanza/utils/training/run_ner.py +++ b/stanza/utils/training/run_ner.py @@ -93,8 +93,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): # However, to keep the naming consistent, we leave the # method which does the training as run_treebank # TODO: rename treebank -> dataset everywhere -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): ner_dir = paths["NER_DATA_DIR"] language, dataset = short_name.split("_") diff --git a/stanza/utils/training/run_pos.py b/stanza/utils/training/run_pos.py index dc19e16c4..2808d2f8a 100644 --- a/stanza/utils/training/run_pos.py +++ b/stanza/utils/training/run_pos.py @@ -1,13 +1,12 @@ - +import io import logging import os from stanza.models import tagger -from stanza.resources.default_packages import no_pretrain_languages, pos_pretrains, default_pretrains from stanza.utils.training import common -from stanza.utils.training.common import Mode, add_charlm_args, build_pos_charlm_args, choose_pos_charlm, find_wordvec_pretrain +from stanza.utils.training.common import Mode, add_charlm_args, build_pos_charlm_args, choose_pos_charlm, find_wordvec_pretrain, build_pos_wordvec_args logger = logging.getLogger('stanza') @@ -16,24 +15,6 @@ def add_pos_args(parser): parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') -# TODO: move this somewhere common -def wordvec_args(short_language, dataset, extra_args): - if '--wordvec_pretrain_file' in extra_args or '--no_pretrain' in extra_args: - return [] - - if short_language in no_pretrain_languages: - # we couldn't find word vectors for a few languages...: - # coptic, naija, old russian, turkish german, swedish sign language - logger.warning("No known word vectors for language {} If those vectors can be found, please update the training scripts.".format(short_language)) - return ["--no_pretrain"] - else: - if short_language in pos_pretrains and dataset in pos_pretrains[short_language]: - dataset_pretrains = pos_pretrains - else: - dataset_pretrains = {} - wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, dataset_pretrains, dataset) - return ["--wordvec_pretrain_file", wordvec_pretrain] - def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) @@ -45,7 +26,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): train_args = ["--shorthand", short_name, "--mode", "train"] # TODO: also, this downloads the wordvec, which we might not want to do yet - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args + train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: @@ -56,8 +37,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) pos_dir = paths["POS_DATA_DIR"] @@ -65,9 +45,9 @@ def run_treebank(mode, paths, treebank, short_name, if short_name == 'vi_vlsp22': train_file += f";{pos_dir}/vi_vtb.train.in.conllu" dev_in_file = f"{pos_dir}/{short_name}.dev.in.conllu" - dev_pred_file = temp_output_file if temp_output_file else f"{pos_dir}/{short_name}.dev.pred.conllu" + dev_pred_file = f"{pos_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{pos_dir}/{short_name}.test.in.conllu" - test_pred_file = temp_output_file if temp_output_file else f"{pos_dir}/{short_name}.test.pred.conllu" + test_pred_file = f"{pos_dir}/{short_name}.test.pred.conllu" charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm) bert_args = common.choose_transformer(short_language, command_args, extra_args) @@ -95,51 +75,58 @@ def run_treebank(mode, paths, treebank, short_name, train_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--train_file", train_file, - "--output_file", dev_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "train"] if eval_file is None: train_args += ['--eval_file', dev_in_file] - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args train_args = train_args + extra_args logger.info("Running train POS for {} with args {}".format(treebank, train_args)) tagger.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"], - "--output_file", dev_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if eval_file is None: dev_args += ['--eval_file', dev_in_file] - dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + if command_args.save_output: + dev_args.extend(["--output_file", dev_pred_file]) + dev_args = dev_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev POS for {} with args {}".format(treebank, dev_args)) - tagger.main(dev_args) + _, dev_doc = tagger.main(dev_args) + if not command_args.save_output: + dev_pred_file = "{:C}\n\n".format(dev_doc) + dev_pred_file = io.StringIO(dev_pred_file) results = common.run_eval_script_pos(eval_file if eval_file else dev_in_file, dev_pred_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", dev_pred_file) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--wordvec_dir", paths["WORDVEC_DIR"], - "--output_file", test_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if eval_file is None: test_args += ['--eval_file', test_in_file] - test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + if command_args.save_output: + dev_args.extend(["--output_file", test_pred_file]) + test_args = test_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test POS for {} with args {}".format(treebank, test_args)) - tagger.main(test_args) + _, test_doc = tagger.main(test_args) + if not command_args.save_output: + test_pred_file = "{:C}\n\n".format(test_doc) + test_pred_file = io.StringIO(test_pred_file) results = common.run_eval_script_pos(eval_file if eval_file else test_in_file, test_pred_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", test_pred_file) diff --git a/stanza/utils/training/run_sentiment.py b/stanza/utils/training/run_sentiment.py index ae46f1275..a985b1b94 100644 --- a/stanza/utils/training/run_sentiment.py +++ b/stanza/utils/training/run_sentiment.py @@ -66,8 +66,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): return save_name -def run_dataset(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_dataset(mode, paths, treebank, short_name, command_args, extra_args): sentiment_dir = paths["SENTIMENT_DATA_DIR"] short_language, dataset = short_name.split("_", 1) diff --git a/stanza/utils/training/run_tokenizer.py b/stanza/utils/training/run_tokenizer.py index 2128e8e00..6c2173457 100644 --- a/stanza/utils/training/run_tokenizer.py +++ b/stanza/utils/training/run_tokenizer.py @@ -66,8 +66,7 @@ def uses_dictionary(short_language): return True return False -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): tokenize_dir = paths["TOKENIZE_DATA_DIR"] short_language, dataset = short_name.split("_", 1)