Skip to content

Commit

Permalink
Merge pull request #417 from pykale/modify_ckpt_in_cifa_cnntransformer
Browse files Browse the repository at this point in the history
Add resume and test checkpoint loading to main.py in example/cifas_cn…
  • Loading branch information
xianyuanliu committed Oct 20, 2023
2 parents 1240a4a + 48c28e1 commit d53af25
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/cifar_cnntransformer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def arg_parse():
help="gpu id(s) to use. int(0) for cpu. list[x,y] for xth, yth GPU."
"str(x) for the first x GPUs. str(-1)/int(-1) for all available GPUs",
)
parser.add_argument("--ckpt_resume", default="", help="path to train checkpoint file", type=str)
parser.add_argument("--ckpt_test", default="best", help="path to test checkpoint file", type=str)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -82,10 +84,10 @@ def main():
)

# ---- start training ----
trainer.fit(model, train_loader, valid_loader)
trainer.fit(model, train_loader, valid_loader, ckpt_path=args.ckpt_resume)

# ---- start testing ----
trainer.test(model, valid_loader)
trainer.test(model, valid_loader, ckpt_path=args.ckpt_test)


if __name__ == "__main__":
Expand Down

0 comments on commit d53af25

Please sign in to comment.