Skip to content

Commit

Permalink
remove grad func
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Apr 29, 2024
1 parent b106099 commit d934ec8
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 14 deletions.
2 changes: 0 additions & 2 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
register_optim_in_bwd_hooks,
set_activation_checkpointing,
)
from .metric_logging import compute_grad_norm
from .precision import (
get_dtype,
list_dtypes,
Expand Down Expand Up @@ -84,5 +83,4 @@
"register_optim_in_bwd_hooks",
"profiler",
"get_quantizer_mode",
"compute_grad_norm",
]
13 changes: 1 addition & 12 deletions torchtune/utils/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from numpy import ndarray
from omegaconf import DictConfig, OmegaConf
from torch import nn, Tensor
from torch import Tensor

from torchtune.utils import get_logger
from torchtune.utils._distributed import get_world_size_and_rank
Expand All @@ -23,17 +23,6 @@
log = get_logger("DEBUG")


def compute_grad_norm(model: nn.Module) -> float:
"""Compute models grad norm"""
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2

return total_norm**0.5


class MetricLoggerInterface(Protocol):
"""Abstract metric logger."""

Expand Down

0 comments on commit d934ec8

Please sign in to comment.