Skip to content

Commit

Permalink
Add a short test method that a single optimizer case saves checkpoint…
Browse files Browse the repository at this point in the history
…s and the checkpoints are loadable
  • Loading branch information
AngledLuffa committed Jan 30, 2024
1 parent 938da0e commit dbe6d1f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
6 changes: 4 additions & 2 deletions stanza/models/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def main(args=None):
logger.info("Running parser in {} mode".format(args['mode']))

if args['mode'] == 'train':
train(args)
return train(args)
else:
evaluate(args)

Expand Down Expand Up @@ -202,7 +202,8 @@ def train(args):
checkpoint_file = None # used explicitly as the *PATH TO THE CHECKPOINT* could be None if we don't want to save chkpt
if args.get("checkpoint"):
model_to_load = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
checkpoint_file = copy.deepcopy(model_to_load)
checkpoint_file = model_to_load
args["checkpoint_save_name"] = checkpoint_file
if args["continue_from"]:
model_to_load = args["continue_from"]

Expand Down Expand Up @@ -302,6 +303,7 @@ def train(args):
logger.info("Dev set never evaluated. Saving final model.")
trainer.save(model_file)

return trainer

def evaluate(args):
# file paths
Expand Down
24 changes: 22 additions & 2 deletions stanza/tests/depparse/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import os
import pytest

import torch

from stanza.models import parser
from stanza.models.common import pretrain
from stanza.models.depparse.trainer import Trainer
Expand Down Expand Up @@ -108,12 +110,13 @@ def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, au
args.extend(["--augment_nopunct", "0.0"])
if extra_args is not None:
args = args + extra_args
parser.main(args)
trainer = parser.main(args)

assert os.path.exists(save_file)
pt = pretrain.Pretrain(wordvec_pretrain_file)
# test loading the saved model
saved_model = Trainer(pretrain=pt, model_file=save_file)
return saved_model
return trainer

def test_train(self, tmp_path, wordvec_pretrain_file):
"""
Expand All @@ -127,3 +130,20 @@ def test_with_bert(self, tmp_path, wordvec_pretrain_file):
def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])

def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file):
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam'])

save_dir = trainer.args['save_dir']
save_name = trainer.args['save_name']
checkpoint_name = trainer.args["checkpoint_save_name"]

assert os.path.exists(os.path.join(save_dir, save_name))
assert checkpoint_name is not None
assert os.path.exists(checkpoint_name)

assert isinstance(trainer.optimizer, torch.optim.Adam)

pt = pretrain.Pretrain(wordvec_pretrain_file)
checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
assert checkpoint.optimizer is not None
assert isinstance(checkpoint.optimizer, torch.optim.Adam)

0 comments on commit dbe6d1f

Please sign in to comment.