diff --git a/references/detection/train.py b/references/detection/train.py index 40459d110c3..00d94fd1636 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -115,7 +115,7 @@ def main(args): optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 - + if args.test_only: evaluate(model, data_loader_test, device=device) return