Skip to content
Merged
10 changes: 5 additions & 5 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ jobs:
export example_path="examples/contrib/cifar10"
# initial run
export stop_cmd="--stop_iteration=500"
export test_cmd="CI=1 python ${example_path}/main.py run"
export test_cmd="CI=1 python ${example_path}/main.py run --checkpoint_every=200"
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
# resume
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-None-1_stop-on-500/training_checkpoint_400.pt"
Expand All @@ -268,7 +268,7 @@ jobs:
export example_path="examples/contrib/cifar10"
# initial run
export stop_cmd="--stop_iteration=500"
export test_cmd="CI=1 python -u -m torch.distributed.launch --nproc_per_node=2 --use_env ${example_path}/main.py run --backend=nccl"
export test_cmd="CI=1 python -u -m torch.distributed.launch --nproc_per_node=2 --use_env ${example_path}/main.py run --backend=nccl --checkpoint_every=200"
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
# resume
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-nccl-2_stop-on-500/training_checkpoint_400.pt"
Expand All @@ -280,7 +280,7 @@ jobs:
export example_path="examples/contrib/cifar10"
# initial run
export stop_cmd="--stop_iteration=500"
export test_cmd="CI=1 python -u ${example_path}/main.py run --backend=nccl --nproc_per_node=2"
export test_cmd="CI=1 python -u ${example_path}/main.py run --backend=nccl --nproc_per_node=2 --checkpoint_every=200"
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
# resume
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-nccl-2_stop-on-500/training_checkpoint_400.pt"
Expand Down Expand Up @@ -334,7 +334,7 @@ jobs:
export example_path="examples/contrib/cifar10"
# initial run
export stop_cmd="--stop_iteration=500"
export test_cmd="cd ${example_path} && CI=1 horovodrun -np 2 python -u main.py run --backend=horovod"
export test_cmd="cd ${example_path} && CI=1 horovodrun -np 2 python -u main.py run --backend=horovod --checkpoint_every=200"
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
# resume
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-horovod-2_stop-on-500/training_checkpoint_400.pt"
Expand All @@ -346,7 +346,7 @@ jobs:
export example_path="examples/contrib/cifar10"
# initial run
export stop_cmd="--stop_iteration=500"
export test_cmd="cd ${example_path} && CI=1 python -u main.py run --backend=horovod --nproc_per_node=2"
export test_cmd="cd ${example_path} && CI=1 python -u main.py run --backend=horovod --nproc_per_node=2 --checkpoint_every=200"
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
# resume
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-horovod-2_stop-on-500/training_checkpoint_400.pt"
Expand Down
10 changes: 7 additions & 3 deletions examples/contrib/cifar10/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def run_validation(engine):
n_saved=2,
global_step_transform=global_step_from_engine(trainer),
score_name="test_accuracy",
score_function=Checkpoint.get_default_score_fn("accuracy"),
score_function=Checkpoint.get_default_score_fn("Accuracy"),
)
evaluator.add_event_handler(
Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler
Expand Down Expand Up @@ -173,7 +173,7 @@ def run(
learning_rate (float): peak of piecewise linear learning rate scheduler. Default, 0.4.
num_warmup_epochs (int): number of warm-up epochs before learning rate decay. Default, 4.
validate_every (int): run model's validation every ``validate_every`` epochs. Default, 3.
checkpoint_every (int): store training checkpoint every ``checkpoint_every`` iterations. Default, 200.
checkpoint_every (int): store training checkpoint every ``checkpoint_every`` iterations. Default, 1000.
backend (str, optional): backend to use for distributed configuration. Possible values: None, "nccl", "xla-tpu",
"gloo" etc. Default, None.
nproc_per_node (int, optional): optional argument to setup number of processes per node. It is useful,
Expand Down Expand Up @@ -258,9 +258,13 @@ def log_basic_info(logger, config):
logger.info(f"- PyTorch version: {torch.__version__}")
logger.info(f"- Ignite version: {ignite.__version__}")
if torch.cuda.is_available():
# explicitly import cudnn as
# torch.backends.cudnn can not be pickled with hvd spawning procs
from torch.backends import cudnn

logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
logger.info(f"- CUDA version: {torch.version.cuda}")
logger.info(f"- CUDNN version: {torch.backends.cudnn.version()}")
logger.info(f"- CUDNN version: {cudnn.version()}")

logger.info("\n")
logger.info("Configuration:")
Expand Down
14 changes: 11 additions & 3 deletions examples/contrib/cifar10_qat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,17 @@ def run_validation(engine):
n_saved=2,
global_step_transform=global_step_from_engine(trainer),
score_name="test_accuracy",
score_function=Checkpoint.get_default_score_fn("accuracy"),
score_function=Checkpoint.get_default_score_fn("Accuracy"),
)
evaluator.add_event_handler(
Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler
)

trainer.run(train_loader, max_epochs=config["num_epochs"])
try:
trainer.run(train_loader, max_epochs=config["num_epochs"])
except Exception as e:
logger.exception("")
raise e

if rank == 0:
tb_logger.close()
Expand Down Expand Up @@ -241,9 +245,13 @@ def log_basic_info(logger, config):
logger.info(f"- PyTorch version: {torch.__version__}")
logger.info(f"- Ignite version: {ignite.__version__}")
if torch.cuda.is_available():
# explicitly import cudnn as
# torch.backends.cudnn can not be pickled with hvd spawning procs
from torch.backends import cudnn

logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
logger.info(f"- CUDA version: {torch.version.cuda}")
logger.info(f"- CUDNN version: {torch.backends.cudnn.version()}")
logger.info(f"- CUDNN version: {cudnn.version()}")

logger.info("\n")
logger.info("Configuration:")
Expand Down