Skip to content

Commit

Permalink
Put the global_step and dev score history into the model files so tha…
Browse files Browse the repository at this point in the history
…t when a checkpoint gets loaded, the training continues from the position it was formerly at rather than restarting from 0

Report some details of the model being loaded after loading it
  • Loading branch information
AngledLuffa committed Jan 31, 2024
1 parent 207a0d4 commit 7bc8cc9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 32 deletions.
20 changes: 18 additions & 2 deletions stanza/models/depparse/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,26 @@ def unpack_batch(batch, device):
class Trainer(BaseTrainer):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,
device=None, foundation_cache=None, ignore_model_config=False):
device=None, foundation_cache=None, ignore_model_config=False, reset_history=False):
self.global_step = 0
self.last_best_step = 0
self.dev_score_history = []

orig_args = copy.deepcopy(args)
# whether the training is in primary or secondary stage
# during FT (loading weights), etc., the training is considered to be in "secondary stage"
# during this time, we (optionally) use a different set of optimizers than that during "primary stage".
#
# Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary

if model_file is not None:
# load everything from file
self.load(model_file, pretrain, args, foundation_cache, device)

if reset_history:
self.global_step = 0
self.last_best_step = 0
self.dev_score_history = []
else:
# build model from scratch
self.args = args
Expand Down Expand Up @@ -112,7 +122,10 @@ def save(self, filename, skip_modules=True, save_optimizer=False):
params = {
'model': model_state,
'vocab': self.vocab.state_dict(),
'config': self.args
'config': self.args,
'global_step': self.global_step,
'last_best_step': self.last_best_step,
'dev_score_history': self.dev_score_history,
}

if save_optimizer and self.optimizer is not None:
Expand Down Expand Up @@ -157,3 +170,6 @@ def load(self, filename, pretrain, args=None, foundation_cache=None, device=None
if optim_state_dict:
self.optimizer.load_state_dict(optim_state_dict)

self.global_step = checkpoint.get("global_step", 0)
self.last_best_step = checkpoint.get("last_best_step", 0)
self.dev_score_history = checkpoint.get("dev_score_history", list())
59 changes: 29 additions & 30 deletions stanza/models/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,45 +198,42 @@ def train(args):
wandb.run.define_metric('dev_score', summary='max')

logger.info("Training parser...")
# calculate checkpoint file name and the sav
model_to_load = None # used for general loading and reloading
checkpoint_file = None # used explicitly as the *PATH TO THE CHECKPOINT* could be None if we don't want to save chkpt
if args.get("checkpoint"):
model_to_load = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
checkpoint_file = model_to_load
# calculate checkpoint file name from the save filename
checkpoint_file = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
args["checkpoint_save_name"] = checkpoint_file
if args["continue_from"]:
model_to_load = args["continue_from"]

if model_to_load is not None and os.path.exists(model_to_load):
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=model_to_load, device=args['device'], ignore_model_config=True)
if args.get("checkpoint") and os.path.exists(args["checkpoint_save_name"]):
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["checkpoint_save_name"], device=args['device'], ignore_model_config=True)
if len(trainer.dev_score_history) > 0:
logger.info("Continuing from checkpoint %s Model was previously trained for %d steps, with a best dev score of %.4f", args["checkpoint_save_name"], trainer.global_step, max(trainer.dev_score_history))
elif args["continue_from"]:
if not os.path.exists(args["continue_from"]):
raise FileNotFoundError("--continue_from specified, but the file %s does not exist" % args["continue_from"])
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["continue_from"], device=args['device'], ignore_model_config=True, reset_history=True)
else:
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'])

global_step = 0
max_steps = args['max_steps']
dev_score_history = []
best_dev_preds = []
current_lr = args['lr']
global_start_time = time.time()
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'

is_second_stage = False
last_best_step = 0
# start training
train_loss = 0
while True:
do_break = False
for i, batch in enumerate(train_batch):
start_time = time.time()
global_step += 1
trainer.global_step += 1
loss = trainer.update(batch, eval=False) # update step
train_loss += loss
if global_step % args['log_step'] == 0:
if trainer.global_step % args['log_step'] == 0:
duration = time.time() - start_time
logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))
logger.info(format_str.format(trainer.global_step, max_steps, loss, duration, current_lr))

if global_step % args['eval_interval'] == 0:
if trainer.global_step % args['eval_interval'] == 0:
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = []
Expand All @@ -250,60 +247,62 @@ def train(args):
_, _, dev_score = scorer.score(system_pred_file, gold_file)

train_loss = train_loss / args['eval_interval'] # avg loss per batch
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(trainer.global_step, train_loss, dev_score))

if args['wandb']:
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})

train_loss = 0

# save best model
if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
last_best_step = global_step
if len(trainer.dev_score_history) == 0 or dev_score > max(trainer.dev_score_history):
trainer.last_best_step = trainer.global_step
trainer.save(model_file)
logger.info("new best model saved.")
best_dev_preds = dev_preds

dev_score_history += [dev_score]
trainer.dev_score_history += [dev_score]

if not is_second_stage and args.get('second_optim', None) is not None:
if global_step - last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and global_step >= args['second_optim_start_step']):
if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and trainer.global_step >= args['second_optim_start_step']):
logger.info("Switching to second optimizer: {}".format(args.get('second_optim', None)))
args["second_stage"] = True
# if the loader gets a model file, it uses secondary optimizer
trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain,
model_file=model_file, device=args['device'])
logger.info('Reloading best model to continue from current local optimum')
is_second_stage = True
last_best_step = global_step
trainer.last_best_step = trainer.global_step
else:
if global_step - last_best_step >= args['max_steps_before_stop']:
if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop']:
do_break = True
break

if global_step % args['eval_interval'] == 0:
if trainer.global_step % args['eval_interval'] == 0:
# if we need to save checkpoint, do so
# (save after switching the optimizer, if applicable, so that
# the new optimizer is the optimizer used if a restart happens)
if checkpoint_file is not None:
trainer.save(checkpoint_file, save_optimizer=True)
logger.info("new model checkpoint saved.")

if global_step >= args['max_steps']:
if trainer.global_step >= args['max_steps']:
do_break = True
break

if do_break: break

train_batch.reshuffle()

logger.info("Training ended with {} steps.".format(global_step))
logger.info("Training ended with {} steps.".format(trainer.global_step))

if args['wandb']:
wandb.finish()

if len(dev_score_history) > 0:
best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
if len(trainer.dev_score_history) > 0:
# TODO: technically the iteration position will be wrong if
# the eval_interval changed when running from a checkpoint
# could fix this by saving step & score instead of just score
best_f, best_eval = max(trainer.dev_score_history)*100, np.argmax(trainer.dev_score_history)+1
logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
else:
logger.info("Dev set never evaluated. Saving final model.")
Expand Down

0 comments on commit 7bc8cc9

Please sign in to comment.