Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync before and after deleting #2268

Merged
merged 8 commits into from Nov 30, 2023
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/transformer/train.py
Expand Up @@ -171,7 +171,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py
Expand Up @@ -173,7 +173,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/CommonVoice/ASR/transformer/train.py
Expand Up @@ -183,7 +183,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/Fisher-Callhome-Spanish/ST/transformer/train.py
Expand Up @@ -251,7 +251,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["BLEU"] = self.bleu_metric.summarize("BLEU")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
current_epoch = self.hparams.epoch_counter.current

# report different epoch stages according current stage
Expand Down
2 changes: 1 addition & 1 deletion recipes/IWSLT22_lowresource/train.py
Expand Up @@ -148,7 +148,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
current_epoch = self.hparams.epoch_counter.current

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
current_epoch = self.hparams.epoch_counter.current
old_lr_adam, new_lr_adam = self.hparams.lr_annealing_adam(
stage_stats["BLEU"]
Expand Down
2 changes: 1 addition & 1 deletion recipes/KsponSpeech/ASR/transformer/train.py
Expand Up @@ -221,7 +221,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/KsponSpeech/LM/train.py
Expand Up @@ -81,7 +81,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/transducer/train.py
Expand Up @@ -256,7 +256,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# Perform end-of-iteration things, like annealing, logging, etc.
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/transformer/train.py
Expand Up @@ -194,7 +194,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/LM/train.py
Expand Up @@ -72,7 +72,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
Expand Up @@ -84,7 +84,7 @@ def compute_objectives(self, forward_outputs, batch, stage):
loss, accuracy = self.hparams.loss(embeddings, targets, negs)

# This is only used for logging purpose
if stage != sb.Stage.TRAIN and sb.utils.distributed.if_main_process():
if stage != sb.Stage.TRAIN:
self.acc_metric.append(accuracy)

objectives = {
Expand Down
2 changes: 1 addition & 1 deletion recipes/Switchboard/ASR/transformer/train.py
Expand Up @@ -247,7 +247,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/Switchboard/LM/train.py
Expand Up @@ -69,7 +69,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
31 changes: 20 additions & 11 deletions recipes/TIMIT/ASR/transducer/train.py
Expand Up @@ -21,7 +21,7 @@
import logging
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main, if_main_process
from speechbrain.utils.distributed import run_on_main

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,16 +136,25 @@ def on_stage_end(self, stage, stage_loss, epoch):
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats={"loss": stage_loss, "PER": per},
)
if if_main_process():
with open(self.hparams.test_wer_file, "w") as w:
w.write("Transducer loss stats:\n")
self.transducer_metrics.write_stats(w)
w.write("\nPER stats:\n")
self.per_metrics.write_stats(w)
print(
"Transducer and PER stats written to file",
self.hparams.test_wer_file,
)
run_on_main(
save_metrics_to_file,
args=[
self.hparams.test_wer_file,
self.transducer_metrics,
self.per_metrics,
],
)


def save_metrics_to_file(wer_file, transducer_metrics, per_metrics):
with open(wer_file, "w") as w:
w.write("Transducer loss stats:\n")
transducer_metrics.write_stats(w)
w.write("\nPER stats:\n")
per_metrics.write_stats(w)
print(
"Transducer and PER stats written to file", hparams.test_wer_file,
)


def dataio_prep(hparams):
Expand Down
2 changes: 1 addition & 1 deletion recipes/Tedlium2/ASR/transformer/train.py
Expand Up @@ -184,7 +184,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
16 changes: 15 additions & 1 deletion speechbrain/utils/checkpoints.py
Expand Up @@ -60,7 +60,11 @@
import warnings
from packaging import version
import speechbrain.utils._workarounds as __wa
from speechbrain.utils.distributed import main_process_only, if_main_process
from speechbrain.utils.distributed import (
main_process_only,
if_main_process,
ddp_barrier,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -990,11 +994,21 @@ def delete_checkpoints(
)
)

# Sync before deleting to avoid another process saving at the same time.
# This has led to errors as documented here:
# https://github.com/speechbrain/speechbrain/issues/2250
ddp_barrier()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pplantinga what will happen if torch_recovery is called outside of a run on main? These barrier would be hit and MAIN_PROC_ENV wouldn't be 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of run_on_main the program should be operating with multiple processes, so all should hit the barrier together. The only scenario where it would still freeze is if you are inside if if_main_process(): block, which we should discourage use of.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

Copy link
Collaborator

@Gastron Gastron Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a solution could be developed to catch those bugs where you branch based on the main_process, but inside that branch you call some code which should hit a DDP barrier. So this will not automatically solve problems, but should help catch bugs. This would replace the if_main_process() (almost drop-in, just adds indentation).

BARRIER_PROTECTOR = "SPEECHBRAIN_DDP_BARRIER_PROTECTOR"
os.environ[BARRIER_PROTECTOR] = 0

class DDPProtector(object):
    """Protects from running into DDP Barrier in a code block that has already branched"""
    def __enter__(self):
        # Increment so that we can support nested protectors
        os.environ[BARRIER_PROTECTOR] = str(int(os.environ[BARRIER_PROTECTOR])+1)

    def on_main_process(self):
        # ...There would be a check here...
        return  ## True if on main process, else False

    def __exit__(self, exception_type, exception_value, traceback):
        <something to possibly handle exceptions>
        os.environ[BARRIER_PROTECTOR] = str(int(os.environ[BARRIER_PROTECTOR])-1)
        return

def ddp_barrier():
    """In DDP mode, this function will synchronize all processes.
    torch.distributed.barrier() will block processes until the whole
    group enters this function.
    """
    if int(os.environ[BARRIER_PROTECTOR]) > 0:
        raise RuntimeError("DDP Barrier inside a main process only branch, this will create a deadlock or a subtle bug.")
    # Check if we're in a single-threaded section, skip barrier
    elif os.environ.get(MAIN_PROC_ENV, "0") == "1":
        return
    elif torch.distributed.is_initialized():
        torch.distributed.barrier()

This would be simply used to mark that you intend not to run into DDP Barriers in this part of the code:

with DDPProtector() as protector:
    if protector.on_main_process():
        ...

So when if_main_process() is replaced by this, we should catch some bugs more easily.


# Delete unprotected checkpoints
for ckpt in potential_deletions:
if ckpt not in protected_checkpoints:
Checkpointer._delete_checkpoint(ckpt, verbosity=verbosity)

# Sync after deleting to avoid another process saving at the same time.
# This has led to errors as documented here:
# https://github.com/speechbrain/speechbrain/issues/2250
ddp_barrier()

@staticmethod
@main_process_only
def _delete_checkpoint(checkpoint, verbosity=logging.INFO):
Expand Down
38 changes: 20 additions & 18 deletions speechbrain/utils/distributed.py
Expand Up @@ -3,12 +3,15 @@
Authors:
* Abdel Heba 2020
* Aku Rouhe 2020
* Peter Plantinga 2023
"""
import datetime
import os
import torch
from functools import wraps

MAIN_PROC_ONLY = 0


def run_on_main(
func,
Expand Down Expand Up @@ -54,27 +57,17 @@ def run_on_main(
if post_kwargs is None:
post_kwargs = {}

if if_main_process():
# Main comes here
try:
func(*args, **kwargs)
finally:
ddp_barrier()
else:
# Others go here
ddp_barrier()
main_process_only(func)(*args, **kwargs)
ddp_barrier()

if post_func is not None:
if run_post_on_main:
# Just run on every process without any barrier.
post_func(*post_args, **post_kwargs)
elif not if_main_process():
# Others go here
try:
post_func(*post_args, **post_kwargs)
finally:
ddp_barrier()
else:
# But main comes here
# Do the opposite of `run_on_main`
if not if_main_process():
post_func(*post_args, **post_kwargs)
ddp_barrier()


Expand Down Expand Up @@ -103,8 +96,14 @@ def main_process_only(function):
@wraps(function)
def main_proc_wrapped_func(*args, **kwargs):
"""This decorated function runs only if this is the main process."""
global MAIN_PROC_ONLY
MAIN_PROC_ONLY += 1
if if_main_process():
return function(*args, **kwargs)
result = function(*args, **kwargs)
asumagic marked this conversation as resolved.
Show resolved Hide resolved
else:
result = None
MAIN_PROC_ONLY -= 1
return result

return main_proc_wrapped_func

Expand All @@ -114,7 +113,10 @@ def ddp_barrier():
torch.distributed.barrier() will block processes until the whole
group enters this function.
"""
if torch.distributed.is_initialized():
# Check if we're in a single-threaded section, skip barrier
if MAIN_PROC_ONLY >= 1:
return
elif torch.distributed.is_initialized():
torch.distributed.barrier()


Expand Down