Skip to content

Commit

Permalink
Rename maybe_enable_tf32 to set_float32_precision (#661)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #661

The "maybe" and "conditional" wording is unnecessary.

Reviewed By: JKSenthil

Differential Revision: D52348852

fbshipit-source-id: 3bc6612215a3c8abf5d460058469eb3f9db86e02
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Dec 21, 2023
1 parent 9fa78c3 commit a22919e
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/source/utils/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Device Utils
record_data_in_stream
get_nvidia_smi_gpu_stats
get_psutil_cpu_stats
maybe_enable_tf32
set_float32_precision


Distributed Utils
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
get_device_from_env,
get_nvidia_smi_gpu_stats,
get_psutil_cpu_stats,
maybe_enable_tf32,
record_data_in_stream,
set_float32_precision,
)


Expand Down Expand Up @@ -353,13 +353,13 @@ def test_record_data_in_stream_list(self) -> None:
@unittest.skipUnless(
condition=(cuda_available), reason="This test must run on a GPU host."
)
def test_maybe_enable_tf32(self) -> None:
maybe_enable_tf32("highest")
def test_set_float32_precision(self) -> None:
set_float32_precision("highest")
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
self.assertFalse(torch.backends.cudnn.allow_tf32)
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)

maybe_enable_tf32("high")
set_float32_precision("high")
self.assertEqual(torch.get_float32_matmul_precision(), "high")
self.assertTrue(torch.backends.cudnn.allow_tf32)
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
4 changes: 2 additions & 2 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
get_nvidia_smi_gpu_stats,
get_psutil_cpu_stats,
GPUStats,
maybe_enable_tf32,
record_data_in_stream,
set_float32_precision,
)
from .distributed import (
all_gather_tensors,
Expand Down Expand Up @@ -90,7 +90,7 @@
"get_nvidia_smi_gpu_stats",
"get_psutil_cpu_stats",
"GPUStats",
"maybe_enable_tf32",
"set_float32_precision",
"record_data_in_stream",
"all_gather_tensors",
"barrier",
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def collect_system_stats(device: torch.device) -> Dict[str, Any]:
return system_stats


def maybe_enable_tf32(precision: str = "high") -> None:
"""Conditionally sets the precision of float32 matrix multiplications and convolution operations.
def set_float32_precision(precision: str = "high") -> None:
"""Sets the precision of float32 matrix multiplications and convolution operations.
For more information, see the PyTorch docs:
- https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
from torch.distributed.constants import default_pg_timeout
from torchtnt.utils.device import get_device_from_env, maybe_enable_tf32
from torchtnt.utils.device import get_device_from_env, set_float32_precision
from torchtnt.utils.distributed import (
get_file_init_method,
get_process_group_backend_from_device,
Expand Down Expand Up @@ -103,7 +103,7 @@ def init_from_env(
torch.distributed.init_process_group(
init_method=init_method, backend=pg_backend, timeout=pg_timeout
)
maybe_enable_tf32(float32_matmul_precision)
set_float32_precision(float32_matmul_precision)
return device


Expand Down

0 comments on commit a22919e

Please sign in to comment.