Skip to content

Commit

Permalink
Adding no_sync and on_fit_batch_end method to core
Browse files Browse the repository at this point in the history
  • Loading branch information
RuABraun committed Jun 19, 2022
1 parent 4d4f744 commit edb7714
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions speechbrain/core.py
Expand Up @@ -20,6 +20,7 @@
import argparse
import tempfile
import warnings
from contextlib import contextmanager
import speechbrain as sb
from datetime import date
from enum import Enum, auto
Expand Down Expand Up @@ -882,9 +883,14 @@ def fit_batch(self, batch):
self.optimizer.step()
self.optimizer.zero_grad()
self.optimizer_step += 1

if should_step:
self.on_fit_batch_end()
return loss.detach().cpu()

def on_fit_batch_end(self):
"""Called after ``fit_batch()``"""
pass

def check_gradients(self, loss):
"""Check if gradients are finite and not too large.
Expand Down Expand Up @@ -923,9 +929,10 @@ def check_gradients(self, loss):
return False

# Clip gradient norm
torch.nn.utils.clip_grad_norm_(
(p for p in self.modules.parameters()), self.max_grad_norm
)
if self.max_grad_norm != 0.0:
torch.nn.utils.clip_grad_norm_(
(p for p in self.modules.parameters()), self.max_grad_norm
)

return True

Expand Down Expand Up @@ -1281,6 +1288,38 @@ def update_average(self, loss, avg_loss):
avg_loss += float(loss) / self.step
return avg_loss

@contextmanager
def no_sync(self, use=True):
"""Copies pytorch's implementation for doing no_sync across all modules.
Explanation: nn.module.no_sync() is a context manager for when one does
not want to sync gradients, which happens when using both DDP and gradient accumulation.
Speechbrain brain's class can contain multiple modules and calling no_sync on these
individually would be very awkward, therefore this contextmanager exists.
Arguments
---------
use : bool
If set to `False` will still sync gradients.
"""
if use:
old_values_list = []
for module in self.modules.values():
if not hasattr(module, "require_backward_grad_sync"):
# if not using DDP
break
old_values_list.append(module.require_backward_grad_sync)
module.require_backward_grad_sync = False
yield
for module, old_value in zip(
self.modules.values(), old_values_list
):
if not hasattr(module, "require_backward_grad_sync"):
break
module.require_backward_grad_sync = old_value
else:
yield

@sb.utils.checkpoints.mark_as_saver
def _save(self, path):
save_dict = {
Expand Down

0 comments on commit edb7714

Please sign in to comment.