diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 2dcffade4..da6cb53d7 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -46,7 +46,6 @@ register_optim_in_bwd_hooks, set_activation_checkpointing, ) -from .metric_logging import compute_grad_norm from .precision import ( get_dtype, list_dtypes, @@ -84,5 +83,4 @@ "register_optim_in_bwd_hooks", "profiler", "get_quantizer_mode", - "compute_grad_norm", ] diff --git a/torchtune/utils/metric_logging.py b/torchtune/utils/metric_logging.py index f631e4c2f..01e5e188e 100644 --- a/torchtune/utils/metric_logging.py +++ b/torchtune/utils/metric_logging.py @@ -12,7 +12,7 @@ from numpy import ndarray from omegaconf import DictConfig, OmegaConf -from torch import nn, Tensor +from torch import Tensor from torchtune.utils import get_logger from torchtune.utils._distributed import get_world_size_and_rank @@ -23,17 +23,6 @@ log = get_logger("DEBUG") -def compute_grad_norm(model: nn.Module) -> float: - """Compute models grad norm""" - total_norm = 0.0 - for p in model.parameters(): - if p.grad is not None: - param_norm = p.grad.detach().data.norm(2) - total_norm += param_norm.item() ** 2 - - return total_norm**0.5 - - class MetricLoggerInterface(Protocol): """Abstract metric logger."""