Skip to content

Commit

Permalink
feat: run evaluation for 1 epoch before training (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Yang committed Apr 10, 2021
1 parent 23c87b4 commit e79bff3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
11 changes: 10 additions & 1 deletion templates/image_classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ def _():
eval_engine.run(eval_dataloader, max_epochs=1)
eval_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), log_metrics, tag="eval")

# --------------------------------------------------
# let's try run evaluation first as a sanity check
# --------------------------------------------------

@train_engine.on(Events.STARTED)
def _():
eval_engine.run(eval_dataloader, max_epochs=1, epoch_length=2)
eval_engine.state.max_epochs = None

# ------------------------------------------
# setup if done. let's run the training
# ------------------------------------------
Expand Down Expand Up @@ -195,7 +204,7 @@ def main():

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
name = f'{config.model}-backend-{idist.backend()}-{now}'
name = f"{config.model}-backend-{idist.backend()}-{now}"
path = Path(config.output_dir, name)
path.mkdir(parents=True, exist_ok=True)
config.output_dir = path
Expand Down
11 changes: 10 additions & 1 deletion templates/single/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def _():
eval_engine.run(eval_dataloader, max_epochs=1)
eval_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), log_metrics, tag="eval")

# --------------------------------------------------
# let's try run evaluation first as a sanity check
# --------------------------------------------------

@train_engine.on(Events.STARTED)
def _():
eval_engine.run(eval_dataloader, max_epochs=1, epoch_length=2)
eval_engine.state.max_epochs = None

# ------------------------------------------
# setup if done. let's run the training
# ------------------------------------------
Expand Down Expand Up @@ -172,7 +181,7 @@ def main():

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
name = f'backend-{idist.backend()}-{now}'
name = f"backend-{idist.backend()}-{now}"
path = Path(config.output_dir, name)
path.mkdir(parents=True, exist_ok=True)
config.output_dir = path
Expand Down

0 comments on commit e79bff3

Please sign in to comment.