From a78fd9cce623515c1a6dcfbd98acce44fe150995 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 01:00:42 -0700 Subject: [PATCH 1/9] Connect MWT to the scorer using io.StringIO rather than an output file when --save_output is not set. Avoiding tempfiles makes the script work on Windows Addresses the MWT part of https://github.com/stanfordnlp/stanza-train/issues/20 --- stanza/models/mwt_expander.py | 30 ++++++++++++++++++++---------- stanza/utils/training/common.py | 2 +- stanza/utils/training/run_mwt.py | 22 +++++++++++++++------- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index 29a67ae336..b9316cd44e 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -9,6 +9,7 @@ composing the MWT, a classifier over the characters is used instead of the seq2seq """ +import io import sys import os import shutil @@ -104,9 +105,9 @@ def main(args=None): logger.info("Running MWT expander in {} mode".format(args['mode'])) if args['mode'] == 'train': - train(args) + return train(args) else: - evaluate(args) + return evaluate(args) def train(args): # load data @@ -129,7 +130,6 @@ def train(args): save_each_name = utils.build_save_each_filename(save_each_name) # pred and gold path - system_pred_file = args['output_file'] gold_file = args['gold_file'] # skip training if the language does not have training or dev data @@ -174,8 +174,9 @@ def train(args): dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True)) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) - _, _, dev_f = scorer.score(system_pred_file, gold_file) + system_preds = "{:C}\n\n".format(doc) + system_preds = io.StringIO(system_preds) + _, _, dev_f = scorer.score(system_preds, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) if args.get('dict_only', False): @@ -228,8 +229,9 @@ def train(args): dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) - _, _, dev_score = scorer.score(system_pred_file, gold_file) + system_preds = "{:C}\n\n".format(doc) + system_preds = io.StringIO(system_preds) + _, _, dev_score = scorer.score(system_preds, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score)) @@ -263,11 +265,14 @@ def train(args): dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) - _, _, dev_score = scorer.score(system_pred_file, gold_file) + system_preds = "{:C}\n\n".format(doc) + system_preds = io.StringIO(system_preds) + _, _, dev_score = scorer.score(system_preds, gold_file) logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100)) best_f = max(best_f, dev_score) + return trainer, _ + def evaluate(args): # file paths system_pred_file = args['output_file'] @@ -310,13 +315,18 @@ def evaluate(args): # write to file and score doc = copy.deepcopy(batch.doc) doc.set_mwt_expansions(preds, fake_dependencies=True) - CoNLL.write_doc2conll(doc, system_pred_file) + if system_pred_file is not None: + CoNLL.write_doc2conll(doc, system_pred_file) + else: + system_pred_file = "{:C}\n\n".format(doc) + system_pred_file = io.StringIO(system_pred_file) if gold_file is not None: _, _, score = scorer.score(system_pred_file, gold_file) logger.info("MWT expansion score: {} {:.2f}".format(args['shorthand'], score*100)) + return trainer, doc if __name__ == '__main__': main() diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index 875b11d98f..c176c090ea 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -195,7 +195,7 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) - if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer': + if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'mwt_expander': with tempfile.NamedTemporaryFile() as temp_output_file: run_treebank(mode, paths, treebank, short_name, temp_output_file.name, command_args, extra_args + save_name_args) diff --git a/stanza/utils/training/run_mwt.py b/stanza/utils/training/run_mwt.py index af3171d715..0185ab52c3 100644 --- a/stanza/utils/training/run_mwt.py +++ b/stanza/utils/training/run_mwt.py @@ -16,6 +16,7 @@ """ +import io import logging import math @@ -46,10 +47,10 @@ def run_treebank(mode, paths, treebank, short_name, train_file = f"{mwt_dir}/{short_name}.train.in.conllu" dev_in_file = f"{mwt_dir}/{short_name}.dev.in.conllu" dev_gold_file = f"{mwt_dir}/{short_name}.dev.gold.conllu" - dev_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.dev.pred.conllu" + dev_output_file = f"{mwt_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{mwt_dir}/{short_name}.test.in.conllu" test_gold_file = f"{mwt_dir}/{short_name}.test.gold.conllu" - test_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.test.pred.conllu" + test_output_file = f"{mwt_dir}/{short_name}.test.pred.conllu" train_json = f"{mwt_dir}/{short_name}-ud-train-mwt.json" dev_json = f"{mwt_dir}/{short_name}-ud-dev-mwt.json" @@ -76,7 +77,6 @@ def run_treebank(mode, paths, treebank, short_name, logger.info("Max len: %f" % max_mwt_len) train_args = ['--train_file', train_file, '--eval_file', eval_file if eval_file else dev_in_file, - '--output_file', dev_output_file, '--gold_file', gold_file if gold_file else dev_gold_file, '--lang', short_language, '--shorthand', short_name, @@ -88,28 +88,36 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ['--eval_file', eval_file if eval_file else dev_in_file, - '--output_file', dev_output_file, '--gold_file', gold_file if gold_file else dev_gold_file, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict'] + if command_args.save_output: + dev_args.extend(['--output_file', dev_output_file]) dev_args = dev_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) - mwt_expander.main(dev_args) + _, dev_doc = mwt_expander.main(dev_args) + if not command_args.save_output: + dev_output_file = "{:C}\n\n".format(dev_doc) + dev_output_file = io.StringIO(dev_output_file) results = common.run_eval_script_mwt(gold_file if gold_file else dev_gold_file, dev_output_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) if mode == Mode.SCORE_TEST: test_args = ['--eval_file', eval_file if eval_file else test_in_file, - '--output_file', test_output_file, '--gold_file', gold_file if gold_file else test_gold_file, '--lang', short_language, '--shorthand', short_name, '--mode', 'predict'] + if command_args.save_output: + test_args.extend(['--output_file', test_output_file]) test_args = test_args + extra_args logger.info("Running test step with args: {}".format(test_args)) - mwt_expander.main(test_args) + _, test_doc = mwt_expander.main(test_args) + if not command_args.save_output: + test_output_file = "{:C}\n\n".format(test_doc) + test_output_file = io.StringIO(test_output_file) results = common.run_eval_script_mwt(gold_file if gold_file else test_gold_file, test_output_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) From 091656e7b15734e1f6f6b58b5e81774d45584efc Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 08:15:43 -0700 Subject: [PATCH 2/9] Refactor build_pos_wordvec_args from run_pos. Also used in depparse (and other models look for wordvec_args as well) prepare_depparse also uses the POS wordvec args when retagging the depparse files build_depparse_wordvec_args instead of reusing the pos version --- .../datasets/prepare_depparse_treebank.py | 4 +-- stanza/utils/training/common.py | 26 +++++++++++++++++ stanza/utils/training/run_depparse.py | 10 +++---- stanza/utils/training/run_pos.py | 29 ++++--------------- 4 files changed, 38 insertions(+), 31 deletions(-) diff --git a/stanza/utils/datasets/prepare_depparse_treebank.py b/stanza/utils/datasets/prepare_depparse_treebank.py index 2fd8ab3948..6319268cd1 100644 --- a/stanza/utils/datasets/prepare_depparse_treebank.py +++ b/stanza/utils/datasets/prepare_depparse_treebank.py @@ -20,7 +20,7 @@ from stanza.resources.default_packages import default_charlms, pos_charlms import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank -from stanza.utils.training.run_pos import wordvec_args +from stanza.utils.training.common import build_pos_wordvec_args from stanza.utils.training.common import add_charlm_args, build_charlm_args, choose_charlm logger = logging.getLogger('stanza') @@ -110,7 +110,7 @@ def process_treebank(treebank, model_type, paths, args) -> None: if args.wordvec_pretrain_file: base_args += ["--wordvec_pretrain_file", args.wordvec_pretrain_file] else: - base_args = base_args + wordvec_args(short_language, dataset, []) + base_args = base_args + build_pos_wordvec_args(short_language, dataset, []) # charlm for POS diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index c176c090ea..15816ca393 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -10,6 +10,7 @@ from enum import Enum from stanza.resources.default_packages import default_charlms, lemma_charlms, tokenizer_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS +from stanza.resources.default_packages import no_pretrain_languages, pos_pretrains, depparse_pretrains, default_pretrains from stanza.models.common.constant import treebank_to_short_name from stanza.models.common.utils import ud_scores from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError @@ -421,3 +422,28 @@ def build_tokenizer_charlm_args(short_language, dataset, charlm, base_args=True, charlm = choose_tokenizer_charlm(short_language, dataset, charlm) charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir, use_backward_model=False) return charlm_args + + +def build_wordvec_args(short_language, dataset, extra_args, task_pretrains): + if '--wordvec_pretrain_file' in extra_args or '--no_pretrain' in extra_args: + return [] + + if short_language in no_pretrain_languages: + # we couldn't find word vectors for a few languages...: + # coptic, naija, old russian, turkish german, swedish sign language + logger.warning("No known word vectors for language {} If those vectors can be found, please update the training scripts.".format(short_language)) + return ["--no_pretrain"] + else: + if short_language in task_pretrains and dataset in task_pretrains[short_language]: + dataset_pretrains = task_pretrains + else: + dataset_pretrains = {} + wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, dataset_pretrains, dataset) + return ["--wordvec_pretrain_file", wordvec_pretrain] + +def build_pos_wordvec_args(short_language, dataset, extra_args): + return build_wordvec_args(short_language, dataset, extra_args, pos_pretrains) + +def build_depparse_wordvec_args(short_language, dataset, extra_args): + return build_wordvec_args(short_language, dataset, extra_args, depparse_pretrains) + diff --git a/stanza/utils/training/run_depparse.py b/stanza/utils/training/run_depparse.py index daa9ff3da9..c7680e5b78 100644 --- a/stanza/utils/training/run_depparse.py +++ b/stanza/utils/training/run_depparse.py @@ -5,7 +5,7 @@ from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_depparse_charlm_args, choose_depparse_charlm, choose_transformer -from stanza.utils.training.run_pos import wordvec_args +from stanza.utils.training.common import build_depparse_wordvec_args from stanza.resources.default_packages import default_charlms, depparse_charlms @@ -29,7 +29,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): train_args = ["--shorthand", short_name, "--mode", "train"] # TODO: also, this downloads the wordvec, which we might not want to do yet - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args + train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: @@ -89,7 +89,7 @@ def run_treebank(mode, paths, treebank, short_name, "--lang", short_language, "--shorthand", short_name, "--mode", "train"] - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + train_args = train_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args train_args = train_args + extra_args logger.info("Running train depparse for {} with args {}".format(treebank, train_args)) parser.main(train_args) @@ -101,7 +101,7 @@ def run_treebank(mode, paths, treebank, short_name, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] - dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + dev_args = dev_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev depparse for {} with args {}".format(treebank, dev_args)) parser.main(dev_args) @@ -119,7 +119,7 @@ def run_treebank(mode, paths, treebank, short_name, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] - test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + test_args = test_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test depparse for {} with args {}".format(treebank, test_args)) parser.main(test_args) diff --git a/stanza/utils/training/run_pos.py b/stanza/utils/training/run_pos.py index dc19e16c43..e3909811e7 100644 --- a/stanza/utils/training/run_pos.py +++ b/stanza/utils/training/run_pos.py @@ -5,9 +5,8 @@ from stanza.models import tagger -from stanza.resources.default_packages import no_pretrain_languages, pos_pretrains, default_pretrains from stanza.utils.training import common -from stanza.utils.training.common import Mode, add_charlm_args, build_pos_charlm_args, choose_pos_charlm, find_wordvec_pretrain +from stanza.utils.training.common import Mode, add_charlm_args, build_pos_charlm_args, choose_pos_charlm, find_wordvec_pretrain, build_pos_wordvec_args logger = logging.getLogger('stanza') @@ -16,24 +15,6 @@ def add_pos_args(parser): parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') -# TODO: move this somewhere common -def wordvec_args(short_language, dataset, extra_args): - if '--wordvec_pretrain_file' in extra_args or '--no_pretrain' in extra_args: - return [] - - if short_language in no_pretrain_languages: - # we couldn't find word vectors for a few languages...: - # coptic, naija, old russian, turkish german, swedish sign language - logger.warning("No known word vectors for language {} If those vectors can be found, please update the training scripts.".format(short_language)) - return ["--no_pretrain"] - else: - if short_language in pos_pretrains and dataset in pos_pretrains[short_language]: - dataset_pretrains = pos_pretrains - else: - dataset_pretrains = {} - wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, dataset_pretrains, dataset) - return ["--wordvec_pretrain_file", wordvec_pretrain] - def build_model_filename(paths, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) @@ -45,7 +26,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): train_args = ["--shorthand", short_name, "--mode", "train"] # TODO: also, this downloads the wordvec, which we might not want to do yet - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args + train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args if command_args.save_name is not None: train_args.extend(["--save_name", command_args.save_name]) if command_args.save_dir is not None: @@ -101,7 +82,7 @@ def run_treebank(mode, paths, treebank, short_name, "--mode", "train"] if eval_file is None: train_args += ['--eval_file', dev_in_file] - train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + train_args = train_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args train_args = train_args + extra_args logger.info("Running train POS for {} with args {}".format(treebank, train_args)) tagger.main(train_args) @@ -114,7 +95,7 @@ def run_treebank(mode, paths, treebank, short_name, "--mode", "predict"] if eval_file is None: dev_args += ['--eval_file', dev_in_file] - dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + dev_args = dev_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev POS for {} with args {}".format(treebank, dev_args)) tagger.main(dev_args) @@ -132,7 +113,7 @@ def run_treebank(mode, paths, treebank, short_name, "--mode", "predict"] if eval_file is None: test_args += ['--eval_file', test_in_file] - test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + test_args = test_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test POS for {} with args {}".format(treebank, test_args)) tagger.main(test_args) From 426526c83fab40090f38b83c4b0211ed8a0bcc08 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 10:06:42 -0700 Subject: [PATCH 3/9] Make it so run_pos.py functions on Windows as well. https://github.com/stanfordnlp/stanza-train/issues/20 Don't need to create a tempfile for POS --- stanza/models/tagger.py | 26 +++++++++++++++++--------- stanza/utils/training/common.py | 2 +- stanza/utils/training/run_pos.py | 27 +++++++++++++++++---------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/stanza/models/tagger.py b/stanza/models/tagger.py index 46d42fe8be..e905dda398 100644 --- a/stanza/models/tagger.py +++ b/stanza/models/tagger.py @@ -8,6 +8,7 @@ import argparse import logging +import io import os import time import zipfile @@ -143,9 +144,9 @@ def main(args=None): logger.info("Running tagger in {} mode".format(args['mode'])) if args['mode'] == 'train': - train(args) + return train(args) else: - evaluate(args) + return evaluate(args) def model_file_name(args): return utils.standard_model_file_name(args, "tagger") @@ -282,14 +283,11 @@ def train(args): eval_type = get_eval_type(dev_data) - # pred and gold path - system_pred_file = args['output_file'] - # skip training if the language does not have training or dev data # sum(...) to check if all of the training files are empty if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0: logger.info("Skip training because no data available...") - return + return None, None if args['wandb']: import wandb @@ -350,7 +348,9 @@ def train(args): indices.extend(batch[-1]) dev_preds = utils.unsort(dev_preds, indices) dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x]) - CoNLL.write_doc2conll(dev_data.doc, system_pred_file) + + system_pred_file = "{:C}\n\n".format(dev_data.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type) @@ -410,6 +410,7 @@ def train(args): logger.info("Dev set never evaluated. Saving final model.") trainer.save(model_file) + return trainer, _ def evaluate(args): # file paths @@ -423,7 +424,8 @@ def evaluate(args): # load model logger.info("Loading model from: {}".format(model_file)) trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args) - evaluate_trainer(args, trainer, pretrain) + result_doc = evaluate_trainer(args, trainer, pretrain) + return trainer, result_doc def evaluate_trainer(args, trainer, pretrain): system_pred_file = args['output_file'] @@ -455,12 +457,18 @@ def evaluate_trainer(args, trainer, pretrain): # write to file and score dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x]) - CoNLL.write_doc2conll(dev_data.doc, system_pred_file) + if system_pred_file: + CoNLL.write_doc2conll(dev_data.doc, system_pred_file) if args['gold_labels']: + system_pred_file = "{:C}\n\n".format(dev_data.doc) + system_pred_file = io.StringIO(system_pred_file) + _, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type) logger.info("POS Tagger score: %s %.2f", args['shorthand'], score*100) + return dev_data.doc + if __name__ == '__main__': main() diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index 15816ca393..2ad86e3920 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -196,7 +196,7 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) - if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'mwt_expander': + if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'tagger' and model_name != 'mwt_expander': with tempfile.NamedTemporaryFile() as temp_output_file: run_treebank(mode, paths, treebank, short_name, temp_output_file.name, command_args, extra_args + save_name_args) diff --git a/stanza/utils/training/run_pos.py b/stanza/utils/training/run_pos.py index e3909811e7..353090d60e 100644 --- a/stanza/utils/training/run_pos.py +++ b/stanza/utils/training/run_pos.py @@ -1,5 +1,5 @@ - +import io import logging import os @@ -46,9 +46,9 @@ def run_treebank(mode, paths, treebank, short_name, if short_name == 'vi_vlsp22': train_file += f";{pos_dir}/vi_vtb.train.in.conllu" dev_in_file = f"{pos_dir}/{short_name}.dev.in.conllu" - dev_pred_file = temp_output_file if temp_output_file else f"{pos_dir}/{short_name}.dev.pred.conllu" + dev_pred_file = f"{pos_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{pos_dir}/{short_name}.test.in.conllu" - test_pred_file = temp_output_file if temp_output_file else f"{pos_dir}/{short_name}.test.pred.conllu" + test_pred_file = f"{pos_dir}/{short_name}.test.pred.conllu" charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm) bert_args = common.choose_transformer(short_language, command_args, extra_args) @@ -76,7 +76,6 @@ def run_treebank(mode, paths, treebank, short_name, train_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--train_file", train_file, - "--output_file", dev_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "train"] @@ -89,38 +88,46 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"], - "--output_file", dev_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if eval_file is None: dev_args += ['--eval_file', dev_in_file] + if command_args.save_output: + dev_args.extend(["--output_file", dev_pred_file]) dev_args = dev_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev POS for {} with args {}".format(treebank, dev_args)) - tagger.main(dev_args) + _, dev_doc = tagger.main(dev_args) + if not command_args.save_output: + dev_pred_file = "{:C}\n\n".format(dev_doc) + dev_pred_file = io.StringIO(dev_pred_file) results = common.run_eval_script_pos(eval_file if eval_file else dev_in_file, dev_pred_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", dev_pred_file) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--wordvec_dir", paths["WORDVEC_DIR"], - "--output_file", test_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] if eval_file is None: test_args += ['--eval_file', test_in_file] + if command_args.save_output: + dev_args.extend(["--output_file", test_pred_file]) test_args = test_args + build_pos_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test POS for {} with args {}".format(treebank, test_args)) - tagger.main(test_args) + _, test_doc = tagger.main(test_args) + if not command_args.save_output: + test_pred_file = "{:C}\n\n".format(test_doc) + test_pred_file = io.StringIO(test_pred_file) results = common.run_eval_script_pos(eval_file if eval_file else test_in_file, test_pred_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", test_pred_file) From 21bdbb22f0aac5f7fd4d6b463da43250de22cc0a Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 10:16:48 -0700 Subject: [PATCH 4/9] Attempt to fix an encoding issue for a log line on Windows --- stanza/utils/datasets/prepare_tokenizer_treebank.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index b61add1b15..8032c47a3a 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -27,6 +27,7 @@ import os import random import re +import sys import tempfile import zipfile @@ -501,7 +502,11 @@ def augment_quotes(sents, ratio=0.15): new_sents.append(new_sent) - print("Augmented {} quotes: {}".format(sum(counts.values()), counts)) + # we go through this to make it simpler to execute on Windows + # rather than nagging the user to set utf-8 + out = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) + print("Augmented {} quotes: {}".format(sum(counts.values()), counts), file=out) + out.detach() return new_sents def find_text_idx(sentence): From 15e42410b9874ccc8dcab0ed806e252c2311aeec Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 14:18:36 -0700 Subject: [PATCH 5/9] Redo the depparse a bit so that it doesn't need temp files for training or scoring --- stanza/models/parser.py | 24 +++++++++++++++--------- stanza/tests/depparse/test_parser.py | 2 +- stanza/utils/training/common.py | 2 +- stanza/utils/training/run_depparse.py | 26 +++++++++++++++++--------- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/stanza/models/parser.py b/stanza/models/parser.py index 67da9d3052..6518678193 100644 --- a/stanza/models/parser.py +++ b/stanza/models/parser.py @@ -9,6 +9,7 @@ Training and evaluation for the parser. """ +import io import sys import os import copy @@ -153,7 +154,7 @@ def main(args=None): if args['mode'] == 'train': return train(args) else: - evaluate(args) + return evaluate(args) def model_file_name(args): return utils.standard_model_file_name(args, "parser") @@ -233,9 +234,6 @@ def train(args): dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) - # pred path - system_pred_file = args['output_file'] - # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: logger.info("Skip training because no data available...") @@ -296,7 +294,9 @@ def train(args): dev_preds = predict_dataset(trainer, dev_batch) dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x]) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file']) train_loss = train_loss / args['eval_interval'] # avg loss per batch @@ -333,7 +333,8 @@ def train(args): dev_preds = predict_dataset(trainer, dev_batch) dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x]) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, args['eval_file']) logger.info("Reloaded model with dev score %.4f", dev_score) @@ -379,7 +380,7 @@ def train(args): logger.info("Dev set never evaluated. Saving final model.") trainer.save(model_file) - return trainer + return trainer, _ def evaluate(args): model_file = model_file_name(args) @@ -392,7 +393,7 @@ def evaluate(args): # load model logger.info("Loading model from: {}".format(model_file)) trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args) - return evaluate_trainer(args, trainer, pretrain) + return trainer, evaluate_trainer(args, trainer, pretrain) def evaluate_trainer(args, trainer, pretrain): system_pred_file = args['output_file'] @@ -412,7 +413,8 @@ def evaluate_trainer(args, trainer, pretrain): # write to file and score batch.doc.set([HEAD, DEPREL], [y for x in preds for y in x]) - CoNLL.write_doc2conll(batch.doc, system_pred_file) + if system_pred_file: + CoNLL.write_doc2conll(batch.doc, system_pred_file) if args['gold_labels']: gold_doc = CoNLL.conll2doc(input_file=args['eval_file']) @@ -424,10 +426,14 @@ def evaluate_trainer(args, trainer, pretrain): raise ValueError("Gold document {} has a None at sentence {} word {}\n{:C}".format(args['eval_file'], sent_idx, word_idx, sentence)) scorer.score_named_dependencies(batch.doc, gold_doc, args['output_latex']) + system_pred_file = "{:C}\n\n".format(batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, args['eval_file']) logger.info("Parser score:") logger.info("{} {:.2f}".format(args['shorthand'], score*100)) + return batch.doc + if __name__ == '__main__': main() diff --git a/stanza/tests/depparse/test_parser.py b/stanza/tests/depparse/test_parser.py index a9299bf925..27fd813630 100644 --- a/stanza/tests/depparse/test_parser.py +++ b/stanza/tests/depparse/test_parser.py @@ -118,7 +118,7 @@ def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, au args.extend(["--augment_nopunct", "0.0"]) if extra_args is not None: args = args + extra_args - trainer = parser.main(args) + trainer, _ = parser.main(args) assert os.path.exists(save_file) pt = pretrain.Pretrain(wordvec_pretrain_file) diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index 2ad86e3920..5aef349a8e 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -196,7 +196,7 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) - if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'tagger' and model_name != 'mwt_expander': + if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'tagger' and model_name != 'mwt_expander' and model_name != 'parser': with tempfile.NamedTemporaryFile() as temp_output_file: run_treebank(mode, paths, treebank, short_name, temp_output_file.name, command_args, extra_args + save_name_args) diff --git a/stanza/utils/training/run_depparse.py b/stanza/utils/training/run_depparse.py index c7680e5b78..d5b6b4f1a7 100644 --- a/stanza/utils/training/run_depparse.py +++ b/stanza/utils/training/run_depparse.py @@ -1,3 +1,4 @@ +import io import logging import os @@ -47,9 +48,9 @@ def run_treebank(mode, paths, treebank, short_name, depparse_dir = paths["DEPPARSE_DATA_DIR"] train_file = f"{depparse_dir}/{short_name}.train.in.conllu" dev_in_file = f"{depparse_dir}/{short_name}.dev.in.conllu" - dev_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.dev.pred.conllu" + dev_pred_file = f"{depparse_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{depparse_dir}/{short_name}.test.in.conllu" - test_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.test.pred.conllu" + test_pred_file = f"{depparse_dir}/{short_name}.test.pred.conllu" eval_file = None if '--eval_file' in extra_args: @@ -84,7 +85,6 @@ def run_treebank(mode, paths, treebank, short_name, train_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--train_file", train_file, "--eval_file", eval_file if eval_file else dev_in_file, - "--output_file", dev_pred_file, "--batch_size", batch_size, "--lang", short_language, "--shorthand", short_name, @@ -97,37 +97,45 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--eval_file", eval_file if eval_file else dev_in_file, - "--output_file", dev_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] + if command_args.save_output: + dev_args.extend(["--output_file", dev_pred_file]) dev_args = dev_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args dev_args = dev_args + extra_args logger.info("Running dev depparse for {} with args {}".format(treebank, dev_args)) - parser.main(dev_args) + _, dev_doc = parser.main(dev_args) if '--no_gold_labels' not in extra_args: + if not command_args.save_output: + dev_pred_file = "{:C}\n\n".format(dev_doc) + dev_pred_file = io.StringIO(dev_pred_file) results = common.run_eval_script_depparse(eval_file if eval_file else dev_in_file, dev_pred_file) logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", dev_pred_file) if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--wordvec_dir", paths["WORDVEC_DIR"], "--eval_file", eval_file if eval_file else test_in_file, - "--output_file", test_pred_file, "--lang", short_language, "--shorthand", short_name, "--mode", "predict"] + if command_args.save_output: + test_args.extend(["--output_file", test_pred_file]) test_args = test_args + build_depparse_wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args test_args = test_args + extra_args logger.info("Running test depparse for {} with args {}".format(treebank, test_args)) - parser.main(test_args) + _, test_doc = parser.main(test_args) if '--no_gold_labels' not in extra_args: + if not command_args.save_output: + test_pred_file = "{:C}\n\n".format(test_doc) + test_pred_file = io.StringIO(test_pred_file) results = common.run_eval_script_depparse(eval_file if eval_file else test_in_file, test_pred_file) logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) - if not temp_output_file: + if command_args.save_output: logger.info("Output saved to %s", test_pred_file) From ab3e26ee926e998bb6b47b2c807556a18761c087 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 14:52:36 -0700 Subject: [PATCH 6/9] Need to specify utf-8 for preparing lemma data on Windows --- stanza/utils/datasets/prepare_lemma_treebank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/utils/datasets/prepare_lemma_treebank.py b/stanza/utils/datasets/prepare_lemma_treebank.py index 16a88c5f8e..e0205d5427 100644 --- a/stanza/utils/datasets/prepare_lemma_treebank.py +++ b/stanza/utils/datasets/prepare_lemma_treebank.py @@ -31,7 +31,7 @@ def check_lemmas(train_file): # but what if a later dataset includes lemmas? #if short_language in ('vi', 'fro', 'th'): # return False - with open(train_file) as fin: + with open(train_file, encoding="utf-8") as fin: for line in fin: line = line.strip() if not line or line.startswith("#"): From b12bc27f119b771ca31032d5108b1f9dc4a9b6d0 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 14:54:14 -0700 Subject: [PATCH 7/9] Make the identity lemmatizer run on Windows without using a tempfile --- stanza/models/identity_lemmatizer.py | 7 ++++++- stanza/utils/training/run_lemma.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/stanza/models/identity_lemmatizer.py b/stanza/models/identity_lemmatizer.py index 745fb25f16..f0b7a0df78 100644 --- a/stanza/models/identity_lemmatizer.py +++ b/stanza/models/identity_lemmatizer.py @@ -55,12 +55,17 @@ def main(args=None): # write to file and score batch.doc.set([LEMMA], preds) - CoNLL.write_doc2conll(batch.doc, system_pred_file) + if system_pred_file is not None: + CoNLL.write_doc2conll(batch.doc, system_pred_file) if gold_file is not None: + system_pred_file = "{:C}\n\n".format(batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, gold_file) logger.info("Lemma score:") logger.info("{} {:.2f}".format(args['shorthand'], score*100)) + return None, batch.doc + if __name__ == '__main__': main() diff --git a/stanza/utils/training/run_lemma.py b/stanza/utils/training/run_lemma.py index a2668ba092..56562ae339 100644 --- a/stanza/utils/training/run_lemma.py +++ b/stanza/utils/training/run_lemma.py @@ -77,9 +77,9 @@ def run_treebank(mode, paths, treebank, short_name, lemma_dir = paths["LEMMA_DATA_DIR"] train_file = f"{lemma_dir}/{short_name}.train.in.conllu" dev_in_file = f"{lemma_dir}/{short_name}.dev.in.conllu" - dev_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.dev.pred.conllu" + dev_pred_file = f"{lemma_dir}/{short_name}.dev.pred.conllu" test_in_file = f"{lemma_dir}/{short_name}.test.in.conllu" - test_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.test.pred.conllu" + test_pred_file = f"{lemma_dir}/{short_name}.test.pred.conllu" charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) @@ -94,15 +94,19 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.TRAIN or mode == Mode.SCORE_DEV: train_args = ["--train_file", train_file, "--eval_file", dev_in_file, - "--output_file", dev_pred_file, + "--gold_file", dev_in_file, "--shorthand", short_name] + if command_args.save_output: + train_args.extend(["--output_file", dev_pred_file]) logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) identity_lemmatizer.main(train_args) elif mode == Mode.SCORE_TEST: train_args = ["--train_file", train_file, "--eval_file", test_in_file, - "--output_file", test_pred_file, + "--gold_file", test_in_file, "--shorthand", short_name] + if command_args.save_output: + train_args.extend(["--output_file", test_pred_file]) logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) identity_lemmatizer.main(train_args) else: From 15e80f2a2a1a17bb4ad3435ccf0c6f7005d9a849 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 17:31:19 -0700 Subject: [PATCH 8/9] Don't try to save the output of the lemmatizer unless --save_output is chosen. Addresses the lemmatizer portion of https://github.com/stanfordnlp/stanza-train/issues/20 --- stanza/models/lemmatizer.py | 14 +++++++++----- stanza/utils/training/common.py | 2 +- stanza/utils/training/run_lemma.py | 7 ++++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/stanza/models/lemmatizer.py b/stanza/models/lemmatizer.py index c3f0a821a6..6b3e1bfec8 100644 --- a/stanza/models/lemmatizer.py +++ b/stanza/models/lemmatizer.py @@ -145,8 +145,7 @@ def train(args): model_file = build_model_filename(args) logger.info("Using full savename: %s", model_file) - # pred and gold path - system_pred_file = args['output_file'] + # gold path gold_file = args['eval_file'] utils.print_config(args) @@ -169,7 +168,8 @@ def train(args): logger.info("Evaluating on dev set...") dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS])) dev_batch.doc.set([LEMMA], dev_preds) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_f = scorer.score(system_pred_file, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) @@ -223,7 +223,8 @@ def train(args): logger.info("[Ensembling dict with seq2seq model...]") dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds) dev_batch.doc.set([LEMMA], dev_preds) - CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(dev_batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, dev_score = scorer.score(system_pred_file, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch @@ -302,8 +303,11 @@ def evaluate(args): # write to file and score batch.doc.set([LEMMA], preds) - CoNLL.write_doc2conll(batch.doc, system_pred_file) + if system_pred_file: + CoNLL.write_doc2conll(batch.doc, system_pred_file) + system_pred_file = "{:C}\n\n".format(batch.doc) + system_pred_file = io.StringIO(system_pred_file) _, _, score = scorer.score(system_pred_file, args['eval_file']) logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100)) diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index 5aef349a8e..a77e12e0a5 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -196,7 +196,7 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) - if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'tagger' and model_name != 'mwt_expander' and model_name != 'parser': + if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'tagger' and model_name != 'mwt_expander' and model_name != 'parser' and model_name != 'lemma': with tempfile.NamedTemporaryFile() as temp_output_file: run_treebank(mode, paths, treebank, short_name, temp_output_file.name, command_args, extra_args + save_name_args) diff --git a/stanza/utils/training/run_lemma.py b/stanza/utils/training/run_lemma.py index 56562ae339..56ff861270 100644 --- a/stanza/utils/training/run_lemma.py +++ b/stanza/utils/training/run_lemma.py @@ -119,7 +119,6 @@ def run_treebank(mode, paths, treebank, short_name, train_args = ["--train_file", train_file, "--eval_file", dev_in_file, - "--output_file", dev_pred_file, "--shorthand", short_name, "--num_epoch", num_epochs, "--mode", "train"] @@ -129,18 +128,20 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--eval_file", dev_in_file, - "--output_file", dev_pred_file, "--shorthand", short_name, "--mode", "predict"] + if command_args.save_output: + train_args.extend(["--output_file", dev_pred_file]) dev_args = dev_args + charlm_args + extra_args logger.info("Running dev lemmatizer for {} with args {}".format(treebank, dev_args)) lemmatizer.main(dev_args) if mode == Mode.SCORE_TEST: test_args = ["--eval_file", test_in_file, - "--output_file", test_pred_file, "--shorthand", short_name, "--mode", "predict"] + if command_args.save_output: + train_args.extend(["--output_file", test_pred_file]) test_args = test_args + charlm_args + extra_args logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args)) lemmatizer.main(test_args) From 94752fb57ba52a558d69aebbe15e8f4cfb14124a Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 22 Sep 2025 18:19:37 -0700 Subject: [PATCH 9/9] Remove the temp_output_file functionality from all training scripts, since they now do the testing in memory when not saving output. https://github.com/stanfordnlp/stanza-train/issues/20 --- stanza/utils/training/common.py | 9 +-------- stanza/utils/training/run_charlm.py | 2 +- stanza/utils/training/run_constituency.py | 2 +- stanza/utils/training/run_depparse.py | 3 +-- stanza/utils/training/run_ete.py | 3 +-- stanza/utils/training/run_lemma.py | 3 +-- stanza/utils/training/run_lemma_classifier.py | 3 +-- stanza/utils/training/run_mwt.py | 3 +-- stanza/utils/training/run_ner.py | 3 +-- stanza/utils/training/run_pos.py | 3 +-- stanza/utils/training/run_sentiment.py | 3 +-- stanza/utils/training/run_tokenizer.py | 3 +-- 12 files changed, 12 insertions(+), 28 deletions(-) diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index a77e12e0a5..727e4a7c08 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -5,7 +5,6 @@ import pathlib import random import sys -import tempfile from enum import Enum @@ -196,13 +195,7 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) - if not command_args.save_output and model_name != 'ete' and model_name != 'tokenizer' and model_name != 'tagger' and model_name != 'mwt_expander' and model_name != 'parser' and model_name != 'lemma': - with tempfile.NamedTemporaryFile() as temp_output_file: - run_treebank(mode, paths, treebank, short_name, - temp_output_file.name, command_args, extra_args + save_name_args) - else: - run_treebank(mode, paths, treebank, short_name, - None, command_args, extra_args + save_name_args) + run_treebank(mode, paths, treebank, short_name, command_args, extra_args + save_name_args) def run_eval_script(gold_conllu_file, system_conllu_file, evals=None): """ Wrapper for lemma scorer. """ diff --git a/stanza/utils/training/run_charlm.py b/stanza/utils/training/run_charlm.py index 2aa251382f..2369dcfc41 100644 --- a/stanza/utils/training/run_charlm.py +++ b/stanza/utils/training/run_charlm.py @@ -22,7 +22,7 @@ def add_charlm_args(parser): def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): + command_args, extra_args): short_language, dataset_name = short_name.split("_", 1) train_dir = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "train") diff --git a/stanza/utils/training/run_constituency.py b/stanza/utils/training/run_constituency.py index 4e993f56ae..9322866ba3 100644 --- a/stanza/utils/training/run_constituency.py +++ b/stanza/utils/training/run_constituency.py @@ -72,7 +72,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): return save_name -def run_treebank(mode, paths, treebank, short_name, temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): constituency_dir = paths["CONSTITUENCY_DATA_DIR"] short_language, dataset = short_name.split("_") diff --git a/stanza/utils/training/run_depparse.py b/stanza/utils/training/run_depparse.py index d5b6b4f1a7..404ea21830 100644 --- a/stanza/utils/training/run_depparse.py +++ b/stanza/utils/training/run_depparse.py @@ -40,8 +40,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): return save_name -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) # TODO: refactor these blocks? diff --git a/stanza/utils/training/run_ete.py b/stanza/utils/training/run_ete.py index bd4ef1f006..67f3000f1f 100644 --- a/stanza/utils/training/run_ete.py +++ b/stanza/utils/training/run_ete.py @@ -168,8 +168,7 @@ def run_ete(paths, dataset, short_name, command_args, extra_args): results = common.run_eval_script(gold_file, ete_file) logger.info("{} {} models on {} {} data:\n{}".format(RESULTS_STRING, short_name, test_short_name, dataset, results)) -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): if mode == Mode.TRAIN: dataset = 'train' elif mode == Mode.SCORE_DEV: diff --git a/stanza/utils/training/run_lemma.py b/stanza/utils/training/run_lemma.py index 56ff861270..f3d7331c97 100644 --- a/stanza/utils/training/run_lemma.py +++ b/stanza/utils/training/run_lemma.py @@ -70,8 +70,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): save_name = lemmatizer.build_model_filename(args) return save_name -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) lemma_dir = paths["LEMMA_DATA_DIR"] diff --git a/stanza/utils/training/run_lemma_classifier.py b/stanza/utils/training/run_lemma_classifier.py index 73669e79b0..68ef748ead 100644 --- a/stanza/utils/training/run_lemma_classifier.py +++ b/stanza/utils/training/run_lemma_classifier.py @@ -18,8 +18,7 @@ def add_lemma_args(parser): def build_model_filename(paths, short_name, command_args, extra_args): return os.path.join("saved_models", "lemma_classifier", short_name + "_lemma_classifier.pt") -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) base_args = [] diff --git a/stanza/utils/training/run_mwt.py b/stanza/utils/training/run_mwt.py index 0185ab52c3..74d4ba1993 100644 --- a/stanza/utils/training/run_mwt.py +++ b/stanza/utils/training/run_mwt.py @@ -38,8 +38,7 @@ def check_mwt(filename): data = doc.get_mwt_expansions(False) return len(data) > 0 -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language = short_name.split("_")[0] mwt_dir = paths["MWT_DATA_DIR"] diff --git a/stanza/utils/training/run_ner.py b/stanza/utils/training/run_ner.py index 2651df58fa..10fdabc182 100644 --- a/stanza/utils/training/run_ner.py +++ b/stanza/utils/training/run_ner.py @@ -93,8 +93,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): # However, to keep the naming consistent, we leave the # method which does the training as run_treebank # TODO: rename treebank -> dataset everywhere -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): ner_dir = paths["NER_DATA_DIR"] language, dataset = short_name.split("_") diff --git a/stanza/utils/training/run_pos.py b/stanza/utils/training/run_pos.py index 353090d60e..2808d2f8af 100644 --- a/stanza/utils/training/run_pos.py +++ b/stanza/utils/training/run_pos.py @@ -37,8 +37,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): short_language, dataset = short_name.split("_", 1) pos_dir = paths["POS_DATA_DIR"] diff --git a/stanza/utils/training/run_sentiment.py b/stanza/utils/training/run_sentiment.py index ae46f12750..a985b1b945 100644 --- a/stanza/utils/training/run_sentiment.py +++ b/stanza/utils/training/run_sentiment.py @@ -66,8 +66,7 @@ def build_model_filename(paths, short_name, command_args, extra_args): return save_name -def run_dataset(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_dataset(mode, paths, treebank, short_name, command_args, extra_args): sentiment_dir = paths["SENTIMENT_DATA_DIR"] short_language, dataset = short_name.split("_", 1) diff --git a/stanza/utils/training/run_tokenizer.py b/stanza/utils/training/run_tokenizer.py index 2128e8e006..6c21734579 100644 --- a/stanza/utils/training/run_tokenizer.py +++ b/stanza/utils/training/run_tokenizer.py @@ -66,8 +66,7 @@ def uses_dictionary(short_language): return True return False -def run_treebank(mode, paths, treebank, short_name, - temp_output_file, command_args, extra_args): +def run_treebank(mode, paths, treebank, short_name, command_args, extra_args): tokenize_dir = paths["TOKENIZE_DATA_DIR"] short_language, dataset = short_name.split("_", 1)