New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sync before and after deleting #2268
Changes from 5 commits
97c8abe
717953b
3bc59fe
ff564b8
707193f
4513a2d
bd56517
f51ec61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,15 @@ | |
Authors: | ||
* Abdel Heba 2020 | ||
* Aku Rouhe 2020 | ||
* Peter Plantinga 2023 | ||
""" | ||
import datetime | ||
import os | ||
import torch | ||
from functools import wraps | ||
|
||
MAIN_PROC_ENV = "MAIN_PROC_ONLY" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps this should have a SPEECHBRAIN_ prefix just in case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to module-level variable, which makes this unnecessary |
||
|
||
|
||
def run_on_main( | ||
func, | ||
|
@@ -54,27 +57,15 @@ def run_on_main( | |
if post_kwargs is None: | ||
post_kwargs = {} | ||
|
||
if if_main_process(): | ||
# Main comes here | ||
try: | ||
func(*args, **kwargs) | ||
finally: | ||
ddp_barrier() | ||
else: | ||
# Others go here | ||
ddp_barrier() | ||
main_process_only(func)(*args, **kwargs) | ||
ddp_barrier() | ||
|
||
if post_func is not None: | ||
if run_post_on_main: | ||
# Just run on every process without any barrier. | ||
post_func(*post_args, **post_kwargs) | ||
elif not if_main_process(): | ||
# Others go here | ||
try: | ||
post_func(*post_args, **post_kwargs) | ||
finally: | ||
ddp_barrier() | ||
else: | ||
# But main comes here | ||
main_process_only(post_func)(*post_args, **post_kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the logic is now inverted, post_func is meant to be run on everything else except main (e.g. load a tokenizer that was just created). With run_post_on_main, post_func is also run on main. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aha, you are totally right about this... I'll go ahead and fix this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be fixed in latest commit |
||
ddp_barrier() | ||
|
||
|
||
|
@@ -103,8 +94,11 @@ def main_process_only(function): | |
@wraps(function) | ||
def main_proc_wrapped_func(*args, **kwargs): | ||
"""This decorated function runs only if this is the main process.""" | ||
os.environ[MAIN_PROC_ENV] = "1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additionally I wonder if the environment variables (like MAIN_PROC_FLAG=0
def main_proc_wrapped_func(*args, **kwargs):
global __MAIN_PROC_FLAG
MAIN_PROC_FLAG = 1
...
MAIN_PROC_FLAG = 0
def ddp_barrier():
# Note: as long as this doesn't locally redefine MAIN_PROC_FLAG,
# it doesn't need to be marked as global, as it is not mutated.
if MAIN_PROC_FLAG == 1:
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, a module-level flag is better here. |
||
if if_main_process(): | ||
return function(*args, **kwargs) | ||
result = function(*args, **kwargs) | ||
asumagic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
os.environ[MAIN_PROC_ENV] = "0" | ||
TParcollet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return result | ||
|
||
return main_proc_wrapped_func | ||
|
||
|
@@ -114,7 +108,10 @@ def ddp_barrier(): | |
torch.distributed.barrier() will block processes until the whole | ||
group enters this function. | ||
""" | ||
if torch.distributed.is_initialized(): | ||
# Check if we're in a single-threaded section, skip barrier | ||
if os.environ.get(MAIN_PROC_ENV, "0") == "1": | ||
return | ||
elif torch.distributed.is_initialized(): | ||
torch.distributed.barrier() | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pplantinga what will happen if torch_recovery is called outside of a run on main? These barrier would be hit and MAIN_PROC_ENV wouldn't be 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outside of
run_on_main
the program should be operating with multiple processes, so all should hit the barrier together. The only scenario where it would still freeze is if you are insideif if_main_process():
block, which we should discourage use of.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a solution could be developed to catch those bugs where you branch based on the main_process, but inside that branch you call some code which should hit a DDP barrier. So this will not automatically solve problems, but should help catch bugs. This would replace the if_main_process() (almost drop-in, just adds indentation).
This would be simply used to mark that you intend not to run into DDP Barriers in this part of the code:
So when
if_main_process()
is replaced by this, we should catch some bugs more easily.