From 4c2dc6b2a737a03885e8cf4e6e88ff606f3f02c8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 7 May 2021 09:24:26 +0100 Subject: [PATCH] Add checkpoints used for preemption. --- references/detection/train.py | 10 ++++++++-- references/segmentation/train.py | 18 +++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 712c41f658f..4eb39bf17f5 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -188,13 +188,19 @@ def main(args): train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() if args.output_dir: - utils.save_on_master({ + checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'args': args, - 'epoch': epoch}, + 'epoch': epoch + } + utils.save_on_master( + checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'checkpoint.pth')) # evaluate after every epoch evaluate(model, data_loader_test, device=device) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 47907546dbc..fb6c7eeee15 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -157,15 +157,19 @@ def main(args): train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) + checkpoint = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args + } utils.save_on_master( - { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args - }, + checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time)))