diff --git a/horovod/torch/mpi_ops.py b/horovod/torch/mpi_ops.py index bd85c54c43..4b10d93ef5 100644 --- a/horovod/torch/mpi_ops.py +++ b/horovod/torch/mpi_ops.py @@ -40,8 +40,8 @@ _NULL = "" _basics = _HorovodBasics(__file__, 'mpi_lib_v2') + # import basic methods -init = _basics.init is_initialized = _basics.is_initialized start_timeline = _basics.start_timeline stop_timeline = _basics.stop_timeline @@ -61,10 +61,16 @@ ccl_built = _basics.ccl_built cuda_built = _basics.cuda_built rocm_built = _basics.rocm_built + def shutdown(*args, **kwargs): mpi_lib.horovod_torch_reset() return _basics.shutdown(*args, **kwargs) +def init(*args, **kwargs): + global _handle_map + _handle_map = {} + return _basics.init(*args, **kwargs) + # import reduction op values Average = _basics.Average Sum = _basics.Sum @@ -939,6 +945,7 @@ def synchronize(handle): output = _handle_map.pop(handle)[-1] return output except RuntimeError as e: + _handle_map.pop(handle, None) raise HorovodInternalError(e)