Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync before and after deleting #2268

Merged
merged 8 commits into from Nov 30, 2023
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/transformer/train.py
Expand Up @@ -171,7 +171,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py
Expand Up @@ -173,7 +173,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/CommonVoice/ASR/transformer/train.py
Expand Up @@ -183,7 +183,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/Fisher-Callhome-Spanish/ST/transformer/train.py
Expand Up @@ -251,7 +251,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["BLEU"] = self.bleu_metric.summarize("BLEU")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
current_epoch = self.hparams.epoch_counter.current

# report different epoch stages according current stage
Expand Down
2 changes: 1 addition & 1 deletion recipes/IWSLT22_lowresource/train.py
Expand Up @@ -148,7 +148,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
current_epoch = self.hparams.epoch_counter.current

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
current_epoch = self.hparams.epoch_counter.current
old_lr_adam, new_lr_adam = self.hparams.lr_annealing_adam(
stage_stats["BLEU"]
Expand Down
2 changes: 1 addition & 1 deletion recipes/KsponSpeech/ASR/transformer/train.py
Expand Up @@ -221,7 +221,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/KsponSpeech/LM/train.py
Expand Up @@ -81,7 +81,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/transducer/train.py
Expand Up @@ -256,7 +256,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# Perform end-of-iteration things, like annealing, logging, etc.
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/transformer/train.py
Expand Up @@ -194,7 +194,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/LM/train.py
Expand Up @@ -72,7 +72,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
Expand Up @@ -84,7 +84,7 @@ def compute_objectives(self, forward_outputs, batch, stage):
loss, accuracy = self.hparams.loss(embeddings, targets, negs)

# This is only used for logging purpose
if stage != sb.Stage.TRAIN and sb.utils.distributed.if_main_process():
if stage != sb.Stage.TRAIN:
self.acc_metric.append(accuracy)

objectives = {
Expand Down
2 changes: 1 addition & 1 deletion recipes/Switchboard/ASR/transformer/train.py
Expand Up @@ -247,7 +247,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/Switchboard/LM/train.py
Expand Up @@ -69,7 +69,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
31 changes: 20 additions & 11 deletions recipes/TIMIT/ASR/transducer/train.py
Expand Up @@ -21,7 +21,7 @@
import logging
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main, if_main_process
from speechbrain.utils.distributed import run_on_main

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,16 +136,25 @@ def on_stage_end(self, stage, stage_loss, epoch):
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats={"loss": stage_loss, "PER": per},
)
if if_main_process():
with open(self.hparams.test_wer_file, "w") as w:
w.write("Transducer loss stats:\n")
self.transducer_metrics.write_stats(w)
w.write("\nPER stats:\n")
self.per_metrics.write_stats(w)
print(
"Transducer and PER stats written to file",
self.hparams.test_wer_file,
)
run_on_main(
save_metrics_to_file,
args=[
self.hparams.test_wer_file,
self.transducer_metrics,
self.per_metrics,
],
)


def save_metrics_to_file(wer_file, transducer_metrics, per_metrics):
with open(wer_file, "w") as w:
w.write("Transducer loss stats:\n")
transducer_metrics.write_stats(w)
w.write("\nPER stats:\n")
per_metrics.write_stats(w)
print(
"Transducer and PER stats written to file", hparams.test_wer_file,
)


def dataio_prep(hparams):
Expand Down
2 changes: 1 addition & 1 deletion recipes/Tedlium2/ASR/transformer/train.py
Expand Up @@ -184,7 +184,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
16 changes: 15 additions & 1 deletion speechbrain/utils/checkpoints.py
Expand Up @@ -60,7 +60,11 @@
import warnings
from packaging import version
import speechbrain.utils._workarounds as __wa
from speechbrain.utils.distributed import main_process_only, if_main_process
from speechbrain.utils.distributed import (
main_process_only,
if_main_process,
ddp_barrier,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -990,11 +994,21 @@ def delete_checkpoints(
)
)

# Sync before deleting to avoid another process saving at the same time.
# This has led to errors as documented here:
# https://github.com/speechbrain/speechbrain/issues/2250
ddp_barrier()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pplantinga what will happen if torch_recovery is called outside of a run on main? These barrier would be hit and MAIN_PROC_ENV wouldn't be 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of run_on_main the program should be operating with multiple processes, so all should hit the barrier together. The only scenario where it would still freeze is if you are inside if if_main_process(): block, which we should discourage use of.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

Copy link
Collaborator

@Gastron Gastron Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a solution could be developed to catch those bugs where you branch based on the main_process, but inside that branch you call some code which should hit a DDP barrier. So this will not automatically solve problems, but should help catch bugs. This would replace the if_main_process() (almost drop-in, just adds indentation).

BARRIER_PROTECTOR = "SPEECHBRAIN_DDP_BARRIER_PROTECTOR"
os.environ[BARRIER_PROTECTOR] = 0

class DDPProtector(object):
    """Protects from running into DDP Barrier in a code block that has already branched"""
    def __enter__(self):
        # Increment so that we can support nested protectors
        os.environ[BARRIER_PROTECTOR] = str(int(os.environ[BARRIER_PROTECTOR])+1)

    def on_main_process(self):
        # ...There would be a check here...
        return  ## True if on main process, else False

    def __exit__(self, exception_type, exception_value, traceback):
        <something to possibly handle exceptions>
        os.environ[BARRIER_PROTECTOR] = str(int(os.environ[BARRIER_PROTECTOR])-1)
        return

def ddp_barrier():
    """In DDP mode, this function will synchronize all processes.
    torch.distributed.barrier() will block processes until the whole
    group enters this function.
    """
    if int(os.environ[BARRIER_PROTECTOR]) > 0:
        raise RuntimeError("DDP Barrier inside a main process only branch, this will create a deadlock or a subtle bug.")
    # Check if we're in a single-threaded section, skip barrier
    elif os.environ.get(MAIN_PROC_ENV, "0") == "1":
        return
    elif torch.distributed.is_initialized():
        torch.distributed.barrier()

This would be simply used to mark that you intend not to run into DDP Barriers in this part of the code:

with DDPProtector() as protector:
    if protector.on_main_process():
        ...

So when if_main_process() is replaced by this, we should catch some bugs more easily.


# Delete unprotected checkpoints
for ckpt in potential_deletions:
if ckpt not in protected_checkpoints:
Checkpointer._delete_checkpoint(ckpt, verbosity=verbosity)

# Sync after deleting to avoid another process saving at the same time.
# This has led to errors as documented here:
# https://github.com/speechbrain/speechbrain/issues/2250
ddp_barrier()

@staticmethod
@main_process_only
def _delete_checkpoint(checkpoint, verbosity=logging.INFO):
Expand Down
33 changes: 15 additions & 18 deletions speechbrain/utils/distributed.py
Expand Up @@ -3,12 +3,15 @@
Authors:
* Abdel Heba 2020
* Aku Rouhe 2020
* Peter Plantinga 2023
"""
import datetime
import os
import torch
from functools import wraps

MAIN_PROC_ENV = "MAIN_PROC_ONLY"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this should have a SPEECHBRAIN_ prefix just in case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to module-level variable, which makes this unnecessary



def run_on_main(
func,
Expand Down Expand Up @@ -54,27 +57,15 @@ def run_on_main(
if post_kwargs is None:
post_kwargs = {}

if if_main_process():
# Main comes here
try:
func(*args, **kwargs)
finally:
ddp_barrier()
else:
# Others go here
ddp_barrier()
main_process_only(func)(*args, **kwargs)
ddp_barrier()

if post_func is not None:
if run_post_on_main:
# Just run on every process without any barrier.
post_func(*post_args, **post_kwargs)
elif not if_main_process():
# Others go here
try:
post_func(*post_args, **post_kwargs)
finally:
ddp_barrier()
else:
# But main comes here
main_process_only(post_func)(*post_args, **post_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic is now inverted, post_func is meant to be run on everything else except main (e.g. load a tokenizer that was just created). With run_post_on_main, post_func is also run on main.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, you are totally right about this... I'll go ahead and fix this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in latest commit

ddp_barrier()


Expand Down Expand Up @@ -103,8 +94,11 @@ def main_process_only(function):
@wraps(function)
def main_proc_wrapped_func(*args, **kwargs):
"""This decorated function runs only if this is the main process."""
os.environ[MAIN_PROC_ENV] = "1"
Copy link
Collaborator

@Gastron Gastron Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally I wonder if the environment variables (like MAIN_PROC_ENV here) are the right way to do this sort of process-wide communication. I think something like a variable in a module (Python modules are singletons) should be enough here. So instead of this, I think we could just have:

MAIN_PROC_FLAG=0

def main_proc_wrapped_func(*args, **kwargs):
    global __MAIN_PROC_FLAG
    MAIN_PROC_FLAG = 1
    ...
    MAIN_PROC_FLAG = 0
    
    
def ddp_barrier():
    # Note: as long as this doesn't locally redefine MAIN_PROC_FLAG, 
    # it doesn't need to be marked as global, as it is not mutated.
    if MAIN_PROC_FLAG == 1:
        ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, a module-level flag is better here.

if if_main_process():
return function(*args, **kwargs)
result = function(*args, **kwargs)
asumagic marked this conversation as resolved.
Show resolved Hide resolved
os.environ[MAIN_PROC_ENV] = "0"
TParcollet marked this conversation as resolved.
Show resolved Hide resolved
return result

return main_proc_wrapped_func

Expand All @@ -114,7 +108,10 @@ def ddp_barrier():
torch.distributed.barrier() will block processes until the whole
group enters this function.
"""
if torch.distributed.is_initialized():
# Check if we're in a single-threaded section, skip barrier
if os.environ.get(MAIN_PROC_ENV, "0") == "1":
return
elif torch.distributed.is_initialized():
torch.distributed.barrier()


Expand Down