Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion stanza/models/identity_lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,17 @@ def main(args=None):

# write to file and score
batch.doc.set([LEMMA], preds)
CoNLL.write_doc2conll(batch.doc, system_pred_file)
if system_pred_file is not None:
CoNLL.write_doc2conll(batch.doc, system_pred_file)
if gold_file is not None:
system_pred_file = "{:C}\n\n".format(batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, score = scorer.score(system_pred_file, gold_file)

logger.info("Lemma score:")
logger.info("{} {:.2f}".format(args['shorthand'], score*100))

return None, batch.doc

if __name__ == '__main__':
main()
14 changes: 9 additions & 5 deletions stanza/models/lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def train(args):
model_file = build_model_filename(args)
logger.info("Using full savename: %s", model_file)

# pred and gold path
system_pred_file = args['output_file']
# gold path
gold_file = args['eval_file']

utils.print_config(args)
Expand All @@ -169,7 +168,8 @@ def train(args):
logger.info("Evaluating on dev set...")
dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS]))
dev_batch.doc.set([LEMMA], dev_preds)
CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)
system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_f = scorer.score(system_pred_file, gold_file)
logger.info("Dev F1 = {:.2f}".format(dev_f * 100))

Expand Down Expand Up @@ -223,7 +223,8 @@ def train(args):
logger.info("[Ensembling dict with seq2seq model...]")
dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds)
dev_batch.doc.set([LEMMA], dev_preds)
CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)
system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, gold_file)

train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
Expand Down Expand Up @@ -302,8 +303,11 @@ def evaluate(args):

# write to file and score
batch.doc.set([LEMMA], preds)
CoNLL.write_doc2conll(batch.doc, system_pred_file)
if system_pred_file:
CoNLL.write_doc2conll(batch.doc, system_pred_file)

system_pred_file = "{:C}\n\n".format(batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, score = scorer.score(system_pred_file, args['eval_file'])
logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100))

Expand Down
30 changes: 20 additions & 10 deletions stanza/models/mwt_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
composing the MWT, a classifier over the characters is used instead of the seq2seq
"""

import io
import sys
import os
import shutil
Expand Down Expand Up @@ -104,9 +105,9 @@ def main(args=None):
logger.info("Running MWT expander in {} mode".format(args['mode']))

if args['mode'] == 'train':
train(args)
return train(args)
else:
evaluate(args)
return evaluate(args)

def train(args):
# load data
Expand All @@ -129,7 +130,6 @@ def train(args):
save_each_name = utils.build_save_each_filename(save_each_name)

# pred and gold path
system_pred_file = args['output_file']
gold_file = args['gold_file']

# skip training if the language does not have training or dev data
Expand Down Expand Up @@ -174,8 +174,9 @@ def train(args):
dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True))
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
CoNLL.write_doc2conll(doc, system_pred_file)
_, _, dev_f = scorer.score(system_pred_file, gold_file)
system_preds = "{:C}\n\n".format(doc)
system_preds = io.StringIO(system_preds)
_, _, dev_f = scorer.score(system_preds, gold_file)
logger.info("Dev F1 = {:.2f}".format(dev_f * 100))

if args.get('dict_only', False):
Expand Down Expand Up @@ -228,8 +229,9 @@ def train(args):
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds)
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
CoNLL.write_doc2conll(doc, system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, gold_file)
system_preds = "{:C}\n\n".format(doc)
system_preds = io.StringIO(system_preds)
_, _, dev_score = scorer.score(system_preds, gold_file)
train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score))

Expand Down Expand Up @@ -263,11 +265,14 @@ def train(args):
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds)
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
CoNLL.write_doc2conll(doc, system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, gold_file)
system_preds = "{:C}\n\n".format(doc)
system_preds = io.StringIO(system_preds)
_, _, dev_score = scorer.score(system_preds, gold_file)
logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100))
best_f = max(best_f, dev_score)

return trainer, _

def evaluate(args):
# file paths
system_pred_file = args['output_file']
Expand Down Expand Up @@ -310,13 +315,18 @@ def evaluate(args):
# write to file and score
doc = copy.deepcopy(batch.doc)
doc.set_mwt_expansions(preds, fake_dependencies=True)
CoNLL.write_doc2conll(doc, system_pred_file)
if system_pred_file is not None:
CoNLL.write_doc2conll(doc, system_pred_file)
else:
system_pred_file = "{:C}\n\n".format(doc)
system_pred_file = io.StringIO(system_pred_file)

if gold_file is not None:
_, _, score = scorer.score(system_pred_file, gold_file)

logger.info("MWT expansion score: {} {:.2f}".format(args['shorthand'], score*100))

return trainer, doc

if __name__ == '__main__':
main()
24 changes: 15 additions & 9 deletions stanza/models/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Training and evaluation for the parser.
"""

import io
import sys
import os
import copy
Expand Down Expand Up @@ -153,7 +154,7 @@ def main(args=None):
if args['mode'] == 'train':
return train(args)
else:
evaluate(args)
return evaluate(args)

def model_file_name(args):
return utils.standard_model_file_name(args, "parser")
Expand Down Expand Up @@ -233,9 +234,6 @@ def train(args):
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)

# pred path
system_pred_file = args['output_file']

# skip training if the language does not have training or dev data
if len(train_batch) == 0 or len(dev_batch) == 0:
logger.info("Skip training because no data available...")
Expand Down Expand Up @@ -296,7 +294,9 @@ def train(args):
dev_preds = predict_dataset(trainer, dev_batch)

dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x])
CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)

system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, args['eval_file'])

train_loss = train_loss / args['eval_interval'] # avg loss per batch
Expand Down Expand Up @@ -333,7 +333,8 @@ def train(args):

dev_preds = predict_dataset(trainer, dev_batch)
dev_batch.doc.set([HEAD, DEPREL], [y for x in dev_preds for y in x])
CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)
system_pred_file = "{:C}\n\n".format(dev_batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, args['eval_file'])
logger.info("Reloaded model with dev score %.4f", dev_score)

Expand Down Expand Up @@ -379,7 +380,7 @@ def train(args):
logger.info("Dev set never evaluated. Saving final model.")
trainer.save(model_file)

return trainer
return trainer, _

def evaluate(args):
model_file = model_file_name(args)
Expand All @@ -392,7 +393,7 @@ def evaluate(args):
# load model
logger.info("Loading model from: {}".format(model_file))
trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)
return evaluate_trainer(args, trainer, pretrain)
return trainer, evaluate_trainer(args, trainer, pretrain)

def evaluate_trainer(args, trainer, pretrain):
system_pred_file = args['output_file']
Expand All @@ -412,7 +413,8 @@ def evaluate_trainer(args, trainer, pretrain):

# write to file and score
batch.doc.set([HEAD, DEPREL], [y for x in preds for y in x])
CoNLL.write_doc2conll(batch.doc, system_pred_file)
if system_pred_file:
CoNLL.write_doc2conll(batch.doc, system_pred_file)

if args['gold_labels']:
gold_doc = CoNLL.conll2doc(input_file=args['eval_file'])
Expand All @@ -424,10 +426,14 @@ def evaluate_trainer(args, trainer, pretrain):
raise ValueError("Gold document {} has a None at sentence {} word {}\n{:C}".format(args['eval_file'], sent_idx, word_idx, sentence))

scorer.score_named_dependencies(batch.doc, gold_doc, args['output_latex'])
system_pred_file = "{:C}\n\n".format(batch.doc)
system_pred_file = io.StringIO(system_pred_file)
_, _, score = scorer.score(system_pred_file, args['eval_file'])

logger.info("Parser score:")
logger.info("{} {:.2f}".format(args['shorthand'], score*100))

return batch.doc

if __name__ == '__main__':
main()
26 changes: 17 additions & 9 deletions stanza/models/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import argparse
import logging
import io
import os
import time
import zipfile
Expand Down Expand Up @@ -143,9 +144,9 @@ def main(args=None):
logger.info("Running tagger in {} mode".format(args['mode']))

if args['mode'] == 'train':
train(args)
return train(args)
else:
evaluate(args)
return evaluate(args)

def model_file_name(args):
return utils.standard_model_file_name(args, "tagger")
Expand Down Expand Up @@ -282,14 +283,11 @@ def train(args):

eval_type = get_eval_type(dev_data)

# pred and gold path
system_pred_file = args['output_file']

# skip training if the language does not have training or dev data
# sum(...) to check if all of the training files are empty
if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0:
logger.info("Skip training because no data available...")
return
return None, None

if args['wandb']:
import wandb
Expand Down Expand Up @@ -350,7 +348,9 @@ def train(args):
indices.extend(batch[-1])
dev_preds = utils.unsort(dev_preds, indices)
dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x])
CoNLL.write_doc2conll(dev_data.doc, system_pred_file)

system_pred_file = "{:C}\n\n".format(dev_data.doc)
system_pred_file = io.StringIO(system_pred_file)

_, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)

Expand Down Expand Up @@ -410,6 +410,7 @@ def train(args):
logger.info("Dev set never evaluated. Saving final model.")
trainer.save(model_file)

return trainer, _

def evaluate(args):
# file paths
Expand All @@ -423,7 +424,8 @@ def evaluate(args):
# load model
logger.info("Loading model from: {}".format(model_file))
trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)
evaluate_trainer(args, trainer, pretrain)
result_doc = evaluate_trainer(args, trainer, pretrain)
return trainer, result_doc

def evaluate_trainer(args, trainer, pretrain):
system_pred_file = args['output_file']
Expand Down Expand Up @@ -455,12 +457,18 @@ def evaluate_trainer(args, trainer, pretrain):

# write to file and score
dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x])
CoNLL.write_doc2conll(dev_data.doc, system_pred_file)
if system_pred_file:
CoNLL.write_doc2conll(dev_data.doc, system_pred_file)

if args['gold_labels']:
system_pred_file = "{:C}\n\n".format(dev_data.doc)
system_pred_file = io.StringIO(system_pred_file)

_, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)

logger.info("POS Tagger score: %s %.2f", args['shorthand'], score*100)

return dev_data.doc

if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion stanza/tests/depparse/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, au
args.extend(["--augment_nopunct", "0.0"])
if extra_args is not None:
args = args + extra_args
trainer = parser.main(args)
trainer, _ = parser.main(args)

assert os.path.exists(save_file)
pt = pretrain.Pretrain(wordvec_pretrain_file)
Expand Down
4 changes: 2 additions & 2 deletions stanza/utils/datasets/prepare_depparse_treebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from stanza.resources.default_packages import default_charlms, pos_charlms
import stanza.utils.datasets.common as common
import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
from stanza.utils.training.run_pos import wordvec_args
from stanza.utils.training.common import build_pos_wordvec_args
from stanza.utils.training.common import add_charlm_args, build_charlm_args, choose_charlm

logger = logging.getLogger('stanza')
Expand Down Expand Up @@ -110,7 +110,7 @@ def process_treebank(treebank, model_type, paths, args) -> None:
if args.wordvec_pretrain_file:
base_args += ["--wordvec_pretrain_file", args.wordvec_pretrain_file]
else:
base_args = base_args + wordvec_args(short_language, dataset, [])
base_args = base_args + build_pos_wordvec_args(short_language, dataset, [])


# charlm for POS
Expand Down
2 changes: 1 addition & 1 deletion stanza/utils/datasets/prepare_lemma_treebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def check_lemmas(train_file):
# but what if a later dataset includes lemmas?
#if short_language in ('vi', 'fro', 'th'):
# return False
with open(train_file) as fin:
with open(train_file, encoding="utf-8") as fin:
for line in fin:
line = line.strip()
if not line or line.startswith("#"):
Expand Down
7 changes: 6 additions & 1 deletion stanza/utils/datasets/prepare_tokenizer_treebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import os
import random
import re
import sys
import tempfile
import zipfile

Expand Down Expand Up @@ -501,7 +502,11 @@ def augment_quotes(sents, ratio=0.15):

new_sents.append(new_sent)

print("Augmented {} quotes: {}".format(sum(counts.values()), counts))
# we go through this to make it simpler to execute on Windows
# rather than nagging the user to set utf-8
out = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True)
print("Augmented {} quotes: {}".format(sum(counts.values()), counts), file=out)
out.detach()
return new_sents

def find_text_idx(sentence):
Expand Down
Loading