New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradscaler flags #2281
Gradscaler flags #2281
Changes from 7 commits
f720eab
de7e0b8
18a7021
1faa548
65612e2
1d2d16d
7ffb517
744ae6d
9c32cab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,10 @@ | |
"compile_using_fullgraph": False, | ||
"compile_using_dynamic_shape_tracing": False, | ||
"precision": "fp32", | ||
"gradscaler_init_scale": 65536.0, | ||
"gradscaler_growth_factor": 2.0, | ||
"gradscaler_backoff_factor": 0.5, | ||
"gradscaler_growth_interval": 2000, | ||
"auto_mix_prec": False, | ||
"bfloat16_mix_prec": False, | ||
"max_grad_norm": 5.0, | ||
|
@@ -360,6 +364,29 @@ def parse_arguments(arg_list=None): | |
help="This flag enables training with automatic mixed-precision." | ||
"It can be set to `fp32`, `fp16`, or `bf16`.", | ||
) | ||
parser.add_argument( | ||
"--gradscaler_init_scale", | ||
type=float, | ||
help="GradScaler initial scale factor.", | ||
) | ||
parser.add_argument( | ||
"--gradscaler_growth_factor", | ||
type=float, | ||
help="GradScaler factor by which the scale is multiplied during " | ||
"`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.", | ||
) | ||
parser.add_argument( | ||
"--gradscaler_backoff_factor", | ||
type=float, | ||
help="GradScaler factor by which the scale is multiplied during `update`" | ||
"if inf/NaN gradients occur in an iteration.", | ||
) | ||
parser.add_argument( | ||
"--gradscaler_growth_interval", | ||
type=float, | ||
help="Gradscaler number of consecutive iterations without inf/NaN gradients that must occur for the scale" | ||
"to be multiplied by `growth_factor`.", | ||
) | ||
parser.add_argument( | ||
"--auto_mix_prec", | ||
default=None, | ||
|
@@ -555,6 +582,14 @@ class and override any methods for which the default behavior does not | |
The location for performing computations. | ||
precision (str) | ||
One of ``fp32``, ``fp16``, ``bf16``. | ||
gradscaler_init_scale (float) | ||
Initial scale for the GradScaler. Default: ``65536.0``. | ||
gradscaler_growth_factor (float) | ||
Growth factor for the GradScaler. Default: ``2.0``. | ||
gradscaler_backoff_factor (float) | ||
Backoff factor for the GradScaler. Default: ``0.5``. | ||
gradscaler_growth_interval (int) | ||
Growth interval for the GradScaler. Default: ``2000``. | ||
auto_mix_prec (bool) | ||
If ``True``, automatic mixed-precision (fp16) is used. | ||
Activate it only with cuda. Note: this is a | ||
|
@@ -753,7 +788,13 @@ def __init__( # noqa: C901 | |
logger.info( | ||
f"Gradscaler enabled: {gradscaler_enabled}. Using precision: {self.precision}." | ||
) | ||
self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled) | ||
self.scaler = torch.cuda.amp.GradScaler( | ||
init_scale=self.gradscaler_init_scale, | ||
growth_factor=self.gradscaler_growth_factor, | ||
backoff_factor=self.gradscaler_backoff_factor, | ||
growth_interval=self.gradscaler_growth_interval, | ||
enabled=gradscaler_enabled, | ||
) | ||
|
||
self.use_amp = False | ||
if self.device == "cpu" and self.precision == "bf16": | ||
|
@@ -1133,6 +1174,7 @@ def fit_batch(self, batch): | |
scaled_loss = self.scaler.scale( | ||
loss / self.grad_accumulation_factor | ||
) | ||
self.check_loss_isfinite(scaled_loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checking for all parameters to be finite is redundant with the GradScaler. It already does this. Also it can be crazy expensive for very large large models. Checking if the loss is not finite makes sense, not the parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The grad scaler does this but it does not care for how long the parameters have gone non-finite. An idea I suggested yesterday was to occasionally check the gradscaler scale for insane values with a patience mechanism, as I've sometimes seen the scale vanish or explode when issues occurred. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @TParcollet. I will remove the part where I'm checking the NaNs/inf in the weights and will only cover the loss part. BTW, the function was also intended for other non-GradScaler use cases like fp32. |
||
scaled_loss.backward() | ||
|
||
if should_step: | ||
|
@@ -1141,6 +1183,43 @@ def fit_batch(self, batch): | |
self.on_fit_batch_end(batch, outputs, loss, should_step) | ||
return loss.detach().cpu() | ||
|
||
def check_loss_isfinite(self, loss): | ||
"""Check if the loss is finite. | ||
|
||
If the loss is not finite, log a helpful message and increment the `nonfinite_count`. | ||
If the `nonfinite_count` exceeds the `--nonfinite_patience` threshold, stop the training | ||
and raise an error. | ||
|
||
This check is particularly useful when the loss becomes NaN or inf, while the | ||
parameters and gradients remain finite. It helps prevent getting stuck in an | ||
infinite loop during training. | ||
|
||
Arguments | ||
--------- | ||
loss : tensor | ||
The loss tensor after ``backward()`` has been called but | ||
before the optimizers ``step()``. | ||
""" | ||
if not torch.isfinite(loss): | ||
self.nonfinite_count += 1 | ||
|
||
# Print helpful debug info | ||
logger.warning(f"Loss is {loss}.") | ||
for p in self.modules.parameters(): | ||
if not torch.isfinite(p).all(): | ||
logger.warning("Parameter is not finite: " + str(p)) | ||
|
||
# Check if patience is exhausted | ||
if self.nonfinite_count > self.nonfinite_patience: | ||
raise ValueError( | ||
"Loss is not finite and patience is exhausted. " | ||
"To debug, wrap `fit()` with " | ||
"autograd's `detect_anomaly()`, e.g.\n\nwith " | ||
"torch.autograd.detect_anomaly():\n\tbrain.fit(...)" | ||
) | ||
else: | ||
logger.warning("Patience not yet exhausted.") | ||
|
||
def check_gradients(self): | ||
""" Checks if the gradients are finite. If not, it will emit a warning and set them to zero.""" | ||
for param in self.modules.parameters(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not in favor in adding these parameters as basic Brain parameters. They will never be used by 99.99% of our users.