Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradscaler flags #2281

Merged
merged 9 commits into from Dec 13, 2023
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
81 changes: 80 additions & 1 deletion speechbrain/core.py
Expand Up @@ -70,6 +70,10 @@
"compile_using_fullgraph": False,
"compile_using_dynamic_shape_tracing": False,
"precision": "fp32",
"gradscaler_init_scale": 65536.0,
"gradscaler_growth_factor": 2.0,
"gradscaler_backoff_factor": 0.5,
"gradscaler_growth_interval": 2000,
"auto_mix_prec": False,
"bfloat16_mix_prec": False,
"max_grad_norm": 5.0,
Expand Down Expand Up @@ -360,6 +364,29 @@ def parse_arguments(arg_list=None):
help="This flag enables training with automatic mixed-precision."
"It can be set to `fp32`, `fp16`, or `bf16`.",
)
parser.add_argument(
"--gradscaler_init_scale",
type=float,
help="GradScaler initial scale factor.",
)
parser.add_argument(
"--gradscaler_growth_factor",
type=float,
help="GradScaler factor by which the scale is multiplied during "
"`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.",
)
parser.add_argument(
"--gradscaler_backoff_factor",
type=float,
help="GradScaler factor by which the scale is multiplied during `update`"
"if inf/NaN gradients occur in an iteration.",
)
parser.add_argument(
"--gradscaler_growth_interval",
type=float,
help="Gradscaler number of consecutive iterations without inf/NaN gradients that must occur for the scale"
"to be multiplied by `growth_factor`.",
)
parser.add_argument(
"--auto_mix_prec",
default=None,
Expand Down Expand Up @@ -555,6 +582,14 @@ class and override any methods for which the default behavior does not
The location for performing computations.
precision (str)
One of ``fp32``, ``fp16``, ``bf16``.
gradscaler_init_scale (float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not in favor in adding these parameters as basic Brain parameters. They will never be used by 99.99% of our users.

Initial scale for the GradScaler. Default: ``65536.0``.
gradscaler_growth_factor (float)
Growth factor for the GradScaler. Default: ``2.0``.
gradscaler_backoff_factor (float)
Backoff factor for the GradScaler. Default: ``0.5``.
gradscaler_growth_interval (int)
Growth interval for the GradScaler. Default: ``2000``.
auto_mix_prec (bool)
If ``True``, automatic mixed-precision (fp16) is used.
Activate it only with cuda. Note: this is a
Expand Down Expand Up @@ -753,7 +788,13 @@ def __init__( # noqa: C901
logger.info(
f"Gradscaler enabled: {gradscaler_enabled}. Using precision: {self.precision}."
)
self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
self.scaler = torch.cuda.amp.GradScaler(
init_scale=self.gradscaler_init_scale,
growth_factor=self.gradscaler_growth_factor,
backoff_factor=self.gradscaler_backoff_factor,
growth_interval=self.gradscaler_growth_interval,
enabled=gradscaler_enabled,
)

self.use_amp = False
if self.device == "cpu" and self.precision == "bf16":
Expand Down Expand Up @@ -1133,6 +1174,7 @@ def fit_batch(self, batch):
scaled_loss = self.scaler.scale(
loss / self.grad_accumulation_factor
)
self.check_loss_isfinite(scaled_loss)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for all parameters to be finite is redundant with the GradScaler. It already does this. Also it can be crazy expensive for very large large models. Checking if the loss is not finite makes sense, not the parameters.

Copy link
Collaborator

@asumagic asumagic Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grad scaler does this but it does not care for how long the parameters have gone non-finite. An idea I suggested yesterday was to occasionally check the gradscaler scale for insane values with a patience mechanism, as I've sometimes seen the scale vanish or explode when issues occurred.
That does induce a CPU-GPU sync though, but I'm not sure how often we have one in the first place.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @TParcollet. I will remove the part where I'm checking the NaNs/inf in the weights and will only cover the loss part. BTW, the function was also intended for other non-GradScaler use cases like fp32.

scaled_loss.backward()

if should_step:
Expand All @@ -1141,6 +1183,43 @@ def fit_batch(self, batch):
self.on_fit_batch_end(batch, outputs, loss, should_step)
return loss.detach().cpu()

def check_loss_isfinite(self, loss):
"""Check if the loss is finite.

If the loss is not finite, log a helpful message and increment the `nonfinite_count`.
If the `nonfinite_count` exceeds the `--nonfinite_patience` threshold, stop the training
and raise an error.

This check is particularly useful when the loss becomes NaN or inf, while the
parameters and gradients remain finite. It helps prevent getting stuck in an
infinite loop during training.

Arguments
---------
loss : tensor
The loss tensor after ``backward()`` has been called but
before the optimizers ``step()``.
"""
if not torch.isfinite(loss):
self.nonfinite_count += 1

# Print helpful debug info
logger.warning(f"Loss is {loss}.")
for p in self.modules.parameters():
if not torch.isfinite(p).all():
logger.warning("Parameter is not finite: " + str(p))

# Check if patience is exhausted
if self.nonfinite_count > self.nonfinite_patience:
raise ValueError(
"Loss is not finite and patience is exhausted. "
"To debug, wrap `fit()` with "
"autograd's `detect_anomaly()`, e.g.\n\nwith "
"torch.autograd.detect_anomaly():\n\tbrain.fit(...)"
)
else:
logger.warning("Patience not yet exhausted.")

def check_gradients(self):
""" Checks if the gradients are finite. If not, it will emit a warning and set them to zero."""
for param in self.modules.parameters():
Expand Down