Skip to content

Commit

Permalink
Sync before and after deleting (#2268)
Browse files Browse the repository at this point in the history
* Sync before and after deleting to prevent another process writing at the same time

* Remove unneeded if_main_process that conflicts with barrier in delete checkpoint

* Add env variable to signify single-threaded execution

* Fix wrong call to run_on_main in timit recipe

* Fix bug in run_on_main logic

* Fix bug in run_on_main logic

* Convert single-threaded flag to module variable counter

* Fix logic of run_on_main post_func
  • Loading branch information
pplantinga committed Nov 30, 2023
1 parent 5240bdc commit 3fcbbba
Show file tree
Hide file tree
Showing 17 changed files with 69 additions and 44 deletions.
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/transformer/train.py
Expand Up @@ -171,7 +171,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py
Expand Up @@ -173,7 +173,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/CommonVoice/ASR/transformer/train.py
Expand Up @@ -183,7 +183,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
Expand Down
2 changes: 1 addition & 1 deletion recipes/Fisher-Callhome-Spanish/ST/transformer/train.py
Expand Up @@ -251,7 +251,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["BLEU"] = self.bleu_metric.summarize("BLEU")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
current_epoch = self.hparams.epoch_counter.current

# report different epoch stages according current stage
Expand Down
2 changes: 1 addition & 1 deletion recipes/IWSLT22_lowresource/train.py
Expand Up @@ -148,7 +148,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
current_epoch = self.hparams.epoch_counter.current

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
current_epoch = self.hparams.epoch_counter.current
old_lr_adam, new_lr_adam = self.hparams.lr_annealing_adam(
stage_stats["BLEU"]
Expand Down
2 changes: 1 addition & 1 deletion recipes/KsponSpeech/ASR/transformer/train.py
Expand Up @@ -221,7 +221,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/KsponSpeech/LM/train.py
Expand Up @@ -81,7 +81,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/transducer/train.py
Expand Up @@ -256,7 +256,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# Perform end-of-iteration things, like annealing, logging, etc.
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/ASR/transformer/train.py
Expand Up @@ -194,7 +194,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/LibriSpeech/LM/train.py
Expand Up @@ -72,7 +72,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
Expand Up @@ -84,7 +84,7 @@ def compute_objectives(self, forward_outputs, batch, stage):
loss, accuracy = self.hparams.loss(embeddings, targets, negs)

# This is only used for logging purpose
if stage != sb.Stage.TRAIN and sb.utils.distributed.if_main_process():
if stage != sb.Stage.TRAIN:
self.acc_metric.append(accuracy)

objectives = {
Expand Down
2 changes: 1 addition & 1 deletion recipes/Switchboard/ASR/transformer/train.py
Expand Up @@ -247,7 +247,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
2 changes: 1 addition & 1 deletion recipes/Switchboard/LM/train.py
Expand Up @@ -69,7 +69,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats

if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
Expand Down
31 changes: 20 additions & 11 deletions recipes/TIMIT/ASR/transducer/train.py
Expand Up @@ -21,7 +21,7 @@
import logging
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main, if_main_process
from speechbrain.utils.distributed import run_on_main

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,16 +136,25 @@ def on_stage_end(self, stage, stage_loss, epoch):
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats={"loss": stage_loss, "PER": per},
)
if if_main_process():
with open(self.hparams.test_wer_file, "w") as w:
w.write("Transducer loss stats:\n")
self.transducer_metrics.write_stats(w)
w.write("\nPER stats:\n")
self.per_metrics.write_stats(w)
print(
"Transducer and PER stats written to file",
self.hparams.test_wer_file,
)
run_on_main(
save_metrics_to_file,
args=[
self.hparams.test_wer_file,
self.transducer_metrics,
self.per_metrics,
],
)


def save_metrics_to_file(wer_file, transducer_metrics, per_metrics):
with open(wer_file, "w") as w:
w.write("Transducer loss stats:\n")
transducer_metrics.write_stats(w)
w.write("\nPER stats:\n")
per_metrics.write_stats(w)
print(
"Transducer and PER stats written to file", hparams.test_wer_file,
)


def dataio_prep(hparams):
Expand Down
2 changes: 1 addition & 1 deletion recipes/Tedlium2/ASR/transformer/train.py
Expand Up @@ -184,7 +184,7 @@ def on_stage_end(self, stage, stage_loss, epoch):
stage_stats["WER"] = self.wer_metric.summarize("error_rate")

# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if stage == sb.Stage.VALID:

lr = self.hparams.noam_annealing.current_lr
steps = self.optimizer_step
Expand Down
16 changes: 15 additions & 1 deletion speechbrain/utils/checkpoints.py
Expand Up @@ -60,7 +60,11 @@
import warnings
from packaging import version
import speechbrain.utils._workarounds as __wa
from speechbrain.utils.distributed import main_process_only, if_main_process
from speechbrain.utils.distributed import (
main_process_only,
if_main_process,
ddp_barrier,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -990,11 +994,21 @@ def delete_checkpoints(
)
)

# Sync before deleting to avoid another process saving at the same time.
# This has led to errors as documented here:
# https://github.com/speechbrain/speechbrain/issues/2250
ddp_barrier()

# Delete unprotected checkpoints
for ckpt in potential_deletions:
if ckpt not in protected_checkpoints:
Checkpointer._delete_checkpoint(ckpt, verbosity=verbosity)

# Sync after deleting to avoid another process saving at the same time.
# This has led to errors as documented here:
# https://github.com/speechbrain/speechbrain/issues/2250
ddp_barrier()

@staticmethod
@main_process_only
def _delete_checkpoint(checkpoint, verbosity=logging.INFO):
Expand Down
38 changes: 20 additions & 18 deletions speechbrain/utils/distributed.py
Expand Up @@ -3,12 +3,15 @@
Authors:
* Abdel Heba 2020
* Aku Rouhe 2020
* Peter Plantinga 2023
"""
import datetime
import os
import torch
from functools import wraps

MAIN_PROC_ONLY = 0


def run_on_main(
func,
Expand Down Expand Up @@ -54,27 +57,17 @@ def run_on_main(
if post_kwargs is None:
post_kwargs = {}

if if_main_process():
# Main comes here
try:
func(*args, **kwargs)
finally:
ddp_barrier()
else:
# Others go here
ddp_barrier()
main_process_only(func)(*args, **kwargs)
ddp_barrier()

if post_func is not None:
if run_post_on_main:
# Just run on every process without any barrier.
post_func(*post_args, **post_kwargs)
elif not if_main_process():
# Others go here
try:
post_func(*post_args, **post_kwargs)
finally:
ddp_barrier()
else:
# But main comes here
# Do the opposite of `run_on_main`
if not if_main_process():
post_func(*post_args, **post_kwargs)
ddp_barrier()


Expand Down Expand Up @@ -103,8 +96,14 @@ def main_process_only(function):
@wraps(function)
def main_proc_wrapped_func(*args, **kwargs):
"""This decorated function runs only if this is the main process."""
global MAIN_PROC_ONLY
MAIN_PROC_ONLY += 1
if if_main_process():
return function(*args, **kwargs)
result = function(*args, **kwargs)
else:
result = None
MAIN_PROC_ONLY -= 1
return result

return main_proc_wrapped_func

Expand All @@ -114,7 +113,10 @@ def ddp_barrier():
torch.distributed.barrier() will block processes until the whole
group enters this function.
"""
if torch.distributed.is_initialized():
# Check if we're in a single-threaded section, skip barrier
if MAIN_PROC_ONLY >= 1:
return
elif torch.distributed.is_initialized():
torch.distributed.barrier()


Expand Down

0 comments on commit 3fcbbba

Please sign in to comment.