Skip to content

Commit

Permalink
moves grad scaler util from autounit to precision utils (#658)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #658

Moves grad scaler util from `framework/auto_unit.py` to `utils/precision.py`, adds docstring, and new test

Reviewed By: galrotem

Differential Revision: D52264512

fbshipit-source-id: e5cecb2ca401b6c886f29ed8bb5bc417a6cbb599
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 19, 2023
1 parent 799918a commit 2d008e2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
25 changes: 24 additions & 1 deletion tests/utils/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
# LICENSE file in the root directory of this source tree.

import unittest
from unittest.mock import patch

import torch
from torch.cuda.amp.grad_scaler import GradScaler
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtnt.utils.precision import convert_precision_str_to_dtype
from torchtnt.utils.precision import (
convert_precision_str_to_dtype,
get_grad_scaler_from_precision,
)


class PrecisionTest(unittest.TestCase):
Expand All @@ -32,3 +38,20 @@ def test_convert_precision_str_to_dtype_throws(self) -> None:
"Precision foo not supported. Please use one of .*",
):
convert_precision_str_to_dtype("foo")

def test_get_grad_scaler_from_precision(self) -> None:
grad_scaler = get_grad_scaler_from_precision(
torch.float32, torch.nn.Linear(2, 2)
)
self.assertIsNone(grad_scaler)

grad_scaler = get_grad_scaler_from_precision(
torch.float16, torch.nn.Linear(2, 2)
)
self.assertTrue(isinstance(grad_scaler, GradScaler))

with patch("torchtnt.utils.precision._is_fsdp_module", return_value=True):
grad_scaler = get_grad_scaler_from_precision(
torch.float16, torch.nn.Linear(2, 2)
)
self.assertTrue(isinstance(grad_scaler, ShardedGradScaler))
20 changes: 5 additions & 15 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
from torchtnt.utils.device import copy_data_to_device, record_data_in_stream
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.precision import convert_precision_str_to_dtype
from torchtnt.utils.precision import (
convert_precision_str_to_dtype,
get_grad_scaler_from_precision,
)
from torchtnt.utils.prepare_module import (
_is_fsdp_module,
ActivationCheckpointParams,
Expand Down Expand Up @@ -499,7 +502,7 @@ def __init__(

self.grad_scaler: Optional[GradScaler] = None
if self.precision:
self.grad_scaler = _get_grad_scaler_from_precision(
self.grad_scaler = get_grad_scaler_from_precision(
self.precision,
self.module,
)
Expand Down Expand Up @@ -858,16 +861,3 @@ def _validate_torch_compile_available() -> None:
"Torch compile support is available only in PyTorch 2.0 or higher. "
"Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/"
)


def _get_grad_scaler_from_precision(
precision: torch.dtype, module: torch.nn.Module
) -> Optional[GradScaler]:
if precision == torch.float16:
if _is_fsdp_module(module):
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

return ShardedGradScaler()
else:
return GradScaler()
return None
27 changes: 27 additions & 0 deletions torchtnt/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Mapping, Optional

import torch
from torch.cuda.amp.grad_scaler import GradScaler
from torchtnt.utils.prepare_module import _is_fsdp_module

_DTYPE_STRING_TO_DTYPE_MAPPING: Mapping[str, Optional[torch.dtype]] = {
"fp16": torch.float16,
Expand All @@ -32,3 +34,28 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]:
f"Precision {precision} not supported. Please use one of {list(_DTYPE_STRING_TO_DTYPE_MAPPING.keys())}"
)
return _DTYPE_STRING_TO_DTYPE_MAPPING[precision]


def get_grad_scaler_from_precision(
precision: torch.dtype, module: torch.nn.Module
) -> Optional[GradScaler]:
"""
Returns the correct grad scaler to use based on the precision and whether
or not the model is FSDP.
Args:
precision: the precision being used
module: the module being trained
Returns:
The appropriate grad scaler to use, ``None`` if no grad scaler should be used.
"""

if precision == torch.float16:
if _is_fsdp_module(module):
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

return ShardedGradScaler()
else:
return GradScaler()
return None

0 comments on commit 2d008e2

Please sign in to comment.