From 0f89e66d1745b8f4b304ebf46174bc726f0c28f5 Mon Sep 17 00:00:00 2001 From: Kurman Karabukaev Date: Mon, 17 Jun 2024 20:07:13 +0000 Subject: [PATCH 01/63] Validate logs are created by default (#128522) Summary: Make sure that logs are caputured in default settings Test Plan: ci Differential Revision: D58395812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128522 Approved by: https://github.com/d4l3k --- test/distributed/launcher/test_run.py | 31 ++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/test/distributed/launcher/test_run.py b/test/distributed/launcher/test_run.py index ba58aec438715..f71bffd527c1e 100644 --- a/test/distributed/launcher/test_run.py +++ b/test/distributed/launcher/test_run.py @@ -6,6 +6,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import io import multiprocessing as mp import os import runpy @@ -14,7 +15,7 @@ import sys import tempfile import uuid -from contextlib import closing +from contextlib import closing, redirect_stderr, redirect_stdout from unittest import mock from unittest.mock import MagicMock, Mock, patch @@ -629,6 +630,34 @@ def test_init_method_env_with_torchelastic(self): ) # nothing to validate, just make sure it runs + def test_capture_logs_using_default_logs_specs(self): + run_id = str(uuid.uuid4().int) + nnodes = 1 + nproc_per_node = 4 + args = [ + f"--nnodes={nnodes}", + f"--nproc-per-node={nproc_per_node}", + f"--rdzv-id={run_id}", + "--redirect=3", + "--tee=3", + "--monitor-interval=1", + "--start-method=spawn", + "--no-python", + ] + + script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] + + captured_out = io.StringIO() + captured_err = io.StringIO() + with redirect_stdout(captured_out), redirect_stderr(captured_err): + with patch.dict( + os.environ, {"TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE": "[rank${rank}]: "} + ): + launch.main(args + script_args) + + for i in range(nproc_per_node): + self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue()) + if __name__ == "__main__": run_tests() From a59766ee058ba10d61e94c96daf2f7ded63efdb8 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 17 Jun 2024 20:50:22 +0000 Subject: [PATCH 02/63] replace `AT_ERROR(...)` with `TORCH_CHECK(false, ...)` (#128788) as per title. encountered the old-fashioned by chance Pull Request resolved: https://github.com/pytorch/pytorch/pull/128788 Approved by: https://github.com/mikaylagawarecki --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 3a473495ff9f1..250fe68ff5e66 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1619,7 +1619,7 @@ Tensor alias_with_sizes_and_strides( Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { if (self.is_sparse()) { - AT_ERROR("reshape is not implemented for sparse tensors"); + TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } if (self.is_contiguous() && !self.is_mkldnn()) { @@ -1682,7 +1682,7 @@ Tensor _reshape_copy_symint(const Tensor& self, c10::SymIntArrayRef proposed_sha // minimize breakages. Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) { if (self.is_sparse()) { - AT_ERROR("reshape is not implemented for sparse tensors"); + TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } DimVector shape = infer_size_dv(proposed_shape, self.numel()); From 8c06eae17eb470e3eb97f58cf6c0eddad26937f6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 15 Jun 2024 18:30:41 -0700 Subject: [PATCH 03/63] [GPT-benchmark] Add metric: compilation time for GPT models (#128768) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128768 Approved by: https://github.com/Chillee --- benchmarks/gpt_fast/generate.py | 40 ++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index 3ec72bf1e3195..92d00cb1bdb6b 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -27,6 +27,7 @@ class GPTModelConfig: quantizer: type token_per_sec: float memory_bandwidth: float + compilation_time: float def device_sync(device): @@ -190,6 +191,7 @@ def run_experiment( aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} start = -1 + compilation_time = None for i in range(start, num_samples): device_sync(device=device) # MKG @@ -200,7 +202,8 @@ def run_experiment( ) if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + compilation_time = time.perf_counter() - t0 + print(f"Compilation time: {compilation_time:.2f} seconds") continue device_sync(device=device) # MKG @@ -217,7 +220,7 @@ def run_experiment( print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - return token_per_sec, memory_bandwidth + return token_per_sec, memory_bandwidth, compilation_time # token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. @@ -231,8 +234,9 @@ def run_llama2_7b_bf16(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 94, 1253, + 162, ) - token_per_sec, memory_bandwidth = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -250,6 +254,14 @@ def run_llama2_7b_bf16(device: str = "cuda"): model.mode, device, ), + Experiment( + model.name, + "compilation_time(s)", + model.compilation_time, + f"{compilation_time:.02f}", + model.mode, + device, + ), ] @@ -264,8 +276,9 @@ def run_llama2_7b_int8(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 144, 957, + 172, ) - token_per_sec, memory_bandwidth = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -283,6 +296,14 @@ def run_llama2_7b_int8(device: str = "cuda"): model.mode, device, ), + Experiment( + model.name, + "compilation_time(s)", + model.compilation_time, + f"{compilation_time:.02f}", + model.mode, + device, + ), ] @@ -298,8 +319,9 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): MixtralMoEWeightOnlyInt8QuantHandler, 175, 4129, + 162, ) - token_per_sec, memory_bandwidth = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -317,4 +339,12 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): model.mode, device, ), + Experiment( + model.name, + "compilation_time(s)", + model.compilation_time, + f"{compilation_time:.02f}", + model.mode, + device, + ), ] From a489792bb2d59ad7e36e0d3ae55074ce707b47e8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 15 Jun 2024 18:30:44 -0700 Subject: [PATCH 04/63] [GPT-benchmark] Fix memory bandwidth for MoE (#128783) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128783 Approved by: https://github.com/Chillee ghstack dependencies: #128768 --- benchmarks/gpt_fast/generate.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index 92d00cb1bdb6b..19c32d06be104 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -3,8 +3,9 @@ import time from typing import Optional, Tuple -from mixtral_moe_model import Transformer as MixtralMoE +from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE from mixtral_moe_quantize import ( + ConditionalFeedForwardInt8, WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, ) from model import Transformer as LLaMA @@ -154,6 +155,7 @@ def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16): return model.eval() +# Only count activated parameters and buffers. def _get_model_size(model): model_size = 0 for name, child in model.named_children(): @@ -164,6 +166,28 @@ def _get_model_size(model): for p in itertools.chain(child.parameters(), child.buffers()) ] ) + + # Remove the inactivated experts from the model size if this is mixture of experts + # architecture, since only activated experts are loaded. + if hasattr(model.config, "num_experts"): + config = model.config + for submodule in model.modules(): + if isinstance( + submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8) + ): + model_size -= ( + sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain( + submodule.parameters(), child.buffers() + ) + ] + ) + * (config.num_experts - config.num_activated_experts) + / config.num_experts + ) + return model_size @@ -318,7 +342,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): "int8", MixtralMoEWeightOnlyInt8QuantHandler, 175, - 4129, + 1280, 162, ) token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) From 8953725e6d68b3b7011626319a17fca5bd0b3e75 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 17 Jun 2024 21:10:55 +0000 Subject: [PATCH 05/63] [Inductor][FlexAttention] Tune backwards kernel block sizes (#128853) This replaces #128767 which somehow closed by mistake. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128853 Approved by: https://github.com/angelayi --- torch/_inductor/kernel/flex_attention.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 932bcd50b9203..987dc6d89328b 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -361,7 +361,7 @@ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: return (64, 64, 4, 1) return (128, 128, 4, 3) elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100 - return (32, 32, 4, 1) + return (64, 64, 4, 1) else: # modest hardware or extremely large head_dim return (16, 16, 4, 1) @@ -763,14 +763,13 @@ def flex_attention_backward(*args, **kwargs): configs: List[Tuple[int, int, int, int]] = [] configs.append(_get_default_config_bwd(query)) if config.max_autotune: - configs += [ - (128, 128, 4, 3), - (128, 128, 8, 1), - (64, 64, 4, 3), - (64, 64, 8, 1), - ] + for BLOCK1 in [32, 64]: + for BLOCK2 in [32, 64]: + for w in [4, 8]: + for s in [1, 3]: + configs.append((BLOCK1, BLOCK2, w, s)) - for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + for BLOCK1, BLOCK2, num_warps, num_stages in configs: flex_attention_backward_template.maybe_append_choice( choices=choices, input_nodes=[ @@ -790,10 +789,10 @@ def flex_attention_backward(*args, **kwargs): call_sizes=query.get_size() + [key.get_size()[2]], num_stages=num_stages, num_warps=num_warps, - BLOCK_M1=BLOCK_M, - BLOCK_N1=BLOCK_N, - BLOCK_M2=BLOCK_N, - BLOCK_N2=BLOCK_M, + BLOCK_M1=BLOCK1, + BLOCK_N1=BLOCK1, + BLOCK_M2=BLOCK2, + BLOCK_N2=BLOCK2, BLOCK_DMODEL=query.get_size()[-1], # For now, we always assume the "sound" option SCORE_MOD_IS_LINEAR=False, From 163847b1bb5cc36a0915a189b2dd4cfbbfaf9c49 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 17 Jun 2024 21:25:55 +0000 Subject: [PATCH 06/63] [1/N] [Caffe2] Remove caffe2_aten_fallback code (#128675) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128675 Approved by: https://github.com/r-barnes --- test/onnx/test_export_modes.py | 20 --- test/onnx/test_operators.py | 6 +- test/onnx/test_pytorch_onnx_no_runtime.py | 80 +--------- test/onnx/test_utility_funs.py | 22 +-- test/quantization/core/test_quantized_op.py | 43 +----- test/test_jit.py | 52 +------ torch/_C/_onnx.pyi | 1 - torch/csrc/onnx/init.cpp | 2 - torch/onnx/__init__.py | 7 +- torch/onnx/_internal/jit_utils.py | 19 --- torch/onnx/symbolic_helper.py | 9 +- torch/onnx/symbolic_opset11.py | 25 ---- torch/onnx/symbolic_opset12.py | 2 - torch/onnx/symbolic_opset16.py | 3 - torch/onnx/symbolic_opset9.py | 137 ++---------------- torch/onnx/utils.py | 106 ++------------ torch/onnx/verification.py | 5 +- .../testing/_internal/common_quantization.py | 8 - torch/testing/_internal/common_utils.py | 15 -- 19 files changed, 33 insertions(+), 529 deletions(-) diff --git a/test/onnx/test_export_modes.py b/test/onnx/test_export_modes.py index 5bf84c1b409a0..6d48b2f4578de 100644 --- a/test/onnx/test_export_modes.py +++ b/test/onnx/test_export_modes.py @@ -86,26 +86,6 @@ def foo(a): x = torch.ones(3) torch.onnx.export(foo, (x,), f) - @common_utils.skipIfNoCaffe2 - @common_utils.skipIfNoLapack - def test_caffe2_aten_fallback(self): - class ModelWithAtenNotONNXOp(nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - torch.onnx.export_to_pretty_string( - ModelWithAtenNotONNXOp(), - (x, y), - add_node_names=False, - do_constant_folding=False, - operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - - @common_utils.skipIfCaffe2 @common_utils.skipIfNoLapack def test_aten_fallback(self): class ModelWithAtenNotONNXOp(nn.Module): diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 99f0d533a61ca..87ec424cf65d5 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -39,7 +39,7 @@ parse_args, ) from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack +from torch.testing._internal.common_utils import skipIfNoLapack unittest.TestCase.maxDiff = None @@ -414,7 +414,6 @@ def test_maxpool_indices(self): x = torch.randn(20, 16, 50) self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x) - @skipIfCaffe2 def test_at_op(self): x = torch.randn(3, 4) @@ -694,7 +693,6 @@ def test_batchnorm_noaffine(self): keep_initializers_as_inputs=True, ) - @skipIfCaffe2 def test_embedding_bags(self): emb_bag = nn.EmbeddingBag(10, 8) input = torch.tensor([1, 2, 3, 4]).long() @@ -949,7 +947,6 @@ def forward(self, input, other): other = torch.randint(-50, 50, (2, 3, 4), dtype=torch.int8) self.assertONNX(BiwiseAndModel(), (input, other), opset_version=18) - @skipIfCaffe2 def test_layer_norm_aten(self): model = torch.nn.LayerNorm([10, 10]) x = torch.randn(20, 5, 10, 10) @@ -1203,7 +1200,6 @@ def forward(self, x, y): torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version) # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding. - @skipIfCaffe2 def test_aten_embedding_2(self): _onnx_opset_version = 12 diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 54fc178251539..324806eaf0adf 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -20,7 +20,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.onnx import OperatorExportTypes, symbolic_helper, utils +from torch.onnx import symbolic_helper, utils from torch.onnx._internal import registration from torch.testing._internal import common_quantization, common_utils, jit_utils @@ -394,7 +394,6 @@ def forward(self, input): for node in graph.nodes(): self.assertTrue(node.sourceRange()) - @common_utils.skipIfCaffe2 def test_clip_aten_fallback_due_exception(self): def bad_clamp(g, self, min, max): return symbolic_helper._onnx_unsupported("Bad boy!") @@ -411,7 +410,6 @@ def forward(self, x): ) self.assertAtenOp(onnx_model, "clamp", "Tensor") - @common_utils.skipIfCaffe2 def test_clip_aten_fallback_explicit_request(self): class MyClip(torch.nn.Module): def forward(self, x): @@ -961,60 +959,6 @@ def forward(self, x, w): torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) - @common_utils.skipIfNoCaffe2 - def test_caffe2_aten_fallback_must_fallback(self): - class ModelWithAtenNotONNXOp(torch.nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize - for operator_export_type in ( - OperatorExportTypes.ONNX_ATEN, - OperatorExportTypes.ONNX_ATEN_FALLBACK, - ): - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=operator_export_type, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") - - @common_utils.skipIfNoCaffe2 - def test_caffe2_onnx_aten_must_not_fallback(self): - class ModelWithAtenFmod(torch.nn.Module): - def forward(self, x, y): - return torch.fmod(x, y) - - # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize - for operator_export_type in ( - OperatorExportTypes.ONNX_ATEN_FALLBACK, - OperatorExportTypes.ONNX_ATEN, - ): - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=operator_export_type, - opset_version=10, # or higher - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - assert onnx_model.graph.node[0].op_type == "Mod" - - @common_utils.skipIfCaffe2 def test_aten_fallback_must_fallback(self): class ModelWithAtenNotONNXOp(torch.nn.Module): def forward(self, x, y): @@ -1037,7 +981,6 @@ def forward(self, x, y): onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertAtenOp(onnx_model, "linalg_qr") - @common_utils.skipIfCaffe2 def test_onnx_aten(self): class ModelWithAtenFmod(torch.nn.Module): def forward(self, x, y): @@ -1056,7 +999,6 @@ def forward(self, x, y): onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertAtenOp(onnx_model, "fmod", "Tensor") - @common_utils.skipIfCaffe2 def test_onnx_aten_fallback_must_not_fallback(self): # For BUILD_CAFFE2=0, aten fallback only when not exportable class ONNXExportable(torch.nn.Module): @@ -1233,26 +1175,6 @@ def _export_to_onnx(model, input, input_names): _export_to_onnx(model, data, input_names) - @common_quantization.skipIfNoFBGEMM - @common_utils.skipIfNoCaffe2 - def test_lower_graph_linear(self): - model = torch.ao.quantization.QuantWrapper( - torch.nn.Linear(5, 10, bias=True) - ).to(dtype=torch.float) - data_numpy = np.random.rand(1, 2, 5).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - - @common_quantization.skipIfNoFBGEMM - @common_utils.skipIfNoCaffe2 - def test_lower_graph_conv2d(self): - model = torch.ao.quantization.QuantWrapper( - torch.nn.Conv2d(3, 5, 2, bias=True) - ).to(dtype=torch.float) - data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - @common_quantization.skipIfNoFBGEMM @unittest.skip( "onnx opset9 does not support quantize_per_tensor and caffe2 \ diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 9ee4129879652..e7c8f40781033 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -17,7 +17,6 @@ skipIfUnsupportedMaxOpsetVersion, skipIfUnsupportedMinOpsetVersion, ) -from verify import verify import torch import torch.onnx @@ -26,7 +25,7 @@ from torch.onnx._globals import GLOBALS from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack +from torch.testing._internal.common_utils import skipIfNoLapack def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str: @@ -1623,25 +1622,6 @@ def forward(self, x): "Graph parameter names does not match model parameters.", ) - @skipIfNoCaffe2 - def test_modifying_params(self): - class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.tensor([2.0])) - - def forward(self, x): - y = x * x - self.param.data.add_(1.0) - return y - - x = torch.tensor([1, 2]) - # Move import to local as caffe2 backend requires additional build flag, - # and is only used in this test case. - import caffe2.python.onnx.backend as backend - - verify(MyModel(), x, backend, do_constant_folding=False) - def test_fuse_conv_bn(self): class Fuse(torch.nn.Module): def __init__(self): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 5b86693e11c10..2e606938192dd 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -23,7 +23,7 @@ from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_utils import TestCase -from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE +from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ override_quantized_engine, supported_qengines, override_qengines, _snr @@ -4524,47 +4524,6 @@ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embe self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight) - """ Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """ - @unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2") - @given(num_embeddings=st.integers(10, 100), - embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), - num_batches=st.integers(1, 5), - data_type=st.sampled_from([np.float32, np.float16]),) - def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches, data_type): - pack_fn = torch.ops.quantized.embedding_bag_byte_prepack - unpack_fn = torch.ops.quantized.embedding_bag_byte_unpack - - self._test_embedding_bag_unpack_fn( - pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches, data_type=data_type) - - """ Tests the correctness of the embedding_bag_4bit pack/unpack op against C2 """ - @unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2") - @given(num_embeddings=st.integers(10, 100), - embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), - optimized_qparams=st.booleans(), - data_type=st.sampled_from([np.float32, np.float16]),) - def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type): - pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack - unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack - - # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1 - self._test_embedding_bag_unpack_fn( - pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type) - - """ Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """ - @unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2") - @given(num_embeddings=st.integers(10, 100), - embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0), - optimized_qparams=st.booleans(), - data_type=st.sampled_from([np.float32, np.float16]),) - def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type): - pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack - unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack - - # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1 - self._test_embedding_bag_unpack_fn( - pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type) - def embedding_bag_rowwise_offsets_run( self, bit_rate, num_embeddings, diff --git a/test/test_jit.py b/test/test_jit.py index 13bdd07be6cd9..afecb5f390402 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -96,7 +96,7 @@ from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ - suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ + suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -15299,56 +15299,6 @@ def is_tensor_value(item): continue self.assertEqual(value, getattr(loaded, "_" + name)) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") - @unittest.skipIf(not BUILD_WITH_CAFFE2, "PyTorch is build without Caffe2 support") - def test_old_models_bc(self): - model = { - 'archive/version': b'1', - 'archive/code/archive.py': - b''' - op_version_set = 0 - def forward(self, - _0: Tensor) -> Tensor: - _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu")) - result = torch.to(torch.fill_(_1, 5), dtype=6, layout=0, device=torch.device("cpu"), - non_blocking=False, copy=False) - result2 = torch.rand([10], dtype=6, layout=0, device=torch.device("cpu")) - result3 = torch.rand_like(result2, dtype=6, layout=0, device=torch.device("cpu")) - _2 = torch.add(torch.add(result, result2, alpha=1), result3, alpha=1) - return _2 - ''', - 'archive/attributes.pkl': b'\x80\x02](e.', - 'archive/libs.py': b'op_version_set = 0\n', - 'archive/model.json': - b''' - { - "protoVersion":"2", - "mainModule":{ - "torchscriptArena":{ - "key":"code/archive.py" - }, - "name":"archive", - "optimize":true - }, - "producerName":"pytorch", - "producerVersion":"1.0", - "libs":{ - "torchscriptArena":{ - "key":"libs.py" - } - } - }'''} - with TemporaryFileName() as fname: - archive_name = os.path.basename(os.path.normpath(fname)) - with zipfile.ZipFile(fname, 'w') as archive: - for k, v in model.items(): - archive.writestr(k, v) - - with open(fname, "rb") as f: - fn = torch.jit.load(f) - - x = torch.zeros(10) - fn(x) def test_submodule_attribute_serialization(self): class S(torch.jit.ScriptModule): diff --git a/torch/_C/_onnx.pyi b/torch/_C/_onnx.pyi index 2e8e5a0c66117..349e0b9ad12f0 100644 --- a/torch/_C/_onnx.pyi +++ b/torch/_C/_onnx.pyi @@ -2,7 +2,6 @@ from enum import Enum -_CAFFE2_ATEN_FALLBACK: bool PRODUCER_VERSION: str class TensorProtoDataType(Enum): diff --git a/torch/csrc/onnx/init.cpp b/torch/csrc/onnx/init.cpp index b8bef342323c5..6b06eb649cae0 100644 --- a/torch/csrc/onnx/init.cpp +++ b/torch/csrc/onnx/init.cpp @@ -292,7 +292,5 @@ void initONNXBindings(PyObject* module) { .value("TRAINING", TrainingMode::TRAINING); onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION); - - onnx.attr("_CAFFE2_ATEN_FALLBACK") = false; } } // namespace torch::onnx diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 2b2f2bdae0de3..3f013b1235842 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -1,12 +1,7 @@ # mypy: allow-untyped-defs from torch import _C from torch._C import _onnx as _C_onnx -from torch._C._onnx import ( - _CAFFE2_ATEN_FALLBACK, - OperatorExportTypes, - TensorProtoDataType, - TrainingMode, -) +from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode from . import ( # usort:skip. Keep the order instead of sorting lexicographically _deprecation, diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py index 13ae4209da5dd..ed064f6f874d7 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/jit_utils.py @@ -12,7 +12,6 @@ import torch from torch import _C -from torch._C import _onnx as _C_onnx from torch.onnx._globals import GLOBALS from torch.onnx._internal import _beartype, registration @@ -329,14 +328,6 @@ def _scalar(x: torch.Tensor): return x[0] -@_beartype.beartype -def _is_caffe2_aten_fallback() -> bool: - return ( - GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - and _C_onnx._CAFFE2_ATEN_FALLBACK - ) - - @_beartype.beartype def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): r"""Initializes the right attribute based on type of value.""" @@ -350,16 +341,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): if _is_onnx_list(value): kind += "s" - if aten and _is_caffe2_aten_fallback(): - if isinstance(value, torch.Tensor): - # Caffe2 proto does not support tensor attribute. - if value.numel() > 1: - raise ValueError("Should not pass tensor attribute") - value = _scalar(value) - if isinstance(value, float): - kind = "f" - else: - kind = "i" return getattr(node, f"{kind}_")(name, value) diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 676c3d68048b0..6d876486f642c 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -537,10 +537,7 @@ def is_complex_value(x: _C.Value) -> bool: @_beartype.beartype def is_caffe2_aten_fallback() -> bool: - return ( - GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - and _C_onnx._CAFFE2_ATEN_FALLBACK - ) + return False @_beartype.beartype @@ -592,9 +589,7 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]): @_beartype.beartype def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None: # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators - if _C_onnx._CAFFE2_ATEN_FALLBACK: - warnings.warn(f"ONNX export failed on {op} because {msg} not supported") - elif GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: _onnx_unsupported(f"{op}, {msg}", value) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index e562d5a47567c..90963c4f17fa7 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -211,10 +211,6 @@ def index_put( indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] - if symbolic_helper.is_caffe2_aten_fallback(): - args = [self] + indices_list + [values, accumulate] - return g.at("index_put", *args) - accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: @@ -398,8 +394,6 @@ def __interpolate( def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): if symbolic_helper._maybe_get_const(sparse_grad, "i"): return symbolic_helper._unimplemented("gather", "sparse_grad == True") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("gather", self, dim, index, sparse_grad) return g.op("GatherElements", self, index, axis_i=dim) @@ -407,8 +401,6 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter(g: jit_utils.GraphContext, self, dim, index, src): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("scatter", self, dim, index, src, overload_name="src") src_type = _type_utils.JitScalarType.from_value(src) src = symbolic_helper._maybe_get_scalar(src) if symbolic_helper._is_value(src): @@ -898,8 +890,6 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim): stop = g.op( "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 ) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.op("_caffe2::Range", stop) return arange(g, stop, 4, None, None, None) @@ -982,9 +972,6 @@ def mm(g: jit_utils.GraphContext, self, other): @_onnx_symbolic("aten::index") @_beartype.beartype def index(g: jit_utils.GraphContext, self, index): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index", self, index, overload_name="Tensor") - if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: @@ -1007,16 +994,6 @@ def index(g: jit_utils.GraphContext, self, index): @_beartype.beartype def index_fill(g: jit_utils.GraphContext, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "index_fill", - self, - index, - value, - overload_name="int_Scalar", - dim_i=dim_value, - ) - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -1030,8 +1007,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value): @_beartype.beartype def index_copy(g: jit_utils.GraphContext, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 5a6bf720df36f..cf24fe43247ca 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -330,8 +330,6 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step): const_step ): return opset9.unfold(g, input, dimension, const_size, const_step) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) if sizedim is not None: diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index cd5829ada850d..8df3d954ba433 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -71,9 +71,6 @@ def grid_sampler( @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("scatter", self, dim, index, src, overload_name="src") - src_type = _type_utils.JitScalarType.from_value( src, _type_utils.JitScalarType.UNDEFINED ) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index b4c937ed3f66b..f43a09aa4b147 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -841,36 +841,18 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def cumsum(g: jit_utils.GraphContext, input, dim, dtype): - if symbolic_helper.is_caffe2_aten_fallback(): - if dtype.node().kind() != "prim::Constant": - return symbolic_helper._unimplemented("cumsum", "dtype", dtype) - return g.at("cumsum", input, dim_i=dim) - symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) @_onnx_symbolic("aten::_sample_dirichlet") @_beartype.beartype def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): - if symbolic_helper.is_caffe2_aten_fallback(): - if not symbolic_helper._is_none(generator): - return symbolic_helper._unimplemented( - "_sample_dirichlet", "We are not able to export generator", self - ) - return g.at("_sample_dirichlet", self) return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) @_onnx_symbolic("aten::_standard_gamma") @_beartype.beartype def _standard_gamma(g: jit_utils.GraphContext, self, generator): - if symbolic_helper.is_caffe2_aten_fallback(): - if not symbolic_helper._is_none(generator): - return symbolic_helper._unimplemented( - "_standard_gamma", "not able to export generator", self - ) - return g.at("_standard_gamma", self) - return symbolic_helper._onnx_unsupported("_standard_gamma", self) @@ -1007,19 +989,6 @@ def embedding_bag( return symbolic_helper._onnx_unsupported( "embedding_bag with per_sample_weights" ) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "embedding_bag", - embedding_matrix, - indices, - offsets, - outputs=4, - scale_grad_by_freq_i=scale_grad_by_freq, - mode_i=mode, - sparse_i=sparse, - include_last_offset_i=include_last_offset, - padding_idx_i=padding_idx, - ) return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) @@ -1052,10 +1021,6 @@ def transpose(g: jit_utils.GraphContext, self, dim0, dim1): axes = list(range(rank)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return g.op("Transpose", self, perm_i=axes) - elif symbolic_helper.is_caffe2_aten_fallback(): - # if we don't have dim information we cannot - # output a permute so use ATen instead - return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1) else: raise errors.SymbolicValueError( "Unsupported: ONNX export of transpose for tensor of unknown rank.", @@ -2927,16 +2892,6 @@ def layer_norm( eps: float, cudnn_enable: bool, ) -> _C.Value: - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "layer_norm", - input, - weight, - bias, - normalized_shape_i=normalized_shape, - eps_f=eps, - cudnn_enable_i=cudnn_enable, - ) normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) return normalized @@ -3043,8 +2998,6 @@ def instance_norm( @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unfold(g: jit_utils.GraphContext, input, dimension, size, step): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizes = symbolic_helper._get_tensor_sizes(input) # FIXME(justinchuby): Get rid of the try catch here to improve readability try: @@ -3119,9 +3072,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] - if symbolic_helper.is_caffe2_aten_fallback(): - args = [self] + indices_list + [values, accumulate] - return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") @@ -3136,16 +3086,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum @_beartype.beartype def index_fill(g: jit_utils.GraphContext, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "index_fill", - self, - index, - value, - overload_name="int_Scalar", - dim_i=dim_value, - ) - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -3160,8 +3100,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value): @_beartype.beartype def index_copy(g: jit_utils.GraphContext, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -3220,10 +3158,6 @@ def type_as(g: jit_utils.GraphContext, self, other): to_i=other_dtype.onnx_type(), ) - if symbolic_helper.is_caffe2_aten_fallback(): - # We don't know the type of other, bail by emitting ATen - return g.at("type_as", self, other) - raise errors.SymbolicValueError( "Unsupported: ONNX export of type_as for tensor " "of unknown dtype. Please check if the dtype of the " @@ -3236,8 +3170,6 @@ def type_as(g: jit_utils.GraphContext, self, other): @symbolic_helper.parse_args("v", "v", "i", "f") @_beartype.beartype def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) cross = symbolic_helper._reducesum_helper( g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 ) @@ -3516,50 +3448,28 @@ def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("conv_tbc", input, weight, bias, pad_i=pad) - else: - # input must have 3 dimensions, see: - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 - # input = (time, batch, in_channels) - # weight = (kernel_width, in_channels, out_channels) - # bias = (out_channels,) - input = g.op("Transpose", input, perm_i=[1, 2, 0]) - weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) - conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) - return g.op("Transpose", conv, perm_i=[2, 0, 1]) + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) @_onnx_symbolic("aten::_unique") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "_unique", - input, - sorted_i=sorted, - return_inverse_i=return_inverse, - outputs=2, - ) - else: - return symbolic_helper._onnx_unsupported("_unique", input) + return symbolic_helper._onnx_unsupported("_unique", input) @_onnx_symbolic("aten::_unique2") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "_unique2", - input, - sorted_i=sorted, - return_inverse_i=return_inverse, - return_counts_i=return_counts, - outputs=3, - ) - symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) @@ -4973,11 +4883,8 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim): stop = g.op( "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 ) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.op("_caffe2::Range", stop) - else: - # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) - return arange(g, stop, 4, None, None, None) + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + return arange(g, stop, 4, None, None, None) @_onnx_symbolic("aten::detach") @@ -5543,9 +5450,6 @@ def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): @_onnx_symbolic("aten::arange") @_beartype.beartype def arange(g: jit_utils.GraphContext, *args): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("arange", *args) - @_beartype.beartype def _get_arange_dtype(dtype): dtype = symbolic_helper._maybe_get_const(dtype, "i") @@ -5665,9 +5569,6 @@ def masked_fill_(g: jit_utils.GraphContext, self, mask, value): @_onnx_symbolic("aten::index") @_beartype.beartype def index(g: jit_utils.GraphContext, self, index): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index", self, index, overload_name="Tensor") - if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: @@ -6083,17 +5984,6 @@ def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "no def group_norm( g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled ): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "group_norm", - input, - weight, - bias, - num_groups_i=num_groups, - eps_f=eps, - cudnn_enabled_i=cudnn_enabled, - ) - channel_size = symbolic_helper._get_tensor_dim_size(input, 1) if channel_size is not None: assert channel_size % num_groups == 0 @@ -6169,9 +6059,6 @@ def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): norm_v = norm(g, weight_v, 2, axes, 1) div = g.op("Div", weight_v, norm_v) return g.op("Mul", div, weight_g) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("_weight_norm", weight_v, weight_g, dim_i=dim) - raise errors.SymbolicValueError( "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", weight_v, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 0d02fabd1beb5..870a599aebce2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -11,7 +11,6 @@ import inspect import io import re -import textwrap import typing import warnings from typing import ( @@ -681,27 +680,6 @@ def _optimize_graph( _C._jit_pass_onnx_unpack_quantized_weights( graph, params_dict, symbolic_helper.is_caffe2_aten_fallback() ) - if symbolic_helper.is_caffe2_aten_fallback(): - # Insert permutes before and after each conv op to ensure correct order. - _C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict) - - # Find consecutive permutes that are no-ops and remove them. - _C._jit_pass_custom_pattern_based_rewrite_graph( - textwrap.dedent( - """\ - graph(%Pi): - %Pq = quantized::nhwc2nchw(%Pi) - %Pr = quantized::nchw2nhwc(%Pq) - return (%Pr)""" - ), - textwrap.dedent( - """\ - graph(%Ri): - return (%Ri)""" - ), - graph, - ) - # onnx only supports tensors, so we turn all out number types into tensors _C._jit_pass_erase_number_types(graph) if GLOBALS.onnx_shape_inference: @@ -734,18 +712,9 @@ def _optimize_graph( graph = _C._jit_pass_canonicalize(graph) _C._jit_pass_lint(graph) if GLOBALS.onnx_shape_inference: - try: - _C._jit_pass_onnx_graph_shape_type_inference( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - except RuntimeError as exc: - if ( - _C_onnx._CAFFE2_ATEN_FALLBACK - and exc.args[0] - == "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!" - ): - # Caffe2 builds can have UNKNOWN_SCALAR for some tensors - pass + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) return graph @@ -783,17 +752,6 @@ def warn_on_static_input_change(input_states): @_beartype.beartype def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" - if ( - operator_export_type is not operator_export_type.ONNX - and _C_onnx._CAFFE2_ATEN_FALLBACK - ): - if arg_value is True: - warnings.warn( - f"'{arg_name}' can be set to True only when 'operator_export_type' is " - "`ONNX`. Since 'operator_export_type' is not set to 'ONNX', " - f"'{arg_name}' argument will be ignored." - ) - arg_value = False return arg_value @@ -1298,18 +1256,9 @@ def _model_to_graph( _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) if GLOBALS.onnx_shape_inference: - try: - _C._jit_pass_onnx_graph_shape_type_inference( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - except RuntimeError as exc: - if ( - _C_onnx._CAFFE2_ATEN_FALLBACK - and exc.args[0] - == "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!" - ): - # Caffe2 builds can have UNKNOWN_SCALAR for some tensors - pass + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) @@ -1612,15 +1561,6 @@ def _export( if export_type is None: export_type = _exporter_states.ExportTypes.PROTOBUF_FILE - # Discussed deprecation with Nikita Shulga and Sergii Dymchenko from Meta - if _C_onnx._CAFFE2_ATEN_FALLBACK: - warnings.warn( - "Caffe2 ONNX exporter is deprecated in version 2.0 and will be " - "removed in 2.2. Please use PyTorch 2.1 or older for this capability.", - category=FutureWarning, - stacklevel=2, - ) - if isinstance(model, torch.nn.DataParallel): raise ValueError( "torch.nn.DataParallel is not supported by ONNX " @@ -1655,10 +1595,7 @@ def _export( "no local function support. " ) if not operator_export_type: - if _C_onnx._CAFFE2_ATEN_FALLBACK: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - else: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX + operator_export_type = _C_onnx.OperatorExportTypes.ONNX # By default, training=TrainingMode.EVAL, # which is good because running a model in training mode could result in @@ -1904,21 +1841,12 @@ def _should_aten_fallback( is_aten_fallback_export = ( operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK ) - is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK if not name.startswith("aten::"): return False - if is_caffe2_build: - if ( - is_onnx_aten_export or is_aten_fallback_export - ) and not is_exportable_aten_op: - return True - else: - if is_onnx_aten_export or ( - is_aten_fallback_export and not is_exportable_aten_op - ): - return True + if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): + return True return False @@ -1968,7 +1896,7 @@ def wrapper(graph_context: jit_utils.GraphContext, *args, **kwargs): def _get_aten_op_overload_name(n: _C.Node) -> str: # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds schema = n.schema() - if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback(): + if not schema.startswith("aten::"): return "" return _C.parse_schema(schema).overload_name @@ -2032,14 +1960,7 @@ def _run_symbolic_function( ) try: - # Caffe2-specific: Quantized op symbolics are registered for opset 9 only. - if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9: - symbolic_caffe2.register_quantized_ops("caffe2", opset_version) - - if namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback(): - domain = "caffe2" - else: - domain = namespace + domain = namespace symbolic_function_name = f"{domain}::{op_name}" symbolic_function_group = registration.registry.get_function_group( @@ -2073,10 +1994,7 @@ def _run_symbolic_function( except RuntimeError: if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: return None - elif ( - operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - and not symbolic_helper.is_caffe2_aten_fallback() - ): + elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` attrs = { k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index 95ed873bf6335..38a23893a8ba5 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -633,10 +633,7 @@ def _onnx_graph_from_model( utils._setup_trace_module_map(model, export_modules_as_functions) if not operator_export_type: - if _C_onnx._CAFFE2_ATEN_FALLBACK: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - else: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX + operator_export_type = _C_onnx.OperatorExportTypes.ONNX GLOBALS.export_onnx_opset_version = opset_version GLOBALS.operator_export_type = operator_export_type diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 0c27032b9871b..f7a5016bd8f41 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -329,14 +329,6 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if not torch.onnx._CAFFE2_ATEN_FALLBACK: - raise unittest.SkipTest(reason) - else: - fn(*args, **kwargs) - return wrapper - def withQNNPACKBackend(fn): # TODO(future PR): consider combining with skipIfNoQNNPACK, # will require testing of existing callsites diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bb5f3fa8e3308..2d5ea4a6c64ff 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1252,8 +1252,6 @@ def TemporaryDirectoryName(suffix=None): TEST_Z3 = _check_module_exists('z3') -BUILD_WITH_CAFFE2 = torch.onnx._CAFFE2_ATEN_FALLBACK - def split_if_not_empty(x: str): return x.split(",") if len(x) != 0 else [] @@ -1886,19 +1884,6 @@ def skipIfNotRegistered(op_name, message): """ return unittest.skip("Pytorch is compiled without Caffe2") -def _decide_skip_caffe2(expect_caffe2, reason): - def skip_dec(func): - @wraps(func) - def wrapper(self): - if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2: - raise unittest.SkipTest(reason) - return func(self) - return wrapper - return skip_dec - -skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2") -skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available") - def skipIfNoSciPy(fn): @wraps(fn) def wrapper(*args, **kwargs): From 1fd7496ab2e66ac116a801d9aef54915230dbe44 Mon Sep 17 00:00:00 2001 From: Jun Luo Date: Mon, 17 Jun 2024 21:58:46 +0000 Subject: [PATCH 07/63] [MTIA] Fix synchronize API (#128714) Reviewed By: fenypatel99 Differential Revision: D58590313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128714 Approved by: https://github.com/aaronenyeshi --- torch/csrc/mtia/Module.cpp | 2 +- torch/mtia/__init__.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 84cc11f718759..63cfae1972552 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -56,7 +56,7 @@ void initModule(PyObject* module) { return at::detail::getMTIAHooks().getCurrentStream(device_index); }); - m.def("_mtia_deviceSynchronize", [](c10::DeviceIndex device_index) { + m.def("_mtia_deviceSynchronize", []() { torch::utils::device_lazy_init(at::kMTIA); at::detail::getMTIAHooks().deviceSynchronize( at::detail::getMTIAHooks().getCurrentDevice()); diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index f9554a9bcb277..1bd7d2a9b7c6f 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -107,9 +107,10 @@ def is_available() -> bool: return device_count() > 0 -def synchronize() -> None: +def synchronize(device: Optional[_device_t] = None) -> None: r"""Waits for all jobs in all streams on a MTIA device to complete.""" - return torch._C._mtia_deviceSynchronize() + with torch.mtia.device(device): + return torch._C._mtia_deviceSynchronize() def device_count() -> int: From 7baf32b5e7440cb6c32b6ecf5dad0454bff39794 Mon Sep 17 00:00:00 2001 From: Shengbao Zheng Date: Mon, 17 Jun 2024 22:07:40 +0000 Subject: [PATCH 08/63] [c10d] fix p2p group commsplit (#128803) Summary: For PointToPoint(sendrecv), the deviceId is lower_rank:higher_rank. This means a p2p group cannot be created through commSplit since it cannot find a parent. Fix this by using the right device key of current rank. Differential Revision: D58631639 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128803 Approved by: https://github.com/shuqiangzhang --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index e7699b5524514..d293c4d470b83 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2118,7 +2118,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( // Find a valid, healthy communicator to split from if possible. std::lock_guard lock(options_->split_from->mutex_); auto& other_comms = options_->split_from->devNCCLCommMap_; - auto dit = other_comms.find(deviceKey); + auto dit = other_comms.find(getKeyFromDevice(device)); if (dit != other_comms.end()) { auto& parentComm = dit->second; if (parentComm != nullptr && !parentComm->isAborted()) { From 1835e3beab7e6e019b2a61137779297bfc3852ae Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 17 Jun 2024 22:20:33 +0000 Subject: [PATCH 09/63] Fix the inductor ci (#128879) Fix the torchbench+inductor ci on trunk due to recent upgrade to numpy 2.0.0rc1. We have to remove DALLE2_pytorch model, since it depends on embedding-reader, which is not compatible with numpy>2: https://github.com/rom1504/embedding-reader/blob/main/requirements.txt#L3 Fixes #128845 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128879 Approved by: https://github.com/eellison --- .ci/pytorch/common_utils.sh | 2 +- .ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_lstm.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_mlstm.sh | 2 +- .github/ci_commit_pins/torchbench.txt | 2 +- benchmarks/dynamo/Makefile | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 2f03e8c4255e6..91c2d1b5dd3bd 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -191,7 +191,7 @@ function clone_pytorch_xla() { function checkout_install_torchbench() { local commit commit=$(get_pinned_commit torchbench) - git clone https://github.com/eellison/benchmark torchbench + git clone https://github.com/pytorch/benchmark torchbench pushd torchbench git checkout "$commit" diff --git a/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh b/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh index 70c4be781e288..72496691286e4 100644 --- a/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh +++ b/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh @@ -9,7 +9,7 @@ test_cpu_speed_mini_sequence_labeler () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh index 9633f7dfdfae3..1693b00f17e2d 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_cudnn_lstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh index b8548f8206a9c..2e26b9902b868 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_lstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh index e224dd27f74f4..a0617530194a1 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_mlstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.github/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt index 8779f5b61aa9b..4a60ff3d38d40 100644 --- a/.github/ci_commit_pins/torchbench.txt +++ b/.github/ci_commit_pins/torchbench.txt @@ -1 +1 @@ -pin_yolo_dep +0dab1dd97709096e8129f8a08115ee83f64f2194 diff --git a/benchmarks/dynamo/Makefile b/benchmarks/dynamo/Makefile index dacddec4b2919..720542f28608b 100644 --- a/benchmarks/dynamo/Makefile +++ b/benchmarks/dynamo/Makefile @@ -10,7 +10,7 @@ clone-deps: && (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \ && (test -e FBGEMM || git clone --recursive https://github.com/pytorch/FBGEMM) \ && (test -e torchrec || git clone --recursive https://github.com/pytorch/torchrec) \ - && (test -e torchbenchmark || git clone --recursive https://github.com/eellison/benchmark torchbenchmark) \ + && (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \ ) pull-deps: clone-deps From 3b8c9b8ab11682b958dfe002d7106d94cf75ef7a Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 17 Jun 2024 22:51:12 +0000 Subject: [PATCH 10/63] [Docker Release] Test if pytorch was compiled with CUDA before pushing to repo (#128852) Related to: https://github.com/pytorch/pytorch/issues/125879 Would check if we are compiled with CUDA before publishing CUDA Docker nightly image Test ``` #18 [conda-installs 5/5] RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo "Is torch compiled with cuda: ${IS_CUDA}"; if test "${IS_CUDA}" != "True" -a ! -z "12.4.0"; then exit 1; fi #18 1.656 Is torch compiled with cuda: False #18 ERROR: process "/bin/sh -c IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo \"Is torch compiled with cuda: ${IS_CUDA}\"; if test \"${IS_CUDA}\" != \"True\" -a ! -z \"${CUDA_VERSION}\"; then \texit 1; fi" did not complete successfully: exit code: 1 ------ > [conda-installs 5/5] RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo "Is torch compiled with cuda: ${IS_CUDA}"; if test "${IS_CUDA}" != "True" -a ! -z "12.4.0"; then exit 1; fi: 1.656 Is torch compiled with cuda: False ------ Dockerfile:80 -------------------- 79 | RUN /opt/conda/bin/pip install torchelastic 80 | >>> RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())');\ 81 | >>> echo "Is torch compiled with cuda: ${IS_CUDA}"; \ 82 | >>> if test "${IS_CUDA}" != "True" -a ! -z "${CUDA_VERSION}"; then \ 83 | >>> exit 1; \ 84 | >>> fi 85 | -------------------- ERROR: failed to solve: process "/bin/sh -c IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo \"Is torch compiled with cuda: ${IS_CUDA}\"; if test \"${IS_CUDA}\" != \"True\" -a ! -z \"${CUDA_VERSION}\"; then \texit 1; fi" did not complete successfully: exit code: 1 (base) [ec2-user@ip-172-30-2-248 pytorch]$ docker buildx build --progress=plain --platform="linux/amd64" --target official -t ghcr.io/pytorch/pytorch:2.5.0.dev20240617-cuda12.4-cudnn9-devel --build-arg BASE_IMAGE=nvidia/cuda:12.4.0-devel-ubuntu22.04 --build-arg PYTHON_VERSION=3.11 --build-arg CUDA_VERSION= --build-arg CUDA_CHANNEL=nvidia --build-arg PYTORCH_VERSION=2.5.0.dev20240617 --build-arg INSTALL_CHANNEL=pytorch --build-arg TRITON_VERSION= --build-arg CMAKE_VARS="" . #0 building with "default" instance using docker driver ``` Please note looks like we are installing from pytorch rather then nighlty channel on PR hence cuda 12.4 is failing since its not in pytorch channel yet: https://github.com/pytorch/pytorch/actions/runs/9555354734/job/26338476741?pr=128852 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128852 Approved by: https://github.com/malfet --- Dockerfile | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Dockerfile b/Dockerfile index ae88187972ef2..b751c64a8439e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -77,6 +77,11 @@ RUN case ${TARGETPLATFORM} in \ esac && \ /opt/conda/bin/conda clean -ya RUN /opt/conda/bin/pip install torchelastic +RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); \ + echo "Is torch compiled with cuda: ${IS_CUDA}"; \ + if test "${IS_CUDA}" != "True" -a ! -z "${CUDA_VERSION}"; then \ + exit 1; \ + fi FROM ${BASE_IMAGE} as official ARG PYTORCH_VERSION From 8415a4ba98f337e6d21a3c0b026917c03a19e955 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Mon, 17 Jun 2024 22:52:25 +0000 Subject: [PATCH 11/63] Back out "[ROCm] TunableOp for gemm_and_bias (#128143)" (#128815) Summary: Original commit changeset: 35083f04fdae Original Phabricator Diff: D58501726 This PR is bringing a large numerical gap. e.g. for 256 x 4096 x 4096 GEMM, if we enable tunable op + DISABLE_ADDMM_HIP_LT=0, the results are way off. Differential Revision: D58660832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128815 Approved by: https://github.com/mxz297, https://github.com/eqy, https://github.com/malfet --- aten/src/ATen/cuda/tunable/GemmCommon.h | 76 +----------- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 133 ++++----------------- aten/src/ATen/cuda/tunable/Tunable.cpp | 4 +- aten/src/ATen/cuda/tunable/TunableGemm.h | 68 +---------- aten/src/ATen/native/cuda/Blas.cpp | 63 ++-------- 5 files changed, 38 insertions(+), 306 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 64a482bc2781b..a2c7c734a551f 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -81,8 +81,7 @@ struct GemmParams : OpParams { } std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); - return val; + return c10::str(transa, transb, "_", m, "_", n, "_", k); } size_t GetSize(bool duplicate_inputs) const { @@ -144,73 +143,6 @@ struct GemmParams : OpParams { bool duplicate_inputs_; }; -template -struct GemmAndBiasParams : OpParams { - std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); - return val; - } - - size_t GetSize(bool duplicate_inputs) const { - size_t size = sizeof(T) * ldc * n; - if (duplicate_inputs) { - size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); - size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); - } - return size; - } - - GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const { - GemmAndBiasParams* copy = new GemmAndBiasParams; - *copy = *this; - c10::DeviceIndex device = 0; - AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); - size_t c_size = ldc * n * sizeof(T); - copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); - AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( - copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); - if (duplicate_inputs) { - size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); - size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); - copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); - copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); - copy->duplicate_inputs_ = true; - } - return copy; - } - - // only call on object returned by DeepCopy - void Delete() { - c10::cuda::CUDACachingAllocator::raw_delete(c); - if (duplicate_inputs_) { - c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); - c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); - } - } - - TuningStatus NumericalCheck(GemmAndBiasParams *other) { - auto c_dtype = c10::CppTypeToScalarType::value; - return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; - } - - char transa; - char transb; - int64_t m; - int64_t n; - int64_t k; - at::opmath_type alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - T* c; - int64_t ldc; - const T* bias; - at::cuda::blas::GEMMAndBiasActivationEpilogue activation; -private: - bool duplicate_inputs_; -}; - template struct GemmStridedBatchedParams : OpParams { GemmStridedBatchedParams() { @@ -218,8 +150,7 @@ struct GemmStridedBatchedParams : OpParams { } std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); - return val; + return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); } size_t GetSize(bool duplicate_inputs) const { @@ -292,8 +223,7 @@ struct ScaledGemmParams : OpParams { } std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); - return val; + return c10::str(transa, transb, "_", m, "_", n, "_", k); } size_t GetSize(bool duplicate_inputs) const { diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index ab1525bef6522..a9c420700275e 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -25,35 +25,35 @@ namespace at::cuda::tunable { template -constexpr hipblasDatatype_t HipDataTypeFor(); +constexpr hipblasDatatype_t HipBlasDataTypeFor(); template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_32F; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_32F; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_16F; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_16F; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_16BF; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_16B; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_64F; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_64F; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIP_R_8F_E4M3_FNUZ; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIP_R_8F_E5M2_FNUZ; } @@ -62,11 +62,6 @@ int GetBatchFromParams(const GemmParams* params) { return 1; } -template -int GetBatchFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetBatchFromParams(const GemmStridedBatchedParams* params) { return params->batch; @@ -82,11 +77,6 @@ int GetStrideAFromParams(const GemmParams* params) { return 1; } -template -int GetStrideAFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetStrideAFromParams(const GemmStridedBatchedParams* params) { return params->stride_a; @@ -102,11 +92,6 @@ int GetStrideBFromParams(const GemmParams* params) { return 1; } -template -int GetStrideBFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetStrideBFromParams(const GemmStridedBatchedParams* params) { return params->stride_b; @@ -122,11 +107,6 @@ int GetStrideCFromParams(const GemmParams* params) { return 1; } -template -int GetStrideCFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetStrideCFromParams(const GemmStridedBatchedParams* params) { return params->stride_c; @@ -142,11 +122,6 @@ float GetAlphaFromParams(const GemmParams* params) { return params->alpha; } -template -float GetAlphaFromParams(const GemmAndBiasParams* params) { - return params->alpha; -} - template float GetAlphaFromParams(const GemmStridedBatchedParams* params) { return params->alpha; @@ -162,11 +137,6 @@ float GetBetaFromParams(const GemmParams* params) { return params->beta; } -template -float GetBetaFromParams(const GemmAndBiasParams* params) { - return 0.0; -} - template float GetBetaFromParams(const GemmStridedBatchedParams* params) { return params->beta; @@ -182,11 +152,6 @@ const void* GetAScalePointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { - return nullptr; -} - template const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -202,11 +167,6 @@ const void* GetBScalePointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { - return nullptr; -} - template const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -222,11 +182,6 @@ const void* GetDScalePointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { - return nullptr; -} - template const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -242,11 +197,6 @@ const void* GetBiasPointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { - return params->bias; -} - template const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -262,11 +212,6 @@ hipDataType GetBiasTypeFromParams(const GemmParams* params) { return HIP_R_32F; } -template -hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { - return HipDataTypeFor(); -} - template hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { return HIP_R_32F; @@ -277,26 +222,6 @@ hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); } -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { - return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; -} - -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { - return params->activation; -} - -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { - return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; -} - -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { - return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; -} - static hipblasOperation_t _hipblasOpFromChar(char op) { switch (op) { case 'n': @@ -402,9 +327,9 @@ class HipblasltGemmOp : public Callable { TuningStatus Call(const ParamsT* params) override { hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); - auto a_datatype = HipDataTypeFor(); - auto b_datatype = HipDataTypeFor(); - auto in_out_datatype = HipDataTypeFor(); + auto a_datatype = HipBlasDataTypeFor(); + auto b_datatype = HipBlasDataTypeFor(); + auto in_out_datatype = HipBlasDataTypeFor(); auto opa = _hipblasOpFromChar(params->transa); auto opb = _hipblasOpFromChar(params->transb); @@ -460,22 +385,13 @@ class HipblasltGemmOp : public Callable { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); - } - const void* bias_ptr = GetBiasPointerFromParams(params); - auto bias_datatype = GetBiasTypeFromParams(params); - if (bias_ptr) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); - auto activation = GetActivationFromParams(params); - if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); - } - else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); - } - else { + const void* bias_ptr = GetBiasPointerFromParams(params); + auto bias_datatype = GetBiasTypeFromParams(params); + if (bias_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); } } @@ -544,9 +460,9 @@ template (); - auto b_datatype = HipDataTypeFor(); - auto in_out_datatype = HipDataTypeFor(); + auto a_datatype = HipBlasDataTypeFor(); + auto b_datatype = HipBlasDataTypeFor(); + auto in_out_datatype = HipBlasDataTypeFor(); std::vector heuristic_result; hipblasLtHandle_t handle; @@ -589,11 +505,6 @@ auto GetHipBlasLtGemmTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } -template -auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(); -} - template auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index d3d2333323e7f..fc27fab77d790 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -376,8 +376,8 @@ void TuningContext::EnableNumericsCheck(bool value) { bool TuningContext::IsNumericsCheckEnabled() const { static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); - if (env != nullptr && strcmp(env, "1") == 0) { - return true; + if (env != nullptr && strcmp(env, "0") == 0) { + return false; } return numerics_check_enable_; } diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 6b02e26ade4d7..53e6154120c92 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -48,28 +48,6 @@ class DefaultGemmOp : public Callable> { } }; -static bool _transposeBoolFromChar(char op) { - return op == 't' || op == 'T'; -} - -template -class DefaultGemmAndBiasOp : public Callable> { - public: - TuningStatus Call(const GemmAndBiasParams* params) override { - at::cuda::blas::gemm_and_bias( - _transposeBoolFromChar(params->transa), - _transposeBoolFromChar(params->transb), - params->m, params->n, params->k, - params->alpha, - params->a, params->lda, - params->b, params->ldb, - params->bias, - params->c, params->ldc, - params->activation); - return OK; - } -}; - template class DefaultGemmStridedBatchedOp : public Callable> { public: @@ -287,45 +265,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { } std::string Signature() override { - static std::string val = c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; - } -}; - -template -class GemmAndBiasTunableOp : public TunableOp, StreamTimer> { - public: - GemmAndBiasTunableOp() { - this->RegisterOp(std::string("Default"), std::make_unique>()); - - auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); - -#if defined(USE_ROCM) - bool rocm_validators = false; - - static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { - rocm_validators = true; - // disallow tuning of hipblaslt with c10::complex - if constexpr ( - !std::is_same_v> && - !std::is_same_v>) { - for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { - this->RegisterOp(std::move(name), std::move(op)); - } - } - AddHipblasltValidator(); - } - - if (rocm_validators) { - AddRocmValidator(); - } -#endif - } - - std::string Signature() override { - static std::string val = c10::str("GemmAndBiasTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; + return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; @@ -368,8 +308,7 @@ class GemmStridedBatchedTunableOp : public TunableOp } std::string Signature() override { - static std::string val = c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; + return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; @@ -391,12 +330,11 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> } std::string Signature() override { - static std::string val = c10::str("ScaledGemmTunableOp", + return c10::str("ScaledGemmTunableOp", "_", TypeName(AT{}), "_", TypeName(BT{}), "_", TypeName(CT{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 728f210b66ed0..ff8eb60b290ba 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -175,6 +175,12 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa static bool getDisableAddmmCudaLt() { static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); #ifdef USE_ROCM + // if we enable tunable op, it'll take priority over just hipblaslt (heuristics) + // note the current tunable op is not the hipblaslt path (gemm_and_bias) + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + return true; + } // allow both CUDA and HIP env var names for ROCm builds // also, current default for ROCm builds is disable by default if (env_value == nullptr) { @@ -208,49 +214,6 @@ static bool isSupportedHipLtROCmArch(int index) { } #endif -template -static void launchTunableGemmAndBias(cublasCommonArgs &args, Tensor& result, const Tensor& self, bool is_rocm) { - bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); - bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); - at::cuda::tunable::GemmAndBiasParams params; - params.transa = args.transa; - params.transb = args.transb; - params.m = args.m; - params.n = args.n; - params.k = args.k; - params.a = args.mata->const_data_ptr(); - params.lda = args.lda; - params.b = args.matb->const_data_ptr(); - params.ldb = args.ldb; - if (is_rocm) { - params.bias = (&result != &self) ? self.const_data_ptr() : nullptr; - } - else { - params.bias = self.const_data_ptr(); - } - params.c = args.result->data_ptr(); - params.ldc = args.result_ld; - if (transa_ && transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else if (transa_ && !transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else if (!transa_ && transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else if (!transa_ && !transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else { - TORCH_CHECK(false, "unreachable"); - } -} - Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { // Make sure to keep addmm_cuda below in sync with this code; it // preflights a check to try to avoid actually needing to call @@ -378,11 +341,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { - launchTunableGemmAndBias(args, result, self, true); - } - else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -401,7 +359,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_to_gemm_and_blas_arg(activation) ); - }}); + }); #else auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) @@ -419,11 +377,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { - launchTunableGemmAndBias(args, result, self, false); - } - else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -440,7 +393,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_epilogue ); - }}); + }); #endif } else { From 95b5ea9cdef67d211ec2b1e7242100c7e2fad52a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 13 Jun 2024 12:25:02 -0700 Subject: [PATCH 12/63] Add mark_unbacked (#128638) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128638 Approved by: https://github.com/IvanKobzarev --- torch/_dynamo/decorators.py | 24 ++++++++++++++++++++++++ torch/_dynamo/variables/builder.py | 5 ++++- torch/fx/experimental/symbolic_shapes.py | 8 ++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index ec25d06281fc0..79bbb493865c8 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -184,6 +184,30 @@ class directly; instead, use :func:`mark_dynamic`. max: int +@forbid_in_graph +def mark_unbacked(t, index): + """ + Mark a tensor as having an unbacked dim. This changes the semantics of operations, + we will always report the size does not equal zero/one, we will turn asserts + on this index into runtime asserts, and if you try to get the real value we will + raise an exception. In other words, we will treat this dimension as if it was + data dependent (we do not know anything about its value.) + """ + # You could have copied the mark_dynamic behavior but I'm not convinced + # it's what you want + assert not is_traceable_wrapper_subclass(t), "not implemented yet" + + if isinstance(index, int): + if not hasattr(t, "_dynamo_unbacked_indices"): + t._dynamo_unbacked_indices = set() + t._dynamo_unbacked_indices.add(index) + return + + assert isinstance(index, (list, tuple)) + for i in index: + mark_unbacked(t, i) + + @forbid_in_graph def mark_dynamic(t, index, *, min=None, max=None): """ diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f36f53b6537aa..2097690b88b03 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2211,6 +2211,7 @@ def update_dim2constraint(dim, constraint_range, debug_name): constraint_dims = [] for i in range(e.dim()): # NB: mark dynamic has precedence over static + marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set()) marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) marked_static = i in getattr(e, "_dynamo_static_indices", set()) @@ -2262,7 +2263,9 @@ def update_dim2constraint(dim, constraint_range, debug_name): constraint_dims.append(constraint_dim) # Now, figure out if the dim is dynamic/duck/static - if ( + if marked_unbacked: + dynamic = DimDynamic.SIZE_LIKE_UNBACKED + elif ( constraint_dim is not None or marked_dynamic or marked_weak_dynamic diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index fcfe7d9667daf..2994853408465 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1017,6 +1017,8 @@ class DimDynamic(Enum): DUCK = 1 # Treat the dimension statically based on its hint STATIC = 2 + # Treat the dimension as a size-like unbacked + SIZE_LIKE_UNBACKED = 3 # NB: These constraints affect both clients and backends: given some @@ -3433,6 +3435,12 @@ def create_symbol( ) -> "sympy.Expr": """Create a new symbol which is tracked by this ShapeEnv """ + if dynamic_dim is DimDynamic.SIZE_LIKE_UNBACKED: + r = self.create_unbacked_symint().node.expr + self._constrain_range_for_size(r) + # TODO: maybe put the hint somewhere + return r + # check if constraint_dim is actually static integer if isinstance(constraint_dim, StrictMinMaxConstraint) and constraint_dim.vr.lower == constraint_dim.vr.upper: dynamic_dim = DimDynamic.STATIC From b70440f0a7ff031decaf994c15474148007b5aa5 Mon Sep 17 00:00:00 2001 From: awayzjj Date: Mon, 17 Jun 2024 23:42:40 +0000 Subject: [PATCH 13/63] Document the torch.cuda.profiler.profile function (#128216) Fixes https://github.com/pytorch/pytorch/issues/127901 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128216 Approved by: https://github.com/malfet, https://github.com/eqy --- torch/cuda/profiler.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch/cuda/profiler.py b/torch/cuda/profiler.py index f95aae0f85a7d..65269414f55a3 100644 --- a/torch/cuda/profiler.py +++ b/torch/cuda/profiler.py @@ -65,6 +65,18 @@ def stop(): @contextlib.contextmanager def profile(): + """ + Enable profiling. + + Context Manager to enabling profile collection by the active profiling tool from CUDA backend. + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> model = torch.nn.Linear(20, 30).cuda() + >>> inputs = torch.randn(128, 20).cuda() + >>> with torch.cuda.profiler.profile() as prof: + ... model(inputs) + """ try: start() yield From 11ff5345d249c27950a06a347cc70aa0047dd46e Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 17 Jun 2024 12:27:40 -0700 Subject: [PATCH 14/63] Changed colored logging to only be turned on if printing to interactive terminal (#128874) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128874 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 10 +--------- torch/fx/_utils.py | 4 ++++ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index c47552fc1b2a7..128b1fbbe4ecc 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -49,12 +49,7 @@ unsupported, xfailIfPy312, ) -from torch._dynamo.utils import ( - CompileProfiler, - counters, - ifdynstaticdefault, - strip_color_from_string, -) +from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault from torch._inductor.utils import run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize @@ -748,7 +743,6 @@ def f(x, y, z, n): post_grad_graphs = "\n".join( log_stream.getvalue().strip().split("\n")[3:] ).strip() - post_grad_graphs = strip_color_from_string(post_grad_graphs) # Check the graph under static shapes if torch._dynamo.config.assume_static_by_default: @@ -811,7 +805,6 @@ def f(x, y, z, n): post_grad_graphs = "\n".join( log_stream.getvalue().strip().split("\n")[3:] ).strip() - post_grad_graphs = strip_color_from_string(post_grad_graphs) self.assertExpectedInline( post_grad_graphs, """\ @@ -904,7 +897,6 @@ def f(x, y, z, n): post_grad_graphs = "\n".join( log_stream.getvalue().strip().split("\n")[3:] ).strip() - post_grad_graphs = strip_color_from_string(post_grad_graphs) self.assertExpectedInline( post_grad_graphs, """\ diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index 36c831dfdee06..b27e1df553918 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import sys from typing import Dict, Optional import torch @@ -20,6 +21,9 @@ def format_name(): if "print_output" not in kwargs: kwargs["print_output"] = False + if "colored" in kwargs and not sys.stdout.isatty(): + kwargs["colored"] = False + return LazyString( lambda: _format_graph_code( f"===== {format_name()} =====\n", From beb29836cd1e5b30df8c5a3c1122c926ef4021bc Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 15 Jun 2024 17:15:38 -0700 Subject: [PATCH 15/63] [Inductor][CPP] Add Min/Max with VecMask (#126841) **Summary** Fix issue: https://github.com/pytorch/pytorch/issues/126824 which is missing the support of `min/max` with `VecMask`. **TestPlan** ``` python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_clamp_max_cpu_bool python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_clamp_min_cpu_bool ``` Co-authored-by: Isuru Fernando Pull Request resolved: https://github.com/pytorch/pytorch/pull/126841 Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_torchinductor_opinfo.py | 2 -- torch/_inductor/codegen/cpp.py | 31 +++++++++++++++++----- torch/_inductor/codegen/cpp_utils.py | 26 +++++++++++++++--- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 7998a3aff58d6..8c85e731e98c4 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -413,8 +413,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "addmv": {f16}, "argsort": {b8, f16, f32, f64, i32, i64}, "as_strided.partial_views": {f16}, - "clamp_max": {b8}, - "clamp_min": {b8}, "corrcoef": {f16}, "diff": {f16}, "einsum": {f16, i32}, diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 6b8574b9268ad..0b6dca7652651 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -64,7 +64,14 @@ OptimizationContext, ) -from .cpp_utils import cexpr, cexpr_index, DTYPE_TO_CPP, INDEX_TYPE, value_to_cpp +from .cpp_utils import ( + cexpr, + cexpr_index, + DTYPE_TO_CPP, + INDEX_TYPE, + unify_mask_base_type, + value_to_cpp, +) schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") @@ -1311,11 +1318,21 @@ def truncdiv(a, b): @staticmethod def minimum(a, b): - return f"at::vec::minimum({a}, {b})" + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} & {b_cast}" + else: + return f"at::vec::minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"at::vec::maximum({a}, {b})" + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} | {b_cast}" + else: + return f"at::vec::maximum({a}, {b})" @staticmethod def square(a): @@ -1326,10 +1343,10 @@ def where(a, b, c): assert isinstance(V.kernel, CppVecKernel) if b.dtype == torch.bool: assert c.dtype == torch.bool - blendv_a = f"{V.kernel._get_mask_cast(a, torch.float)}" - blendv_b = f"{V.kernel._get_mask_cast(b, torch.float)}" - blendv_c = f"{V.kernel._get_mask_cast(c, torch.float)}" - return f"decltype({b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" + blendv_a, blendv_b, blendv_c = unify_mask_base_type( + V.kernel.compute, (a, b, c) + ) + return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" else: return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 336837328a0e5..66f2dfb54aac0 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -4,7 +4,7 @@ import math from collections import namedtuple -from typing import Dict, List +from typing import Dict, List, Tuple from unittest.mock import patch import sympy @@ -12,10 +12,11 @@ import torch from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import ir -from ..utils import sympy_index_symbol_with_prefix +from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix from ..virtualized import V -from .common import ExprPrinter, Kernel +from .common import CSEVariable, ExprPrinter, Kernel + DTYPE_TO_CPP = { torch.float32: "float", @@ -421,3 +422,22 @@ def inner(index): return inner return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes] + + +def unify_mask_base_type( + buffer: IndentedBuffer, + vars: Tuple[CSEVariable, ...], + dtype=torch.float, +): + """ + Given list of cse variables, + Cast each to new mask base dtype and return casted cse variable. + """ + new_vars = ( + V.kernel.cse.generate( + buffer, + f"{V.kernel._get_mask_cast(var, dtype)}", + ) + for var in vars + ) + return new_vars From c35ffaf954ffdfc76aac24e9c503fb0e5d190722 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 15 Jun 2024 17:15:38 -0700 Subject: [PATCH 16/63] [Inductor][CPP] Add ne with VecMask (#126940) **Summary** Fix https://github.com/pytorch/pytorch/issues/126824#issuecomment-2125039161 which is missing the support of `ne` with `VecMask`. **Test Plan** ``` python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_ne_cpu_bool ``` Co-authored-by: Isuru Fernando Pull Request resolved: https://github.com/pytorch/pytorch/pull/126940 Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10 ghstack dependencies: #126841 --- aten/src/ATen/cpu/vec/vec_mask.h | 1 + test/inductor/test_torchinductor_opinfo.py | 1 - torch/_inductor/codegen/cpp.py | 9 +++++++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index 6b773c40ca8c9..ebec8d4a3e3c5 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -259,6 +259,7 @@ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b)) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b)) #undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL #undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 8c85e731e98c4..5f97c2f0fd712 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -431,7 +431,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "maximum": {b8}, "min.binary": {b8}, "minimum": {b8}, - "ne": {b8}, "new_empty_strided": {f16}, "nn.functional.adaptive_avg_pool3d": {f16}, "nn.functional.adaptive_max_pool1d": {f16, f32}, diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 0b6dca7652651..2c800b41adff5 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1114,8 +1114,13 @@ def eq(x, y): def ne(x, y): assert isinstance(V.kernel, CppVecKernel) assert isinstance(x, CppCSEVariable) - assert x.dtype is not None - return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" + if x.dtype == torch.bool: + assert y.dtype == torch.bool + x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) + return f"{x_cast} != {y_cast}" + else: + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" @staticmethod def lt(x, y): From fbc7559ceb372d88b55c96ef6984accbaa0ec3ec Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 18 Jun 2024 00:55:48 +0000 Subject: [PATCH 17/63] [custom ops] convert string type annotation to real type (#128809) Fixes #105157 Bug source: `from __future__ import annotations` converts type annotation to strings to make forwards references easier. However, existing custom ops do not consider strings to be valid types. Fix: We check if the argument and return type annotation is string type. If so, we try to use `eval` to convert it to a type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128809 Approved by: https://github.com/zou3519 --- .../test_infer_schema_annotation.py | 207 ++++++++++++++++++ torch/_library/infer_schema.py | 28 ++- 2 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 test/custom_operator/test_infer_schema_annotation.py diff --git a/test/custom_operator/test_infer_schema_annotation.py b/test/custom_operator/test_infer_schema_annotation.py new file mode 100644 index 0000000000000..9de44224f1c03 --- /dev/null +++ b/test/custom_operator/test_infer_schema_annotation.py @@ -0,0 +1,207 @@ +# Owner(s): ["module: pt2-dispatcher"] +from __future__ import annotations + +import typing +from typing import List, Optional, Sequence, Union # noqa: F401 + +import torch +import torch._custom_op.impl +from torch import Tensor, types +from torch.testing._internal.common_utils import run_tests, TestCase + + +mutates_args = {} + + +class TestInferSchemaWithAnnotation(TestCase): + def test_tensor(self): + def foo_op(x: torch.Tensor) -> torch.Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x.clone() + y + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor") + + def test_native_types(self): + def foo_op(x: int) -> int: + return x + + result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + self.assertEqual(result, "(SymInt x) -> SymInt") + + def foo_op_2(x: bool) -> bool: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(bool x) -> bool") + + def foo_op_3(x: str) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(str x) -> SymInt") + + def foo_op_4(x: float) -> float: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + self.assertEqual(result, "(float x) -> float") + + def test_torch_types(self): + def foo_op_1(x: torch.types.Number) -> torch.types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def foo_op_2(x: torch.dtype) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(ScalarType x) -> SymInt") + + def foo_op_3(x: torch.device) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(Device x) -> SymInt") + + def test_type_variants(self): + def foo_op_1(x: typing.Optional[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args) + self.assertEqual(result, "(SymInt? x) -> SymInt") + + def foo_op_2(x: typing.Sequence[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_3(x: typing.List[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_5(x: typing.Optional[typing.List[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def test_no_library_prefix(self): + def foo_op(x: Tensor) -> Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_2(x: Tensor) -> torch.Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_3(x: torch.Tensor) -> Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_4(x: List[int]) -> types.Number: + return x[0] + + result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> Scalar") + + def foo_op_5(x: Optional[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args) + self.assertEqual(result, "(SymInt? x) -> SymInt") + + def foo_op_6(x: Sequence[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_7(x: List[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_8(x: Optional[Sequence[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_8, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_9(x: Optional[List[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_9, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_10(x: Union[int, float, bool]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_10, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def foo_op_11(x: Union[int, bool, float]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_11, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def test_unsupported_annotation(self): + with self.assertRaisesRegex( + ValueError, + r"Unsupported type annotation D. It is not a type.", + ): + + def foo_op(x: D) -> Tensor: # noqa: F821 + return torch.Tensor(x) + + torch._custom_op.impl.infer_schema(foo_op, mutates_args) + + with self.assertRaisesRegex( + ValueError, + r"Unsupported type annotation E. It is not a type.", + ): + + def foo_op_2(x: Tensor) -> E: # noqa: F821 + return x + + torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 6305375e4433d..c4f7b8ee51e6c 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs import inspect import typing +from typing import List, Optional, Sequence, Union # noqa: F401 +import torch # noqa: F401 from .. import device, dtype, Tensor, types @@ -12,6 +14,9 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str: write custom ops in real life: - none of the outputs alias any of the inputs or each other. - only the args listed in mutates_args are being mutated. + - string type annotations "device, dtype, Tensor, types" without library specification + are assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + without library specification are assumed to be typing.*. Callers (e.g. the custom ops API) are responsible for checking these assumptions. """ @@ -22,6 +27,14 @@ def error_fn(what): f"infer_schema(func): {what} " f"Got func with signature {sig})" ) + def convert_type_string(annotation_type: str): + try: + return eval(annotation_type) + except Exception as e: + error_fn( + f"Unsupported type annotation {annotation_type}. It is not a type." + ) + params = [] seen_args = set() saw_kwarg_only_arg = False @@ -38,13 +51,19 @@ def error_fn(what): if param.annotation is inspect.Parameter.empty: error_fn(f"Parameter {name} must have a type annotation.") - if param.annotation not in SUPPORTED_PARAM_TYPES.keys(): + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type = param.annotation + if type(annotation_type) == str: + annotation_type = convert_type_string(annotation_type) + + if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): error_fn( f"Parameter {name} has unsupported type {param.annotation}. " f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." ) - schema_type = SUPPORTED_PARAM_TYPES[param.annotation] + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] if name in mutates_args: if not schema_type.startswith("Tensor"): error_fn( @@ -72,7 +91,10 @@ def error_fn(what): f"mutates_args should contain the names of all args that the " f"custom op mutates." ) - ret = parse_return(sig.return_annotation, error_fn) + return_annotation = sig.return_annotation + if type(return_annotation) == str: + return_annotation = convert_type_string(return_annotation) + ret = parse_return(return_annotation, error_fn) return f"({', '.join(params)}) -> {ret}" From 9e8443b56f5a83877803be3ba43f7941841904c9 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 18 Jun 2024 01:26:45 +0000 Subject: [PATCH 18/63] Remove dtype from gpt-fast micro benchmark experiments model name (#128789) Per comments on https://github.com/pytorch/test-infra/pull/5344, we already have a dtype column with the same information Pull Request resolved: https://github.com/pytorch/pytorch/pull/128789 Approved by: https://github.com/yanboliang --- benchmarks/gpt_fast/benchmark.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 16f3e55af17b0..1c09fe03a904a 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -76,7 +76,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"mlp_layer_norm_gelu_{dtype_str}", + "mlp_layer_norm_gelu", "flops_utilization", expected_flops_utilization, f"{flops_utilization:.02f}", @@ -113,7 +113,7 @@ def run_layer_norm(device: str = "cuda"): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"layer_norm_{dtype_str}", + "layer_norm", "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", @@ -156,7 +156,7 @@ def gather_gemv(W, score_idxs, x): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"gather_gemv_{dtype_str}", + "gather_gemv", "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", @@ -197,7 +197,7 @@ def gemv(W, x): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"gemv_{dtype_str}", + "gemv", "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", From e12fa93b8bb3b7b7148f6111577e454bd3251223 Mon Sep 17 00:00:00 2001 From: Fuzzkatt Date: Tue, 18 Jun 2024 02:00:01 +0000 Subject: [PATCH 19/63] add is_big_gpu(0) check to test_select_algorithm tests in tests/inductor/test_cuda_cpp_wrapper.py (#128652) In NVIDIA internal CI, on Jetson devices we are seeing this failure for `python test/inductor/test_cuda_cpp_wrapper.py -k test_addmm_cuda_cuda_wrapper -k test_linear_relu_cuda_cuda_wrapper`: ``` /usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:132: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. warnings.warn( W0613 20:57:17.722000 281473279256672 torch/_inductor/utils.py:902] [0/0] Not enough SMs to use max_autotune_gemm mode frames [('total', 1), ('ok', 1)] stats [('calls_captured', 2), ('unique_graphs', 1)] inductor [('extern_calls', 2), ('fxgraph_cache_miss', 1), ('pattern_matcher_count', 1), ('pattern_matcher_nodes', 1)] aot_autograd [('total', 1), ('ok', 1)] F ====================================================================== FAIL: test_linear_relu_cuda_cuda_wrapper (__main__.TestCudaWrapper) ---------------------------------------------------------------------- Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 2759, in wrapper method(*args, **kwargs) File "/opt/pytorch/pytorch/test/inductor/test_torchinductor.py", line 9818, in new_test return value(self) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/pytorch/pytorch/test/inductor/test_cuda_cpp_wrapper.py", line 152, in fn _, code = test_torchinductor.run_and_get_cpp_code( File "/opt/pytorch/pytorch/test/inductor/test_torchinductor.py", line 356, in run_and_get_cpp_code result = fn(*args, **kwargs) File "/opt/pytorch/pytorch/test/inductor/test_select_algorithm.py", line 43, in wrapped return fn(*args, **kwargs) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/usr/lib/python3.10/unittest/mock.py", line 1379, in patched return func(*newargs, **newkeywargs) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/pytorch/pytorch/test/inductor/test_select_algorithm.py", line 62, in test_linear_relu_cuda self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 3642, in assertEqual raise error_metas.pop()[0].to_error( AssertionError: Scalars are not equal! Expected 1 but got 0. Absolute difference: 1 Relative difference: 1.0 ``` Looking into it, we see the failure is from https://github.com/pytorch/pytorch/blob/main/test/inductor/test_select_algorithm.py#L62. The warning `W0613 20:57:17.722000 281473279256672 torch/_inductor/utils.py:902] [0/0] Not enough SMs to use max_autotune_gemm ` is triggered from https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L973. Printing torch.cuda.get_device_properties(0).multi_processor_count returns 16 on the computelab AGX Orin; thus it makes sense that this check is failing, since the min_required_sms is 68, thus not letting it pick the autotune algorithm. Looking at the main for test_select_algorithm.py, we see that these tests should only be run if is_big_gpu(0) is true: https://github.com/pytorch/pytorch/blob/main/test/inductor/test_select_algorithm.py#L344. Thus this PR adds a similar check to the invocation of these tests in test_cuda_cpp_wrapper.py. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128652 Approved by: https://github.com/soulitzer, https://github.com/eqy --- test/inductor/test_cuda_cpp_wrapper.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index eaa0134be8f09..1289de2743659 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -194,14 +194,6 @@ class BaseTest(NamedTuple): "test_cat_slice_cat", tests=test_pattern_matcher.TestPatternMatcher(), ), - BaseTest( - "test_addmm", - tests=test_select_algorithm.TestSelectAlgorithm(), - ), - BaseTest( - "test_linear_relu", - tests=test_select_algorithm.TestSelectAlgorithm(), - ), # TODO: Re-enable this test after fixing cuda wrapper for conv Triton templates with dynamic shapes. # This test is unstable: it succeeds when an ATEN kernel is used, and fails when a Triton kernel is used. # Currently it passes on CI (an ATEN kernel is chosen) and fails locally (a Triton kernel is chosen). @@ -226,6 +218,21 @@ class BaseTest(NamedTuple): ]: make_test_case(item.name, item.device, item.tests) + from torch._inductor.utils import is_big_gpu + + if is_big_gpu(0): + for item in [ + BaseTest( + "test_addmm", + tests=test_select_algorithm.TestSelectAlgorithm(), + ), + BaseTest( + "test_linear_relu", + tests=test_select_algorithm.TestSelectAlgorithm(), + ), + ]: + make_test_case(item.name, item.device, item.tests) + test_torchinductor.copy_tests( CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper", test_failures_cuda_wrapper ) From 43998711a794b6c324a59397ded048786e0e9312 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Jun 2024 02:07:03 +0000 Subject: [PATCH 20/63] [CUDAGraph] add more docs for cudagraph trees (#127963) This PR adds more documentation for CUDAGraph Trees, including - Iteration Support - Input Mutation Support - Dynamic Shape Support - NCCL Support - Reasons for Skipping CUDAGraph Pull Request resolved: https://github.com/pytorch/pytorch/pull/127963 Approved by: https://github.com/eellison --- .../source/torch.compiler_cudagraph_trees.rst | 192 +++++++++++++++++- 1 file changed, 188 insertions(+), 4 deletions(-) diff --git a/docs/source/torch.compiler_cudagraph_trees.rst b/docs/source/torch.compiler_cudagraph_trees.rst index b1986dc0dc47f..360fbf0c5d9ce 100644 --- a/docs/source/torch.compiler_cudagraph_trees.rst +++ b/docs/source/torch.compiler_cudagraph_trees.rst @@ -1,7 +1,10 @@ CUDAGraph Trees ================ -CUDAGraph Background +**Background** +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +CUDAGraph -------------------- For a longer background on CUDAGraphs, read `accelerating pytorch with CUDAGraphs `_. @@ -35,8 +38,8 @@ TorchDynamo Previous CUDA Graphs Integration Running with ``cudagraph_trees=False`` does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks, this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward. -CUDAGraph Trees Integration ---------------------------- +**CUDAGraph Trees Integration** +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Like Graph Callables, CUDA Graph Trees use a single memory pool across all graph captures. However, instead of requiring a single sequence of invocations, CUDA Graph Trees create separate trees of CUDA Graph captures. Let’s take a look at an illustrative example: @@ -90,6 +93,181 @@ The second time we hit graph 3 we are warmed up and ready to record. We record g \\ \\ 4 4 + +Input Mutation Support +---------------------- + +Input mutation function refers to a function conducting in-place writes to an input tensor, +as illustrated below: + +.. code-block:: python + + def foo(x, y): + # mutates input x + x.add_(1) + return x + y + +Input mutation functions generally lead to challenges for CUDAGraph Trees. Due to the static +CUDA memory address requirement from CUDAGraph, for each input tensor x, CUDAGraph Trees may +allocate a static memory address x'. During execution, CUDAGraph Trees first copy the input +tensor x to the static memory address x', and then replay the recorded CUDAGraph. For input +mutation function, x' is in-place updated, which is not reflected on the input tensor x since +x and x' reside on different CUDA memory addresses. + +A closer look at input mutation functions reveals that there are three types of inputs: + +* **inputs from eager**: These tensors we assume will vary input tensor addresses from + execution to execution. Because cudagraphs freeze memory addresses, we need to copy these + inputs to a static address tensor prior to graph recording and execution. +* **Parameters and buffers**: These tensors we assume (and runtime-check) have the same tensor + addresses on every execution. We do not need to copy over their contents because the recorded + memory address will be the same as the executed memory address. +* **Tensors which are prior outputs from CUDAGraph Trees**: Because the output tensor addresses + of a cudagraph are fixed, if we run CUDAGraph1, then run CUDAGraph2, the inputs which came from + CUDAGraph1 into CUDAGraph2 will have a fixed memory address. These inputs, like parameters and + buffers, do not require copying over to a static address tensor. We check to make sure that + these inputs are stable at runtime, and if they're not we will re-record. + +CUDAGraph Trees support input mutation on parameters and buffers, and tensors which are prior +outputs from CUDAGraph Trees. For mutation on inputs from eager, CUDAGraph Trees will run the +function without CUDAGraph and emit *skipping due to mutated inputs* log. The following example +shows CUDAGraph Trees' support for tensors which are prior outputs from CUDAGraph Trees. + + +.. code-block:: python + + import torch + + @torch.compile(mode="reduce-overhead") + def foo(x): + return x + 1 + + @torch.compile(mode="reduce-overhead") + def mut(x): + return x.add_(2) + + # Enable input mutation support + torch._inductor.config.triton.cudagraph_support_input_mutation = True + + for i in range(3): + torch.compiler.cudagraph_mark_step_begin() + inp = torch.rand([4], device="cuda") + + # CUDAGraph is applied since `foo` does not mutate `inp` + tmp = foo(inp) + # Although `mut` mutates `tmp`, which is an output of a CUDAGraph + # managed function. So CUDAGraph is still applied. + mut(tmp) + + + torch.compiler.cudagraph_mark_step_begin() + inp = torch.rand([4], device="cuda") + + tmp = foo(inp) + # While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()` + # is not. So CUDAGraph is not applied to `mut` and there is a log + # `skipping cudagraphs due to mutated inputs` + mut(tmp.clone()) + + +To enable CUDAGraph Trees for a function mutating inputs from eager, please re-write +the function to avoid input mutation. + +.. note:: Enable input mutation support by setting + `torch._inductor.config.cudagraph_support_input_mutation = True `_ + for "reduce-overhead" mode. + + +Dynamic Shape Support +--------------------- + +`Dynamic shape `_ +means that an input tensor has different shapes across function calls. Since CUDAGraph +requires fixed tensor addresses, CUDAGraph Trees re-record CUDAGraph for every unique +shape of an input tensor. This leads to multiple CUDAGraphs for a single inductor graph. +When there are limited shapes (e.g., batch sizes in inference), it is profitable to +re-record CUDAGraphs. However, if input tensor shapes change frequently or even on +every invocation, re-recording CUDAGraph may not be profitable. Nvidia uses 64 KB of +device memory per kernel launch in CUDAGraph, up until CUDA 12.4 and Driver Version 550+. +This memory cost can be significant with many CUDAGraph re-recordings. + +For functions with frequently changing input tensor shapes, we suggest padding input +tensors to a few fixed tensor shapes to still enjoy benefits from CUDAGraph. In addition, +setting `torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True `_ +allows to skip cudagraphing functions with dynamic shape inputs and only cudagraphing +functions with static input tensor shapes. + + +NCCL Support +------------ + +CUDAGraph Trees support functions with nccl operators. While CUDAGraph Trees perform per-device +record for CUDAGraph, NCCL support allows cross-device communication. + +.. code-block:: python + + @torch.compile(mode="reduce-overhead") + def func(x): + y = x * x + y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM) + x = torch.nn.functional.silu(x) + return x * y + + +Reasons for Skipping CUDAGraph +------------------------------ + +Since CUDAGraph has requirements such as static input tensor addresses and not supporting +CPU operators, CUDAGraph Trees check whether a function satisfies these requirements and +may skip CUDAGraph when necessary. Here, we list common reasons for skipping CUDAGraph. + +* **Input mutation**: CUDAGraph Trees skip functions that in-place mutates eager input. + In-place mutating parameters and buffers, or output tensors from CUDAGraph Tree managed + functions are still supported. Please see *Input Mutation Support* section for more details. +* **CPU operators**: Functions containing CPU operator are skipped. Please split the + function into multiple functions and apply CUDAGraph Trees on functions with only GPU operators. +* **Multi-device operators**: A function is skipped if it contains operators on multiple + devices. Currently, CUDAGraph is applied on a per-device basis. Please use supported + libraries such as NCCL for cross-device communication. Please see *NCCL Support* + section for more details. +* **Free unbacked symbols**: Free unbacked symbols usually happen during + `dynamic shapes `_. + CUDAGraph Trees currently record a CUDAGraph for every unique input tensor shapes. + Please see *Dynamic Shape Support* for more details. +* **Incompatible operators**: CUDAGraph Trees skip a function if it contain incompatible + operators. Please replace these operators in a function with supported operators. We + show an exhaustive list of incompatible operators: + + +.. code-block:: python + + aten._fused_moving_avg_obs_fq_helper.default + aten._fused_moving_avg_obs_fq_helper_functional.default + aten.multinomial.default + fbgemm.dense_to_jagged.default + fbgemm.jagged_to_padded_dense.default + run_and_save_rng_state + run_with_rng_state + aten._local_scalar_dense + aten._assert_scalar + + +The following operators are incompatible when `torch.are_deterministic_algorithms_enabled() `_. + + +.. code-block:: python + + aten._fused_moving_avg_obs_fq_helper.default + aten._fused_moving_avg_obs_fq_helper_functional.default + aten.multinomial.default + fbgemm.dense_to_jagged.default + fbgemm.jagged_to_padded_dense.default + run_and_save_rng_state + run_with_rng_state + aten._local_scalar_dense + aten._assert_scalar + + Limitations ----------- @@ -112,8 +290,14 @@ Let’s say we are benchmarking running inference with the following code: print(y1) # RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. +In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDAGraph +Trees, we don’t want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we want +to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for +torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics +are wrong, you can mark the start of a new iteration with +`torch.compiler.mark_step_begin() `_, or clone +tensors of a prior iteration (outside of torch.compile) before you begin the next run. -In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDA Graph Trees, we don’t want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we want to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics are wrong, you can mark the start of a new iteration with torch.compiler.mark_step_begin(), or clone tensors of a prior iteration (outside of torch.compile) before you begin the next run. Comparisons ----------- From 22f1793c0ac644a357ee44ccaa78e1252731f57e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 12:43:18 -0700 Subject: [PATCH 21/63] [dynamo][easy] Use LazyVariableTracker for UserDefinedObject var_getattr (#128877) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128877 Approved by: https://github.com/mlazos ghstack dependencies: #128315, #128748 --- torch/_dynamo/variables/user_defined.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2b97d921b73b1..fb2b3c1b6ac4f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -941,8 +941,7 @@ def var_getattr(self, tx, name): ) ): if source: - install_guard(source.make_guard(GuardBuilder.HASATTR)) - return VariableBuilder(tx, source)(subobj) + return variables.LazyVariableTracker.create(subobj, source) elif ConstantVariable.is_literal(subobj): return ConstantVariable.create(subobj) elif ( From 4e97d37fd947236333d5ccb37c9d9382878b4003 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 12:43:21 -0700 Subject: [PATCH 22/63] [inlining-inbuilt-nn-modules][pre-grad] Adjust efficient_conv_bn_eval_graph for inlining (#128878) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128878 Approved by: https://github.com/mlazos ghstack dependencies: #128315, #128748, #128877 --- .../fx_passes/efficient_conv_bn_eval.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index 7aecc3f15f33d..c8165a1a3926a 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -139,6 +139,97 @@ def efficient_conv_bn_eval_decomposed( return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args)) +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.nn.functional.batch_norm, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 8 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-3]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch._C._nn.linear, + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_running_mean = bn_node.args[1] + bn_running_var = bn_node.args[2] + bn_weight = bn_node.args[3] + bn_bias = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) + + return + + @register_graph_pattern( CallFunctionVarArgs( [ From c017c97333dfb9d17f2e5357980241827e50e8d5 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 12:47:25 -0700 Subject: [PATCH 23/63] [dynamo][inlining-inbuilt-nn-modules] Update test output (#128880) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128880 Approved by: https://github.com/mlazos ghstack dependencies: #128315, #128748, #128877, #128878 --- test/dynamo/test_structured_trace.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index f84e08b8f9cce..e3a82921a838b 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -379,18 +379,50 @@ def forward(self, x): {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} -{"dynamo_output_graph": {"sizes": {"l_self_layers_0_weight": [1024, 1024], "l_self_layers_0_bias": [1024], "l_x_": [1024, 1024], "l_self_layers_1_weight": [1024, 1024], "l_self_layers_1_bias": [1024], "input_1": [1024, 1024], "input_2": [1024, 1024]}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 2, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 2, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 2, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 3, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 8, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 3, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 8, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 4, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 9, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 4, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 9, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_self_modules_layers_modules_0_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_0_parameters_bias_": [1024], "l_x_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_bias_": [1024], "input_1": [1024, 1024], "input_2": [1024, 1024]}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_0"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_1"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 2, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 2, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 31, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 31, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 32, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 32, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} From 4061b3b8225f522ae0ed6db00111441e7d3cc3d5 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Mon, 17 Jun 2024 17:06:46 -0400 Subject: [PATCH 24/63] Forward fix to skip ROCm tests for #122836 (#128891) Fixes broken ROCm tests from #122836. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128891 Approved by: https://github.com/huydhn ghstack dependencies: #127007, #128057, #122836 --- test/test_nestedtensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 86f58b5a0de3a..50d6deea92911 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5470,6 +5470,7 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_preserves_metadata_cache(self, device, dtype): # shape (B, *, D) nt = random_nt_from_dims( @@ -5500,6 +5501,7 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_with_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 @@ -5536,6 +5538,7 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_with_dynamic_min_seq_len(self, device, dtype): # shape (B, *, D) # min seq len: 7 @@ -5572,6 +5575,7 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 From 17abbafdfc6935bcc133e5f43ba32d44914fe316 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 18 Jun 2024 03:25:20 +0000 Subject: [PATCH 25/63] [inductor] Fix some windows cpp builder issue (#128765) 1. fix some Windows build args. 2. fix c++20 likely issue on Windows, reference: https://github.com/pytorch/pytorch/pull/124997. 3. remove compiler return value check, different compilers return variant value, let's check exception to catch error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128765 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/codecache.py | 7 ++++--- torch/_inductor/cpp_builder.py | 9 +++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 3d265f181b159..422728a9e59ae 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1376,8 +1376,6 @@ def __bool__(self) -> bool: output_path = x86_isa_help_builder.get_target_file_path() if not os.path.isfile(output_path): status, target_file = x86_isa_help_builder.build() - if status: - return False # Check build result subprocess.check_call( @@ -2573,11 +2571,14 @@ class CppPythonBindingsCodeCache(CppCodeCache): #ifndef _MSC_VER #if __cplusplus < 202002L - // C++20 earlier code + // C++20 (earlier) code // https://en.cppreference.com/w/cpp/language/attributes/likely #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) #endif + #else + #define likely(x) (x) + #define unlikely(x) (x) #endif // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index f75f079d72db2..a574b33473425 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -345,7 +345,11 @@ def _get_optimization_cflags() -> List[str]: def _get_shared_cflag(compile_only: bool) -> List[str]: if _IS_WINDOWS: - SHARED_FLAG = ["DLL"] + """ + MSVC `/MD` using python `ucrtbase.dll` lib as runtime. + https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170 + """ + SHARED_FLAG = ["DLL", "MD"] else: if compile_only: return ["fPIC"] @@ -567,7 +571,7 @@ def _get_torch_related_args(include_pytorch: bool, aot_mode: bool): ] libraries_dirs = [TORCH_LIB_PATH] libraries = [] - if sys.platform == "linux" and not config.is_fbcode(): + if sys.platform != "darwin" and not config.is_fbcode(): libraries = ["torch", "torch_cpu"] if not aot_mode: libraries.append("torch_python") @@ -663,6 +667,7 @@ def _get_openmp_args(cpp_compiler): # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL libs = [] else: if config.is_fbcode(): From 59b4983dc06f12eded69ab1471817c67c1fc72c0 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 18 Jun 2024 03:40:14 +0000 Subject: [PATCH 26/63] DebugPlane: add dump_traceback handler (#128904) This adds a `dump_traceback` handler so you can see all running threads for a job. This uses a temporary file as a buffer when calling `faulthandler.dump_traceback` and requires the GIL to be held during dumping. Test plan: ``` python test/distributed/elastic/test_control_plane.py -v -k traceback ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128904 Approved by: https://github.com/c-p-i-o --- build_variables.bzl | 1 + .../distributed/elastic/test_control_plane.py | 6 +++ .../c10d/control_plane/PythonHandlers.cpp | 44 +++++++++++++++++++ 3 files changed, 51 insertions(+) create mode 100644 torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp diff --git a/build_variables.bzl b/build_variables.bzl index b4b4d1ab139cd..793b611a0a6f0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -927,6 +927,7 @@ libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + "torch/csrc/distributed/rpc/unpickled_python_call.cpp", "torch/csrc/distributed/rpc/unpickled_python_remote_call.cpp", "torch/csrc/jit/runtime/register_distributed_ops.cpp", + "torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp", ] def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 775b062451b16..7d01bd9eb0300 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -92,6 +92,12 @@ def test_tcp(self) -> None: server.shutdown() + def test_dump_traceback(self) -> None: + with local_worker_server() as pool: + resp = pool.request("POST", "/handler/dump_traceback") + self.assertEqual(resp.status, 200) + self.assertIn(b"in test_dump_traceback\n", resp.data) + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp b/torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp new file mode 100644 index 0000000000000..cc1539a9527b4 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp @@ -0,0 +1,44 @@ +#include + +#include +#include +#include + +#include +#include +#include + +namespace c10d::control_plane { +namespace { + +RegisterHandler tracebackHandler{ + "dump_traceback", + [](const Request&, Response& res) { + auto tmpfile = c10::make_tempfile("torch-dump_traceback"); + + auto cfile = ::fopen(tmpfile.name.c_str(), "w"); + if (!cfile) { + throw std::runtime_error("failed to open file for writing"); + } + + { + py::gil_scoped_acquire guard{}; + + auto faulthandler = py::module::import("faulthandler"); + faulthandler.attr("dump_traceback")(fileno(cfile), true); + } + + ::fclose(cfile); + + std::ifstream file(tmpfile.name); + std::string str; + std::string file_contents; + while (std::getline(file, str)) { + file_contents += str; + file_contents.push_back('\n'); + } + + res.setContent(std::move(file_contents), "text/plain"); + }}; +} +} // namespace c10d::control_plane From d9eaa224f2512639e55cb11b372fcd1983d22ea5 Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Tue, 18 Jun 2024 03:44:38 +0000 Subject: [PATCH 27/63] Fixes #128429: NaN in triu op on MPS (#128575) Fixes triu op when k > 0 and the lower triangle of the input tensor contains inf leading to NaNs in the computation through complement. Fixed by using select API instead. Fixes #128429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128575 Approved by: https://github.com/kulinseth --- .../ATen/native/mps/operations/TriangularOps.mm | 15 ++++++++++----- test/test_mps.py | 8 ++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm index 5fa0b22184535..dcea978655b85 100644 --- a/aten/src/ATen/native/mps/operations/TriangularOps.mm +++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm @@ -35,11 +35,16 @@ if (k > 0) { MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32]; - MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor - numLowerTensor:minusOneTensor - numUpperTensor:diagMinusOneTensor - name:nil]; - outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:complementTensor name:nil]; + MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt32]; + onesTensor = [mpsGraph broadcastTensor:onesTensor toShape:inputTensor.shape name:nil]; + MPSGraphTensor* maskTensor = [mpsGraph bandPartWithTensor:onesTensor + numLowerTensor:minusOneTensor + numUpperTensor:diagMinusOneTensor + name:nil]; + outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor + truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:inputTensor.dataType] + falsePredicateTensor:inputTensor + name:nil]; } else { MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32]; outputTensor = [mpsGraph bandPartWithTensor:inputTensor diff --git a/test/test_mps.py b/test/test_mps.py index 275013f20effc..311cf8245c4f3 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1617,6 +1617,14 @@ def test_exp(self, device="mps", dtype=torch.float): a = torch.tensor(v, dtype=dtype, device="mps") * b self.compare_with_numpy(torch.exp, np.exp, a) + def test_triu_inf(self, device="mps", dtype=torch.float): + for diag in [-1, 0, 1]: + mask = torch.full((3, 6, 6), float("-inf")) + mask_mps = mask.clone().detach().to('mps') + cpu_ref = torch.triu(mask, diagonal=diag) + mps_out = torch.triu(mask_mps, diagonal=diag) + self.assertEqual(cpu_ref, mps_out) + def test_exp1(self, device="mps", dtype=torch.float): input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype) output = torch.exp(input) From f7eae279463b719c5f25587aac225bd2be891373 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Mon, 17 Jun 2024 10:06:14 -0700 Subject: [PATCH 28/63] Pass params to dump_nccl_trace_pickle (#128781) Summary Pass parameters from request to dump_nccl_trace_pickle handler. The supported parameters + value are all lowercase. includecollectives={true, false} includestacktraces={true, false} onlyactive={true, false} Example post is: /handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true Test Plan: unit tests Differential Revision: [D58640474](https://our.internmc.facebook.com/intern/diff/D58640474) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128781 Approved by: https://github.com/d4l3k --- .../distributed/elastic/test_control_plane.py | 37 ++++++++++++++ torch/csrc/distributed/c10d/NCCLUtils.cpp | 51 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 10 ---- .../c10d/control_plane/Handlers.hpp | 3 ++ .../c10d/control_plane/WorkerServer.cpp | 4 ++ 5 files changed, 95 insertions(+), 10 deletions(-) diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 7d01bd9eb0300..9eb57952e2bdf 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -80,6 +80,43 @@ def test_dump_nccl_trace_pickle(self) -> None: resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") self.assertEqual(resp.status, 200) out = pickle.loads(resp.data) + self.assertIsInstance(out, dict) + self.assertIn("version", out) + + @requires_cuda + def test_dump_nccl_trace_pickle_with_params(self) -> None: + with local_worker_server() as pool: + # bad key - not lower case + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true" + ) + self.assertEqual(resp.status, 400) + # unknown key + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true" + ) + self.assertEqual(resp.status, 400) + # bad value - not a bool + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool" + ) + self.assertEqual(resp.status, 400) + # bad value - value not lowercase + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True" + ) + self.assertEqual(resp.status, 400) + # good key and value + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true" + ) + self.assertEqual(resp.status, 200) + # multiple good keys and values + resp = pool.request( + "POST", + "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true", + ) + self.assertEqual(resp.status, 200) def test_tcp(self) -> None: import requests diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 6507fe6abc2a2..d3a997625e144 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,7 +1,10 @@ #include +#include +#include #include #include +#include #ifdef USE_C10D_NCCL #include @@ -238,6 +241,54 @@ std::string getNcclErrorDetailStr( return interpret + err; } +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request& req, control_plane::Response& res) { + const auto params = req.params(); + size_t validParamCount = 0; + + // valid params + const std::string includeCollectivesStr = "includecollectives"; + const std::string includeStackTracesStr = "includestacktraces"; + const std::string onlyActiveStr = "onlyactive"; + + std::unordered_map expectedParams = { + {includeCollectivesStr, true}, + {includeStackTracesStr, true}, + {onlyActiveStr, false}}; + + for (const auto& [paramName, paramValue] : params) { + auto it = expectedParams.find(paramName); + if (it != expectedParams.end()) { + validParamCount++; + if (paramValue == "true") { + it->second = true; + } else if (paramValue == "false") { + it->second = false; + } else { + res.setStatus(400); + res.setContent( + "Invalid value for " + paramName + + " valid values are true or false", + "text/plain"); + return; + } + } + } + if (validParamCount < params.size()) { + res.setStatus(400); + res.setContent( + "Invalid parameters - unexpected param passed in", "text/plain"); + return; + } + res.setContent( + dump_nccl_trace( + expectedParams[includeCollectivesStr], + expectedParams[includeStackTracesStr], + expectedParams[onlyActiveStr]), + "application/octet-stream"); + }}; + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d293c4d470b83..06804a544a388 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include @@ -380,15 +379,6 @@ std::string dump_nccl_trace( } #endif -// TODO(c-p-i-o): add a JSON endpoint. -control_plane::RegisterHandler dumpHandler{ - "dump_nccl_trace_pickle", - [](const control_plane::Request& req, control_plane::Response& res) { - // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. - res.setContent( - dump_nccl_trace(true, true, false), "application/octet-stream"); - }}; - std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index 0c10630549312..f230e7a4c0e47 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -15,6 +16,8 @@ class TORCH_API Request { virtual ~Request() = default; virtual const std::string& body() = 0; + + virtual const std::multimap& params() const = 0; }; // Response represents a response to the handler. This conceptually maps to an diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 947a281982f14..0e9de35322abb 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -23,6 +23,10 @@ class RequestImpl : public Request { return req_.body; } + const std::multimap& params() const override { + return req_.params; + } + private: const httplib::Request& req_; }; From e3a39d49a0b06399f074b30c4be6ef9670633185 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 18 Jun 2024 06:22:14 +0000 Subject: [PATCH 29/63] [Traceable FSDP][Compiled Autograd] Add queue_callback() support (#126366) Adds support for `Variable._execution_engine.queue_callback()`, which is used in FSDP2. Important tests: - `pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_callback_graph_break_throws_error` - `pytest -rA test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_callback_adds_callback` - `PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_callback_adds_callback` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126366 Approved by: https://github.com/xmfan --- test/inductor/test_compiled_autograd.py | 28 ++++++++++++++++- torch/_dynamo/compiled_autograd.py | 12 +++++++- torch/_dynamo/external_utils.py | 19 ++++++++++++ torch/_dynamo/side_effects.py | 12 ++++++++ torch/_dynamo/symbolic_convert.py | 2 ++ torch/_dynamo/variables/builder.py | 19 ++++++++++++ torch/_dynamo/variables/misc.py | 41 +++++++++++++++++++++++++ 7 files changed, 131 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index a3dfcb59f2fdd..91b8178ae6ccd 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1767,6 +1767,33 @@ def fn(inputs): out = compiled_fn(activations) self.assertTrue(len(activations) == 0) + def test_callback_graph_break_throws_error(self): + called = [0] + + def callback_final(): + called[0] += 1 + + class MyFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad): + torch.autograd.Variable._execution_engine.queue_callback(callback_final) + torch._dynamo.graph_break() + return grad + + a = torch.rand((3, 3), requires_grad=True) + with self.assertRaisesRegex( + AssertionError, + "only supported when Compiled Autograd is enabled with fullgraph=True", + ): + with compiled_autograd.enable(make_compiler_fn(fullgraph=False)): + b = MyFunc.apply(a) + b.sum().backward() + @unittest.skipIf(not HAS_CUDA, "requires cuda") def test_cudagraphs_cpu_division(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -2177,7 +2204,6 @@ def wrap_test_class(orig_cls): "test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable( "test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance "test_autograd_simple_views_python", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function - "test_callback_adds_callback", # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable "test_callback_propagates_errors_from_device_thread", # AssertionError: "blah" does not match "call_method "test_custom_autograd_no_early_free", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients "test_custom_function_cycle", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 2570278ef4788..e72cf40d65ded 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -4,7 +4,11 @@ from typing import Dict, List, Optional, TYPE_CHECKING import torch -from torch._dynamo.external_utils import call_backward, call_hook +from torch._dynamo.external_utils import ( + call_backward, + call_hook, + FakeCompiledAutogradEngine, +) from torch._dynamo.source import GetItemSource, LocalSource from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal from torch._logging import getArtifactLogger, trace_structured @@ -255,6 +259,12 @@ def move_graph_nodes_to_cuda(self, graph) -> List[int]: return [] def end_capture(self, outputs): + self.fx_tracer.create_proxy( + "call_function", + FakeCompiledAutogradEngine._exec_final_callbacks_stub, + (), + {}, + ) self.stack.close() self.fx_tracer.create_node( "output", diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index caea92bc6be08..7d3b0fc6ada43 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -98,6 +98,25 @@ def untyped_storage_size(x: torch.Tensor): return x.untyped_storage().size() +class FakeCompiledAutogradEngine: + @staticmethod + def queue_callback(final_callbacks, cb): + final_callbacks.append(cb) + + @staticmethod + def exec_final_callbacks(final_callbacks): + i = 0 + while i < len(final_callbacks): + cb = final_callbacks[i] + cb() + i += 1 + final_callbacks.clear() + + @staticmethod + def _exec_final_callbacks_stub(): + pass + + def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs): return getattr(bw_state, hook_name)(*args, **kwargs) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 28ce9811b4c38..5689fa0977db8 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -89,6 +89,9 @@ def __init__( self.keepalive = keepalive or [] self.save_for_backward = save_for_backward or [] self.tensor_hooks = tensor_hooks or {} + # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. + # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. + self.ca_final_callbacks_var = None def __eq__(self, other: object) -> bool: assert isinstance(other, SideEffects) @@ -476,6 +479,15 @@ def codegen_hooks(self, cg): # be associated with the return value of register_hook(). This consumes the top of stack. cg.add_cache(handle) + def get_ca_final_callbacks_var(self): + from .variables.base import MutableLocal + + if self.ca_final_callbacks_var is None: + self.ca_final_callbacks_var = variables.ListVariable( + [], mutable_local=MutableLocal() + ) + return self.ca_final_callbacks_var + def codegen_update_mutated(self, cg: PyCodegen): suffixes = [] for var in self._get_modified_vars(): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 7e129a05a0905..6105ae466b012 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2305,6 +2305,7 @@ def __init__( self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} # Flag to indicate whether tracing is used for export. self.export = export + self.one_graph = False self.current_speculation = None @@ -2860,6 +2861,7 @@ def __init__( self.symbolic_result = None self.closure_cells = closure_cells self.nn_module_stack = parent.nn_module_stack.copy() + self.one_graph = parent.one_graph @property def fake_mode(self): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 2097690b88b03..af91edb432c88 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -129,6 +129,7 @@ CollectiveFunctionRewriteVariable, FunctoolsPartialVariable, TritonKernelVariable, + UserFunctionVariable, UserMethodVariable, ) from .higher_order_ops import TorchHigherOrderOperatorVariable @@ -146,6 +147,7 @@ TupleVariable, ) from .misc import ( + AutogradEngineVariable, AutogradFunctionContextVariable, AutogradFunctionVariable, ComptimeVariable, @@ -726,6 +728,23 @@ def build_key_value(i, k, v): ), "apply", ) + elif isinstance(value, torch._C._ImperativeEngine): + self.install_guards(GuardBuilder.ID_MATCH) + return AutogradEngineVariable(value, source=self.source) + elif ( + value + is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub + ): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return LambdaVariable( + lambda: UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks, + ).call_function( + self.tx, + (self.tx.output.side_effects.get_ca_final_callbacks_var(),), + {}, + ) + ) elif callable(value) and trace_rules.lookup_callable(value) is not None: if is_callable_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 179bb9a52bf98..0e54e0f613a34 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -643,6 +643,47 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) +class AutogradEngineVariable(UserDefinedObjectVariable): + """ + Represents a torch._C._ImperativeEngine instance. + """ + + def __init__( + self, + value, + value_type=None, + **kwargs, + ): + super().__init__(value=value, value_type=value_type, **kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "queue_callback": + if torch._dynamo.compiled_autograd.compiled_autograd_enabled: + assert ( + tx.one_graph + ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" + return variables.UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, + source=self.source, + ).call_function( + tx, + (tx.output.side_effects.get_ca_final_callbacks_var(), *args), + kwargs, + ) + else: + unimplemented( + "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" + ) + else: + unimplemented(f"torch._C._ImperativeEngine method: {name}") + + class LambdaVariable(VariableTracker): def __init__(self, fn, **kwargs): super().__init__(**kwargs) From 60baeee59f7a6ff610c42411bf2709d2bbd5bd2c Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 13:31:23 -0700 Subject: [PATCH 30/63] [BE] Skip the test if CUDA is not available (#128885) As title Differential Revision: [D58690210](https://our.internmc.facebook.com/intern/diff/D58690210/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128885 Approved by: https://github.com/wz337 --- test/distributed/_tensor/debug/test_comm_mode.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 5483b3171f309..bd862220b210d 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -116,6 +116,9 @@ def f(x, y): @requires_nccl() def test_comm_mode_with_c10d(self): + if not torch.cuda.is_available(): + return + world_pg = self.world_pg inp = torch.rand(2, 8, 16).cuda() From 6e43897912d149d7dad676f496c608fe32a31978 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 09:40:54 -0700 Subject: [PATCH 31/63] [BE][ptd_fb_test][3/N] Enable TestSlide for MultiThreadedTestCase (#128843) Enabling testslide for MultiThreadedTestCase, similar to https://github.com/pytorch/pytorch/pull/127512. Differential Revision: [D58677457](https://our.internmc.facebook.com/intern/diff/D58677457/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128843 Approved by: https://github.com/wz337 --- torch/testing/_internal/common_distributed.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 473e5c35e07a1..a0a3429797c28 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -994,10 +994,14 @@ def wrapper(self): return types.MethodType(wrapper, self) - def __init__(self, method_name: str = "runTest") -> None: + def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName super().__init__(method_name) - test_fn = getattr(self, method_name, None) - setattr(self, method_name, self.join_or_run(test_fn)) + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) def perThreadSetUp(self): # super().setUp() # TestCase.setUp() calls torch.manual_seed() From 304c9345726e68c9bbd0ea370b3c056db6964c4b Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 15 Jun 2024 17:15:38 -0700 Subject: [PATCH 32/63] Move MKLDNN Specific IR to Separate File (#126504) **Summary** Following the discussion in https://github.com/pytorch/pytorch/pull/122593#discussion_r1604144782, Move Inductor MKLDNN specific IRs to a separate file. Co-authored-by: Isuru Fernando Pull Request resolved: https://github.com/pytorch/pytorch/pull/126504 Approved by: https://github.com/desertfire, https://github.com/jgong5 ghstack dependencies: #126841, #126940 --- torch/_inductor/ir.py | 1632 -------------------------- torch/_inductor/mkldnn_ir.py | 1659 +++++++++++++++++++++++++++ torch/_inductor/mkldnn_lowerings.py | 26 +- 3 files changed, 1672 insertions(+), 1645 deletions(-) create mode 100644 torch/_inductor/mkldnn_ir.py diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 9e1c90e995378..898eed09268ea 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -81,7 +81,6 @@ get_kernel_metadata, is_dynamic, is_gpu, - pad_listlike, sympy_dot, sympy_index_symbol, sympy_index_symbol_with_prefix, @@ -5792,1637 +5791,6 @@ def get_inputs_that_alias_output(self): ] -def _prepare_convolution_fusion_create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding: List[int], - stride: List[int], - dilation: List[int], - groups: int, - transposed: bool = False, - output_padding: Optional[List[int]] = None, -): - """ - This function is a helper function to prepare inputs, layout and constant args - for convolution post-op fusion's create function, including deciding the output - layout (channels first or channels last), realizing inputs and make them etc. The - function only supports the CPU device since conv post-op fusion kernel is only - supported on CPU right now. - """ - - # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size - def _conv_input_size( - output_size, weight_size, padding, output_padding, stride, dilation, groups - ): - assert len(output_size) == len(weight_size), "Expect input dim == weight dim" - dim = len(output_size) - assert dim > 2, "Expect input dim > 2" - - BATCH_DIM = 0 - WEIGHT_INPUT_CHANNELS_DIM = 1 - input_size = [] - input_size.append(output_size[BATCH_DIM]) - input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) - for d in range(2, dim): - kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 - input_size_d = ( - (output_size[d] - 1) * stride[d - 2] - - (padding[d - 2] * 2) - + kernel - + output_padding[d - 2] - ) - input_size.append(input_size_d) - return list(map(int, input_size)) - - # The size of prepacked_weight is the prepacked weight size of deconv: - # Groups > 1: [g*o, i/g, ...] - # Groups == 1: [o, i, ...] - # Returns original weight size in [i, o, ...] - def _original_deconv_weight_size( - prepacked_weight, - groups, - ): - prepacked_weight_size = prepacked_weight.size() - dim = len(prepacked_weight_size) - assert dim > 2, "Expect weight dim > 2" - if groups > 1: - weight_size = [] - weight_size.append(prepacked_weight_size[1] * groups) - weight_size.append(prepacked_weight_size[0] / groups) - for d in range(2, dim): - weight_size.append(prepacked_weight_size[d]) - else: - weight_size = prepacked_weight.transpose(0, 1).size() - return weight_size - - x.realize() - weight.realize() - if bias is not None: - bias.realize() - with V.graph.fake_mode: - # TODO cleaned up the fake_tensor trace as Linear implementation - x_fake = ir_node_to_tensor(x, guard_shape=True) - weight_fake = ir_node_to_tensor(weight, guard_shape=True) - dims = len(x_fake.size()) - 2 - assert 0 < len(padding) <= dims - assert 0 < len(dilation) <= dims - assert 0 < len(stride) <= dims - padding = pad_listlike(padding, dims) - dilation = pad_listlike(dilation, dims) - stride = pad_listlike(stride, dims) - if output_padding is None: - output_padding = pad_listlike([0], dims) - else: - assert 0 < len(output_padding) <= dims - output_padding = pad_listlike(output_padding, dims) - assert isinstance(groups, int) - if transposed: - # When transposed, the size of the prepacked oneDNN weight is different - # from the PyTorch weight. We're not able to run aten conv with such - # size. We infer the output size from the input params here: - weight_size = _original_deconv_weight_size(weight_fake, groups) - input_size = x_fake.size() - output_size = _conv_input_size( - input_size, - weight_size, - padding, - output_padding, - stride, - dilation, - groups, - ) - else: - bias_fake = ( - ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias - ) - output = torch.ops.aten.convolution( - x_fake, - weight_fake, - bias_fake, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - output_size = output.size() - - req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) - req_stride_order = [len(req_stride_order)] + req_stride_order - - x = cls.require_stride_order(x, req_stride_order) - - # We won't do weight prepack for Conv if dynamic_shapes. - # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. - # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), - # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order - # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, - # this tensor is considered as channels first and the output will be in contiguous format. - # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. - dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) - if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = FlexibleLayout.contiguous_strides(output_size) - else: - output_stride = make_channels_last_strides_for(output_size) - - assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" - inputs = [x, weight] - - kernel_layout = FixedLayout( - x.get_device(), - x.get_dtype(), - convert_shape_to_inductor(output_size), - convert_shape_to_inductor(output_stride), - ) - constant_args = [padding, stride, dilation, groups] - if transposed: - constant_args.insert(1, output_padding) - - if bias is not None: - inputs.append(bias) - else: - constant_args.insert(0, bias) - return inputs, constant_args, kernel_layout, req_stride_order - - -def _prepare_linear_fusion_create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", -): - """ - This function is a helper function to prepare inputs, layout and constant args - for linear post-op fusion's create function. The function only supports the CPU device - since linear post-op fusion kernel is only supported on CPU right now. - """ - x.realize() - weight.realize() - if bias is not None: - bias.realize() - - *m, _ = x.get_size() - # The weight has been transposed during the qlinear weight prepack process. - # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ - # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 - _, oc = weight.get_size() - output_size = list(m) + [oc] - req_stride_order = list(reversed(range(len(x.get_size())))) - - x = cls.require_stride_order(x, req_stride_order) - assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" - inputs = [x, weight] - - output_stride = FlexibleLayout.contiguous_strides(output_size) - kernel_layout = FixedLayout( - x.get_device(), - x.get_dtype(), - output_size, - output_stride, - ) - constant_args: List[Any] = [] - - if bias is not None: - inputs.append(bias) - else: - constant_args.insert(0, bias) - return inputs, constant_args, kernel_layout, req_stride_order - - -class ConvolutionUnary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_pointwise", - cpp_kernel_name="mkldnn::_convolution_pointwise", - ) - self.cpp_kernel_key = "convolution_pointwise" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups: int, - attr, - scalars: Optional[List[Any]], - algorithm, - ): - (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( - cls, x, weight, bias, padding_, stride_, dilation_, groups - ) - constant_args = constant_args + [ - attr, - may_convert_to_optional(scalars), - algorithm, - ] - return ConvolutionUnary( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class ConvolutionBinary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - cpp_constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", - cpp_kernel_name="mkldnn::_convolution_pointwise", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "convolution_pointwise_binary" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& other_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view binary_attr, - c10::optional alpha, - c10::optional unary_attr, - torch::List> unary_scalars, - c10::optional unary_algorithm)""" - self.cpp_constant_args = cpp_constant_args - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - other: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups: int, - binary_attr: str, - binary_alpha: Optional[float], - unary_attr: Optional[str], - unary_scalars: Optional[List[Any]], - unary_algorithm: Optional[str], - ): - ( - inputs, - constant_args, - kernel_layout, - req_stride_order, - ) = _prepare_convolution_fusion_create( - cls, x, weight, bias, padding_, stride_, dilation_, groups - ) - other = cls.require_stride_order(other, req_stride_order) - inputs.insert(1, other) - constant_args = constant_args + [ - binary_attr, - binary_alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - return ConvolutionBinary( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class ConvolutionBinaryInplace(ExternKernelAlloc): - def __init__( - self, - kernel_layout, - inputs, - constant_args=(), - ): - # Due to constrain of op.call, other (Tensor&) should be at input[0] - reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] - - super().__init__( - kernel_layout, - reordered_inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", - cpp_kernel_name="mkldnn::_convolution_pointwise_", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "convolution_pointwise_binary_" - # TODO: op.call: input[0] should be at::Tensor& - self.cpp_op_schema = """ - at::Tensor&( - at::Tensor& other_t, - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view binary_attr, - c10::optional alpha, - c10::optional unary_attr, - torch::List> unary_scalars, - c10::optional unary_algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - - def get_mutation_names(self): - return [self.inputs[0].get_name()] - - def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: - return set() - - @classmethod - def create( - cls, - x: "TensorBox", - other: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups: int, - binary_attr: str, - binary_alpha: Optional[float], - unary_attr: Optional[str], - unary_scalars: Optional[List[Any]], - unary_algorithm: Optional[str], - ): - ( - inputs, - constant_args, - _, - req_stride_order, - ) = _prepare_convolution_fusion_create( - cls, x, weight, bias, padding_, stride_, dilation_, groups - ) - other = cls.require_stride_order(other, req_stride_order) - inputs.insert(1, other) - constant_args = constant_args + [ - binary_attr, - binary_alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - packed = ConvolutionBinaryInplace( - kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] - inputs=inputs, - constant_args=constant_args, - ) - mark_node_as_mutating(packed, inputs[1]) - # This op mutates in place which means that the result is not the - # target but rather the input that is being mutated - # init reorders the inputs, so inputs[1] becomes packed.inputs[0] - return packed.inputs[0] - - -class MKLPackedLinear(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkl._mkl_linear", - cpp_kernel_name="mkl::_mkl_linear", - ) - self.cpp_kernel_key = "mkl_linear" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& self, - const at::Tensor& mkl_weight_t, - const at::Tensor& origin_weight_t, - const c10::optional& bias_opt, - const int64_t prepack_batch_size)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - - @classmethod - def create(cls, x, packed_w, orig_w, B, batch_size): - x = cls.require_stride1(cls.realize_input(x)) - orig_w = cls.require_stride1(cls.realize_input(orig_w)) - *m, _ = x.get_size() - oc, _ = orig_w.get_size() - output_size = list(m) + [oc] - output_stride = FlexibleLayout.contiguous_strides(output_size) - inputs = [x, packed_w, orig_w] - constant_args = [batch_size] - if B is not None: - inputs += [B] - else: - constant_args.insert(0, None) - - return MKLPackedLinear( - layout=FixedLayout( - x.get_device(), x.get_dtype(), output_size, output_stride - ), - inputs=inputs, - constant_args=constant_args, - ) - - -class LinearUnary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._linear_pointwise", - cpp_kernel_name="mkldnn::_linear_pointwise", - ) - self.cpp_kernel_key = "linear_pointwise" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - - @classmethod - def create(cls, x, w, B, attr, scalars, algorithm): - x = cls.require_contiguous(cls.realize_input(x)) - w = cls.require_contiguous(cls.realize_input(w)) - - *m, ic = x.get_size() - oc, ic = w.get_size() - inputs = [x, w] - constant_args = [attr, scalars if scalars else [-1], algorithm] - if B is not None: - B = cls.require_contiguous(cls.realize_input(B)) - inputs.append(B) - else: - constant_args.insert(0, None) - - return LinearUnary( - layout=FlexibleLayout( - device=x.get_device(), - dtype=x.get_dtype(), - size=list(m) + [oc], - ), - inputs=inputs, - constant_args=constant_args, - ) - - def apply_constraint(self): - pass - - -class LinearBinary(ExternKernelAlloc): - kernel = "torch.ops.mkldnn._linear_pointwise.binary" - - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", - cpp_kernel_name="mkldnn::_linear_pointwise", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "linear_pointwise_binary" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& other_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - c10::string_view attr) - """ - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - - @classmethod - def create(cls, x, y, w, B, attr): - x = cls.require_contiguous(cls.realize_input(x)) - y = cls.require_contiguous(cls.realize_input(y)) - w = cls.require_contiguous(cls.realize_input(w)) - - *m, ic = x.get_size() - oc, ic = w.get_size() - - inputs = [x, y, w] - constant_args = [attr] - if B is not None: - B = cls.require_contiguous(cls.realize_input(B)) - inputs.append(B) - else: - constant_args.insert(0, B) - - return LinearBinary( - layout=FlexibleLayout( - device=x.get_device(), - dtype=x.get_dtype(), - size=list(m) + [oc], - ), - inputs=inputs, - constant_args=constant_args, - ) - - def apply_constraint(self): - pass - - -class ConvolutionTransposeUnary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", - cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", - ) - self.cpp_kernel_key = "convolution_transpose_pointwise" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef output_padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - - @classmethod - def create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - output_padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups_: int, - attr, - scalars: Optional[List[Any]], - algorithm, - ): - transposed = True - ( - inputs, - constant_args, - kernel_layout, - _, - ) = _prepare_convolution_fusion_create( - cls, - x, - weight, - bias, - padding_, - stride_, - dilation_, - groups_, - transposed, - output_padding_, - ) - constant_args = constant_args + [ - attr, - may_convert_to_optional(scalars), - algorithm, - ] - return ConvolutionTransposeUnary( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class MkldnnRnnLayer(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="aten.mkldnn_rnn_layer", - cpp_kernel_name="at::mkldnn_rnn_layer", - ) - - @classmethod - def create( - cls, - x: "TensorBox", - w0: "TensorBox", - w1: "TensorBox", - w2: "TensorBox", - w3: "TensorBox", - hx: "TensorBox", - cx: "TensorBox", - reverse: bool, - batch_sizes: List[int], - mode: int, - hidden_size: int, - num_layers: int, - has_biases: bool, - bidirectional: bool, - batch_first: bool, - train: bool, - ): - x = cls.require_stride1(cls.realize_input(x)) - # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. - # Make sure x is contiguous in batch_first case. - x.freeze_layout() - w0 = cls.require_stride1(cls.realize_input(w0)) - w1 = cls.require_stride1(cls.realize_input(w1)) - w2 = cls.require_stride1(cls.realize_input(w2)) - w3 = cls.require_stride1(cls.realize_input(w3)) - hx = cls.require_stride1(cls.realize_input(hx)) - hx.freeze_layout() - cx = cls.require_stride1(cls.realize_input(cx)) - cx.freeze_layout() - - input_size = x.get_size() - assert len(input_size) == 3, "Expect lstm input to be 3D" - # batch_first is handled in the lstm OP. When entering - # rnn_layer here, we'll always have batch_first = False - seq_length, mini_batch, input_size = input_size - output_shape = [seq_length, mini_batch, hidden_size] - - hy_shape = hx.get_size() - cy_shape = cx.get_size() - - res: List[IRNode] = [] - - inputs = [x, w0, w1, w2, w3, hx, cx] - constant_args = [ - reverse, - batch_sizes, - mode, - hidden_size, - num_layers, - has_biases, - bidirectional, - batch_first, - train, - ] - - packed = MkldnnRnnLayer( - MultiOutputLayout(x.get_device()), - inputs=inputs, - constant_args=constant_args, - ) - - def get_strides_of_lstm_output(output_shape, batch_first): - assert len(output_shape) == 3, "Expect output_shape to be 3D" - return FlexibleLayout.contiguous_strides(output_shape) - - output_sizes = [output_shape, hy_shape, cy_shape] - output_strides = [ - get_strides_of_lstm_output(output_shape, batch_first), - FlexibleLayout.contiguous_strides(hy_shape), - FlexibleLayout.contiguous_strides(cy_shape), - ] - output_ir = [ - MultiOutput( - FixedLayout( - x.get_device(), - x.get_dtype(), - output_size, - output_stride, - ), - packed, - [(tuple, i)], - ) - for i, (output_size, output_stride) in enumerate( - zip(output_sizes, output_strides) - ) - ] - - return output_ir - - -class QConvPointWisePT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - """ - if bias is not None - - inputs = [x, w, b, weight_scale, weight_zp] - - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, weight_scale, weight_zp] - - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = len(inputs) == 5 - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.onednn.qconv2d_pointwise", - cpp_kernel_name="onednn::qconv2d_pointwise", - ) - self.cpp_kernel_key = "qconv2d_pointwise" - self.cpp_op_schema = """ - at::Tensor( - at::Tensor act, - double act_scale, - int64_t act_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - torch::List stride, - torch::List padding, - torch::List dilation, - int64_t groups, - double output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - w_scale, w_zp = args[-2], args[-1] - ( - stride, - padding, - dilation, - groups, - x_scale, - x_zp, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-12:] - - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - bias, - stride, - padding, - dilation, - groups, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale: float, - x_zp: int, - weight: "TensorBox", # packed_weight - w_scale: "TensorBox", - w_zp: "TensorBox", - bias: "TensorBox", - stride_: List[int], - padding_: List[int], - dilation_: List[int], - groups: int, - o_inv_scale: float, - output_zero_point: int, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ): - transposed = False - output_padding = None - (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( - cls, - x, - weight, - bias, - padding_, - stride_, - dilation_, - groups, - transposed, - output_padding, - ) - # swap padding and stride to align with functional conv arg order - if bias is None: - constant_args[1], constant_args[2] = constant_args[2], constant_args[1] - else: - constant_args[0], constant_args[1] = constant_args[1], constant_args[0] - - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - constant_args = constant_args + [ - x_scale, - x_zp, - o_inv_scale, - output_zero_point, - output_dtype, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - if output_dtype is not None: - assert output_dtype in [torch.float32, torch.bfloat16] - # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. - kernel_layout.dtype = output_dtype - - return QConvPointWisePT2E( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class QConvPointWiseBinaryPT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - """ - Needs input/weight/output qparams - if bias is not None - - inputs = [x, w, b, accum, w_scale, w_zp] - - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, accum, w_scale, w_zp] - - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, - accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = len(inputs) == 6 - self.idx_for_inplace_sum = 3 if self.has_bias else 2 - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", - cpp_kernel_name="onednn::qconv2d_pointwise", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "qconv2d_pointwise_binary" - self.cpp_op_schema = """ - at::Tensor( - at::Tensor act, - double act_scale, - int64_t act_zero_point, - at::Tensor accum, - double accum_scale, - int64_t accum_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - torch::List stride, - torch::List padding, - torch::List dilation, - int64_t groups, - double output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::string_view binary_attr, - c10::optional alpha, - c10::optional attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - accum, w_scale, w_zp = args[-3], args[-2], args[-1] - ( - stride, - padding, - dilation, - groups, - x_scale, - x_zp, - accum_scale, - accum_zp, - o_inv_scale, - o_zp, - output_dtype, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-16:] - conv_args = ( - x, - x_scale, - x_zp, - accum, - accum_scale, - accum_zp, - packed_weight, - w_scale, - w_zp, - bias, - stride, - padding, - dilation, - groups, - o_inv_scale, - o_zp, - output_dtype, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - conv_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - def get_mutation_names(self): - return [self.inputs[self.idx_for_inplace_sum].get_name()] - - def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: - return set() - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale, - x_zp, - accum: "TensorBox", - accum_scale, - accum_zp, - weight: "TensorBox", # packed_weight - w_scale, - w_zp, - bias: "TensorBox", - stride_: List[int], - padding_: List[int], - dilation_: List[int], - groups: int, - o_inv_scale: "TensorBox", - output_zero_point: "TensorBox", - output_dtype, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ): - transposed = False - output_padding = None - ( - inputs, - constant_args, - kernel_layout, - req_stride_order, - ) = _prepare_convolution_fusion_create( - cls, - x, - weight, - bias, - padding_, - stride_, - dilation_, - groups, - transposed, - output_padding, - ) - - accum = cls.require_stride_order(accum, req_stride_order) - inputs.append(accum) - - # swap padding and stride to align with functional conv arg order - if bias is None: - constant_args[1], constant_args[2] = constant_args[2], constant_args[1] - else: - constant_args[0], constant_args[1] = constant_args[1], constant_args[0] - - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - constant_args = constant_args + [ - x_scale, - x_zp, - accum_scale, - accum_zp, - o_inv_scale, - output_zero_point, - output_dtype, - binary_attr, - alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - assert ( - binary_attr == "sum" - ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." - - packed = QConvPointWiseBinaryPT2E( - layout=NoneLayout(accum.get_device()), - inputs=inputs, - constant_args=constant_args, - ) - mark_node_as_mutating(packed, accum) - - # Return accum since it has been inplace changed. - return packed.inputs[packed.idx_for_inplace_sum] - - -class QLinearPointwisePT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - has_bias=True, - x_scale_zp_are_tensors=False, - ): - """ - if bias is not None - - inputs = [x, w, b, weight_scale, weight_zp] - - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, weight_scale, weight_zp] - - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = has_bias - self.x_scale_zp_are_tensors = x_scale_zp_are_tensors - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name=( - "torch.ops.onednn.qlinear_pointwise.tensor" - if x_scale_zp_are_tensors - else "torch.ops.onednn.qlinear_pointwise.default" - ), - cpp_kernel_name="onednn::qlinear_pointwise", - ) - self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" - self.cpp_kernel_key = "qlinear_pointwise" - x_scale_type_str, x_zp_type_str = ( - ("at::Tensor", "at::Tensor") - if x_scale_zp_are_tensors - else ("double", "int64_t") - ) - self.cpp_op_schema = f""" - at::Tensor( - at::Tensor act, - {x_scale_type_str} act_scale, - {x_zp_type_str} act_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - double output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::string_view post_op_name, - torch::List> post_op_args, - c10::string_view post_op_algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - w_scale, w_zp = args[-2], args[-1] - if self.x_scale_zp_are_tensors: - assert len(args) >= 4 - x_scale, x_zp = args[-4], args[-3] - ( - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-6:] - else: - assert len(const_args) >= 8 - ( - x_scale, - x_zp, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-8:] - - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - bias, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale: float, - x_zp: int, - weight: "TensorBox", # packed_weight - w_scale: "TensorBox", - w_zp: "TensorBox", - bias: "TensorBox", - o_inv_scale: float, - output_zero_point: int, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ): - (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( - cls, - x, - weight, - bias, - ) - - if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): - x_scale.realize() - x_zp.realize() - inputs = inputs + [x_scale, x_zp] - x_scale_zp_are_tensors = True - else: - assert isinstance(x_scale, float) and isinstance(x_zp, int) - constant_args = constant_args + [x_scale, x_zp] - x_scale_zp_are_tensors = False - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - constant_args = constant_args + [ - o_inv_scale, - output_zero_point, - output_dtype, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - if output_dtype is not None: - assert output_dtype in [torch.float32, torch.bfloat16] - # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set fp32_output, the output buf should be dtype float32 instead of uint8. - kernel_layout.dtype = output_dtype - - return QLinearPointwisePT2E( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, - ) - - -class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - has_bias=True, - x_scale_zp_are_tensors=False, - ): - """ - if bias is not None - - inputs = [x, w, b, weight_scale, weight_zp, x2] - - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, weight_scale, weight_zp, x2] - - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = has_bias - self.x_scale_zp_are_tensors = x_scale_zp_are_tensors - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name=( - "torch.ops.onednn.qlinear_pointwise.binary_tensor" - if x_scale_zp_are_tensors - else "torch.ops.onednn.qlinear_pointwise.binary" - ), - cpp_kernel_name="onednn::qlinear_pointwise", - ) - self.cpp_kernel_overload_name = ( - "binary_tensor" if x_scale_zp_are_tensors else "binary" - ) - self.cpp_kernel_key = "qlinear_pointwise_binary" - x_scale_type_str, x_zp_type_str = ( - ("at::Tensor", "at::Tensor") - if x_scale_zp_are_tensors - else ("double", "int64_t") - ) - self.cpp_op_schema = f""" - at::Tensor( - at::Tensor act, - {x_scale_type_str} act_scale, - {x_zp_type_str} act_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - double inv_output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::optional other, - double other_scale, - int64_t other_zero_point, - c10::string_view binary_post_op, - double binary_alpha, - c10::string_view unary_post_op, - torch::List> unary_post_op_args, - c10::string_view unary_post_op_algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - w_scale, w_zp, other = args[-3], args[-2], args[-1] - if self.x_scale_zp_are_tensors: - assert len(args) >= 5 - x_scale, x_zp = args[-5], args[-4] - ( - o_inv_scale, - o_zp, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-10:] - else: - assert len(const_args) >= 8 - ( - x_scale, - x_zp, - o_inv_scale, - o_zp, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-12:] - - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - bias, - o_inv_scale, - o_zp, - output_dtype, - other, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale: float, - x_zp: int, - weight: "TensorBox", # packed_weight - w_scale: "TensorBox", - w_zp: "TensorBox", - bias: "TensorBox", - o_inv_scale: float, - output_zero_point: int, - output_dtype, - other: "TensorBox", - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ): - ( - inputs, - constant_args, - kernel_layout, - req_stride_order, - ) = _prepare_linear_fusion_create( - cls, - x, - weight, - bias, - ) - - if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): - x_scale.realize() - x_zp.realize() - inputs = inputs + [x_scale, x_zp] - x_scale_zp_are_tensors = True - else: - assert isinstance(x_scale, float) and isinstance(x_zp, int) - constant_args = constant_args + [x_scale, x_zp] - x_scale_zp_are_tensors = False - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - if binary_attr == "sum": - other = cls.require_stride_order(other, req_stride_order) - inputs.append(other) - constant_args = constant_args + [ - o_inv_scale, - output_zero_point, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - if binary_attr == "sum": - packed = QLinearPointwiseBinaryPT2E( - layout=NoneLayout(other.get_device()), - inputs=inputs, - constant_args=constant_args, - has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, - ) - mark_node_as_mutating(packed, other) - # Return other since it has been inplace changed. - return packed.inputs[-1] - - if output_dtype is not None: - assert output_dtype in [torch.float32, torch.bfloat16] - # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set fp32_output, the output buf should be dtype float32 instead of uint8. - kernel_layout.dtype = output_dtype - - return QLinearPointwiseBinaryPT2E( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, - ) - - @dataclasses.dataclass class MutableBox(IRNode): """ diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py new file mode 100644 index 0000000000000..36be03772e899 --- /dev/null +++ b/torch/_inductor/mkldnn_ir.py @@ -0,0 +1,1659 @@ +# mypy: allow-untyped-defs +from typing import Any, List, Optional, Set + +import sympy + +import torch + +from torch._prims_common import make_channels_last_strides_for + +from .ir import ( + ExternKernelAlloc, + FixedLayout, + FlexibleLayout, + ir_node_to_tensor, + IRNode, + is_contiguous_storage_and_layout, + Layout, + mark_node_as_mutating, + may_convert_to_optional, + MultiOutput, + MultiOutputLayout, + NoneLayout, + TensorBox, +) + +from .utils import convert_shape_to_inductor, pad_listlike + +from .virtualized import V + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: List[int], + stride: List[int], + dilation: List[int], + groups: int, + transposed: bool = False, + output_padding: Optional[List[int]] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU device since conv post-op fusion kernel is only + supported on CPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + for d in range(2, dim): + weight_size.append(prepacked_weight_size[d]) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, int) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + output_size = output.size() + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + + x = cls.require_stride_order(x, req_stride_order) + + # We won't do weight prepack for Conv if dynamic_shapes. + # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. + # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), + # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order + # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, + # this tensor is considered as channels first and the output will be in contiguous format. + # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. + dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) + if dynamic_shapes and is_contiguous_storage_and_layout(x): + output_stride = FlexibleLayout.contiguous_strides(output_size) + else: + output_stride = make_channels_last_strides_for(output_size) + + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: List[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_key = "convolution_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + return ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ): + # Due to constrain of op.call, other (Tensor&) should be at input[0] + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise_", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary_" + # TODO: op.call: input[0] should be at::Tensor& + self.cpp_op_schema = """ + at::Tensor&( + at::Tensor& other_t, + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, inputs[1]) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", + cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", + ) + self.cpp_kernel_key = "convolution_transpose_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + output_padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups_: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_key = "qconv2d_pointwise" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # qw + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: float, + output_zero_point: int, + output_dtype, + attr, + scalars, + algorithm, + ): + transposed = False + output_padding = None + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + x_scale, + x_zero_point, + output_scale, + output_zero_point, + output_dtype, + attr, + may_convert_to_optional(scalars), + algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, w, b, accum, w_scale, w_zp] + - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, accum, w_scale, w_zp] + - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, + accum_zp, o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 6 + self.idx_for_inplace_sum = 3 if self.has_bias else 2 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "qconv2d_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor accum, + double accum_scale, + int64_t accum_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + accum, w_scale, w_zp = args[-3], args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + accum_scale, + accum_zp, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-16:] + conv_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + conv_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + return [self.inputs[self.idx_for_inplace_sum].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale, + x_zero_point, + qaccum: "TensorBox", + accum_scale, + accum_zero_point, + qw: "TensorBox", # packed_weight + w_scale, + w_zero_point, + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + + qaccum = cls.require_stride_order(qaccum, req_stride_order) + inputs.append(qaccum) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + x_scale, + x_zero_point, + accum_scale, + accum_zero_point, + output_scale, + output_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + assert ( + binary_attr == "sum" + ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(qaccum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, qaccum) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkl._mkl_linear", + cpp_kernel_name="mkl::_mkl_linear", + ) + self.cpp_kernel_key = "mkl_linear" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& self, + const at::Tensor& mkl_weight_t, + const at::Tensor& origin_weight_t, + const c10::optional& bias_opt, + const int64_t prepack_batch_size)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, packed_w, orig_w, B, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = FlexibleLayout.contiguous_strides(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_key = "linear_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, w, B, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, None) + + return LinearUnary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "linear_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr) + """ + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + @classmethod + def create(cls, x, y, w, B, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, y, w] + constant_args = [attr] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, B) + + return LinearBinary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.default" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" + self.cpp_kernel_key = "qlinear_pointwise" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view post_op_name, + torch::List> post_op_args, + c10::string_view post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 4 + x_scale, x_zp = args[-4], args[-3] + ( + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-6:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-8:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + post_op_name, + may_convert_to_optional(post_op_args), + post_op_algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp, x2] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp, x2] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.binary_tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.binary" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = ( + "binary_tensor" if x_scale_zp_are_tensors else "binary" + ) + self.cpp_kernel_key = "qlinear_pointwise_binary" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::optional other, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp, other = args[-3], args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 5 + x_scale, x_zp = args[-5], args[-4] + ( + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_scale, + o_zp, + output_dtype, + other, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + other: "TensorBox", + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + if binary_post_op == "sum": + other = cls.require_stride_order(other, req_stride_order) + inputs.append(other) + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + may_convert_to_optional(unary_post_op_args), + unary_post_op_algorithm, + ] + + if binary_post_op == "sum": + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + mark_node_as_mutating(packed, other) + # Return other since it has been inplace changed. + return packed.inputs[-1] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="aten.mkldnn_rnn_layer", + cpp_kernel_name="at::mkldnn_rnn_layer", + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + w0 = cls.require_stride1(cls.realize_input(w0)) + w1 = cls.require_stride1(cls.realize_input(w1)) + w2 = cls.require_stride1(cls.realize_input(w2)) + w3 = cls.require_stride1(cls.realize_input(w3)) + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + res: List[IRNode] = [] + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + packed = MkldnnRnnLayer( + MultiOutputLayout(x.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return FlexibleLayout.contiguous_strides(output_shape) + + output_sizes = [output_shape, hy_shape, cy_shape] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + + return output_ir diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 809c8f10b711d..c006af0095e6c 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -4,7 +4,7 @@ import torch import torch.utils._pytree as pytree from torch._inductor.kernel.mm_common import mm_args -from . import ir +from . import ir, mkldnn_ir from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox from .lowering import ( @@ -173,13 +173,13 @@ def register_onednn_fusion_ops(): torch.ops.mkldnn._linear_pointwise, "mkldnn::_linear_pointwise", has_out_variant=False, - kernel_creator=ir.LinearUnary.create, + kernel_creator=mkldnn_ir.LinearUnary.create, ) aten_mkldnn_linear_binary = ExternKernelChoice( torch.ops.mkldnn._linear_pointwise.binary, "mkldnn::_linear_pointwise", has_out_variant=False, - kernel_creator=ir.LinearBinary.create, + kernel_creator=mkldnn_ir.LinearBinary.create, ) cpu_needs_realized_inputs = [ torch.ops.mkldnn._convolution_pointwise, @@ -204,7 +204,7 @@ def convolution_unary( algorithm, ): return TensorBox.create( - ir.ConvolutionUnary.create( + mkldnn_ir.ConvolutionUnary.create( x, weight, bias, @@ -235,7 +235,7 @@ def convolution_binary( unary_algorithm, ): return TensorBox.create( - ir.ConvolutionBinary.create( + mkldnn_ir.ConvolutionBinary.create( x, other, weight, @@ -269,7 +269,7 @@ def convolution_binary_inplace( unary_algorithm, ): return TensorBox.create( - ir.ConvolutionBinaryInplace.create( + mkldnn_ir.ConvolutionBinaryInplace.create( x, other, weight, @@ -429,7 +429,7 @@ def convolution_transpose_unary( algorithm, ): return TensorBox.create( - ir.ConvolutionTransposeUnary.create( + mkldnn_ir.ConvolutionTransposeUnary.create( x, weight, bias, @@ -465,7 +465,7 @@ def mkldnn_rnn_layer( ): return pytree.tree_map( TensorBox.create, - ir.MkldnnRnnLayer.create( + mkldnn_ir.MkldnnRnnLayer.create( x, w0, w1, @@ -506,7 +506,7 @@ def qconvolution_unary( algorithm, ): return TensorBox.create( - ir.QConvPointWisePT2E.create( + mkldnn_ir.QConvPointWisePT2E.create( x, x_scale, x_zp, @@ -566,7 +566,7 @@ def qconvolution_binary( # we will do accum dtype convertion here. accum = to_dtype(accum, output_dtype) return TensorBox.create( - ir.QConvPointWiseBinaryPT2E.create( + mkldnn_ir.QConvPointWiseBinaryPT2E.create( x, x_scale, x_zp, @@ -609,7 +609,7 @@ def qlinear_unary( algorithm, ): return TensorBox.create( - ir.QLinearPointwisePT2E.create( + mkldnn_ir.QLinearPointwisePT2E.create( x, x_scale, x_zp, @@ -668,7 +668,7 @@ def qlinear_binary( x2.get_dtype() == output_dtype ), "dtype of accum for qlinear post op sum should be the same as output" return TensorBox.create( - ir.QLinearPointwiseBinaryPT2E.create( + mkldnn_ir.QLinearPointwiseBinaryPT2E.create( x, x_scale, x_zp, @@ -695,7 +695,7 @@ def qlinear_binary( torch.ops.mkl._mkl_linear, "mkl::_mkl_linear", has_out_variant=False, - kernel_creator=ir.MKLPackedLinear.create, + kernel_creator=mkldnn_ir.MKLPackedLinear.create, ) cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) From 3dd5f0ecbbb71e3f8edc134baf5fe9fcf638ad07 Mon Sep 17 00:00:00 2001 From: Ahmed Gheith Date: Tue, 18 Jun 2024 12:30:13 +0000 Subject: [PATCH 33/63] Remove circular import (#128875) Summary: A spurious import is causing circular dependency errors Test Plan: phabricator signals Differential Revision: D58685676 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128875 Approved by: https://github.com/kit1980 --- torch/optim/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 341d07b1a2e82..f794a1eafe243 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -37,6 +37,3 @@ del optimizer # type: ignore[name-defined] # noqa: F821 del nadam # type: ignore[name-defined] # noqa: F821 del lbfgs # type: ignore[name-defined] # noqa: F821 - - -import torch.optim._multi_tensor From f2805a0408cfeb01c4d77a960dd4ca8e9a49db49 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 17 Jun 2024 13:18:20 -0700 Subject: [PATCH 34/63] [FSDP2] Added APIs for explicit fwd/bwd prefetching (#128884) This PR adds two APIs `set_modules_to_forward_prefetch` and `set_modules_to_backward_prefetch` to enable explicit forward/backward all-gather prefetching, respectively. ``` def set_modules_to_forward_prefetch(self, modules: List[FSDPModule]): -> None def set_modules_to_backward_prefetch(self, modules: List[FSDPModule]): -> None ``` **Motivation** FSDP2 implements _reasonable defaults_ for forward and backward prefetching. In forward, it uses implicit prefetching and allows two all-gather output tensors to be alive at once (so that the current all-gather copy-out can overlap with the next all-gather). In backward, it uses explicit prefetching based on the reverse post-forward order. However, there may be cases where with expert knowledge, we can reduce communication bubbles by moving all-gathers manually. One way to expose such behavior is to expose _prefetching limits_, i.e. integers that configure how many outstanding all-gathers/all-gather output tensors can be alive at once. IMIHO, this leans toward _easy_, not _simple_ (see [PyTorch design principles](https://pytorch.org/docs/stable/community/design.html#principle-2-simple-over-easy)). The crux of the problem is that there may be special cases where manual intervention can give better performance. Exposing a prefetching limit and allowing users to pass a value >1 just smooths over the problem since such a limit would generally apply over the entire model even though it possibly should not. Then, expert users will see a specific all-gather that they want to deviate from this limit, and there is little we can do. Thus, we instead choose to expose the most primitive extension point: namely, every `FSDPModule` gives an opportunity to prefetch other all-gathers in forward and in backward. How to leverage this extension point is fully up to the user. Implementing the prefetch limit can be done using this extension point (e.g. record the post-forward order yourself using forward hooks, iterate over that order, and call the `set_modules_to_forward_prefetch` / `set_modules_to_backward_prefetch` APIs). Differential Revision: [D58700346](https://our.internmc.facebook.com/intern/diff/D58700346) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128884 Approved by: https://github.com/ckluk2, https://github.com/weifengpy --- .../_composable/fsdp/test_fully_shard_comm.py | 205 +++++++++++++++++- .../fsdp/test_fully_shard_training.py | 40 +++- .../_composable/fsdp/_fsdp_param_group.py | 29 ++- .../_composable/fsdp/_fsdp_state.py | 11 +- .../_composable/fsdp/fully_shard.py | 48 +++- torch/testing/_internal/common_fsdp.py | 13 ++ 6 files changed, 334 insertions(+), 12 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 5acb9d895b413..c0e3fbc9aea88 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -43,6 +43,7 @@ FSDPTestMultiThread, MLP, patch_post_backward, + patch_reshard, patch_unshard, ) from torch.testing._internal.common_utils import run_tests @@ -372,7 +373,7 @@ def test_manual_reshard_with_reshard_after_forward_false(self): ) -class TestFullyShardBackwardPrefetch(FSDPTest): +class TestFullyShardPrefetch(FSDPTest): @property def world_size(self) -> int: return min(4, torch.cuda.device_count()) @@ -578,6 +579,193 @@ def _test_backward_prefetch_unused_in_backward( self.assertEqual(events, expected_events) events.clear() + @skip_if_lt_x_gpu(2) + def test_set_modules_to_forward_prefetch(self): + n_layers = 4 + reshard_after_forward = True + checkpoint_impl = "utils" + model, _, inp = self._init_transformer( + n_layers, reshard_after_forward, checkpoint_impl + ) + + def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: + # Use model-specific knowledge to configure forward prefetching: + # each transformer block (layer) prefetches for the next few + for i, layer in enumerate(model.layers): + if i >= len(model.layers) - num_to_prefetch: + break + layers_to_prefetch = [ + model.layers[i + j] for j in range(1, num_to_prefetch + 1) + ] + layer.set_modules_to_forward_prefetch(layers_to_prefetch) + + events: List[EventType] = [] + unshard_with_record = self._get_unshard_with_record( + FSDPParamGroup.unshard, events + ) + reshard_with_record = self._get_reshard_with_record( + FSDPParamGroup.reshard, events + ) + post_backward_with_record = self._get_post_backward_with_record( + FSDPParamGroup.post_backward, events + ) + expected_backward_events = [ + # Default backward prefetching + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + with patch_unshard(unshard_with_record), patch_reshard( + reshard_with_record + ), patch_post_backward(post_backward_with_record): + set_forward_prefetch(model, num_to_prefetch=1) + loss = model(inp) + expected_forward_events = [ + ("unshard", "", TrainingState.FORWARD), + # `layers.i` prefetches `layers.i+1` + ("unshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + self.assertEqual(events, expected_backward_events) + events.clear() + + set_forward_prefetch(model, num_to_prefetch=2) + loss = model(inp) + expected_forward_events = [ + ("unshard", "", TrainingState.FORWARD), + # `layers.i` prefetches `layers.i+1` and `layers.i+2` + ("unshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + self.assertEqual(events, expected_backward_events) + events.clear() + + @skip_if_lt_x_gpu(2) + def test_set_modules_to_backward_prefetch(self): + n_layers = 4 + reshard_after_forward = True + checkpoint_impl = "utils" + model, _, inp = self._init_transformer( + n_layers, reshard_after_forward, checkpoint_impl + ) + + def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None: + # Use model-specific knowledge to configure backward prefetching: + # each transformer block (layer) prefetches for the previous few + for i, layer in enumerate(model.layers): + if i < num_to_prefetch: + continue + layers_to_prefetch = [ + model.layers[i - j] for j in range(1, num_to_prefetch + 1) + ] + layer.set_modules_to_backward_prefetch(layers_to_prefetch) + + events: List[EventType] = [] + unshard_with_record = self._get_unshard_with_record( + FSDPParamGroup.unshard, events + ) + reshard_with_record = self._get_reshard_with_record( + FSDPParamGroup.reshard, events + ) + post_backward_with_record = self._get_post_backward_with_record( + FSDPParamGroup.post_backward, events + ) + expected_forward_events = [ + # Default forward prefetching + ("unshard", "", TrainingState.FORWARD), # root + ("unshard", "layers.0", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + with patch_unshard(unshard_with_record), patch_reshard( + reshard_with_record + ), patch_post_backward(post_backward_with_record): + set_backward_prefetch(model, num_to_prefetch=1) + loss = model(inp) + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + expected_backward_events = [ + # Root prefetches `layers.3` per default + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + # `layers.i` prefetches for `layers.i-1` (same as default) + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + self.assertEqual(events, expected_backward_events) + events.clear() + + set_backward_prefetch(model, num_to_prefetch=2) + loss = model(inp) + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + expected_backward_events = [ + # Root prefetches `layers.3` per default + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + # `layers.i` prefetches for `layers.i-1` and `layers.i-2` + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + self.assertEqual(events, expected_backward_events) + events.clear() + def _init_transformer( self, n_layers: int, @@ -614,6 +802,21 @@ def unshard_with_record(self, *args, **kwargs): return unshard_with_record + def _get_reshard_with_record( + self, orig_reshard: Callable, events: List[EventType] + ) -> Callable: + def reshard_with_record(self, *args, **kwargs): + nonlocal events + if ( + self._training_state == TrainingState.FORWARD + and not self._reshard_after_forward + ): # skip no-ops + return + events.append(("reshard", self._module_fqn, self._training_state)) + return orig_reshard(self, *args, **kwargs) + + return reshard_with_record + def _get_post_backward_with_record( self, orig_post_backward: Callable, events: List[EventType] ) -> Callable: diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 836013f7fb243..3dbaa65243794 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -3,6 +3,7 @@ import contextlib import copy import functools +import itertools import unittest from typing import Iterable, List, Tuple, Type, Union @@ -337,7 +338,6 @@ def _test_train_parity_multi_group( return assert device_type in ("cuda", "cpu"), f"{device_type}" torch.manual_seed(42) - lin_dim = 32 vocab_size = 1024 model_args = ModelArgs( n_layers=3, @@ -494,6 +494,44 @@ def forward(self, x): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_explicit_prefetching(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=8, dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + num_to_forward_prefetch = num_to_backward_prefetch = 2 + for i, layer in enumerate(model.layers): + if i >= len(model.layers) - num_to_forward_prefetch: + break + layers_to_prefetch = [ + model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1) + ] + layer.set_modules_to_forward_prefetch(layers_to_prefetch) + for i, layer in enumerate(model.layers): + if i < num_to_backward_prefetch: + continue + layers_to_prefetch = [ + model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1) + ] + layer.set_modules_to_backward_prefetch(layers_to_prefetch) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + for iter_idx in range(10): + losses: List[torch.Tensor] = [] + for _model, _optim in ((ref_model, ref_optim), (model, optim)): + _optim.zero_grad() + losses.append(_model(inp).sum()) + losses[-1].backward() + _optim.step() + self.assertEqual(losses[0], losses[1]) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 63142466f001f..06fa90e060e70 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -283,14 +283,15 @@ def _record_post_forward(self) -> None: self.comm_ctx.post_forward_order.append(self) self._post_forward_indices.append(post_forward_index) - def pre_backward(self, *unused: Any): + def pre_backward(self, default_prefetch: bool, *unused: Any): if self._training_state == TrainingState.PRE_BACKWARD: return with record_function(self._with_fqn("FSDP::pre_backward")): self._training_state = TrainingState.PRE_BACKWARD self.unshard() # no-op if prefetched self.wait_for_unshard() - self._prefetch_unshard() + if default_prefetch: + self._backward_prefetch() def post_backward(self, *unused: Any): self._training_state = TrainingState.POST_BACKWARD @@ -348,7 +349,7 @@ def finalize_backward(self): fsdp_param.grad_offload_event = None self._post_forward_indices.clear() - def _prefetch_unshard(self): + def _backward_prefetch(self) -> None: if self._training_state == TrainingState.PRE_BACKWARD: if not self._post_forward_indices: # Can be cleared if running multiple `backward`s @@ -360,11 +361,23 @@ def _prefetch_unshard(self): # have mistargeted prefetches if not all modules used in forward # are used in this backward target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] - target_fqn = target_fsdp_param_group._module_fqn - with record_function( - self._with_fqn(f"FSDP::backward_prefetch for {target_fqn}") - ), target_fsdp_param_group.use_training_state(TrainingState.PRE_BACKWARD): - target_fsdp_param_group.unshard() + self._prefetch_unshard(target_fsdp_param_group, "backward") + + @staticmethod + def _prefetch_unshard( + target_fsdp_param_group: "FSDPParamGroup", pass_type: str + ) -> None: + if pass_type == "backward": + training_state = TrainingState.PRE_BACKWARD + elif pass_type == "forward": + training_state = TrainingState.FORWARD + else: + raise ValueError(f"Unknown pass type: {pass_type}") + target_fqn = target_fsdp_param_group._module_fqn + with record_function( + f"FSDP::{pass_type}_prefetch for {target_fqn}" + ), target_fsdp_param_group.use_training_state(training_state): + target_fsdp_param_group.unshard() # Utilities # def _to_sharded(self): diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index f080e75503384..79a09342704ff 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -56,6 +56,8 @@ def __init__(self): self._state_ctx = FSDPStateContext() self._comm_ctx = FSDPCommContext() self._training_state: TrainingState = TrainingState.IDLE + self._states_to_forward_prefetch: List[FSDPState] = [] + self._states_to_backward_prefetch: List[FSDPState] = [] # Define a separate init since `__init__` is called in the contract def init( @@ -171,6 +173,9 @@ def _pre_forward( args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) if self._fsdp_param_group: args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) + for fsdp_state in self._states_to_forward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "forward") return args, kwargs @disable_if_config_true @@ -205,7 +210,11 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: self._training_state = TrainingState.PRE_BACKWARD self._register_root_post_backward_final_callback() if self._fsdp_param_group: - self._fsdp_param_group.pre_backward() + default_prefetch = len(self._states_to_backward_prefetch) == 0 + self._fsdp_param_group.pre_backward(default_prefetch) + for fsdp_state in self._states_to_backward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "backward") return grad def _root_post_backward_final_callback(self) -> None: diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index d3e70b38eac91..61b7878d467ff 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools -from typing import Any, cast, NoReturn, Optional, Union +from typing import Any, cast, Iterable, List, NoReturn, Optional, Union import torch import torch.nn as nn @@ -270,6 +270,46 @@ def set_reshard_after_backward( if fsdp_param_group := state._fsdp_param_group: fsdp_param_group.reshard_after_backward = reshard_after_backward + def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in forward. The prefetching runs after this + module's all-gather copy-out. + + Passing a singleton list containing the next FSDP module gives the same + all-gather overlap behavior as the default overlap behavior, except the + prefetched all-gather is issued earlier from the CPU. Passing a list + with at least length two is required for more aggressive overlap and + will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_forward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in backward. This overrides the default backward + pretching implementation that prefetches the next FSDP module based on + the reverse post-forward order. + + Passing a singleton list containing the previous FSDP module gives the + same all-gather overlap behavior as the default overlap behavior. + Passing a list with at least length two is required for more aggressive + overlap and will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_backward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") @@ -350,3 +390,9 @@ def wrapped_method(self, *args, **kwargs): method_name, wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] ) + + +def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: + for module in modules: + if not isinstance(module, FSDPModule): + raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}") diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 2b5fdc613c2e2..cfa16307da334 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -997,6 +997,19 @@ def patch_unshard(new_unshard: Callable): FSDPParamGroup.unshard = orig_unshard +@no_type_check +@contextlib.contextmanager +def patch_reshard(new_reshard: Callable): + orig_reshard = FSDPParamGroup.reshard + dist.barrier() + FSDPParamGroup.reshard = new_reshard + try: + yield + finally: + dist.barrier() + FSDPParamGroup.reshard = orig_reshard + + @no_type_check @contextlib.contextmanager def patch_post_backward(new_post_backward: Callable): From e6d4451ae8987bf8d6ad85eb7cde685fac746f6f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 14:31:38 +0800 Subject: [PATCH 35/63] [BE][Easy] enable UFMT for `torch/distributed/{algorithms,autograd,benchmarks,checkpoint,elastic}/` (#128866) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128866 Approved by: https://github.com/fegin --- .lintrunner.toml | 66 -------------- torch/distributed/algorithms/__init__.py | 4 +- .../_checkpoint/checkpoint_wrapper.py | 11 ++- .../algorithms/_comm_hooks/__init__.py | 2 +- .../algorithms/_comm_hooks/default_hooks.py | 45 ++++++---- .../_optimizer_overlap/optimizer_overlap.py | 23 ++--- .../algorithms/_quantization/quantization.py | 53 ++++++------ .../algorithms/ddp_comm_hooks/__init__.py | 13 +-- .../ddp_comm_hooks/ddp_zero_hook.py | 55 +++++++----- .../ddp_comm_hooks/debugging_hooks.py | 1 + .../ddp_comm_hooks/default_hooks.py | 1 + .../ddp_comm_hooks/mixed_precision_hooks.py | 25 +++--- .../ddp_comm_hooks/optimizer_overlap_hooks.py | 33 ++++--- .../ddp_comm_hooks/post_localSGD_hook.py | 6 +- .../ddp_comm_hooks/powerSGD_hook.py | 47 +++++----- torch/distributed/algorithms/join.py | 26 +++--- .../algorithms/model_averaging/averagers.py | 21 +++-- .../hierarchical_model_averager.py | 22 +++-- .../algorithms/model_averaging/utils.py | 28 ++++-- torch/distributed/autograd/__init__.py | 18 ++-- .../benchmarks/benchmark_ddp_rpc.py | 13 ++- .../checkpoint/_dedup_save_plans.py | 1 + .../distributed/checkpoint/_dedup_tensors.py | 1 + .../checkpoint/_fsspec_filesystem.py | 1 + torch/distributed/checkpoint/_nested_dict.py | 1 + .../checkpoint/_sharded_tensor_utils.py | 1 + .../distributed/checkpoint/_storage_utils.py | 1 - torch/distributed/checkpoint/_traverse.py | 1 + torch/distributed/checkpoint/api.py | 1 + .../distributed/checkpoint/default_planner.py | 1 + .../examples/fsdp_checkpoint_example.py | 2 +- torch/distributed/checkpoint/filesystem.py | 2 +- torch/distributed/checkpoint/logger.py | 1 + torch/distributed/checkpoint/metadata.py | 1 + torch/distributed/checkpoint/optimizer.py | 1 + torch/distributed/checkpoint/planner.py | 1 - .../distributed/checkpoint/planner_helpers.py | 3 +- torch/distributed/checkpoint/resharding.py | 1 + torch/distributed/checkpoint/staging.py | 2 +- torch/distributed/checkpoint/state_dict.py | 1 + .../checkpoint/state_dict_loader.py | 1 + .../checkpoint/state_dict_saver.py | 1 - torch/distributed/checkpoint/storage.py | 2 +- torch/distributed/checkpoint/utils.py | 1 + torch/distributed/elastic/agent/server/api.py | 76 ++++++++++++----- .../agent/server/health_check_server.py | 1 + .../agent/server/local_elastic_agent.py | 59 ++++++++----- torch/distributed/elastic/control_plane.py | 1 + torch/distributed/elastic/events/__init__.py | 4 +- torch/distributed/elastic/events/api.py | 5 +- torch/distributed/elastic/metrics/__init__.py | 6 +- torch/distributed/elastic/metrics/api.py | 20 ++++- .../elastic/multiprocessing/__init__.py | 4 +- .../elastic/multiprocessing/api.py | 79 ++++++++++++----- .../multiprocessing/errors/__init__.py | 15 +++- .../multiprocessing/errors/error_handler.py | 26 +++--- .../multiprocessing/errors/handlers.py | 4 +- .../elastic/multiprocessing/redirects.py | 1 + .../subprocess_handler/__init__.py | 1 + .../subprocess_handler/handlers.py | 1 + .../subprocess_handler/subprocess_handler.py | 2 +- .../elastic/multiprocessing/tail_log.py | 8 +- .../elastic/rendezvous/__init__.py | 4 +- torch/distributed/elastic/rendezvous/api.py | 17 +++- .../rendezvous/c10d_rendezvous_backend.py | 16 ++-- .../elastic/rendezvous/dynamic_rendezvous.py | 47 ++++++---- .../elastic/rendezvous/etcd_rendezvous.py | 33 ++++--- .../rendezvous/etcd_rendezvous_backend.py | 11 ++- .../elastic/rendezvous/etcd_server.py | 1 + .../elastic/rendezvous/etcd_store.py | 4 +- .../elastic/rendezvous/registry.py | 11 ++- .../rendezvous/static_tcp_rendezvous.py | 7 +- torch/distributed/elastic/rendezvous/utils.py | 20 +++-- torch/distributed/elastic/timer/__init__.py | 14 ++- torch/distributed/elastic/timer/api.py | 25 ++++-- .../elastic/timer/debug_info_logging.py | 1 + .../elastic/timer/file_based_local_timer.py | 85 +++++++++++++------ .../distributed/elastic/timer/local_timer.py | 6 +- torch/distributed/elastic/utils/api.py | 2 +- .../distributed/elastic/utils/distributed.py | 13 ++- torch/distributed/elastic/utils/store.py | 13 +-- 81 files changed, 729 insertions(+), 456 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 08e434e8f143b..2ea1579ee64c2 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1430,77 +1430,11 @@ exclude_patterns = [ 'torch/distributed/_sharding_spec/__init__.py', 'torch/distributed/_tools/__init__.py', 'torch/distributed/_tools/memory_tracker.py', - 'torch/distributed/algorithms/__init__.py', - 'torch/distributed/algorithms/_checkpoint/__init__.py', - 'torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py', - 'torch/distributed/algorithms/_comm_hooks/__init__.py', - 'torch/distributed/algorithms/_comm_hooks/default_hooks.py', - 'torch/distributed/algorithms/_optimizer_overlap/__init__.py', - 'torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py', - 'torch/distributed/algorithms/_quantization/__init__.py', - 'torch/distributed/algorithms/_quantization/quantization.py', - 'torch/distributed/algorithms/ddp_comm_hooks/__init__.py', - 'torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py', - 'torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py', - 'torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py', - 'torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py', - 'torch/distributed/algorithms/join.py', - 'torch/distributed/algorithms/model_averaging/__init__.py', - 'torch/distributed/algorithms/model_averaging/averagers.py', - 'torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py', - 'torch/distributed/algorithms/model_averaging/utils.py', 'torch/distributed/argparse_util.py', - 'torch/distributed/autograd/__init__.py', - 'torch/distributed/benchmarks/benchmark_ddp_rpc.py', 'torch/distributed/c10d_logger.py', 'torch/distributed/collective_utils.py', 'torch/distributed/constants.py', 'torch/distributed/distributed_c10d.py', - 'torch/distributed/elastic/__init__.py', - 'torch/distributed/elastic/agent/__init__.py', - 'torch/distributed/elastic/agent/server/__init__.py', - 'torch/distributed/elastic/agent/server/api.py', - 'torch/distributed/elastic/agent/server/local_elastic_agent.py', - 'torch/distributed/elastic/events/__init__.py', - 'torch/distributed/elastic/events/api.py', - 'torch/distributed/elastic/events/handlers.py', - 'torch/distributed/elastic/metrics/__init__.py', - 'torch/distributed/elastic/metrics/api.py', - 'torch/distributed/elastic/multiprocessing/__init__.py', - 'torch/distributed/elastic/multiprocessing/api.py', - 'torch/distributed/elastic/multiprocessing/errors/__init__.py', - 'torch/distributed/elastic/multiprocessing/errors/error_handler.py', - 'torch/distributed/elastic/multiprocessing/errors/handlers.py', - 'torch/distributed/elastic/multiprocessing/redirects.py', - 'torch/distributed/elastic/multiprocessing/tail_log.py', - 'torch/distributed/elastic/rendezvous/__init__.py', - 'torch/distributed/elastic/rendezvous/api.py', - 'torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py', - 'torch/distributed/elastic/rendezvous/dynamic_rendezvous.py', - 'torch/distributed/elastic/rendezvous/etcd_rendezvous.py', - 'torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py', - 'torch/distributed/elastic/rendezvous/etcd_server.py', - 'torch/distributed/elastic/rendezvous/etcd_store.py', - 'torch/distributed/elastic/rendezvous/registry.py', - 'torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py', - 'torch/distributed/elastic/rendezvous/utils.py', - 'torch/distributed/elastic/timer/__init__.py', - 'torch/distributed/elastic/timer/api.py', - 'torch/distributed/elastic/timer/file_based_local_timer.py', - 'torch/distributed/elastic/timer/local_timer.py', - 'torch/distributed/elastic/utils/__init__.py', - 'torch/distributed/elastic/utils/api.py', - 'torch/distributed/elastic/utils/data/__init__.py', - 'torch/distributed/elastic/utils/data/cycling_iterator.py', - 'torch/distributed/elastic/utils/data/elastic_distributed_sampler.py', - 'torch/distributed/elastic/utils/distributed.py', - 'torch/distributed/elastic/utils/log_level.py', - 'torch/distributed/elastic/utils/logging.py', - 'torch/distributed/elastic/utils/store.py', 'torch/distributed/examples/memory_tracker_example.py', 'torch/distributed/launch.py', 'torch/distributed/launcher/__init__.py', diff --git a/torch/distributed/algorithms/__init__.py b/torch/distributed/algorithms/__init__.py index a07470a0cfd40..06c8142956994 100644 --- a/torch/distributed/algorithms/__init__.py +++ b/torch/distributed/algorithms/__init__.py @@ -1,3 +1 @@ -from .join import Join -from .join import Joinable -from .join import JoinHook +from .join import Join, Joinable, JoinHook diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 86ab1de003db4..8cc15f4aba311 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -10,6 +10,7 @@ from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint + _CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module" _CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "." @@ -286,8 +287,12 @@ def apply_activation_checkpointing( """ # TODO: Importing inside function to avoid circular import issue between FSDP and # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. - from torch.distributed.fsdp.wrap import _recursive_wrap, lambda_auto_wrap_policy, _Policy from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply + from torch.distributed.fsdp.wrap import ( + _Policy, + _recursive_wrap, + lambda_auto_wrap_policy, + ) policy = ( auto_wrap_policy @@ -302,7 +307,9 @@ def apply_activation_checkpointing( target_module_to_kwargs = policy._run_policy( model, ignored_modules=set(), root_kwargs={} ) - wrap_fn = _construct_wrap_fn(model, target_module_to_kwargs, checkpoint_wrapper_fn) + wrap_fn = _construct_wrap_fn( + model, target_module_to_kwargs, checkpoint_wrapper_fn + ) _post_order_apply(model, wrap_fn) return diff --git a/torch/distributed/algorithms/_comm_hooks/__init__.py b/torch/distributed/algorithms/_comm_hooks/__init__.py index d07adc17247b7..7b57a075ad729 100644 --- a/torch/distributed/algorithms/_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/_comm_hooks/__init__.py @@ -1,6 +1,6 @@ - from . import default_hooks as default + LOW_PRECISION_HOOKS = [ default.fp16_compress_hook, default.bf16_compress_hook, diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index d370fabafc371..0acafd6868d3b 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs import functools +from typing import Optional + import torch import torch.distributed as dist -from typing import Optional class DefaultState: @@ -17,13 +18,10 @@ class DefaultState: "process_group", "world_size", "gradient_predivide_factor", - "gradient_postdivide_factor" + "gradient_postdivide_factor", ] - def __init__( - self, - process_group: dist.ProcessGroup - ): + def __init__(self, process_group: dist.ProcessGroup): if process_group is None: raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") self.process_group = process_group @@ -33,7 +31,9 @@ def __init__( self.gradient_predivide_factor = self._get_gradient_predivide_factor( self.world_size ) - self.gradient_postdivide_factor = self.world_size / self.gradient_predivide_factor + self.gradient_postdivide_factor = ( + self.world_size / self.gradient_predivide_factor + ) @staticmethod def _get_gradient_predivide_factor(world_size: int) -> float: @@ -42,6 +42,7 @@ def _get_gradient_predivide_factor(world_size: int) -> float: factor *= 2 return float(factor) + class LowPrecisionState(DefaultState): r""" Stores state needed to perform gradient communication in a lower precision within a communication hook. @@ -82,12 +83,15 @@ def _decompress(state: LowPrecisionState, grad: torch.Tensor): device_type = grad.device.type backend = getattr(torch, device_type) except AttributeError as e: - raise AttributeError(f"Device {grad.device} does not have a \ - corresponding backend registered as 'torch.device_type'.") from e + raise AttributeError( + f"Device {grad.device} does not have a \ + corresponding backend registered as 'torch.device_type'." + ) from e # Don't let this memory get reused until after the transfer. orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type] + def allreduce_hook(state: DefaultState, grad: torch.Tensor): r""" Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients. @@ -106,6 +110,7 @@ def allreduce_hook(state: DefaultState, grad: torch.Tensor): if state.gradient_postdivide_factor > 1: grad.div_(state.gradient_postdivide_factor) + def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): r""" Implement the FSDP communication hook for ``reduce_scatter`` algorithm. @@ -121,14 +126,18 @@ def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.T # Average grad by pre-division factor. if state.gradient_predivide_factor > 1: grad.div_(state.gradient_predivide_factor) - dist.reduce_scatter_tensor( - output, grad, group=state.process_group - ) + dist.reduce_scatter_tensor(output, grad, group=state.process_group) # Average grad's shard by post-division factor. if state.gradient_postdivide_factor > 1: output.div_(state.gradient_postdivide_factor) -def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor): + +def _low_precision_hook( + prec: torch.dtype, + state: LowPrecisionState, + grad: torch.Tensor, + output: torch.Tensor, +): if grad.dtype != prec: grad.data = grad.data.to(prec) if output is not None: @@ -140,7 +149,10 @@ def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch allreduce_hook(state, grad) _decompress(state, grad) -def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None): + +def fp16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None +): r""" Implement FSDP communication hook for a simple gradient compression approach. Casts ``grad`` to half-precision floating-point format (``torch.float16``). @@ -158,7 +170,10 @@ def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Opt fp16_hook = functools.partial(_low_precision_hook, torch.float16) return fp16_hook(state, grad, output) -def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None): + +def bf16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None +): r""" Implement FSDP communication hook for a simple gradient compression approach . Casts ``grad`` to half-precision floating-point format. diff --git a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py index 1afbb8d7967fc..ada39ca24d970 100644 --- a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py +++ b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -1,19 +1,18 @@ # mypy: allow-untyped-defs -from abc import ABC, abstractmethod import inspect +from abc import ABC, abstractmethod from typing import Dict, Type -from torch.distributed.fsdp import FullyShardedDataParallel -from torch.nn.parallel import DistributedDataParallel -from torch.optim import Optimizer -from torch.distributed.optim import as_functional_optim - from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook - from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( + _hook_then_optimizer, _OptimizerHookState, - _hook_then_optimizer ) +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.optim import as_functional_optim +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer + # Contains the mappings between the regular and overlapped optimizer types. _registered_overlapped_optims: Dict[Type, Type] = {} @@ -29,6 +28,7 @@ def decorator(target_overlapped_optim_cls): ) _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls return target_overlapped_optim_cls + return decorator @@ -71,7 +71,7 @@ def register_ddp(self, ddp_inst: DistributedDataParallel): # yet supported. ddp_inst.register_comm_hook( # type: ignore[operator] None, # wrapped hook state - _hook_then_optimizer(allreduce_hook, self._opt_hook_state) + _hook_then_optimizer(allreduce_hook, self._opt_hook_state), ) # TODO: register_fsdp once FSDP supports communication hook. @@ -81,11 +81,14 @@ def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: f"{self.__class__.__name__} does not support overlapped FSDP." ) + def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs): """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.""" for clz in inspect.getmro(optim_cls): try: - return _registered_overlapped_optims[clz](optim_cls, params, *args, **kwargs) + return _registered_overlapped_optims[clz]( + optim_cls, params, *args, **kwargs + ) except KeyError: pass diff --git a/torch/distributed/algorithms/_quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py index c421076bde3ec..a579a0a02feae 100644 --- a/torch/distributed/algorithms/_quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -1,22 +1,23 @@ # mypy: allow-untyped-defs import functools +from enum import Enum + import torch import torch.distributed as dist -from enum import Enum - - TORCH_HALF_MIN = torch.finfo(torch.float16).min TORCH_HALF_MAX = torch.finfo(torch.float16).max + class DQuantType(Enum): """ Different quantization methods for auto_quantize API are identified here. auto_quantize API currently supports fp16 and bfp16 methods. """ - FP16 = "fp16", + + FP16 = ("fp16",) BFP16 = "bfp16" def __str__(self) -> str: @@ -26,6 +27,7 @@ def __str__(self) -> str: def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() + def _quantize_tensor(tensor, qtype): if not isinstance(tensor, torch.Tensor): raise RuntimeError( @@ -36,9 +38,8 @@ def _quantize_tensor(tensor, qtype): elif qtype == DQuantType.BFP16: return torch.ops.quantization._FloatToBfloat16Quantized(tensor) else: - raise RuntimeError( - f'Quantization type {qtype} is not supported' - ) + raise RuntimeError(f"Quantization type {qtype} is not supported") + def _quantize_tensor_list(tensor_list, qtype): if not isinstance(tensor_list, list) or not all( @@ -50,6 +51,7 @@ def _quantize_tensor_list(tensor_list, qtype): quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] return quantized_tensor_list + def _dequantize_tensor(tensor, qtype, quant_loss=None): if not isinstance(tensor, torch.Tensor): raise RuntimeError( @@ -72,9 +74,7 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None): else: return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor) else: - raise RuntimeError( - f'Quantization type {qtype} is not supported' - ) + raise RuntimeError(f"Quantization type {qtype} is not supported") def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): @@ -103,20 +103,21 @@ def auto_quantize(func, qtype, quant_loss=None): Returns: (Callable): the same collective as func but enables automatic quantization/dequantization. """ + @functools.wraps(func) def wrapper(*args, **kwargs): - group = kwargs.get('group', None) - async_op = kwargs.get('async_op', False) + group = kwargs.get("group", None) + async_op = kwargs.get("async_op", False) if async_op is True: - raise RuntimeError( - 'The async_op=True mode is not supported yet.' - ) + raise RuntimeError("The async_op=True mode is not supported yet.") if func == dist.all_gather: tensors = args[0] input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) - for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): tensors[i] = t elif func == dist.all_to_all: @@ -124,22 +125,26 @@ def wrapper(*args, **kwargs): input_tensors = _quantize_tensor_list(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) - for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): tensors[i] = t elif func == dist.all_to_all_single: tensors = args[0] - out_splits = kwargs.get('out_splits', None) - in_splits = kwargs.get('in_splits', None) + out_splits = kwargs.get("out_splits", None) + in_splits = kwargs.get("in_splits", None) # Quantizing the input/output tensor input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor(tensors, qtype) - dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group) - for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)): + dist.all_to_all_single( + out_tensors, input_tensors, out_splits, in_splits, group=group + ) + for i, t in enumerate( + _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss) + ): tensors[i] = t else: - raise RuntimeError( - f"The collective op {func} is not supported yet" - ) + raise RuntimeError(f"The collective op {func} is not supported yet") return wrapper diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 2366a9d28c138..a1d1ffd2fc877 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -7,12 +7,14 @@ from . import ( debugging_hooks as debugging, default_hooks as default, + optimizer_overlap_hooks as optimizer_overlap, powerSGD_hook as powerSGD, quantization_hooks as quantization, - optimizer_overlap_hooks as optimizer_overlap, ) -__all__ = ['DDPCommHookType', 'register_ddp_comm_hook'] + +__all__ = ["DDPCommHookType", "register_ddp_comm_hook"] + def _ddp_comm_hook_wrapper(comm_hook, model, state): model.register_comm_hook(state, comm_hook) @@ -86,13 +88,12 @@ class DDPCommHookType(Enum): matrix_approximation_rank=2, ) NOOP = partial( - _ddp_comm_hook_wrapper, comm_hook=debugging.noop_hook, + _ddp_comm_hook_wrapper, + comm_hook=debugging.noop_hook, ) -def register_ddp_comm_hook( - comm_hook_type: DDPCommHookType, model, state=None -): +def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None): """ Register ``ddp_comm_hooks`` to DDP model. diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 8ab58cb584421..6db6d1831b1fd 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -5,11 +5,10 @@ import torch import torch.distributed as dist from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.distributed.optim.zero_redundancy_optimizer import ( - _OverlapStatus, -) +from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus from torch.nn.parallel.distributed import DistributedDataParallel + __all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"] # Functional optimizers require passing a list of gradients to their `step()` @@ -39,22 +38,25 @@ def _perform_local_step( """ overlap_info = zero._overlap_info bucket_index = bucket.index() - assert len(zero.optim.param_groups) == 1, \ - "Overlapping DDP with ZeRO only supports a single parameter group" + assert ( + len(zero.optim.param_groups) == 1 + ), "Overlapping DDP with ZeRO only supports a single parameter group" # Construct the `gradients` input for the local optimizer step, which # expects `None` in a list position to indicate that the corresponding # parameter should not be updated num_local_optim_params = len(zero.optim.param_groups[0]["params"]) - gradients: List[Optional[torch.Tensor]] = \ - [_NO_PARAM_UPDATE for _ in range(num_local_optim_params)] - assert bucket_index in overlap_info.offsets, \ - f"Bucket index {bucket_index} was not assigned to rank {rank}" + gradients: List[Optional[torch.Tensor]] = [ + _NO_PARAM_UPDATE for _ in range(num_local_optim_params) + ] + assert ( + bucket_index in overlap_info.offsets + ), f"Bucket index {bucket_index} was not assigned to rank {rank}" gradients_offset = overlap_info.offsets[bucket_index] bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] bucket_offset = bucket_assignment.offset length = len(bucket_assignment.parameters) - bucket_gradients = bucket.gradients()[bucket_offset:bucket_offset + length] + bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length] for i, grad in enumerate(bucket_gradients): gradients[gradients_offset + i] = grad @@ -75,12 +77,14 @@ def _broadcast_bucket( :class:`ZeroRedundancyOptimizer` instance. """ overlap_info = zero._overlap_info - assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \ - "`assigned_ranks_per_bucket` is not fully constructed" + assert ( + len(overlap_info.assigned_ranks_per_bucket) > bucket_index + ), "`assigned_ranks_per_bucket` is not fully constructed" # Sort to ensure the same ordering across ranks assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) - assert len(assigned_ranks) > 0, f"Bucket {bucket_index} should be " \ - "assigned to at least one rank" + assert len(assigned_ranks) > 0, ( + f"Bucket {bucket_index} should be " "assigned to at least one rank" + ) for assigned_rank in assigned_ranks: bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] if bucket_index in bucket_assignments: @@ -229,7 +233,7 @@ def hook_with_zero_step( # NOTE: Gloo may hang with this overlapping approach, so we require # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] - if ((pg != dist.Backend.NCCL) and (pg != 'hccl')): + if (pg != dist.Backend.NCCL) and (pg != "hccl"): raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " "NCCL/HCCL backend to avoid hangs" @@ -267,9 +271,12 @@ def hook_with_zero_fn( rank = zero.global_rank assert overlap_info.status == _OverlapStatus.INITIALIZED - assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \ - "`assigned_ranks_per_bucket` is not fully constructed" - assigned_to_bucket = rank in overlap_info.assigned_ranks_per_bucket[bucket_index] + assert ( + len(overlap_info.assigned_ranks_per_bucket) > bucket_index + ), "`assigned_ranks_per_bucket` is not fully constructed" + assigned_to_bucket = ( + rank in overlap_info.assigned_ranks_per_bucket[bucket_index] + ) # Save the bucket reference and all-reduce future for the final bucket if assigned_to_bucket: @@ -279,8 +286,9 @@ def hook_with_zero_fn( # Check that buckets are indexed incrementally starting from 0 in the # order of their autograd hooks firing if len(overlap_info.bucket_indices_seen) > 0: - assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, \ - "Bucket indices are not in incremental order" + assert ( + overlap_info.bucket_indices_seen[-1] == bucket_index - 1 + ), "Bucket indices are not in incremental order" else: assert bucket_index == 0, "Bucket indices do not start from 0" overlap_info.bucket_indices_seen.append(bucket_index) @@ -302,9 +310,10 @@ def hook_with_zero_fn( if rank in assigned_ranks: # Wait on the bucket's all-reduce future to ensure correct # gradients - assert bucket_index in overlap_info.bucket_index_to_future, \ - f"All-reduce future for bucket {bucket_index} not saved " \ + assert bucket_index in overlap_info.bucket_index_to_future, ( + f"All-reduce future for bucket {bucket_index} not saved " f"on rank {rank}" + ) allreduce_future = overlap_info.bucket_index_to_future[bucket_index] allreduce_future.wait() @@ -386,7 +395,7 @@ def hook_with_zero_step_interleaved( # NOTE: Gloo may hang with this overlapping approach, so we require # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] - if ((pg != dist.Backend.NCCL) and (pg != 'hccl')): + if (pg != dist.Backend.NCCL) and (pg != "hccl"): raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " "NCCL/HCCL backend to avoid hangs" diff --git a/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py index a552f9a359f7e..53a184839a06f 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py @@ -3,6 +3,7 @@ import torch from torch.distributed import GradBucket + __all__ = ["noop_hook"] diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 621e46fc19896..b1296ae712f0c 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist + __all__ = [ "allreduce_hook", "fp16_compress_hook", diff --git a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py index 31b243d44e0fd..4727bbf9d45e6 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py @@ -1,11 +1,12 @@ +from dataclasses import dataclass +from typing import Any, no_type_check + import torch import torch.distributed as dist from torch.autograd import Variable - -from dataclasses import dataclass -from typing import Any, no_type_check from torch.distributed.utils import _free_storage + @dataclass class _AllreduceUpcastHookState: """ @@ -19,6 +20,7 @@ class _AllreduceUpcastHookState: upcast_stream: torch.cuda.Stream wait_for_stream_enqueued: bool = False + @no_type_check def _reducer_allreduce_and_upcast_hook( hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket @@ -35,10 +37,13 @@ def _reducer_allreduce_and_upcast_hook( gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view # Cast bucket if different than param_dtype. if ( - ddp_weakref().mixed_precision.param_dtype != ddp_weakref().mixed_precision.reduce_dtype + ddp_weakref().mixed_precision.param_dtype + != ddp_weakref().mixed_precision.reduce_dtype ): # Cast bucket tensor to reduce_dtype - bucket.set_buffer(bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype)) + bucket.set_buffer( + bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype) + ) fut = reducer._run_allreduce_hook(bucket) ret_fut = torch.futures.Future() stream = hook_state.upcast_stream @@ -66,19 +71,17 @@ def wait_for_stream_cb(): # by hook above as they don't have a grad hook installed, so cast them # back here. for n, p in ddp_weakref().module.named_parameters(): - if hasattr(p, '_ddp_mp_hook_state'): + if hasattr(p, "_ddp_mp_hook_state"): p._ddp_mp_hook_state[1].remove() - delattr(p, '_ddp_mp_hook_state') - if not p.requires_grad and not hasattr(p, '_ddp_ignored'): + delattr(p, "_ddp_mp_hook_state") + if not p.requires_grad and not hasattr(p, "_ddp_ignored"): p.data = p._fp_param # reset for next backward pass hook_state.wait_for_stream_enqueued = False if not hook_state.wait_for_stream_enqueued: - Variable._execution_engine.queue_callback( - wait_for_stream_cb - ) + Variable._execution_engine.queue_callback(wait_for_stream_cb) # mark that the callback is enqueued hook_state.wait_for_stream_enqueued = True diff --git a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py index 76d4cd6de2bdc..5ae242b04a9c5 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -1,16 +1,18 @@ # mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import partial from typing import Any, Callable, List, no_type_check import torch import torch.distributed as dist from torch.autograd import Variable -from functools import partial -from dataclasses import dataclass + __all__: List[str] = [] _FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param" + class _OptimizerHookState: """ Holds state for running optimizer in-line after DDP communication hook. @@ -42,9 +44,10 @@ class _OptimInBackwardHookState: optim_stream: torch.cuda.Stream wait_for_optim_stream_enqueued: bool + @no_type_check def _apply_optim_in_backward_hook( - gradient_is_bucket_view: bool + gradient_is_bucket_view: bool, ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: r""" Register hook to apply the optimizer in backward. @@ -59,7 +62,9 @@ def _apply_optim_in_backward_hook( ) def apply_optim_in_backward_hook( - hook_state: Any, bucket: dist.GradBucket, optim_stream_state, + hook_state: Any, + bucket: dist.GradBucket, + optim_stream_state, ) -> torch.futures.Future[torch.Tensor]: # Run original hook ddp_weakref = hook_state @@ -78,7 +83,7 @@ def apply_optim_in_backward_hook( # TODO (rohan-varma): upcast as needed for DDP mixed precision, # once optimizer in backward + DDP mixed precision is supported. for p, g in zip(model_params, grads): - if hasattr(p, '_in_backward_optimizers'): + if hasattr(p, "_in_backward_optimizers"): # Note: need to set grad to the bucket's grad, because # running allreduce results in the bucket's grad being # reduced, but not grad field. @@ -94,21 +99,17 @@ def apply_optim_in_backward_hook( # enqueue a callback to wait for this optimizer stream at the end of # backward and set all DDP managed grads to None. def wait_for_optim_stream_callback(): - torch.cuda.current_stream().wait_stream( - optim_stream_state.optim_stream - ) + torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream) # Set DDP managed grads to None for param in ddp_inst._get_data_parallel_params(ddp_inst.module): - if hasattr(param, '_in_backward_optimizers'): + if hasattr(param, "_in_backward_optimizers"): param.grad = None # reset for the next backwards pass optim_stream_state.wait_for_optim_stream_enqueued = False if not optim_stream_state.wait_for_optim_stream_enqueued: - Variable._execution_engine.queue_callback( - wait_for_optim_stream_callback - ) + Variable._execution_engine.queue_callback(wait_for_optim_stream_callback) # mark that the callback is enqueued optim_stream_state.wait_for_optim_stream_enqueued = True @@ -123,13 +124,14 @@ def wait_for_optim_stream_callback(): return comm_hook + def _hook_then_optimizer( hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], optimizer_state: _OptimizerHookState, ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: r"""Run optimizer in a functional fashion after DDP communication hook.""" has_set_params = ( - hasattr(optimizer_state, 'params_to_optimize') + hasattr(optimizer_state, "params_to_optimize") and optimizer_state.params_to_optimize is not None ) @@ -143,7 +145,10 @@ def optimizer_step(fut): gradient_tensors = bucket.gradients() model_params = bucket.parameters() for grad_tensor, model_param in zip(gradient_tensors, model_params): - if not has_set_params or model_param in optimizer_state.params_to_optimize: + if ( + not has_set_params + or model_param in optimizer_state.params_to_optimize + ): optimizer_state.functional_optimizer.step_param( model_param, grad_tensor, diff --git a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py index 3528f3987479f..d8da01e6e1fe2 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -6,6 +6,7 @@ from . import default_hooks as default + logger = logging.getLogger(__name__) @@ -62,9 +63,8 @@ def maybe_increase_iter(self, bucket): self.iter += 1 if self.iter == self.start_localSGD_iter: - logger.info( - "Start to apply local SGD after %s iterations.", self.iter - ) + logger.info("Start to apply local SGD after %s iterations.", self.iter) + def post_localSGD_hook( state: PostLocalSGDState, bucket: dist.GradBucket diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index fbc3b9e8739e4..96b3b888511ae 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -1,18 +1,17 @@ # mypy: allow-untyped-defs -from collections import defaultdict import logging import math +from collections import defaultdict from typing import Dict import torch import torch.distributed as dist +from torch.distributed import distributed_c10d from . import default_hooks as default -from torch.distributed import distributed_c10d -__all__ = [ - "PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook" -] + +__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"] logger = logging.getLogger(__name__) @@ -35,10 +34,13 @@ def _orthogonalize(matrices, epsilon=0): matrices, out=( matrices, - torch.empty(num_matrices, rank, rank, device=matrices.device, dtype=dtype) - ) + torch.empty( + num_matrices, rank, rank, device=matrices.device, dtype=dtype + ), + ), ) + def _orthogonalize_gram_schmidt(matrices, epsilon=0): """ Apply Gram-Schmidt procedure to orthogonalize a batch of matrices. @@ -103,14 +105,15 @@ def _should_compress( def _report_compression_stats(bucket, state): """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state.""" - if ( - bucket.is_last() - and state.iter >= state.next_stats_report - ): + if bucket.is_last() and state.iter >= state.next_stats_report: stats = state.compression_stats() logger.info( "Compression stats: iter %s, total before compression %s, total after compression %s, " - "rate %s", state.iter, stats[1], stats[2], stats[0] + "rate %s", + state.iter, + stats[1], + stats[2], + stats[0], ) state.next_stats_report = state.iter + state.compression_stats_logging_frequency @@ -244,6 +247,7 @@ def __init__( # If the same random projection is used, # there will be differences between the gradients that are never synchronized. import numpy as np + self.rng = np.random.RandomState(random_seed) # Since there is only a single state instance for all the input buckets, # need to maintain a dictionary that maps each bucket index to the local error. @@ -280,7 +284,8 @@ def __getstate__(self): ) return { slot: getattr(self, slot) - for slot in self.__slots__ if slot != "process_group" + for slot in self.__slots__ + if slot != "process_group" } def __setstate__(self, state): @@ -305,9 +310,7 @@ def maybe_increase_iter(self, bucket): self.iter += 1 if self.iter == self.start_powerSGD_iter: - logger.info( - "Start to apply PowerSGD after %s iterations.", self.iter - ) + logger.info("Start to apply PowerSGD after %s iterations.", self.iter) def compression_stats(self): r""" @@ -420,7 +423,7 @@ def powerSGD_hook( else: logger.info( "A zero tensor of length %s that represents local error is created.", - total_length + total_length, ) state.error_dict[bucket_index] = torch.zeros( total_length, device=device, dtype=dtype @@ -478,7 +481,8 @@ def powerSGD_hook( if state.warm_start: logger.info( "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.", - total_Ps_size, total_Qs_size + total_Ps_size, + total_Qs_size, ) state.p_memory_dict[bucket_index] = torch.empty( total_Ps_size, device=device, dtype=dtype @@ -724,7 +728,7 @@ def batched_powerSGD_hook( state.total_numel_after_compression += ( square_side_length * state.matrix_approximation_rank * 2 ) - padded_total_length = square_side_length ** 2 + padded_total_length = square_side_length**2 input_tensor.resize_(padded_total_length) input_tensor[total_length:padded_total_length].fill_(0) @@ -739,7 +743,7 @@ def batched_powerSGD_hook( else: logger.info( "A zero tensor of length %s that represents local error is created.", - padded_total_length + padded_total_length, ) state.error_dict[bucket_index] = torch.zeros( padded_total_length, device=device, dtype=input_tensor.dtype @@ -759,7 +763,8 @@ def batched_powerSGD_hook( if state.warm_start: logger.info( "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.", - square_side_length, state.matrix_approximation_rank + square_side_length, + state.matrix_approximation_rank, ) def create_low_rank_tensor(fill_random_values, rng): diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 2936747a1c6ec..140844851938b 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -7,7 +7,9 @@ import torch import torch.distributed as dist -__all__ = ['JoinHook', 'Joinable', 'Join'] + +__all__ = ["JoinHook", "Joinable", "Join"] + class JoinHook: r""" @@ -97,13 +99,10 @@ def construct_disabled_join_config(): e.g. if the caller is not in a join context manager. """ return _JoinConfig( - enable=False, - throw_on_early_termination=False, - is_first_joinable=False + enable=False, throw_on_early_termination=False, is_first_joinable=False ) - class Join: r""" This class defines the generic join context manager, which allows custom hooks to be called after a process joins. @@ -176,7 +175,9 @@ def __init__( if len(joinables) == 0: raise ValueError("The join context manager requires at least one joinable") self._joinables = joinables - self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables] + self._join_hooks = [ + joinable.join_hook(**kwargs) for joinable in self._joinables + ] self._enable = enable self._throw_on_early_termination = throw_on_early_termination self._set_joinable_configs() @@ -190,7 +191,7 @@ def _set_joinable_configs(self) -> None: joinable._join_config = _JoinConfig( enable=self._enable, throw_on_early_termination=self._throw_on_early_termination, - is_first_joinable=is_first_joinable + is_first_joinable=is_first_joinable, ) is_first_joinable = False @@ -215,7 +216,9 @@ def _extract_dist_info(self) -> None: if process_group is None: process_group = joinable.join_process_group elif process_group != joinable.join_process_group: - raise ValueError("Using join context manager with multiple process groups") + raise ValueError( + "Using join context manager with multiple process groups" + ) if device is None: device = joinable.join_device self._process_group = process_group @@ -229,7 +232,7 @@ def __exit__( self, type: Optional[Type[BaseException]], value: Optional[BaseException], - traceback: Optional[TracebackType] + traceback: Optional[TracebackType], ): r""" Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. @@ -318,9 +321,10 @@ def notify_join_context(joinable: Joinable): manager that the process has not yet joined if ``joinable`` is the first one passed into the context manager; ``None`` otherwise. """ - assert hasattr(joinable, "_join_config"), \ - f"Check that the {type(joinable)} constructor calls the " \ + assert hasattr(joinable, "_join_config"), ( + f"Check that the {type(joinable)} constructor calls the " "``Joinable`` constructor" + ) join_config = joinable._join_config # First joinable is responsible for the collective communications diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index 178efd1dbad92..e15154e3f8578 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -1,12 +1,15 @@ # mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod -from typing import Union, Iterable, Dict +from typing import Dict, Iterable, Union + import torch import torch.distributed as dist import torch.distributed.algorithms.model_averaging.utils as utils -__all__ = ['ModelAverager', 'PeriodicModelAverager'] + +__all__ = ["ModelAverager", "PeriodicModelAverager"] + class ModelAverager(ABC): r"""Base class for all model averagers. @@ -82,12 +85,7 @@ class PeriodicModelAverager(ModelAverager): >>> averager.average_parameters(model.parameters()) """ - def __init__( - self, - period, - warmup_steps=0, - process_group=None - ): + def __init__(self, period, warmup_steps=0, process_group=None): super().__init__(process_group) if warmup_steps < 0: raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") @@ -103,7 +101,12 @@ def __init__( ) self.period = period - def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]): + def average_parameters( + self, + params: Union[ + Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + ], + ): """ Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``. diff --git a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py index 02802466ab62a..a27f3b762a9e3 100644 --- a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py +++ b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -3,13 +3,14 @@ import logging import warnings from collections import OrderedDict -from typing import Union, Iterable, Dict +from typing import Dict, Iterable, Union import torch import torch.distributed as dist import torch.distributed.algorithms.model_averaging.averagers as averagers import torch.distributed.algorithms.model_averaging.utils as utils + logger = logging.getLogger(__name__) @@ -103,7 +104,9 @@ def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=No raise ValueError("Arg ``period_group_size_dict`` must not be empty.") self._periods = list(period_group_size_dict.keys()) if self._periods[0] <= 0: - raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.") + raise ValueError( + "The minimum period in arg ``period_group_size_dict`` must be a positive value." + ) elif self._periods[-1] == 1: warnings.warn( "When the maximum period in arg ``period_group_size_dict`` is 1, " @@ -124,10 +127,14 @@ def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=No for period, group_size in period_group_size_dict.items(): logger.info( "\tEach group that has %s processes average parameters every %s iterations, " - "if no higher-level averaging.", group_size, period) + "if no higher-level averaging.", + group_size, + period, + ) if group_size != overall_group_size: self.period_process_group_dict[period], _ = dist.new_subgroups( - group_size=group_size, group=self.process_group) + group_size=group_size, group=self.process_group + ) else: self.period_process_group_dict[period] = self.process_group @@ -149,7 +156,12 @@ def _find_process_group(self): return self.period_process_group_dict[period] return None - def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]): + def average_parameters( + self, + params: Union[ + Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + ], + ): """ Averages parameters or parameter groups of an optimizer. diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index de1977959d21c..20f75152f0b87 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -1,16 +1,23 @@ # mypy: allow-untyped-defs # flake8: noqa C101 import itertools -from typing import Union, Iterable, Dict, Iterator +from typing import Dict, Iterable, Iterator, Union import torch import torch.distributed as dist + # The two imports below are not always available depending on the # USE_DISTRIBUTED compile flag. Make sure they raise import error # if we're trying to use them. -from torch.distributed import ProcessGroup, group +from torch.distributed import group, ProcessGroup + + +__all__ = [ + "average_parameters", + "get_params_to_average", + "average_parameters_or_parameter_groups", +] -__all__ = ["average_parameters", "get_params_to_average", "average_parameters_or_parameter_groups"] def average_parameters( params: Iterator[torch.nn.Parameter], process_group: ProcessGroup @@ -43,7 +50,9 @@ def average_parameters( offset += p.numel() -def get_params_to_average(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]): +def get_params_to_average( + params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]] +): """ Return a list of parameters that need to average. @@ -64,10 +73,17 @@ def get_params_to_average(params: Union[Iterable[torch.nn.Parameter], Iterable[D if param_data.grad is not None: filtered_params.append(param_data) else: - raise NotImplementedError(f"Parameter input of type {type(param)} is not supported") + raise NotImplementedError( + f"Parameter input of type {type(param)} is not supported" + ) return filtered_params -def average_parameters_or_parameter_groups(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]], process_group: ProcessGroup): +def average_parameters_or_parameter_groups( + params: Union[ + Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + ], + process_group: ProcessGroup, +): """Averages parameters of a model or parameter groups of an optimizer.""" average_parameters(iter(get_params_to_average(params)), process_group) diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index 6546c38a37b99..b1cf0aec6140f 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs -import sys import torch @@ -13,22 +12,22 @@ def is_available(): if is_available(): from torch._C._distributed_autograd import ( - get_gradients, - backward, + _current_context, + _get_debug_info, + _get_max_id, _init, + _is_valid_context, _new_context, _release_context, - _get_max_id, - _is_valid_context, _retrieve_context, - _current_context, - _get_debug_info, + backward, DistAutogradContext, + get_gradients, ) class context: - ''' + """ Context object to wrap forward and backward passes when using distributed autograd. The ``context_id`` generated in the ``with`` statement is required to uniquely identify a distributed backward pass @@ -44,7 +43,8 @@ class context: >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> dist_autograd.backward(context_id, [loss]) - ''' + """ + def __enter__(self): self.autograd_context = _new_context() return self.autograd_context._context_id() diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py index 60f71e12213be..d36568953d6b5 100644 --- a/torch/distributed/benchmarks/benchmark_ddp_rpc.py +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -8,12 +8,13 @@ import time import numpy as np + import torch -import torch.nn as nn import torch.distributed as dist import torch.distributed.autograd as dist_autograd import torch.distributed.rpc as rpc import torch.multiprocessing as mp +import torch.nn as nn import torch.optim as optim from torch.distributed.optim import DistributedOptimizer from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions @@ -210,14 +211,13 @@ def run_worker(rank, world_size): # Rank 16. Master if rank == (NUM_TRAINERS + NUM_PS): - rpc.init_rpc( - "master", rank=rank, + "master", + rank=rank, backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] - world_size=world_size + world_size=world_size, ) - # Build the Embedding tables on the Parameter Servers. emb_rref_list = [] index = 0 @@ -256,7 +256,6 @@ def run_worker(rank, world_size): # Rank 0-7. Trainers elif rank >= 0 and rank < NUM_PS: - # Initialize process group for Distributed DataParallel on trainers. dist.init_process_group( backend=dist.Backend.GLOO, @@ -292,7 +291,7 @@ def run_worker(rank, world_size): if __name__ == "__main__": - """ Initializing the distributed environment. """ + """Initializing the distributed environment.""" output = _run_printable("nvidia-smi topo -m") print("-------------------------------------------") diff --git a/torch/distributed/checkpoint/_dedup_save_plans.py b/torch/distributed/checkpoint/_dedup_save_plans.py index 16d46e73baffd..dd37634a0aa64 100644 --- a/torch/distributed/checkpoint/_dedup_save_plans.py +++ b/torch/distributed/checkpoint/_dedup_save_plans.py @@ -5,6 +5,7 @@ from torch.distributed.checkpoint.planner import SavePlan, WriteItem + if TYPE_CHECKING: from torch.distributed.checkpoint.metadata import MetadataIndex diff --git a/torch/distributed/checkpoint/_dedup_tensors.py b/torch/distributed/checkpoint/_dedup_tensors.py index 7689b9452e8cc..687afb287b3c7 100644 --- a/torch/distributed/checkpoint/_dedup_tensors.py +++ b/torch/distributed/checkpoint/_dedup_tensors.py @@ -5,6 +5,7 @@ from torch.distributed.checkpoint.planner import SavePlan + if TYPE_CHECKING: from torch.distributed.checkpoint.metadata import MetadataIndex diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index 7fdd04dff311c..b57df9c3456ca 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -17,6 +17,7 @@ FileSystemWriter, ) + __all__ = [ "FsspecWriter", "FsspecReader", diff --git a/torch/distributed/checkpoint/_nested_dict.py b/torch/distributed/checkpoint/_nested_dict.py index 527a67e6892fe..3347ea8bc432a 100644 --- a/torch/distributed/checkpoint/_nested_dict.py +++ b/torch/distributed/checkpoint/_nested_dict.py @@ -5,6 +5,7 @@ from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict + """ TODO: Need to add ability to handle tuple, OrderedDict, NamedTuple. diff --git a/torch/distributed/checkpoint/_sharded_tensor_utils.py b/torch/distributed/checkpoint/_sharded_tensor_utils.py index f71f129e127c7..a68bcddeb7f9d 100644 --- a/torch/distributed/checkpoint/_sharded_tensor_utils.py +++ b/torch/distributed/checkpoint/_sharded_tensor_utils.py @@ -11,6 +11,7 @@ from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict from .utils import _element_wise_add, _normalize_device_info + if TYPE_CHECKING: from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata diff --git a/torch/distributed/checkpoint/_storage_utils.py b/torch/distributed/checkpoint/_storage_utils.py index 0f5205a1f2030..194c9c8c4b9b1 100644 --- a/torch/distributed/checkpoint/_storage_utils.py +++ b/torch/distributed/checkpoint/_storage_utils.py @@ -2,7 +2,6 @@ from typing import List, Type, Union from .filesystem import FileSystemReader, FileSystemWriter - from .storage import StorageReader, StorageWriter diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index 5d5e87bf13087..8bcb832c71980 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -17,6 +17,7 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + PATH_ITEM = Union[str, int] OBJ_PATH = Tuple[PATH_ITEM, ...] T = TypeVar("T") diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py index 660196bc28de8..e587580617a1b 100644 --- a/torch/distributed/checkpoint/api.py +++ b/torch/distributed/checkpoint/api.py @@ -2,6 +2,7 @@ import traceback as tb from typing import Any, Dict, Tuple + WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary] __all__ = ["CheckpointException"] diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 83b76718a6b7a..cbf855d51417f 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -46,6 +46,7 @@ ) from torch.distributed.checkpoint.utils import find_state_dict_object + logger: logging.Logger = logging.getLogger(__name__) diff --git a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py index 38c637d3a4fd1..f2f03840b0d57 100644 --- a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py +++ b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py @@ -16,10 +16,10 @@ import torch.distributed.checkpoint as dist_cp import torch.multiprocessing as mp from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint" diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 4d512891f1223..859476d71e16f 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -32,7 +32,6 @@ from torch import Tensor from torch._utils import _get_available_device_type, _get_device_module from torch.distributed._shard._utils import narrow_tensor_by_index - from torch.distributed.checkpoint.metadata import ( Metadata, MetadataIndex, @@ -58,6 +57,7 @@ from torch.distributed.checkpoint.utils import _create_file_view from torch.futures import Future + __all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"] _metadata_fn: str = ".metadata" diff --git a/torch/distributed/checkpoint/logger.py b/torch/distributed/checkpoint/logger.py index 270240490c99d..c210819ec5ad7 100644 --- a/torch/distributed/checkpoint/logger.py +++ b/torch/distributed/checkpoint/logger.py @@ -7,6 +7,7 @@ import torch.distributed.c10d_logger as c10d_logger from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME + __all__: List[str] = [] global _dcp_logger diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py index b3bc7a580dad0..d1f87e2d9cba8 100644 --- a/torch/distributed/checkpoint/metadata.py +++ b/torch/distributed/checkpoint/metadata.py @@ -7,6 +7,7 @@ import torch from torch.distributed.checkpoint.stateful import StatefulT + __all__ = [ "ChunkStorageMetadata", "TensorStorageMetadata", diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index 26468d046f29a..220ca22f703e5 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -40,6 +40,7 @@ from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device + STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py index 5eec8bf754665..d3e79950e0c76 100644 --- a/torch/distributed/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -7,7 +7,6 @@ from typing import Any, List, Optional, Tuple, Union import torch - from torch.distributed.checkpoint.metadata import ( ChunkStorageMetadata, Metadata, diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index 4bbe26876c881..56e17281d4e4b 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -4,13 +4,11 @@ import torch import torch.distributed as dist from torch._utils import _get_device_module - from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor import DTensor from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed.checkpoint.planner import _Checkpointable - from torch.utils._pytree import tree_map_only from .metadata import ( @@ -35,6 +33,7 @@ _shards_get_overlap_region_wrt_saved_tensor, ) + __all__: List[str] = ["create_read_items_for_chunk_list"] diff --git a/torch/distributed/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py index a1bf112f17950..0e5153df8da0a 100644 --- a/torch/distributed/checkpoint/resharding.py +++ b/torch/distributed/checkpoint/resharding.py @@ -3,6 +3,7 @@ from torch.distributed.checkpoint.metadata import ChunkStorageMetadata + __all__: List[str] = [] diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index dba7ea0b41361..40f2fbdf0a0d9 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -6,9 +6,9 @@ _create_cpu_state_dict, _offload_state_dict_to_cpu, ) - from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + __all__ = ["AsyncStager", "BlockingAsyncStager"] diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index c906ff3dcc202..16a1ddde21586 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -54,6 +54,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils._pytree import tree_map_only + __all__ = [ "FQNS_T", "PrimitiveType", diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index f443f73f02d6d..c4d1d853e9c66 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -16,6 +16,7 @@ from .storage import StorageReader from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile + __all__ = ["load_state_dict", "load"] diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index ba1695d832122..20abc2212f5e1 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -9,7 +9,6 @@ import torch import torch.distributed as dist from torch.distributed._state_dict_utils import _offload_state_dict_to_cpu - from torch.distributed.checkpoint._storage_utils import _storage_setup from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.logger import _dcp_method_logger diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py index bd786671c4526..dd46fe9246fd4 100644 --- a/torch/distributed/checkpoint/storage.py +++ b/torch/distributed/checkpoint/storage.py @@ -10,9 +10,9 @@ SavePlan, SavePlanner, ) - from torch.futures import Future + __all__ = ["WriteResult", "StorageWriter", "StorageReader"] diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 0efba34a551bc..32649455163e6 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -25,6 +25,7 @@ ) from .metadata import MetadataIndex, STATE_DICT_TYPE + __all__ = ["find_tensor_shard", "find_state_dict_object"] T = TypeVar("T") diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 28937ca47b1a7..50c69b8e96274 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -14,6 +14,7 @@ import time import traceback import warnings +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum @@ -21,16 +22,13 @@ import torch.distributed.elastic.rendezvous as rdzv import torch.distributed.elastic.utils.store as store_util -from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError from torch.distributed.elastic.events import Event, EventSource, record from torch.distributed.elastic.metrics import prof, put_metric -from torch.distributed.elastic.multiprocessing import ( - ProcessFailure, - SignalException, -) -from collections import defaultdict +from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException +from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError from torch.distributed.elastic.utils.logging import get_logger + __all__ = [ "WorkerSpec", "Worker", @@ -250,7 +248,16 @@ class WorkerGroup: group contains cross instance workers or not depends on the implementation of the agent. """ - __slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state", "master_addr", "master_port"] + __slots__ = [ + "spec", + "workers", + "store", + "group_rank", + "group_world_size", + "state", + "master_addr", + "master_port", + ] def __init__(self, spec: WorkerSpec): self.spec = spec @@ -450,7 +457,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: raise NotImplementedError @abc.abstractmethod - def _stop_workers(self, worker_group: WorkerGroup, is_restart: bool = False) -> None: + def _stop_workers( + self, worker_group: WorkerGroup, is_restart: bool = False + ) -> None: r"""Stop all workers in the given worker group. Implementors must deal with workers in all states defined by @@ -468,7 +477,9 @@ def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: raise NotImplementedError @abc.abstractmethod - def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False) -> None: + def _shutdown( + self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False + ) -> None: """Clean up any resources that were allocated during the agent's work. Args: @@ -499,7 +510,9 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: self._store = store with self.record_duration("ASSIGN_WORKER_RANKS"): - workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec) + workers = self._assign_worker_ranks( + store, group_rank, group_world_size, spec + ) worker_group.workers = workers worker_group.store = store worker_group.group_rank = group_rank @@ -532,8 +545,8 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: "role_ranks": [worker.role_rank for worker in workers], "global_ranks": [worker.global_rank for worker in workers], "role_world_sizes": [worker.role_world_size for worker in workers], - "global_world_sizes": [worker.world_size for worker in workers] - } + "global_world_sizes": [worker.world_size for worker in workers], + }, ) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator @@ -612,9 +625,12 @@ def _assign_worker_ranks( store.multi_set(keys, values) # get will block until the data is available in the store. - base_global_rank, global_world_size, base_role_rank, role_world_size = json.loads( - store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}") - ) + ( + base_global_rank, + global_world_size, + base_role_rank, + role_world_size, + ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) workers = [] for local_rank in range(spec.local_world_size): @@ -733,7 +749,11 @@ def record_duration(self, state: str): finally: end_time = time.perf_counter() duration_ms = (end_time - start_time) * 1000 - record(self._construct_event(state=state, source=EventSource.AGENT, duration_ms=duration_ms)) + record( + self._construct_event( + state=state, source=EventSource.AGENT, duration_ms=duration_ms + ) + ) def _construct_event( self, @@ -844,7 +864,8 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: logger.info( "[%s] worker group successfully finished." " Waiting %s seconds for other agents to finish.", - role, self._exit_barrier_timeout + role, + self._exit_barrier_timeout, ) self._exit_barrier() return run_result @@ -854,7 +875,10 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: "[%s] Worker group %s. " "%s/%s attempts left;" " will restart worker group", - role, state.name, self._remaining_restarts, spec.max_restarts + role, + state.name, + self._remaining_restarts, + spec.max_restarts, ) self._remaining_restarts -= 1 self._restart_workers(self._worker_group) @@ -871,11 +895,15 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: "[%s] Detected %s " "new nodes from group_rank=%s; " "will restart worker group", - role, num_nodes_waiting, group_rank + role, + num_nodes_waiting, + group_rank, ) self._restart_workers(self._worker_group) else: - raise Exception(f"[{role}] Worker group in {state.name} state") # noqa: TRY002 + raise Exception( # noqa: TRY002 + f"[{role}] Worker group in {state.name} state" + ) def _exit_barrier(self): """ @@ -889,7 +917,8 @@ def _exit_barrier(self): logger.info( "Local worker group finished (%s). " "Waiting %s seconds for other agents to finish", - self._worker_group.state, self._exit_barrier_timeout + self._worker_group.state, + self._exit_barrier_timeout, ) start = time.time() try: @@ -900,7 +929,8 @@ def _exit_barrier(self): barrier_timeout=self._exit_barrier_timeout, ) logger.info( - "Done waiting for other agents. Elapsed: %s seconds", time.time() - start + "Done waiting for other agents. Elapsed: %s seconds", + time.time() - start, ) except SignalException as e: logger.warning("Got termination signal: %s", e.sigval) @@ -908,5 +938,5 @@ def _exit_barrier(self): except Exception: logger.exception( "Error waiting on exit barrier. Elapsed: %s seconds", - time.time() - start + time.time() - start, ) diff --git a/torch/distributed/elastic/agent/server/health_check_server.py b/torch/distributed/elastic/agent/server/health_check_server.py index 0016073055152..d54915f746168 100644 --- a/torch/distributed/elastic/agent/server/health_check_server.py +++ b/torch/distributed/elastic/agent/server/health_check_server.py @@ -10,6 +10,7 @@ from torch.distributed.elastic.utils.logging import get_logger + log = get_logger(__name__) __all__ = ["HealthCheckServer", "create_healthcheck_server"] diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 232f28234e653..9423ef16f5c01 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -12,14 +12,13 @@ import os import signal import socket -from string import Template import time import uuid +from string import Template from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING import torch.distributed.elastic.timer as timer from torch.distributed.elastic import events - from torch.distributed.elastic.agent.server.api import ( RunResult, SimpleElasticAgent, @@ -32,10 +31,15 @@ HealthCheckServer, ) from torch.distributed.elastic.metrics.api import prof -from torch.distributed.elastic.multiprocessing import PContext, start_processes, LogsSpecs +from torch.distributed.elastic.multiprocessing import ( + LogsSpecs, + PContext, + start_processes, +) from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger + if TYPE_CHECKING: from torch.distributed.elastic.events.api import EventMetadataValue @@ -52,6 +56,7 @@ TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" + class LocalElasticAgent(SimpleElasticAgent): """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. @@ -158,7 +163,6 @@ def __init__( self._logs_specs = logs_specs self._health_check_server: Optional[HealthCheckServer] = None - def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER watchdog_enabled = os.getenv(enable_watchdog_env_name) @@ -169,8 +173,10 @@ def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) if not envs: - logger.warning("Empty envs variables, using empty run_id for FileTimerServer") - run_id = '' + logger.warning( + "Empty envs variables, using empty run_id for FileTimerServer" + ) + run_id = "" else: run_id = envs[0]["TORCHELASTIC_RUN_ID"] self._worker_watchdog = timer.FileTimerServer( @@ -178,11 +184,15 @@ def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: run_id=run_id, max_interval=0.1, daemon=True, - log_event=self._log_watchdog_event) + log_event=self._log_watchdog_event, + ) self._worker_watchdog.start() logger.info("FileTimerServer started") else: - logger.info("Environment variable '%s' not found. Do not start FileTimerServer.", enable_watchdog_env_name) + logger.info( + "Environment variable '%s' not found. Do not start FileTimerServer.", + enable_watchdog_env_name, + ) # Propagate the watchdog file env to worker processes if watchdog_file_path is not None: for worker_env in envs.values(): @@ -202,7 +212,9 @@ def _setup_healthcheck(self) -> None: healthcheck_port, ) if self._worker_watchdog is None: - logger.info("FileTimerServer doesn't exist, using current time as dummy callback") + logger.info( + "FileTimerServer doesn't exist, using current time as dummy callback" + ) alive_callback = LocalElasticAgent._get_current_time_secs else: alive_callback = self._worker_watchdog.get_last_progress_time @@ -219,7 +231,6 @@ def _setup_healthcheck(self) -> None: healthcheck_port_env_name, ) - def _get_fq_hostname(self) -> str: return socket.getfqdn(socket.gethostname()) @@ -230,9 +241,7 @@ def _log_watchdog_event( ) -> None: wg = self._worker_group spec = wg.spec - md = { - "watchdog_event": name - } + md = {"watchdog_event": name} if request is not None: md["worker_pid"] = str(request.worker_pid) md["scope_id"] = request.scope_id @@ -264,7 +273,9 @@ def _log_watchdog_event( # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof - def _stop_workers(self, worker_group: WorkerGroup, is_restart: bool = False) -> None: + def _stop_workers( + self, worker_group: WorkerGroup, is_restart: bool = False + ) -> None: self._shutdown(is_restart=is_restart) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator @@ -280,7 +291,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: args: Dict[int, Tuple] = {} envs: Dict[int, Dict[str, str]] = {} - log_line_prefixes: Optional[Dict[int, str]] = {} if self._log_line_prefix_template else None + log_line_prefixes: Optional[Dict[int, str]] = ( + {} if self._log_line_prefix_template else None + ) for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { @@ -306,12 +319,14 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] - if self._log_line_prefix_template: - log_line_prefix = Template(self._log_line_prefix_template).safe_substitute( + log_line_prefix = Template( + self._log_line_prefix_template + ).safe_substitute( role_name=spec.role, rank=worker.global_rank, - local_rank=local_rank,) + local_rank=local_rank, + ) log_line_prefixes[local_rank] = log_line_prefix envs[local_rank] = worker_env @@ -336,7 +351,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: return self._pcontext.pids() - def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False) -> None: + def _shutdown( + self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False + ) -> None: if self._worker_watchdog is not None: self._worker_watchdog.stop() self._worker_watchdog = None @@ -360,7 +377,9 @@ def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: logger.error( "[%s] worker pids do not match process_context pids." " Expected: %s, actual: %s", - role, worker_pids, pc_pids + role, + worker_pids, + pc_pids, ) return RunResult(state=WorkerState.UNKNOWN) diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py index 160383637865b..e778c08683847 100644 --- a/torch/distributed/elastic/control_plane.py +++ b/torch/distributed/elastic/control_plane.py @@ -4,6 +4,7 @@ from torch.distributed.elastic.multiprocessing.errors import record + __all__ = [ "worker_main", ] diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index 9f6e1733518af..5e4a0a6f23691 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -24,7 +24,6 @@ import os import socket import traceback -from enum import Enum from typing import Dict, Optional from torch.distributed.elastic.events.handlers import get_logging_handler @@ -37,8 +36,10 @@ RdzvEvent, ) + _events_loggers: Dict[str, logging.Logger] = {} + def _get_or_create_logger(destination: str = "null") -> logging.Logger: """ Construct python logger based on the destination type or extends if provided. @@ -71,6 +72,7 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger: def record(event: Event, destination: str = "null") -> None: _get_or_create_logger(destination).info(event.serialize()) + def record_rdzv_event(event: RdzvEvent) -> None: _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) diff --git a/torch/distributed/elastic/events/api.py b/torch/distributed/elastic/events/api.py index 082499b3af638..c610cfd4cb354 100644 --- a/torch/distributed/elastic/events/api.py +++ b/torch/distributed/elastic/events/api.py @@ -10,9 +10,10 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Dict, Union, Optional +from typing import Dict, Optional, Union -__all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent'] + +__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] EventMetadataValue = Union[str, int, float, bool, None] diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index d8bea0b3c0791..4b72dcd7c6020 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -139,14 +139,14 @@ def emit(self, metric_data): from typing import Optional from .api import ( # noqa: F401 + configure, ConsoleMetricHandler, + get_elapsed_time_ms, + getStream, MetricData, MetricHandler, MetricsConfig, NullMetricHandler, - configure, - get_elapsed_time_ms, - getStream, prof, profile, publish_metric, diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 7b6d8295ef051..2c07d3b5c47bb 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -14,9 +14,22 @@ from typing import Dict, Optional from typing_extensions import deprecated -__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream', - 'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms', - 'MetricData'] + +__all__ = [ + "MetricsConfig", + "MetricHandler", + "ConsoleMetricHandler", + "NullMetricHandler", + "MetricStream", + "configure", + "getStream", + "prof", + "profile", + "put_metric", + "publish_metric", + "get_elapsed_time_ms", + "MetricData", +] MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) @@ -150,6 +163,7 @@ def profile(group=None): @metrics.profile("my_metric_group") def some_function(): """ + def wrap(func): @wraps(func) def wrapper(*args, **kwargs): diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index 4e26ab1744a98..21cb5e47d4419 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -62,8 +62,7 @@ def trainer(a, b, c): implementations of the parent :class:`api.PContext` class. """ -import os -from typing import Callable, Dict, Optional, Tuple, Union, Set +from typing import Callable, Dict, Optional, Tuple, Union from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 _validate_full_rank, @@ -81,6 +80,7 @@ def trainer(a, b, c): ) from torch.distributed.elastic.utils.logging import get_logger + __all__ = [ "start_processes", "MultiprocessContext", diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 5d294a7d08021..8968dbdc8e6db 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -17,13 +17,13 @@ import sys import tempfile import time +from abc import ABC, abstractmethod from contextlib import nullcontext from dataclasses import dataclass, field from enum import IntFlag from multiprocessing import synchronize from types import FrameType from typing import Any, Callable, Dict, Optional, Set, Tuple, Union -from abc import ABC, abstractmethod import torch.multiprocessing as mp from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record @@ -31,10 +31,13 @@ redirect_stderr, redirect_stdout, ) - -from torch.distributed.elastic.multiprocessing.subprocess_handler import SubprocessHandler, get_subprocess_handler +from torch.distributed.elastic.multiprocessing.subprocess_handler import ( + get_subprocess_handler, + SubprocessHandler, +) from torch.distributed.elastic.multiprocessing.tail_log import TailLog + IS_WINDOWS = sys.platform == "win32" IS_MACOS = sys.platform == "darwin" @@ -55,6 +58,7 @@ "LogsSpecs", ] + class SignalException(Exception): """ Exception is raised inside the torchelastic agent process by the termination handler @@ -178,6 +182,7 @@ class LogsDest: """ For each log type, holds mapping of local rank ids to file paths. """ + stdouts: Dict[int, str] = field(default_factory=dict) stderrs: Dict[int, str] = field(default_factory=dict) tee_stdouts: Dict[int, str] = field(default_factory=dict) @@ -215,7 +220,10 @@ def __init__( self._local_ranks_filter = local_ranks_filter @abstractmethod - def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: + def reify( + self, + envs: Dict[int, Dict[str, str]], + ) -> LogsDest: """ Given the environment variables, builds destination of log files for each of the local ranks. @@ -229,6 +237,7 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: def root_log_dir(self) -> str: pass + class DefaultLogsSpecs(LogsSpecs): """ Default LogsSpecs implementation: @@ -236,6 +245,7 @@ class DefaultLogsSpecs(LogsSpecs): - `log_dir` will be created if it doesn't exist - Generates nested folders for each attempt and rank. """ + def __init__( self, log_dir: Optional[str] = None, @@ -266,7 +276,10 @@ def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): logger.info("log directory set to: %s", dir) return dir - def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: + def reify( + self, + envs: Dict[int, Dict[str, str]], + ) -> LogsDest: """ Uses following scheme to build log destination paths: @@ -279,7 +292,9 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: if nprocs > 0: global_env = envs[0] else: - logger.warning("Empty envs map provided when defining logging destinations.") + logger.warning( + "Empty envs map provided when defining logging destinations." + ) # Keys are always defined, but values can be missing in unit tests run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") @@ -321,7 +336,6 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: error_files = {} for local_rank in range(nprocs): - if attempt_log_dir == os.devnull: tee_stdouts[local_rank] = os.devnull tee_stderrs[local_rank] = os.devnull @@ -343,7 +357,10 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: if t & Std.ERR == Std.ERR: tee_stderrs[local_rank] = stderrs[local_rank] - if self._local_ranks_filter and local_rank not in self._local_ranks_filter: + if ( + self._local_ranks_filter + and local_rank not in self._local_ranks_filter + ): # If stream is tee'd, only write to file, but don't tail if local_rank in tee_stdouts: tee_stdouts.pop(local_rank, None) @@ -358,7 +375,9 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: error_file = os.path.join(clogdir, "error.json") error_files[local_rank] = error_file - logger.info("Setting worker%s reply file to: %s", local_rank, error_file) + logger.info( + "Setting worker%s reply file to: %s", local_rank, error_file + ) envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files) @@ -423,7 +442,6 @@ def __init__( envs: Dict[int, Dict[str, str]], logs_specs: LogsSpecs, log_line_prefixes: Optional[Dict[int, str]] = None, - ): self.name = name # validate that all mappings have the same number of keys and @@ -444,8 +462,12 @@ def __init__( self.error_files = logs_dest.error_files self.nprocs = nprocs - self._stdout_tail = TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes) - self._stderr_tail = TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes) + self._stdout_tail = TailLog( + name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes + ) + self._stderr_tail = TailLog( + name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes + ) def start(self) -> None: """Start processes using parameters defined in the constructor.""" @@ -678,7 +700,9 @@ def _poll(self) -> Optional[RunProcsResult]: # But the child process might still have not exited. Wait for them. # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. while not self._pc.join(): - logger.debug("entrypoint fn finished, waiting for all child procs to exit...") + logger.debug( + "entrypoint fn finished, waiting for all child procs to exit..." + ) _validate_full_rank( self._return_values, self.nprocs, "return_value queue" @@ -704,8 +728,10 @@ def _poll(self) -> Optional[RunProcsResult]: " local_rank: %s (pid: %s)" " of fn: %s (start_method: %s)", failed_proc.exitcode, - failed_local_rank, e.pid, - fn_name, self.start_method, + failed_local_rank, + e.pid, + fn_name, + self.start_method, ) self.close() @@ -731,7 +757,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: return for proc in self._pc.processes: if proc.is_alive(): - logger.warning("Closing process %s via signal %s", proc.pid, death_sig.name) + logger.warning( + "Closing process %s via signal %s", proc.pid, death_sig.name + ) try: os.kill(proc.pid, death_sig) except ProcessLookupError: @@ -748,7 +776,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: if proc.is_alive(): logger.warning( "Unable to shutdown process %s via %s, forcefully exiting via %s", - proc.pid, death_sig, _get_kill_signal() + proc.pid, + death_sig, + _get_kill_signal(), ) try: os.kill(proc.pid, _get_kill_signal()) @@ -758,6 +788,7 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: pass proc.join() + class SubprocessContext(PContext): """``PContext`` holding worker processes invoked as a binary.""" @@ -769,7 +800,6 @@ def __init__( envs: Dict[int, Dict[str, str]], logs_specs: LogsSpecs, log_line_prefixes: Optional[Dict[int, str]] = None, - ): super().__init__( name, @@ -834,7 +864,10 @@ def _poll(self) -> Optional[RunProcsResult]: "failed (exitcode: %s)" " local_rank: %s (pid: %s)" " of binary: %s", - first_failure.exitcode, first_failure.local_rank, first_failure.pid, self.entrypoint + first_failure.exitcode, + first_failure.local_rank, + first_failure.pid, + self.entrypoint, ) else: # Populate return with dummy values. This provides consistency with MultiprocessingHandler @@ -856,7 +889,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: for handler in self.subprocess_handlers.values(): if handler.proc.poll() is None: logger.warning( - "Sending process %s closing signal %s", handler.proc.pid, death_sig.name + "Sending process %s closing signal %s", + handler.proc.pid, + death_sig.name, ) handler.close(death_sig=death_sig) end = time.monotonic() + timeout @@ -874,7 +909,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: if handler.proc.poll() is None: logger.warning( "Unable to shutdown process %s via %s, forcefully exiting via %s", - handler.proc.pid, death_sig, _get_kill_signal() + handler.proc.pid, + death_sig, + _get_kill_signal(), ) handler.close(death_sig=_get_kill_signal()) handler.proc.wait() diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index d63c283b4c35e..2f5ed2d1ab0b8 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -66,7 +66,14 @@ from .error_handler import ErrorHandler # noqa: F401 from .handlers import get_error_handler # noqa: F401 -__all__ = ["ProcessFailure", "ChildFailedError", "record", "ErrorHandler", "get_error_handler"] + +__all__ = [ + "ProcessFailure", + "ChildFailedError", + "record", + "ErrorHandler", + "get_error_handler", +] logger = get_logger(__name__) @@ -113,7 +120,8 @@ def __post_init__(self): with open(self.error_file) as fp: self.error_file_data = json.load(fp) logger.debug( - "User process failed with error data: %s", json.dumps(self.error_file_data, indent=2) + "User process failed with error data: %s", + json.dumps(self.error_file_data, indent=2), ) self.message, self.timestamp = self._get_error_data( self.error_file_data @@ -264,7 +272,6 @@ def format_msg(self, boarder_delim="=", section_delim="-"): def _format_failure( self, idx: int, rank: int, failure: ProcessFailure ) -> Tuple[str, int]: - # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) # or a dict (json) of the form # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} @@ -363,7 +370,7 @@ def wrapper(*args, **kwargs): "local_rank %s FAILED with no error file." " Decorate your entrypoint fn with @record for traceback info." " See: https://pytorch.org/docs/stable/elastic/errors.html", - rank + rank, ) ) raise diff --git a/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/torch/distributed/elastic/multiprocessing/errors/error_handler.py index 34d6229dda3b6..89e7fffdd5c7d 100644 --- a/torch/distributed/elastic/multiprocessing/errors/error_handler.py +++ b/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -15,7 +15,8 @@ import warnings from typing import Any, Dict, Optional -__all__ = ['ErrorHandler'] + +__all__ = ["ErrorHandler"] logger = logging.getLogger(__name__) @@ -93,13 +94,14 @@ def override_error_code_in_rootcause_data( logger.warning( "child error file (%s) does not have field `message`. \n" "cannot override error code: %s", - rootcause_error_file, error_code + rootcause_error_file, + error_code, ) elif isinstance(rootcause_error["message"], str): logger.warning( "child error file (%s) has a new message format. \n" "skipping error code override", - rootcause_error_file + rootcause_error_file, ) else: rootcause_error["message"]["errorCode"] = error_code @@ -111,11 +113,13 @@ def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): # Override error code since the child process cannot capture the error code if it # is terminated by signals like SIGSEGV. if error_code: - self.override_error_code_in_rootcause_data(rootcause_error_file, rootcause_error, error_code) + self.override_error_code_in_rootcause_data( + rootcause_error_file, rootcause_error, error_code + ) logger.debug( - "child error file (%s) contents:\n" - "%s", - rootcause_error_file, json.dumps(rootcause_error, indent=2) + "child error file (%s) contents:\n" "%s", + rootcause_error_file, + json.dumps(rootcause_error, indent=2), ) my_error_file = self._get_error_file_path() @@ -135,7 +139,8 @@ def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): logger.info("dumped error file to parent's %s", my_error_file) else: logger.error( - "no error file defined for parent, to copy child error file (%s)", rootcause_error_file + "no error file defined for parent, to copy child error file (%s)", + rootcause_error_file, ) def _rm(self, my_error_file): @@ -148,13 +153,14 @@ def _rm(self, my_error_file): "%s already exists" " and will be overwritten." " Original contents:\n%s", - my_error_file, original + my_error_file, + original, ) except json.decoder.JSONDecodeError: logger.warning( "%s already exists" " and will be overwritten." " Unable to load original contents:\n", - my_error_file + my_error_file, ) os.remove(my_error_file) diff --git a/torch/distributed/elastic/multiprocessing/errors/handlers.py b/torch/distributed/elastic/multiprocessing/errors/handlers.py index 09b2aca55f16a..b8a78e73702fd 100644 --- a/torch/distributed/elastic/multiprocessing/errors/handlers.py +++ b/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -11,7 +11,9 @@ from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler -__all__ = ['get_error_handler'] + +__all__ = ["get_error_handler"] + def get_error_handler(): return ErrorHandler() diff --git a/torch/distributed/elastic/multiprocessing/redirects.py b/torch/distributed/elastic/multiprocessing/redirects.py index 8ad3e2edf1c15..057013fbb9e5b 100644 --- a/torch/distributed/elastic/multiprocessing/redirects.py +++ b/torch/distributed/elastic/multiprocessing/redirects.py @@ -16,6 +16,7 @@ from contextlib import contextmanager from functools import partial + IS_WINDOWS = sys.platform == "win32" IS_MACOS = sys.platform == "darwin" diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py index 4c335964c7322..f56d423ce080f 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py @@ -12,4 +12,5 @@ SubprocessHandler, ) + __all__ = ["SubprocessHandler", "get_subprocess_handler"] diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index e122f89a94f77..2660be5af399a 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -12,6 +12,7 @@ SubprocessHandler, ) + __all__ = ["get_subprocess_handler"] diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index 7cacf98685750..c548d09209226 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -9,9 +9,9 @@ import signal import subprocess import sys - from typing import Any, Dict, Optional, Tuple + __all__ = ["SubprocessHandler"] IS_WINDOWS = sys.platform == "win32" diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 804e2e5a6323d..2c814ffb7be99 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -14,6 +14,7 @@ from threading import Event from typing import Dict, List, Optional, TextIO, TYPE_CHECKING + if TYPE_CHECKING: from concurrent.futures._base import Future @@ -25,7 +26,6 @@ def tail_logfile( header: str, file: str, dst: TextIO, finished: Event, interval_sec: float ): - while not os.path.exists(file): if finished.is_set(): return @@ -143,8 +143,10 @@ def stop(self) -> None: except Exception as e: logger.error( "error in log tailor for %s%s. %s: %s", - self._name, local_rank, - e.__class__.__qualname__, e, + self._name, + local_rank, + e.__class__.__qualname__, + e, ) if self._threadpool: diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index f6ec6a6eb62f6..62a31adab27b0 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -128,8 +128,8 @@ class that implements the rendezvous mechanism described above. It is a backend- ) """ - from .api import ( + rendezvous_handler_registry, RendezvousClosedError, RendezvousConnectionError, RendezvousError, @@ -142,9 +142,7 @@ class that implements the rendezvous mechanism described above. It is a backend- RendezvousStateError, RendezvousStoreInfo, RendezvousTimeoutError, - rendezvous_handler_registry, ) - from .registry import _register_default_handlers diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 7ddcd7c70b9af..9cde6758981ab 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -6,7 +6,6 @@ # LICENSE file in the root directory of this source tree. import socket - from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Callable, ClassVar, Dict, Optional @@ -51,15 +50,18 @@ class RendezvousConnectionError(RendezvousError): class RendezvousStateError(RendezvousError): """Raised when the state of a rendezvous is corrupt.""" + class RendezvousGracefulExitError(RendezvousError): """Raised when node wasn't not included in rendezvous and gracefully exits. Exception is a mechanism to exit the stack, however does not mean a failure. """ + @dataclass class RendezvousStoreInfo: """Store address and port that can be used to bootstrap trainer distributed comms""" + MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" master_addr: str @@ -79,13 +81,22 @@ def build(rank: int, store: Store) -> "RendezvousStoreInfo": store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") - port = int(store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")) + port = int( + store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") + ) return RendezvousStoreInfo(master_addr=addr, master_port=port) class RendezvousInfo: """Holds the information about the rendezvous.""" - def __init__(self, store: Store, rank: int, world_size: int, bootstrap_store_info: RendezvousStoreInfo): + + def __init__( + self, + store: Store, + rank: int, + world_size: int, + bootstrap_store_info: RendezvousStoreInfo, + ): self._store = store self._rank = rank self._world_size = world_size diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index 4c1c687411ef2..26c3153d9785b 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -11,13 +11,10 @@ import tempfile from base64 import b64decode, b64encode from datetime import timedelta -from typing import Any, Optional, Tuple, cast +from typing import Any, cast, Optional, Tuple from torch.distributed import FileStore, Store, TCPStore -from torch.distributed.elastic.events import ( - NodeState, - construct_and_record_rdzv_event, -) +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState from .api import ( RendezvousConnectionError, @@ -28,6 +25,7 @@ from .dynamic_rendezvous import RendezvousBackend, Token from .utils import _matches_machine_hostname, parse_rendezvous_endpoint + logger = logging.getLogger(__name__) @@ -96,7 +94,9 @@ def set_state( else: token = self._NULL_SENTINEL - base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str) + base64_state: bytes = self._call_store( + "compare_set", self._key, token, base64_state_str + ) state_token_pair = self._decode_state(base64_state) if state_token_pair is None: @@ -256,7 +256,9 @@ def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, elif store_type == "tcp": store = _create_tcp_store(params) else: - raise ValueError("Invalid store type given. Currently only supports file and tcp.") + raise ValueError( + "Invalid store type given. Currently only supports file and tcp." + ) backend = C10dRendezvousBackend(store, params.run_id) diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index ad45077d3943e..31627cf0a0b27 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -20,7 +20,6 @@ from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple import torch.distributed as dist - from torch.distributed import Store from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState @@ -37,6 +36,7 @@ ) from .utils import _delay, _PeriodicTimer + __all__ = [ "RendezvousBackend", "RendezvousTimeout", @@ -57,6 +57,7 @@ def get_method_name(depth=2): Token = Any """Represent an opaque fencing token used by the rendezvous backend.""" + class RendezvousBackend(ABC): """Represent a backend that holds the rendezvous state.""" @@ -157,7 +158,9 @@ def __init__( close: Optional[timedelta] = None, heartbeat: Optional[timedelta] = None, ) -> None: - self._set_timeouts(join=join, last_call=last_call, close=close, heartbeat=heartbeat) + self._set_timeouts( + join=join, last_call=last_call, close=close, heartbeat=heartbeat + ) @property def join(self) -> timedelta: @@ -311,7 +314,9 @@ def __init__(self) -> None: self.last_heartbeats = {} -def _remove_participant_epilogue(state: _RendezvousState, settings: RendezvousSettings) -> None: +def _remove_participant_epilogue( + state: _RendezvousState, settings: RendezvousSettings +) -> None: if state.complete: # If we do not have any participants left, move to the next round. if not state.participants: @@ -424,7 +429,9 @@ def sync(self) -> Optional[bool]: if self._cache_duration > 0: # Avoid overloading the backend if we are asked to retrieve the # state repeatedly. Try to serve the cached state. - if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0): + if self._last_sync_time >= max( + time.monotonic() - self._cache_duration, 0 + ): return None get_response = self._backend.get_state() @@ -917,14 +924,19 @@ def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: if ctx.node not in state.wait_list: return _Action.ADD_TO_WAIT_LIST elif len(state.participants) >= ctx.settings.max_nodes: - if ctx.node not in state.redundancy_list and ctx.node not in state.wait_list: + if ( + ctx.node not in state.redundancy_list + and ctx.node not in state.wait_list + ): return _Action.ADD_TO_REDUNDANCY_LIST elif is_participant: # If the rendezvous has enough number of participants including us, # check whether we have passed the rendezvous deadline. If yes, # complete it. - if len(state.participants) >= ctx.settings.min_nodes and \ - len(state.participants) <= ctx.settings.max_nodes: + if ( + len(state.participants) >= ctx.settings.min_nodes + and len(state.participants) <= ctx.settings.max_nodes + ): if cast(datetime, state.deadline) < datetime.utcnow(): msg = ( f"The node '{ctx.node}' marking the rendezvous complete, " @@ -1143,10 +1155,7 @@ def next_rendezvous(self) -> RendezvousInfo: deadline = self._get_deadline(self._settings.timeout.join) self._op_executor.run(exit_op, deadline) - self._op_executor.run( - join_op, - deadline, - self._get_deadline) + self._op_executor.run(join_op, deadline, self._get_deadline) self._start_heartbeats() @@ -1182,7 +1191,9 @@ def next_rendezvous(self) -> RendezvousInfo: if isinstance(self._store, dist.TCPStore): addr = self._store.host port = self._store.port - self._bootstrap_store_info = RendezvousStoreInfo(master_addr=addr, master_port=port) + self._bootstrap_store_info = RendezvousStoreInfo( + master_addr=addr, master_port=port + ) if rank == 0: self._shared_tcp_store_server = self._store else: @@ -1190,7 +1201,9 @@ def next_rendezvous(self) -> RendezvousInfo: # bootstrapping info across ranks self._bootstrap_store_info = RendezvousStoreInfo.build(rank, store) if rank == 0: - self._shared_tcp_store_server = self._create_tcp_store_server(self._bootstrap_store_info) + self._shared_tcp_store_server = self._create_tcp_store_server( + self._bootstrap_store_info + ) assert self._bootstrap_store_info is not None if rank == 0: @@ -1321,7 +1334,9 @@ def _start_heartbeats(self) -> None: self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) ) - self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}") + self._keep_alive_timer.set_name( + f"RendezvousKeepAliveTimer_{self._this_node.local_id}" + ) self._keep_alive_timer.start() @@ -1337,7 +1352,9 @@ def _get_world(self) -> Tuple[int, int]: return state.participants[self._this_node], len(state.participants) def _wrap_store(self, store: Store) -> Store: - key_prefix = f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + key_prefix = ( + f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + ) return dist.PrefixStore(key_prefix, store) diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index 1a371b74275a1..fe6170ede0159 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -15,6 +15,7 @@ from typing import Optional import etcd # type: ignore[import] + from torch.distributed.elastic.rendezvous import ( RendezvousClosedError, RendezvousError, @@ -25,15 +26,16 @@ RendezvousTimeoutError, ) +from .etcd_store import cas_delay, EtcdStore from .utils import parse_rendezvous_endpoint -from .etcd_store import EtcdStore, cas_delay + __all__ = [ "EtcdRendezvousRetryableFailure", "EtcdRendezvousRetryImmediately", "EtcdRendezvousHandler", "EtcdRendezvous", - "create_rdzv_handler" + "create_rdzv_handler", ] _log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") @@ -373,7 +375,9 @@ def join_phase(self, expected_version): state = json.loads(active_version.value) logger.info( "Joined rendezvous version %s as rank %s. Full state: %s", - state["version"], this_rank, state + state["version"], + this_rank, + state, ) # If this worker was first to reach num_min_workers requirement, @@ -418,7 +422,8 @@ def confirm_phase(self, expected_version, this_rank): logger.info( "Rendezvous version %s is complete. Final state: %s", - state["version"], state + state["version"], + state, ) # Rendezvous version number; our rank in it; world size @@ -436,12 +441,13 @@ def handle_existing_rendezvous(self, expected_version): # 2. if keep alives are missing, destroy it and bail out. active_state = self.announce_self_waiting(expected_version) logger.info( - "Added self to waiting list. Rendezvous full state: %s", - active_state.value + "Added self to waiting list. Rendezvous full state: %s", active_state.value ) self.wait_for_rendezvous_to_free(expected_version) - logger.info("Previously existing rendezvous state changed. Will re-try joining.") + logger.info( + "Previously existing rendezvous state changed. Will re-try joining." + ) def try_create_rendezvous(self): """ @@ -688,8 +694,7 @@ def wait_for_rendezvous_to_free(self, expected_version): # rendezvous version as dead (but only if it hadn't changed) logger.info("Keep-alive key %s is not renewed.", key) logger.info( - "Rendezvous version %s is incomplete. ", - expected_version + "Rendezvous version %s is incomplete. ", expected_version ) logger.info("Attempting to destroy it.") @@ -703,7 +708,7 @@ def wait_for_rendezvous_to_free(self, expected_version): logger.info( "Destroyed rendezvous version %s successfully.", - expected_version + expected_version, ) # We can return (and retry) immediately @@ -770,7 +775,9 @@ def handle_join_last_call(self, expected_version, deadline): # We successfully made this rendezvous frozen. return except etcd.EtcdCompareFailed: - logger.info("Join last-call transition CAS unsuccessful. Will retry") + logger.info( + "Join last-call transition CAS unsuccessful. Will retry" + ) cas_delay() active_version, state = self.get_rdzv_state() continue @@ -1051,6 +1058,8 @@ def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: num_min_workers=params.min_nodes, num_max_workers=params.max_nodes, timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), - last_call_timeout=params.get_as_int("last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT), + last_call_timeout=params.get_as_int( + "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT + ), ) return EtcdRendezvousHandler(rdzv_impl=rdzv) diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py index c9d60abdc2369..75ae347293c8f 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -7,17 +7,18 @@ import binascii from base64 import b64decode, b64encode -from typing import Optional, Tuple, cast +from typing import cast, Optional, Tuple import urllib3.exceptions # type: ignore[import] -from etcd import Client as EtcdClient # type: ignore[import] -from etcd import ( +from etcd import ( # type: ignore[import] + Client as EtcdClient, EtcdAlreadyExist, EtcdCompareFailed, EtcdException, EtcdKeyNotFound, EtcdResult, ) + from torch.distributed import Store from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError @@ -207,7 +208,9 @@ def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, """ client = _create_etcd_client(params) - backend = EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous") + backend = EtcdRendezvousBackend( + client, params.run_id, key_prefix="/torch/elastic/rendezvous" + ) store = EtcdStore(client, "/torch/elastic/store") diff --git a/torch/distributed/elastic/rendezvous/etcd_server.py b/torch/distributed/elastic/rendezvous/etcd_server.py index 891858534c565..8af8c01c028ae 100644 --- a/torch/distributed/elastic/rendezvous/etcd_server.py +++ b/torch/distributed/elastic/rendezvous/etcd_server.py @@ -17,6 +17,7 @@ import time from typing import Optional, TextIO, Union + try: import etcd # type: ignore[import] except ModuleNotFoundError: diff --git a/torch/distributed/elastic/rendezvous/etcd_store.py b/torch/distributed/elastic/rendezvous/etcd_store.py index 6055964756864..4fa1bef06857d 100644 --- a/torch/distributed/elastic/rendezvous/etcd_store.py +++ b/torch/distributed/elastic/rendezvous/etcd_store.py @@ -178,7 +178,9 @@ def _try_wait_get(self, b64_keys, override_timeout=None): # Read whole directory (of keys), filter only the ones waited for all_nodes = self.client.get(key=self.prefix) req_nodes = { - node.key: node.value for node in all_nodes.children if node.key in b64_keys + node.key: node.value + for node in all_nodes.children + if node.key in b64_keys } if len(req_nodes) == len(b64_keys): diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index eaa5bcfd80e24..1a91d0a8ff794 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -4,11 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .api import RendezvousHandler, RendezvousParameters -from .api import rendezvous_handler_registry as handler_registry +from .api import ( + rendezvous_handler_registry as handler_registry, + RendezvousHandler, + RendezvousParameters, +) from .dynamic_rendezvous import create_handler -__all__ = ['get_rendezvous_handler'] + +__all__ = ["get_rendezvous_handler"] + def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: from . import static_tcp_rendezvous diff --git a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py index ace82d0a22267..5d2679d9fb4a0 100644 --- a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -11,15 +11,16 @@ import logging from typing import cast, Optional -from torch.distributed import Store, TCPStore, PrefixStore +from torch.distributed import PrefixStore, Store, TCPStore from torch.distributed.elastic.rendezvous import ( - RendezvousInfo, RendezvousHandler, + RendezvousInfo, + RendezvousParameters, RendezvousStoreInfo, - RendezvousParameters ) from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint + __all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] logger = logging.getLogger(__name__) diff --git a/torch/distributed/elastic/rendezvous/utils.py b/torch/distributed/elastic/rendezvous/utils.py index 8419051d29f82..a93c9c39e5863 100644 --- a/torch/distributed/elastic/rendezvous/utils.py +++ b/torch/distributed/elastic/rendezvous/utils.py @@ -15,7 +15,9 @@ from threading import Event, Thread from typing import Any, Callable, Dict, Optional, Tuple, Union -__all__ = ['parse_rendezvous_endpoint'] + +__all__ = ["parse_rendezvous_endpoint"] + def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: """Extract key-value pairs from a rendezvous configuration string. @@ -62,7 +64,9 @@ def _try_parse_port(port_str: str) -> Optional[int]: return None -def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]: +def parse_rendezvous_endpoint( + endpoint: Optional[str], default_port: int +) -> Tuple[str, int]: """Extract the hostname and the port number from a rendezvous endpoint. Args: @@ -92,7 +96,7 @@ def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tup if len(rest) == 1: port = _try_parse_port(rest[0]) - if port is None or port >= 2 ** 16: + if port is None or port >= 2**16: raise ValueError( f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " "between 0 and 65536." @@ -135,10 +139,7 @@ def _matches_machine_hostname(host: str) -> bool: except (ValueError, socket.gaierror) as _: host_addr_list = [] - host_ip_list = [ - host_addr_info[4][0] - for host_addr_info in host_addr_list - ] + host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] this_host = socket.gethostname() if host == this_host: @@ -246,7 +247,10 @@ def start(self) -> None: raise RuntimeError("The timer has already started.") self._thread = Thread( - target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True + target=self._run, + name=self._name or "PeriodicTimer", + args=(self._ctx,), + daemon=True, ) # We avoid using a regular finalizer (a.k.a. __del__) for stopping the diff --git a/torch/distributed/elastic/timer/__init__.py b/torch/distributed/elastic/timer/__init__.py index ea4b2a46c4231..b9c2ea349cc67 100644 --- a/torch/distributed/elastic/timer/__init__.py +++ b/torch/distributed/elastic/timer/__init__.py @@ -39,6 +39,16 @@ def trainer_func(message_queue): complete, then the worker process is killed and the agent retries the worker group. """ -from .api import TimerClient, TimerRequest, TimerServer, configure, expires # noqa: F401 +from .api import ( # noqa: F401 + configure, + expires, + TimerClient, + TimerRequest, + TimerServer, +) +from .file_based_local_timer import ( # noqa: F401 + FileTimerClient, + FileTimerRequest, + FileTimerServer, +) from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 -from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest # noqa: F401 diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 77fcaaceed4f2..fe8d440b1afb8 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -12,10 +12,19 @@ from inspect import getframeinfo, stack from typing import Any, Dict, List, Optional, Set -__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires'] + +__all__ = [ + "TimerRequest", + "TimerClient", + "RequestQueue", + "TimerServer", + "configure", + "expires", +] logger = logging.getLogger(__name__) + class TimerRequest: """ Data object representing a countdown timer acquisition and release @@ -192,9 +201,9 @@ def _run_watchdog(self): reaped_worker_ids = set() for worker_id, expired_timers in self.get_expired_timers(now).items(): logger.info( - "Reaping worker_id=[%s]." - " Expired timers: %s", - worker_id, self._get_scopes(expired_timers) + "Reaping worker_id=[%s]." " Expired timers: %s", + worker_id, + self._get_scopes(expired_timers), ) if self._reap_worker_no_throw(worker_id): logger.info("Successfully reaped worker=[%s]", worker_id) @@ -210,10 +219,10 @@ def _get_scopes(self, timer_requests): def start(self) -> None: logger.info( - "Starting %s..." - " max_interval=%s," - " daemon=%s", - type(self).__name__, self._max_interval, self._daemon + "Starting %s..." " max_interval=%s," " daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, ) self._watchdog_thread = threading.Thread( target=self._watchdog_loop, daemon=self._daemon diff --git a/torch/distributed/elastic/timer/debug_info_logging.py b/torch/distributed/elastic/timer/debug_info_logging.py index 55a1a9e9bcdf7..3dce543220d83 100644 --- a/torch/distributed/elastic/timer/debug_info_logging.py +++ b/torch/distributed/elastic/timer/debug_info_logging.py @@ -11,6 +11,7 @@ from torch.distributed.elastic.utils.logging import get_logger + logger = get_logger(__name__) __all__ = ["log_debug_info_for_expired_timers"] diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index fce46f053a7e7..74da756d58c99 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -16,13 +16,17 @@ from typing import Callable, Dict, List, Optional, Set, Tuple from torch.distributed.elastic.timer.api import TimerClient, TimerRequest -from torch.distributed.elastic.timer.debug_info_logging import log_debug_info_for_expired_timers +from torch.distributed.elastic.timer.debug_info_logging import ( + log_debug_info_for_expired_timers, +) from torch.distributed.elastic.utils.logging import get_logger + __all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] logger = get_logger(__name__) + class FileTimerRequest(TimerRequest): """ Data object representing a countdown timer acquisition and release @@ -35,7 +39,9 @@ class FileTimerRequest(TimerRequest): __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"] - def __init__(self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0) -> None: + def __init__( + self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 + ) -> None: self.version = 1 self.worker_pid = worker_pid self.scope_id = scope_id @@ -60,7 +66,7 @@ def to_json(self) -> str: "pid": self.worker_pid, "scope_id": self.scope_id, "expiration_time": self.expiration_time, - "signal": self.signal + "signal": self.signal, }, ) @@ -83,8 +89,12 @@ class FileTimerClient(TimerClient): signal: signal, the signal to use to kill the process. Using a negative or zero signal will not kill the process. """ - def __init__(self, file_path: str, signal=(signal.SIGKILL if sys.platform != "win32" else - signal.CTRL_C_EVENT)) -> None: # type: ignore[attr-defined] + + def __init__( + self, + file_path: str, + signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] + ) -> None: super().__init__() self._file_path = file_path self.signal = signal @@ -103,7 +113,9 @@ def _send_request(self, request: FileTimerRequest) -> None: # be raised if the server is not there. file = self._open_non_blocking() if file is None: - raise BrokenPipeError("Could not send the FileTimerRequest because FileTimerServer is not available.") + raise BrokenPipeError( + "Could not send the FileTimerRequest because FileTimerServer is not available." + ) with file: json_request = request.to_json() # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. @@ -120,17 +132,14 @@ def acquire(self, scope_id: str, expiration_time: float) -> None: worker_pid=os.getpid(), scope_id=scope_id, expiration_time=expiration_time, - signal=self.signal + signal=self.signal, ), ) def release(self, scope_id: str) -> None: self._send_request( request=FileTimerRequest( - worker_pid=os.getpid(), - scope_id=scope_id, - expiration_time=-1, - signal=0 + worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 ), ) @@ -161,7 +170,7 @@ def __init__( run_id: str, max_interval: float = 10, daemon: bool = True, - log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None + log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None, ) -> None: self._file_path = file_path self._run_id = run_id @@ -177,18 +186,21 @@ def __init__( self._request_count = 0 # For test only. Process all requests and stop the server. self._run_once = False - self._log_event = log_event if log_event is not None else lambda name, request: None + self._log_event = ( + log_event if log_event is not None else lambda name, request: None + ) self._last_progress_time = int(time.time()) - def start(self) -> None: logger.info( - "Starting %s..." - " max_interval=%s," - " daemon=%s", - type(self).__name__, self._max_interval, self._daemon + "Starting %s..." " max_interval=%s," " daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon ) - self._watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=self._daemon) logger.info("Starting watchdog thread...") self._watchdog_thread.start() self._log_event("watchdog started", None) @@ -255,11 +267,18 @@ def _run_watchdog(self, fd: io.TextIOWrapper) -> None: all_expired_timers = self.get_expired_timers(now) log_debug_info_for_expired_timers( self._run_id, - {pid: self._get_scopes(expired_timers) for pid, expired_timers in all_expired_timers.items()}, + { + pid: self._get_scopes(expired_timers) + for pid, expired_timers in all_expired_timers.items() + }, ) for worker_pid, expired_timers in all_expired_timers.items(): - logger.info("Reaping worker_pid=[%s]. Expired timers: %s", worker_pid, self._get_scopes(expired_timers)) + logger.info( + "Reaping worker_pid=[%s]. Expired timers: %s", + worker_pid, + self._get_scopes(expired_timers), + ) reaped_worker_pids.add(worker_pid) # In case we have multiple expired timers, we find the first timer # with a valid signal (>0) in the expiration time order. @@ -273,19 +292,28 @@ def _run_watchdog(self, fd: io.TextIOWrapper) -> None: expired_timer = timer break if signal <= 0: - logger.info("No signal specified with worker=[%s]. Do not reap it.", worker_pid) + logger.info( + "No signal specified with worker=[%s]. Do not reap it.", worker_pid + ) continue if self._reap_worker(worker_pid, signal): - logger.info("Successfully reaped worker=[%s] with signal=%s", worker_pid, signal) + logger.info( + "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal + ) self._log_event("kill worker process", expired_timer) else: - logger.error("Error reaping worker=[%s]. Will retry on next watchdog.", worker_pid) + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", + worker_pid, + ) self.clear_timers(reaped_worker_pids) def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]: return [r.scope_id for r in timer_requests] - def _get_requests(self, fd: io.TextIOWrapper, max_interval: float) -> List[FileTimerRequest]: + def _get_requests( + self, fd: io.TextIOWrapper, max_interval: float + ) -> List[FileTimerRequest]: start = time.time() requests = [] while not self._stop_signaled or self._run_once: @@ -309,7 +337,10 @@ def _get_requests(self, fd: io.TextIOWrapper, max_interval: float) -> List[FileT signal = request["signal"] requests.append( FileTimerRequest( - worker_pid=pid, scope_id=scope_id, expiration_time=expiration_time, signal=signal + worker_pid=pid, + scope_id=scope_id, + expiration_time=expiration_time, + signal=signal, ) ) now = time.time() @@ -333,7 +364,7 @@ def register_timers(self, timer_requests: List[FileTimerRequest]) -> None: self._timers[key] = request def clear_timers(self, worker_pids: Set[int]) -> None: - for (pid, scope_id) in list(self._timers.keys()): + for pid, scope_id in list(self._timers.keys()): if pid in worker_pids or not FileTimerServer.is_process_running(pid): del self._timers[(pid, scope_id)] diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index b6a54896fc5ef..fe784b7de46d2 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -14,10 +14,12 @@ from .api import RequestQueue, TimerClient, TimerRequest, TimerServer -__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer'] + +__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] logger = logging.getLogger(__name__) + class LocalTimerClient(TimerClient): """ Client side of ``LocalTimerServer``. This client is meant to be used @@ -101,7 +103,7 @@ def register_timers(self, timer_requests: List[TimerRequest]) -> None: self._timers[(pid, scope_id)] = request def clear_timers(self, worker_ids: Set[int]) -> None: - for (pid, scope_id) in list(self._timers.keys()): + for pid, scope_id in list(self._timers.keys()): if pid in worker_ids: self._timers.pop((pid, scope_id)) diff --git a/torch/distributed/elastic/utils/api.py b/torch/distributed/elastic/utils/api.py index e0607e9c0d5dc..bdb8f02e0176f 100644 --- a/torch/distributed/elastic/utils/api.py +++ b/torch/distributed/elastic/utils/api.py @@ -9,7 +9,7 @@ import os import socket from string import Template -from typing import List, Any +from typing import Any, List def get_env_variable_or_raise(env_name: str) -> str: diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 04ff2fe680f1e..1a7ea81451f7c 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -16,6 +16,7 @@ from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.elastic.utils.store import barrier + __all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] logger = get_logger(__name__) @@ -58,7 +59,12 @@ def create_c10d_store( " is_server : %s\n" " timeout(sec): %s\n" " use_libuv : %s\n", - server_addr, port, world_size, is_server, timeout, use_libuv, + server_addr, + port, + world_size, + is_server, + timeout, + use_libuv, ) try: @@ -90,7 +96,10 @@ def create_c10d_store( if str(e) == _ADDRESS_IN_USE: # this will only happen on the server if attempt < retries: logger.warning( - "port: %s already in use, attempt: [%s/%s]", port, attempt, retries + "port: %s already in use, attempt: [%s/%s]", + port, + attempt, + retries, ) attempt += 1 else: diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index 6d2e1f046502b..a94010c432b18 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -7,15 +7,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from contextlib import contextmanager from datetime import timedelta from typing import List -from contextlib import contextmanager + _NUM_MEMBERS = "/num_members" _LAST_MEMBER_CHECKIN = "/last_member" __all__ = ["store_timeout", "get_all", "synchronize", "barrier"] + @contextmanager def store_timeout(store, timeout: float): """ @@ -52,9 +54,7 @@ def get_all(store, rank: int, prefix: str, world_size: int): value3 = values[2] # retrieves the data for key torchelastic/data2 """ - data_arr = store.multi_get( - [f"{prefix}{idx}" for idx in range(world_size)] - ) + data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) barrier_key = _barrier_nonblocking( store=store, @@ -101,7 +101,6 @@ def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: num_members_key = key_prefix + _NUM_MEMBERS last_member_key = key_prefix + _LAST_MEMBER_CHECKIN - idx = store.add(num_members_key, 1) if idx == world_size: store.set(last_member_key, "") @@ -126,5 +125,7 @@ def barrier( """ with store_timeout(store, barrier_timeout): - last_member_key = _barrier_nonblocking(store=store, world_size=world_size, key_prefix=key_prefix) + last_member_key = _barrier_nonblocking( + store=store, world_size=world_size, key_prefix=key_prefix + ) store.get(last_member_key) From 22d258427baf226fe67f888de044a62941c66dd7 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 14:31:39 +0800 Subject: [PATCH 36/63] [BE][Easy] enable UFMT for `torch/distributed/_shard/` (#128867) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128867 Approved by: https://github.com/fegin ghstack dependencies: #128866 --- .lintrunner.toml | 34 -- torch/distributed/_shard/__init__.py | 7 +- torch/distributed/_shard/_utils.py | 18 +- torch/distributed/_shard/api.py | 107 +++--- .../distributed/_shard/checkpoint/__init__.py | 4 +- torch/distributed/_shard/common_op_utils.py | 9 +- torch/distributed/_shard/metadata.py | 18 +- torch/distributed/_shard/op_registry_utils.py | 11 +- .../_shard/sharded_optim/__init__.py | 12 +- torch/distributed/_shard/sharded_optim/api.py | 16 +- .../_shard/sharded_tensor/__init__.py | 192 ++++++----- .../_shard/sharded_tensor/_ops/__init__.py | 14 +- .../_shard/sharded_tensor/_ops/_common.py | 9 +- .../_shard/sharded_tensor/_ops/binary_cmp.py | 32 +- .../_shard/sharded_tensor/_ops/init.py | 31 +- .../_shard/sharded_tensor/_ops/misc_ops.py | 5 +- .../_shard/sharded_tensor/_ops/tensor_ops.py | 10 +- .../distributed/_shard/sharded_tensor/api.py | 320 ++++++++++-------- .../_shard/sharded_tensor/logger.py | 5 +- .../_shard/sharded_tensor/logging_handlers.py | 1 + .../_shard/sharded_tensor/metadata.py | 22 +- .../_shard/sharded_tensor/reshard.py | 19 +- .../_shard/sharded_tensor/shard.py | 16 +- .../_shard/sharded_tensor/utils.py | 154 ++++++--- torch/distributed/_shard/sharder.py | 3 + .../_shard/sharding_plan/__init__.py | 5 +- torch/distributed/_shard/sharding_plan/api.py | 6 +- .../_shard/sharding_spec/__init__.py | 10 +- .../_shard/sharding_spec/_internals.py | 25 +- torch/distributed/_shard/sharding_spec/api.py | 77 +++-- .../sharding_spec/chunk_sharding_spec.py | 77 +++-- torch/distributed/_spmd/api.py | 3 - torch/distributed/_spmd/batch_dim_utils.py | 5 +- torch/distributed/_spmd/config.py | 1 + torch/distributed/_spmd/data_parallel.py | 6 +- torch/distributed/_spmd/distribute.py | 2 - torch/distributed/_spmd/experimental_ops.py | 2 +- torch/distributed/_spmd/graph_optimization.py | 1 + torch/distributed/_spmd/parallel_mode.py | 1 - torch/distributed/_spmd/partial_lower.py | 3 +- 40 files changed, 736 insertions(+), 557 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 2ea1579ee64c2..dc9f9ddd46c7c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1392,40 +1392,6 @@ exclude_patterns = [ 'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable 'torch/distributed/__init__.py', 'torch/distributed/_composable_state.py', - 'torch/distributed/_shard/__init__.py', - 'torch/distributed/_shard/_utils.py', - 'torch/distributed/_shard/api.py', - 'torch/distributed/_shard/checkpoint/__init__.py', - 'torch/distributed/_shard/common_op_utils.py', - 'torch/distributed/_shard/metadata.py', - 'torch/distributed/_shard/op_registry_utils.py', - 'torch/distributed/_shard/sharded_optim/__init__.py', - 'torch/distributed/_shard/sharded_optim/api.py', - 'torch/distributed/_shard/sharded_tensor/__init__.py', - 'torch/distributed/_shard/sharded_tensor/_ops/__init__.py', - 'torch/distributed/_shard/sharded_tensor/_ops/_common.py', - 'torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py', - 'torch/distributed/_shard/sharded_tensor/_ops/init.py', - 'torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py', - 'torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py', - 'torch/distributed/_shard/sharded_tensor/api.py', - 'torch/distributed/_shard/sharded_tensor/logger.py', - 'torch/distributed/_shard/sharded_tensor/logging_handlers.py', - 'torch/distributed/_shard/sharded_tensor/metadata.py', - 'torch/distributed/_shard/sharded_tensor/reshard.py', - 'torch/distributed/_shard/sharded_tensor/shard.py', - 'torch/distributed/_shard/sharded_tensor/utils.py', - 'torch/distributed/_shard/sharder.py', - 'torch/distributed/_shard/sharding_plan/__init__.py', - 'torch/distributed/_shard/sharding_plan/api.py', - 'torch/distributed/_shard/sharding_spec/__init__.py', - 'torch/distributed/_shard/sharding_spec/_internals.py', - 'torch/distributed/_shard/sharding_spec/api.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py', 'torch/distributed/_sharded_tensor/__init__.py', 'torch/distributed/_sharding_spec/__init__.py', 'torch/distributed/_tools/__init__.py', diff --git a/torch/distributed/_shard/__init__.py b/torch/distributed/_shard/__init__.py index 34539d633f8fa..85a313c779e7a 100644 --- a/torch/distributed/_shard/__init__.py +++ b/torch/distributed/_shard/__init__.py @@ -1,6 +1 @@ -from .api import ( - _shard_tensor, - load_with_process_group, - shard_module, - shard_parameter, -) +from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter diff --git a/torch/distributed/_shard/_utils.py b/torch/distributed/_shard/_utils.py index 26305b99cce30..d06fc4dc96144 100644 --- a/torch/distributed/_shard/_utils.py +++ b/torch/distributed/_shard/_utils.py @@ -1,10 +1,17 @@ +from typing import Sequence + import torch from torch.distributed._shard.metadata import ShardMetadata -from typing import Sequence + DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor." -def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: Sequence[int]) -> torch.Tensor: + +def narrow_tensor_by_index( + tensor: torch.Tensor, + offsets: Sequence[int], + sizes: Sequence[int], +) -> torch.Tensor: """ Narrow the tensor according to ``offsets`` and ``sizes``. """ @@ -14,13 +21,10 @@ def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. - narrowed_tensor = narrowed_tensor.narrow( - idx, - offset, - size - ) + narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) return narrowed_tensor + def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor: """ Narrow the tensor according to the metadata diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index 441bb421b195b..5d1e1179c9a78 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -1,21 +1,17 @@ # mypy: allow-untyped-defs from contextlib import contextmanager from typing import Optional + import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import distributed_c10d -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor, -) -from .sharding_spec import ( - ShardingSpec, - ChunkShardingSpec -) -from .sharding_plan import ( - ShardingPlan -) +from torch.distributed._shard.sharded_tensor import ShardedTensor + from .sharder import Sharder +from .sharding_plan import ShardingPlan +from .sharding_spec import ChunkShardingSpec, ShardingSpec + def _shard_tensor( tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None @@ -47,9 +43,13 @@ def _shard_tensor( currently supported as the ``sharding_spec``. """ if not tensor.is_contiguous(): - raise ValueError('input tensor is not a contiguous Tensor') + raise ValueError("input tensor is not a contiguous Tensor") - pg = process_group if process_group is not None else distributed_c10d._get_default_group() + pg = ( + process_group + if process_group is not None + else distributed_c10d._get_default_group() + ) world_size = dist.get_world_size(pg) current_rank = dist.get_rank(pg) @@ -60,23 +60,27 @@ def _shard_tensor( for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( - f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index] - f'match with src_rank={entry[0]} on rank: {idx}') + f"src_rank={src_rank} on rank: {current_rank} does not " + f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index] + ) if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( - f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index] - f'match with sharding_spec={entry[1]} on rank: {idx}') + f"sharding_spec={sharding_spec} on rank: {current_rank} does not " + f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index] + ) st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg) return st + def shard_parameter( - module: torch.nn.Module, - param_name: str, - sharding_spec: ShardingSpec, - src_rank=0, - process_group=None): + module: torch.nn.Module, + param_name: str, + sharding_spec: ShardingSpec, + src_rank=0, + process_group=None, +): """ Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that module, it shards that parameter according to the provided @@ -107,23 +111,27 @@ def shard_parameter( """ # Perform some validation first. if not hasattr(module, param_name): - raise AttributeError(f'{module._get_name()} has no attribute `{param_name}`') + raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`") tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): - raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + raise ValueError( + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) if not tensor.is_contiguous(): - raise ValueError(f'param: {param_name} is not a contiguous Tensor') + raise ValueError(f"param: {param_name} is not a contiguous Tensor") st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) # Replace param with ShardedTensor. module.register_parameter(param_name, nn.Parameter(st)) + # Tracks the current process group in the load context manager. _CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None + @contextmanager def load_with_process_group(process_group): """ @@ -133,13 +141,15 @@ def load_with_process_group(process_group): if _CURRENT_PROCESS_GROUP is not None: raise RuntimeError( 'ProcessGroup already set by previous "load_with_process_group" ' - 'context manager') + "context manager" + ) _CURRENT_PROCESS_GROUP = process_group try: yield process_group finally: _CURRENT_PROCESS_GROUP = None + def _get_current_process_group(): """ Retrieves the current process group set by ``load_with_process_group``. @@ -151,9 +161,10 @@ def _get_current_process_group(): else: return _CURRENT_PROCESS_GROUP + def _reshard_output( - module: torch.nn.Module, - resharding_spec: ShardingSpec) -> torch.nn.Module: + module: torch.nn.Module, resharding_spec: ShardingSpec +) -> torch.nn.Module: """ Hook a module with output resharding in the forward pass according to the given ``resharding_spec``. @@ -166,13 +177,16 @@ def _reshard_output( Returns: A :class:`torch.nn.Module` object with reshard API hooked. """ + def hook_func(_module, _input, output): if isinstance(output, ShardedTensor): return output.reshard(resharding_spec) return output + module.register_forward_hook(hook_func) return module + def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module: """ Hook a module with local shards collection in the forward pass. @@ -196,21 +210,20 @@ def hook_func(_module, _input, output): local_tensor = output.local_tensor() # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec sharding_spec = output._sharding_spec - if isinstance(sharding_spec, ChunkShardingSpec) \ - and local_tensor.size(sharding_spec.dim) == 1: # type: ignore[attr-defined, arg-type] + if ( + isinstance(sharding_spec, ChunkShardingSpec) + and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type] + ): local_tensor = local_tensor.squeeze( output._sharding_spec.dim # type: ignore[attr-defined] ) return local_tensor + module.register_forward_hook(hook_func) return module -def shard_module( - module: nn.Module, - plan: ShardingPlan, - src_rank=0, - process_group=None -): + +def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None): """ Shards a given module according to the provided sharding `plan`. This method first shards all the parameters according to the given sharding `plan`. Then if @@ -249,18 +262,16 @@ def shard_module( for sharder_path in sharder_paths: if module_path.startswith(sharder_path): - raise RuntimeError(f"ShardingPlan is in-valid, trying to shard a parameter: {name}," - f" but there's already a Sharder entry for module {sharder_path}," - f" parameter sharding should not conflict with the submodule tree" - f" that a Sharder is working with!") + raise RuntimeError( + f"ShardingPlan is in-valid, trying to shard a parameter: {name}," + f" but there's already a Sharder entry for module {sharder_path}," + f" parameter sharding should not conflict with the submodule tree" + f" that a Sharder is working with!" + ) mod = module.get_submodule(module_path) shard_parameter( - mod, - param_name, - spec, - src_rank=src_rank, - process_group=process_group + mod, param_name, spec, src_rank=src_rank, process_group=process_group ) elif isinstance(spec, Sharder): parent_mod_path, _, mod_name = name.rpartition(".") @@ -272,7 +283,9 @@ def shard_module( # swap this submodule with the sharded module parent_mod.mod_name = sharded_mod else: - raise TypeError(f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'") + raise TypeError( + f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'" + ) # reshard output if there's an entry in `reshard_output` for this module if plan.output_plan is not None: @@ -281,7 +294,9 @@ def shard_module( mod = module.get_submodule(module_path) _reshard_output(mod, output_spec) else: - raise TypeError(f"Only `ShardingSpec` is supported as output_plan for '{module_path}'") + raise TypeError( + f"Only `ShardingSpec` is supported as output_plan for '{module_path}'" + ) # convert the output back to data parallel for the modules appears in # `return_local_tensor` of the plan, we will call `_collect_local_shard` # to collect the local tensor for output of modules diff --git a/torch/distributed/_shard/checkpoint/__init__.py b/torch/distributed/_shard/checkpoint/__init__.py index 161a43f276d66..85915636a0146 100644 --- a/torch/distributed/_shard/checkpoint/__init__.py +++ b/torch/distributed/_shard/checkpoint/__init__.py @@ -1,9 +1,9 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed.checkpoint` package. import sys -import torch import warnings +import torch from torch.distributed.checkpoint import * # noqa: F403 @@ -16,4 +16,4 @@ stacklevel=2, ) -sys.modules['torch.distributed._shard.checkpoint'] = torch.distributed.checkpoint +sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index 7506f17b046d4..e2573998712b5 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch.utils import _pytree as pytree -from typing import Optional + def _basic_validation(op, args=(), kwargs=None): """ @@ -37,14 +39,15 @@ def validate_pg(e): if isinstance(e, ShardedTensor): if cur_pg is not None and e._process_group is not cur_pg: raise RuntimeError( - 'All distributed tensors should use the ' - 'same ProcessGroup if used together in an op.' + "All distributed tensors should use the " + "same ProcessGroup if used together in an op." ) cur_pg = e._process_group pytree.tree_map_(validate_pg, args) pytree.tree_map_(validate_pg, kwargs) + def _register_default_op(op, decorator): @decorator(op) def tensor_default_op(types, args=(), kwargs=None, pg=None): diff --git a/torch/distributed/_shard/metadata.py b/torch/distributed/_shard/metadata.py index 850b065e4dab0..2611d13ef3aaf 100644 --- a/torch/distributed/_shard/metadata.py +++ b/torch/distributed/_shard/metadata.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs from dataclasses import dataclass -from typing import List, Union, Optional from functools import reduce +from typing import List, Optional, Union from torch.distributed.remote_device import _remote_device + @dataclass class ShardMetadata: """ @@ -22,7 +23,7 @@ class ShardMetadata: Specifies the placement of this shard. """ - __slots__ = ['shard_offsets', 'shard_sizes', 'placement'] + __slots__ = ["shard_offsets", "shard_sizes", "placement"] shard_offsets: List[int] shard_sizes: List[int] @@ -32,7 +33,7 @@ def __init__( self, shard_offsets: List[int], shard_sizes: List[int], - placement: Optional[Union[str, _remote_device]] = None + placement: Optional[Union[str, _remote_device]] = None, ): self.shard_offsets = shard_offsets self.shard_sizes = shard_sizes @@ -42,15 +43,16 @@ def __init__( self.placement = placement if len(self.shard_offsets) != len(self.shard_sizes): raise ValueError( - f'shard_offsets and shard_sizes should have ' - f'the same number of elements, found {len(self.shard_offsets)} ' - f'and {self.shard_sizes} respectively') + f"shard_offsets and shard_sizes should have " + f"the same number of elements, found {len(self.shard_offsets)} " + f"and {self.shard_sizes} respectively" + ) for i in range(len(self.shard_offsets)): if self.shard_offsets[i] < 0: - raise ValueError('shard_offsets should be >=0') + raise ValueError("shard_offsets should be >=0") if self.shard_sizes[i] < 0: - raise ValueError('shard_sizes should be >= 0') + raise ValueError("shard_sizes should be >= 0") def __hash__(self): def _hash_reduce(a, b): diff --git a/torch/distributed/_shard/op_registry_utils.py b/torch/distributed/_shard/op_registry_utils.py index 033dc7c58e0ad..12e0b1895e2f0 100644 --- a/torch/distributed/_shard/op_registry_utils.py +++ b/torch/distributed/_shard/op_registry_utils.py @@ -1,13 +1,16 @@ # mypy: allow-untyped-defs import functools from inspect import signature + from .common_op_utils import _basic_validation + """ Common utilities to register ops on ShardedTensor and PartialTensor. """ + def _register_op(op, func, op_table): """ Performs basic validation and registers the provided op in the given @@ -15,12 +18,14 @@ def _register_op(op, func, op_table): """ if len(signature(func).parameters) != 4: raise TypeError( - f'Custom sharded op function expects signature: ' - f'(types, args, kwargs, process_group), but received ' - f'signature: {signature(func)}') + f"Custom sharded op function expects signature: " + f"(types, args, kwargs, process_group), but received " + f"signature: {signature(func)}" + ) op_table[op] = func + def _decorator_func(wrapped_func, op, op_table): """ Decorator function to register the given ``op`` in the provided diff --git a/torch/distributed/_shard/sharded_optim/__init__.py b/torch/distributed/_shard/sharded_optim/__init__.py index 172213fb0c171..d1508208c1690 100644 --- a/torch/distributed/_shard/sharded_optim/__init__.py +++ b/torch/distributed/_shard/sharded_optim/__init__.py @@ -1,18 +1,16 @@ from typing import Iterator, Tuple, Union -from .api import ShardedOptimizer import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .api import ShardedOptimizer -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor -) def named_params_with_sharded_tensor( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ShardedTensor]]]: - r"""Returns an iterator over module parameters (together with the ShardedTensor parameters), yielding both the name of the parameter as well as the parameter itself. This is typically passed to a @@ -46,7 +44,7 @@ def named_params_with_sharded_tensor( for name, val in vars(mod).items(): if isinstance(val, ShardedTensor) and val not in memo: memo.add(val) - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val # find all nn.Parameters diff --git a/torch/distributed/_shard/sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py index e1acf7dc17a87..1c7c632f22b59 100644 --- a/torch/distributed/_shard/sharded_optim/api.py +++ b/torch/distributed/_shard/sharded_optim/api.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import List, Union, Mapping, Dict, Any +from typing import Any, Dict, List, Mapping, Union import torch.optim as optim from torch import Tensor @@ -12,7 +12,7 @@ def __init__( named_params: Mapping[str, Union[Tensor, ShardedTensor]], optimizer_class, *optimizer_args, - **optimizer_kwargs + **optimizer_kwargs, ): """ ShardedOptimizer collects all tensors and local shard tensors of @@ -80,7 +80,6 @@ def state_dict(self) -> Dict[str, Any]: # TODO: implement state_dict raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!") - def load_state_dict(self, state_dict: Mapping[str, Any]): r"""Loads the ShardedOptimizer state. @@ -89,10 +88,13 @@ def load_state_dict(self, state_dict: Mapping[str, Any]): from a call to :meth:`state_dict`. """ # TODO: implement load_state_dict - raise NotImplementedError("ShardedOptimizer load_state_dict not implemented yet!") + raise NotImplementedError( + "ShardedOptimizer load_state_dict not implemented yet!" + ) def add_param_group(self, param_group: Any): - r"""Add a new param group - """ + r"""Add a new param group""" # TODO: implement add_param_group - raise NotImplementedError("ShardedOptimizer add_param_group not implemented yet!") + raise NotImplementedError( + "ShardedOptimizer add_param_group not implemented yet!" + ) diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 1b846a8dabb49..db7090820ea0a 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -3,34 +3,37 @@ from typing import List, TYPE_CHECKING import torch - -if TYPE_CHECKING: - from torch.distributed._shard.sharding_spec import ShardingSpec -else: - ShardingSpec = "ShardingSpec" +from torch.distributed._shard.op_registry_utils import _decorator_func from .api import ( _CUSTOM_SHARDED_OPS, _SHARDED_OPS, Shard, - ShardedTensorBase, ShardedTensor, + ShardedTensorBase, ShardedTensorMetadata, TensorProperties, ) from .metadata import ShardMetadata # noqa: F401 -from torch.distributed._shard.op_registry_utils import _decorator_func -def empty(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: +if TYPE_CHECKING: + from torch.distributed._shard.sharding_spec import ShardingSpec +else: + ShardingSpec = "ShardingSpec" + + +def empty( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Returns a :class:`ShardedTensor` filled with uninitialized data. Needs to be called on all ranks in an SPMD fashion. @@ -74,15 +77,18 @@ def empty(sharding_spec: ShardingSpec, init_rrefs=init_rrefs, ) -def ones(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def ones( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Returns a :class:`ShardedTensor` with the scalar value 1. Needs to be called on all ranks in an SPMD fashion. @@ -122,18 +128,21 @@ def ones(sharding_spec: ShardingSpec, pin_memory=pin_memory, memory_format=memory_format, process_group=process_group, - init_rrefs=init_rrefs + init_rrefs=init_rrefs, ) -def zeros(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def zeros( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Returns a :class:`ShardedTensor` filled with the scalar value 0. Needs to be called on all ranks in an SPMD fashion. @@ -173,20 +182,23 @@ def zeros(sharding_spec: ShardingSpec, pin_memory=pin_memory, memory_format=memory_format, process_group=process_group, - init_rrefs=init_rrefs + init_rrefs=init_rrefs, ) -def full(sharding_spec: ShardingSpec, - size, - fill_value, - *, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def full( + sharding_spec: ShardingSpec, + size, + fill_value, + *, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype is inferred from fill_value. If dtype is specified, it will override the @@ -229,15 +241,18 @@ def full(sharding_spec: ShardingSpec, torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] return sharded_tensor -def rand(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def rand( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the interval :math:`[0, 1)`. The shape of the tensor is defined by the @@ -282,15 +297,18 @@ def rand(sharding_spec: ShardingSpec, torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] return sharded_tensor -def randn(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def randn( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution with mean `0` and variance `1` (also called standard normal distribution). The shape @@ -336,11 +354,10 @@ def randn(sharding_spec: ShardingSpec, torch.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type] return sharded_tensor + def init_from_local_shards( - local_shards: List[Shard], - *global_size, - process_group=None, - init_rrefs=False) -> ShardedTensor: + local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False +) -> ShardedTensor: """ Creates an :class:`ShardedTensor` from local shards and the global metadata. Needs to be called on all ranks in an SPMD fashion. @@ -388,12 +405,10 @@ def init_from_local_shards( >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) """ return ShardedTensor._init_from_local_shards( - local_shards, - *global_size, - process_group=process_group, - init_rrefs=init_rrefs + local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs ) + def state_dict_hook(module, destination, prefix, local_metadata): """ Hook to add ShardedTensor to Module's ``state_dict``. Needs to be @@ -404,21 +419,32 @@ def state_dict_hook(module, destination, prefix, local_metadata): for attr_name, attr in submodule.__dict__.items(): if isinstance(attr, ShardedTensor): mod_prefix = prefix + submodule_name - key = mod_prefix + ('.' if mod_prefix else '') + attr_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name destination[key] = attr -def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + +def pre_load_state_dict_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): """ Pre-load state dict hook to add ShardedTensor to the module. """ for submodule_name, submodule in module.named_modules(): for attr_name in submodule.__dict__.keys(): mod_prefix = prefix + submodule_name - key = mod_prefix + ('.' if mod_prefix else '') + attr_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name if key in state_dict: if isinstance(state_dict[key], ShardedTensor): setattr(submodule, attr_name, state_dict[key]) + def custom_sharded_op_impl(func): """ Provides a way for users to write their own custom sharded operator. This @@ -450,21 +476,15 @@ def custom_sharded_op_impl(func): func(Callable): Torch function for which we want to provide a sharded implementation (ex: torch.nn.functional.linear) """ - return functools.partial( - _decorator_func, - op=func, - op_table=_CUSTOM_SHARDED_OPS - ) + return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS) + def _sharded_op_impl(func): """ Decorator to register a default sharded op. """ - return functools.partial( - _decorator_func, - op=func, - op_table=_SHARDED_OPS - ) + return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS) + # Import all builtin sharded ops from ._ops import * # noqa: F403 diff --git a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py index c233840f1ecce..be6d01fc8e54e 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -1,9 +1,13 @@ import torch.distributed._shard.sharded_tensor._ops.misc_ops import torch.distributed._shard.sharded_tensor._ops.tensor_ops -from .binary_cmp import equal, allclose -from .init import kaiming_uniform_, normal_, uniform_, constant_ - # Import all ChunkShardingSpec ops -from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding -from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import ( + sharded_embedding, +) +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import ( + sharded_embedding_bag, +) + +from .binary_cmp import allclose, equal +from .init import constant_, kaiming_uniform_, normal_, uniform_ diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py index 4d35d24ecafca..502e0ac9a8552 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -1,11 +1,13 @@ # mypy: allow-untyped-defs import functools + +from torch.distributed._shard.common_op_utils import _basic_validation from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, Shard, ShardedTensor, ) -from torch.distributed._shard.common_op_utils import _basic_validation + def _sharded_op_common(op, early_stop_func, extra_check): """ @@ -35,6 +37,7 @@ def _sharded_op_common(op, early_stop_func, extra_check): func (Callable): Torch function for which we want to provide a sharded implementation (ex: torch.transpose) """ + def decorator_sharded_func(wrapped_func): @functools.wraps(wrapped_func) def wrapper(types, args=(), kwargs=None, pg=None): @@ -55,6 +58,7 @@ def wrapper(types, args=(), kwargs=None, pg=None): return decorator_sharded_func + def _register_sharded_op_on_local_shards( op, early_stop_func=None, extra_check=None, customized_func=None ): @@ -84,6 +88,7 @@ def _register_sharded_op_on_local_shards( func (Callable): registered implementation for sharded op for ``__torch_function__`` dispatch. """ + @_sharded_op_impl(op) @_sharded_op_common(op, early_stop_func, extra_check) def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): @@ -104,5 +109,5 @@ def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): st_metadata, process_group=pg, init_rrefs=st._init_rrefs, - sharding_spec=st.sharding_spec() + sharding_spec=st.sharding_spec(), ) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index 034f914981612..f8db8b6ebe96f 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -2,10 +2,8 @@ import torch import torch.distributed as dist import torch.distributed.distributed_c10d as distributed_c10d -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor, - _sharded_op_impl -) +from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor + def _communicate_result(result, pg): # Gather results from all ranks. @@ -16,26 +14,35 @@ def _communicate_result(result, pg): dist.all_reduce(result_tensor, group=pg) - expected_result = torch.ones(1, device=torch.device(torch.cuda.current_device())) * dist.get_world_size(pg) + expected_result = torch.ones( + 1, device=torch.device(torch.cuda.current_device()) + ) * dist.get_world_size(pg) return torch.equal(result_tensor, expected_result) + def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): if len(args) != 2: - raise ValueError(f'Expected two arguments for torch.{cmp_fun.__name__}') + raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}") result = True st1 = args[0] st2 = args[1] if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): - raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor') + raise TypeError( + f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor" + ) # Verify same PG if st1._process_group != st2._process_group: return False - if distributed_c10d._rank_not_in_group(st1._process_group) or distributed_c10d._rank_not_in_group(st2._process_group): - return distributed_c10d._rank_not_in_group(st1._process_group) == distributed_c10d._rank_not_in_group(st2._process_group) + if distributed_c10d._rank_not_in_group( + st1._process_group + ) or distributed_c10d._rank_not_in_group(st2._process_group): + return distributed_c10d._rank_not_in_group( + st1._process_group + ) == distributed_c10d._rank_not_in_group(st2._process_group) # Verify metadata if st1.metadata() != st2.metadata(): @@ -54,16 +61,19 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): for idx in range(len(st1_local_shards)): if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: return _communicate_result(False, st1._process_group) - if not cmp_fun(st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs): + if not cmp_fun( + st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs + ): return _communicate_result(False, st1._process_group) - return _communicate_result(True, st1._process_group) + @_sharded_op_impl(torch.equal) def equal(types, args, kwargs, process_group): return binary_cmp(torch.equal, types, args, kwargs, process_group) + @_sharded_op_impl(torch.allclose) def allclose(types, args, kwargs, process_group): return binary_cmp(torch.allclose, types, args, kwargs, process_group) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/init.py b/torch/distributed/_shard/sharded_tensor/_ops/init.py index 736190d491e1e..71a9c20b45352 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/init.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -1,14 +1,14 @@ # mypy: allow-untyped-defs import torch import torch.distributed._shard.sharded_tensor as sharded_tensor -from torch.distributed._shard.sharded_tensor import ( - _sharded_op_impl, -) +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + def validate_param(param, param_name): if param is None: raise ValueError(f"param: {param_name} shouldn't be None!") + @_sharded_op_impl(torch.nn.init.uniform_) def uniform_(types, args=(), kwargs=None, pg=None): r""" @@ -22,15 +22,16 @@ def uniform_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - a = kwargs['a'] + a = kwargs["a"] validate_param(a, "a") - b = kwargs['b'] + b = kwargs["b"] validate_param(b, "b") for shard in sharded_tensor.local_shards(): torch.nn.init.uniform_(shard.tensor, a=a, b=b) return sharded_tensor + @_sharded_op_impl(torch.nn.init.normal_) def normal_(types, args=(), kwargs=None, pg=None): r""" @@ -44,15 +45,16 @@ def normal_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - mean = kwargs['mean'] + mean = kwargs["mean"] validate_param(mean, "mean") - std = kwargs['std'] + std = kwargs["std"] validate_param(std, "std") for shard in sharded_tensor.local_shards(): torch.nn.init.normal_(shard.tensor, mean=mean, std=std) return sharded_tensor + @_sharded_op_impl(torch.nn.init.kaiming_uniform_) def kaiming_uniform_(types, args=(), kwargs=None, pg=None): r""" @@ -78,17 +80,20 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - a = kwargs['a'] + a = kwargs["a"] validate_param(a, "a") - mode = kwargs['mode'] + mode = kwargs["mode"] validate_param(mode, "mode") - nonlinearity = kwargs['nonlinearity'] + nonlinearity = kwargs["nonlinearity"] validate_param(nonlinearity, "nonlinearity") for shard in sharded_tensor.local_shards(): - torch.nn.init.kaiming_uniform_(shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity) + torch.nn.init.kaiming_uniform_( + shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity + ) return sharded_tensor + @_sharded_op_impl(torch.nn.init.constant_) def constant_(types, args=(), kwargs=None, pg=None): r""" @@ -100,12 +105,13 @@ def constant_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - val = kwargs['val'] + val = kwargs["val"] validate_param(val, "val") for shard in sharded_tensor.local_shards(): torch.nn.init.constant_(shard.tensor, val=val) return sharded_tensor + tensor_like_creation_op_map = { torch.full_like: sharded_tensor.full, torch.empty_like: sharded_tensor.empty, @@ -115,6 +121,7 @@ def constant_(types, args=(), kwargs=None, pg=None): torch.randn_like: sharded_tensor.randn, } + # tensor ops that behave the same as the default tensor def register_tensor_creation_op(op): @_sharded_op_impl(op) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py index 82737f82de533..8b84c1684c324 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -1,8 +1,7 @@ # mypy: allow-untyped-defs import torch -from torch.distributed._shard.sharded_tensor import ( - _sharded_op_impl, -) +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + # This is used by `_apply()` within module.py to set new # parameters after apply a certain method, we should follow diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index 7de78bf61f3f1..93902d6f314c5 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -1,15 +1,15 @@ # mypy: allow-untyped-defs import copy + import torch +from torch.distributed._shard.common_op_utils import _register_default_op from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, Shard, ShardedTensor, ) -from ._common import ( - _register_sharded_op_on_local_shards, -) -from torch.distributed._shard.common_op_utils import _register_default_op + +from ._common import _register_sharded_op_on_local_shards # Tensor properties access @@ -33,6 +33,7 @@ _register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[union-attr] _register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined] + # device property is ambiguous as from a global prospective, # ShardedTensor.device consists of multiple devices (might even across hosts) # We choose to return the current device of the local tensor to represent @@ -52,6 +53,7 @@ def tensor_device(types, args=(), kwargs=None, pg=None): dev = torch.device(torch.cuda.current_device()) return dev + @_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined] def st_is_meta(types, args=(), kwargs=None, pg=None): return args[0].local_tensor().is_meta diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index bf5db21b9a16c..68df582cd5145 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1,57 +1,48 @@ # mypy: allow-untyped-defs from __future__ import annotations # type: ignore[attr-defined] -from dataclasses import dataclass -from typing import ( - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - cast, - TYPE_CHECKING, -) -from typing_extensions import deprecated + import copy +import operator +import threading import warnings -from functools import reduce import weakref +from dataclasses import dataclass +from functools import reduce +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing_extensions import deprecated -import threading import torch import torch.distributed as dist -from torch.distributed import rpc -from torch.distributed import distributed_c10d import torch.distributed._shard.sharding_spec as shard_spec -from torch.distributed._shard.sharding_spec.api import ( - _dispatch_custom_op, - _has_custom_op, -) +from torch.distributed import distributed_c10d, rpc +from torch.distributed._shard._utils import DEPRECATE_MSG from torch.distributed._shard.sharding_spec._internals import ( check_tensor, validate_non_overlapping_shards_metadata, ) -from torch.distributed._shard._utils import ( - DEPRECATE_MSG, +from torch.distributed._shard.sharding_spec.api import ( + _dispatch_custom_op, + _has_custom_op, ) +from torch.distributed.remote_device import _remote_device +from torch.utils import _pytree as pytree -from .metadata import TensorProperties, ShardedTensorMetadata +from .metadata import ShardedTensorMetadata, TensorProperties +from .reshard import reshard_local_shard, reshuffle_local_shard from .shard import Shard -from .reshard import reshuffle_local_shard, reshard_local_shard from .utils import ( _flatten_tensor_size, _parse_and_validate_remote_device, _validate_output_tensor_for_gather, + build_global_metadata, build_metadata_from_local_shards, - build_global_metadata ) -from torch.distributed.remote_device import _remote_device -from torch.utils import _pytree as pytree -import operator + if TYPE_CHECKING: from torch.distributed._shard.metadata import ShardMetadata + # Tracking for sharded tensor objects. _sharded_tensor_lock = threading.Lock() _sharded_tensor_current_id = 0 @@ -63,18 +54,23 @@ # Customized user ops _CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {} -def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int): + +def _register_remote_shards( + sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int +): with _sharded_tensor_lock: if sharded_tensor_id not in _sharded_tensor_map: raise RuntimeError( - f'Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}') + f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}" + ) sharded_tensor = _sharded_tensor_map[sharded_tensor_id]() if sharded_tensor is None: - raise RuntimeError('ShardedTensor weakref has been deallocated') + raise RuntimeError("ShardedTensor weakref has been deallocated") else: sharded_tensor._register_remote_shards(rrefs, rpc_rank) + class ShardedTensorBase(torch.Tensor): _sharding_spec: shard_spec.ShardingSpec _metadata: ShardedTensorMetadata @@ -191,6 +187,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): "but the there is no custom __torch_dispatch__ implementation for it." ) + class ShardedTensor(ShardedTensorBase): """ ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded @@ -239,6 +236,7 @@ class ShardedTensor(ShardedTensorBase): individual GPU, via ``torch.cuda.set_device()`` """ + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): self = super().__new__(cls, sharding_spec, *size, **kwargs) return self @@ -260,22 +258,26 @@ def __init__( self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) if layout != torch.strided: - raise ValueError('Only torch.strided layout is currently supported') + raise ValueError("Only torch.strided layout is currently supported") if memory_format != torch.contiguous_format: - raise ValueError('Only torch.contiguous_format memory_format is currently supported') + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) self._metadata.tensor_properties.memory_format = memory_format current_rank = dist.get_rank() # global rank for shard_metadata in self._metadata.shards_metadata: - rank, device = _parse_and_validate_remote_device(self._process_group, shard_metadata.placement) + rank, device = _parse_and_validate_remote_device( + self._process_group, shard_metadata.placement + ) if rank == current_rank: local_tensor = _create_tensor_from_params( shard_metadata.shard_sizes, local_device=device, - tensor_properties=self._metadata.tensor_properties + tensor_properties=self._metadata.tensor_properties, ) self._local_shards.append(Shard(local_tensor, shard_metadata)) @@ -300,8 +302,9 @@ def _post_init(self): if not rpc._is_current_rpc_agent_set(): raise RuntimeError( - 'RPC Framework needs to be initialized using' - ' torch.distributed.rpc.init_rpc if init_rrefs is set to True') + "RPC Framework needs to be initialized using" + " torch.distributed.rpc.init_rpc if init_rrefs is set to True" + ) self._init_rpc() def __del__(self): @@ -320,9 +323,9 @@ def _init_rpc(self): rpc_rank = rpc.get_worker_info().id if pg_rank != rpc_rank: raise ValueError( - f'Default ProcessGroup and RPC ranks must be ' - f'the same for ShardedTensor, found process group rank: ' - f'{pg_rank} and RPC rank: {rpc_rank}' + f"Default ProcessGroup and RPC ranks must be " + f"the same for ShardedTensor, found process group rank: " + f"{pg_rank} and RPC rank: {rpc_rank}" ) self._remote_shards = {} @@ -347,11 +350,14 @@ def _init_rpc(self): continue if len(self.local_shards()) != 0: - rrefs: List[rpc.RRef[Shard]] = [rpc.RRef(shard) for shard in self.local_shards()] + rrefs: List[rpc.RRef[Shard]] = [ + rpc.RRef(shard) for shard in self.local_shards() + ] fut = rpc.rpc_async( rank, _register_remote_shards, - args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank)) + args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank), + ) futs.append(fut) torch.futures.wait_all(futs) @@ -394,6 +400,7 @@ def gather( # type: ignore[override] dtype (torch.dtype): Force the gathered tensors to be this dtype. Default: ``None`` """ + def shard_size(shard_md): return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] @@ -429,7 +436,10 @@ def shard_size(shard_md): # enforce_dtype is deprecated. Do it for backward compatibility. dtype = out.dtype # TODO make it as a view of out tensor - gather_list = [torch.empty((max_rank_size,), device=out.device, dtype=dtype) for _ in range(world_size)] + gather_list = [ + torch.empty((max_rank_size,), device=out.device, dtype=dtype) + for _ in range(world_size) + ] else: gather_list = None @@ -437,15 +447,19 @@ def shard_size(shard_md): if enforce_dtype and len(local_shards) > 0: # enforce_dtype is deprecated. Do it for backward compatibility. dtype = local_shards[0].tensor.dtype - data = torch.empty(max_rank_size, device=self._get_preferred_device(), dtype=dtype) + data = torch.empty( + max_rank_size, device=self._get_preferred_device(), dtype=dtype + ) for shard in local_shards: src = shard.tensor.flatten() - if src.nelement() == 0 : - warnings.warn("Gathering a tensor with zero elements on rank " + str(rank)) + if src.nelement() == 0: + warnings.warn( + "Gathering a tensor with zero elements on rank " + str(rank) + ) return shard_offset = shard_placement[shard.metadata][1] - data[shard_offset: shard_offset + src.numel()].copy_(src) + data[shard_offset : shard_offset + src.numel()].copy_(src) dist.gather( tensor=data, @@ -478,9 +492,7 @@ def shard_size(shard_md): out_narrow_view.copy_(tensor) def cpu( - self, - memory_format=torch.preserve_format, - process_group=None + self, memory_format=torch.preserve_format, process_group=None ) -> ShardedTensor: """ Returns a copy of this object in CPU memory. @@ -495,13 +507,17 @@ def cpu( """ # TODO: make this a __torch_function__ op once ShardedTensor becomes a # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402 - if memory_format != torch.preserve_format and \ - memory_format != torch.contiguous_format: - raise RuntimeError("Only `torch.contiguous_format` or " - "`torch.preserve_format` is supported!") + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) all_on_cpu = True for meta in self.metadata().shards_metadata: - all_on_cpu &= (meta.placement.device().type == "cpu") # type: ignore[union-attr] + all_on_cpu &= meta.placement.device().type == "cpu" # type: ignore[union-attr] # if every shard is already on CPU, return the original object if all_on_cpu: @@ -514,9 +530,7 @@ def cpu( cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] metadata = copy.deepcopy(shard.metadata) metadata.placement._device = torch.device("cpu") # type: ignore[union-attr] - list_shards.append( - Shard(cpu_tensor, metadata) - ) + list_shards.append(Shard(cpu_tensor, metadata)) st_meta = copy.deepcopy(self.metadata()) for meta in st_meta.shards_metadata: @@ -528,7 +542,7 @@ def cpu( list_shards, sharded_tensor_metadata=st_meta, process_group=pg, - init_rrefs=self._init_rrefs + init_rrefs=self._init_rrefs, ) return st_cpu @@ -537,7 +551,7 @@ def cuda( device=None, non_blocking=False, memory_format=torch.preserve_format, - process_group=None + process_group=None, ) -> ShardedTensor: """ Returns a copy of this object in CUDA memory, if the original ShardedTensor @@ -551,15 +565,21 @@ def cuda( it is the user's responsiblity to explicitly pass in a new process_group that is compatible with GPU. """ - if memory_format != torch.preserve_format and \ - memory_format != torch.contiguous_format: - raise RuntimeError("Only `torch.contiguous_format` or " - "`torch.preserve_format` is supported!") + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) if device is not None: device = torch.device(device) if isinstance(device, str) else device - assert isinstance(device, torch.device) and device.index == torch.cuda.current_device(), \ - '''Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!''' + assert ( + isinstance(device, torch.device) + and device.index == torch.cuda.current_device() + ), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" current_device = torch.device(torch.cuda.current_device()) # returns a copy of ShardedTensor on CUDA current device @@ -571,14 +591,12 @@ def cuda( cuda_tensor = shard.tensor.cuda( device=current_device, non_blocking=non_blocking, - memory_format=memory_format + memory_format=memory_format, ) # type: ignore[call-arg] metadata = copy.deepcopy(shard.metadata) metadata.placement._device = current_device # type: ignore[union-attr] - list_shards.append( - Shard(cuda_tensor, metadata) - ) + list_shards.append(Shard(cuda_tensor, metadata)) st_meta = copy.deepcopy(self.metadata()) for meta in st_meta.shards_metadata: @@ -592,7 +610,7 @@ def cuda( list_shards, sharded_tensor_metadata=st_meta, process_group=pg, - init_rrefs=self._init_rrefs + init_rrefs=self._init_rrefs, ) return st_cuda @@ -625,15 +643,19 @@ def to(self, *args, **kwargs) -> ShardedTensor: dtype_to = kwargs.get("dtype", current_dtype) device_to = kwargs.get("device", current_device) - device_to = torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + device_to = ( + torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + ) if device_to.type == "cuda": # if device_to set to cuda, set to current device even # if user specify the device index. current_idx = torch.cuda.current_device() if device_to.index != current_idx: - warnings.warn("ShardedTensor.to only move tensor to its current device" - "If you want to put to different device, use `reshard` instead.") + warnings.warn( + "ShardedTensor.to only move tensor to its current device" + "If you want to put to different device, use `reshard` instead." + ) device_to = torch.device(current_idx) copy_tensor = kwargs.get("copy", False) @@ -641,7 +663,11 @@ def to(self, *args, **kwargs) -> ShardedTensor: memory_format = kwargs.get("memory_format", torch.preserve_format) process_group = kwargs.get("process_group", None) - if not copy_tensor and dtype_to == current_dtype and device_to == current_device: + if ( + not copy_tensor + and dtype_to == current_dtype + and device_to == current_device + ): # already have correct dtype and device, return itself return self @@ -654,7 +680,7 @@ def to(self, *args, **kwargs) -> ShardedTensor: dtype=dtype_to, non_blocking=non_blocking, copy=copy_tensor, - memory_format=memory_format + memory_format=memory_format, ) metadata = copy.deepcopy(shard.metadata) if metadata.placement is not None: @@ -674,12 +700,14 @@ def to(self, *args, **kwargs) -> ShardedTensor: list_shards, sharded_tensor_metadata=st_meta, process_group=pg, - init_rrefs=self._init_rrefs + init_rrefs=self._init_rrefs, ) return st_to @classmethod - def _normalize_pg(cls, process_group: Optional[dist.ProcessGroup]) -> dist.ProcessGroup: + def _normalize_pg( + cls, process_group: Optional[dist.ProcessGroup] + ) -> dist.ProcessGroup: if process_group is not None: return process_group return distributed_c10d._get_default_group() @@ -701,8 +729,9 @@ def _init_from_local_shards( global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: - local_sharded_tensor_metadata = \ - build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) + local_sharded_tensor_metadata = build_metadata_from_local_shards( + local_shards, global_tensor_size, current_rank, process_group + ) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata @@ -711,9 +740,7 @@ def _init_from_local_shards( gathered_metadatas = [None for _ in range(world_size)] dist.all_gather_object( - gathered_metadatas, - local_sharded_tensor_metadata, - group=process_group + gathered_metadatas, local_sharded_tensor_metadata, group=process_group ) else: gathered_metadatas = [local_sharded_tensor_metadata] @@ -726,13 +753,15 @@ def _init_from_local_shards( spec = shard_spec._infer_sharding_spec_from_shards_metadata( global_sharded_tensor_metadata.shards_metadata ) - sharded_tensor = cls.__new__(cls, - spec, - global_sharded_tensor_metadata.size, - dtype=tensor_properties.dtype, - layout=tensor_properties.layout, - pin_memory=tensor_properties.pin_memory, - requires_grad=tensor_properties.requires_grad) + sharded_tensor = cls.__new__( + cls, + spec, + global_sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # attach local_shards to the ShardedTensor created @@ -809,7 +838,7 @@ def _init_from_local_tensor( sharding spec. """ if not local_tensor.is_contiguous(): - raise ValueError('local_tensor is not a contiguous Tensor.') + raise ValueError("local_tensor is not a contiguous Tensor.") global_tensor_size = _flatten_tensor_size(global_size) tensor_properties = TensorProperties( @@ -817,10 +846,10 @@ def _init_from_local_tensor( layout=local_tensor.layout, requires_grad=local_tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=local_tensor.is_pinned()) + pin_memory=local_tensor.is_pinned(), + ) sharded_tensor_metadata = sharding_spec.build_metadata( - global_tensor_size, - tensor_properties + global_tensor_size, tensor_properties ) process_group = cls._normalize_pg(process_group) @@ -828,7 +857,9 @@ def _init_from_local_tensor( local_shards: List[Shard] = [] for shard_metadata in sharded_tensor_metadata.shards_metadata: - rank, device = _parse_and_validate_remote_device(process_group, shard_metadata.placement) + rank, device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) if rank == current_rank: local_shards.append(Shard(local_tensor, shard_metadata)) @@ -868,16 +899,18 @@ def _init_from_local_shards_and_global_metadata( # type: ignore[override] # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] - rank, local_device = _parse_and_validate_remote_device(process_group, shard_metadata.placement) + rank, local_device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( - f'Number of local shards ({len(local_shards)}) does not match number of local ' - f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' - f'on rank ({current_rank}) ' + f"Number of local shards ({len(local_shards)}) does not match number of local " + f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) " + f"on rank ({current_rank}) " ) shards_metadata = sharded_tensor_metadata.shards_metadata @@ -1056,12 +1089,11 @@ def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 """ - if ( - not isinstance(resharding_spec, shard_spec.ChunkShardingSpec) or - not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec) - ): + if not isinstance( + resharding_spec, shard_spec.ChunkShardingSpec + ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") - if (len(self.local_shards()) != 1): + if len(self.local_shards()) != 1: raise NotImplementedError("Only single local shard supported for reshard.") if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] @@ -1110,12 +1142,7 @@ def dispatch(st: ShardedTensor, func: Callable): # Dispatch to custom sharding spec op if it has one. if _has_custom_op(st._sharding_spec, func): return _dispatch_custom_op( - st._sharding_spec, - func, - types, - args, - kwargs, - st._process_group + st._sharding_spec, func, types, args, kwargs, st._process_group ) if func in _SHARDED_OPS: @@ -1123,7 +1150,8 @@ def dispatch(st: ShardedTensor, func: Callable): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for ShardedTensor!") + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) # Find ShardedTensor instance to get process_group and sharding_spec. st_instance = None @@ -1141,7 +1169,8 @@ def find_sharded_tensor(e): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for ShardedTensor!") + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) def is_pinned(self) -> bool: # type: ignore[override] """ @@ -1149,7 +1178,9 @@ def is_pinned(self) -> bool: # type: ignore[override] """ return self._metadata.tensor_properties.pin_memory - def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int): + def _register_remote_shards( + self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int + ): self._remote_shards[rpc_rank] = remote_shards def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: @@ -1162,7 +1193,7 @@ def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: """ if not self._init_rrefs: raise RuntimeError( - 'ShardedTensor created with init_rrefs=False, no RRefs to remote shards available' + "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available" ) return self._remote_shards @@ -1170,13 +1201,14 @@ def __hash__(self): return id(self) def __repr__(self): - return f'ShardedTensor({self._metadata})' + return f"ShardedTensor({self._metadata})" @dataclass class ProcessGroupState: """ State for ser-de of process group """ + local_rank: int global_rank: int local_world_size: int @@ -1190,51 +1222,71 @@ def __getstate__(self): distributed_c10d.get_world_size(), ) - return self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs + return ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) def __setstate__(self, state): self._sharded_tensor_id = None if not distributed_c10d.is_initialized(): raise RuntimeError( - 'Need to initialize default process group using ' - '"init_process_group" before loading ShardedTensor') + "Need to initialize default process group using " + '"init_process_group" before loading ShardedTensor' + ) - self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs = state + ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) = state # Setup process group from torch.distributed._shard.api import _get_current_process_group + self._process_group = _get_current_process_group() # Validate process group. local_rank = distributed_c10d.get_rank(self._process_group) if pg_state.local_rank != local_rank: raise RuntimeError( - f'Local rank at save time was {pg_state.local_rank}, but at ' - f'load time was {local_rank}') + f"Local rank at save time was {pg_state.local_rank}, but at " + f"load time was {local_rank}" + ) global_rank = distributed_c10d.get_rank() if pg_state.global_rank != global_rank: raise RuntimeError( - f'Global rank at save time was {pg_state.global_rank}, but at ' - f'load time was {global_rank}') + f"Global rank at save time was {pg_state.global_rank}, but at " + f"load time was {global_rank}" + ) local_world_size = distributed_c10d.get_world_size(self._process_group) if pg_state.local_world_size != local_world_size: raise RuntimeError( - f'Local world size at save time was {pg_state.local_world_size}, ' - f'but at load time was {local_world_size}') + f"Local world size at save time was {pg_state.local_world_size}, " + f"but at load time was {local_world_size}" + ) global_world_size = distributed_c10d.get_world_size() if pg_state.global_world_size != global_world_size: raise RuntimeError( - f'Global world size at save time was {pg_state.global_world_size}, ' - f'but at load time was {global_world_size}') + f"Global world size at save time was {pg_state.global_world_size}, " + f"but at load time was {global_world_size}" + ) self._post_init() -def _create_tensor_from_params(*size, local_device, tensor_properties: TensorProperties): - """ Helper to construct tensor from size, device and common params. """ +def _create_tensor_from_params( + *size, local_device, tensor_properties: TensorProperties +): + """Helper to construct tensor from size, device and common params.""" dtype = tensor_properties.dtype layout = tensor_properties.layout requires_grad = tensor_properties.requires_grad @@ -1242,7 +1294,11 @@ def _create_tensor_from_params(*size, local_device, tensor_properties: TensorPro pin_memory = tensor_properties.pin_memory return torch.empty( - *size, dtype=dtype, layout=layout, - device=local_device, requires_grad=requires_grad, - memory_format=memory_format, pin_memory=pin_memory + *size, + dtype=dtype, + layout=layout, + device=local_device, + requires_grad=requires_grad, + memory_format=memory_format, + pin_memory=pin_memory, ) diff --git a/torch/distributed/_shard/sharded_tensor/logger.py b/torch/distributed/_shard/sharded_tensor/logger.py index 87cb74fbd01d2..ebb749dc7d5c7 100644 --- a/torch/distributed/_shard/sharded_tensor/logger.py +++ b/torch/distributed/_shard/sharded_tensor/logger.py @@ -9,9 +9,8 @@ import logging from typing import List, Tuple -from torch.distributed._shard.sharded_tensor.logging_handlers import ( - _log_handlers, -) +from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers + __all__: List[str] = [] diff --git a/torch/distributed/_shard/sharded_tensor/logging_handlers.py b/torch/distributed/_shard/sharded_tensor/logging_handlers.py index 3c607fe45da77..021ad100f06a8 100644 --- a/torch/distributed/_shard/sharded_tensor/logging_handlers.py +++ b/torch/distributed/_shard/sharded_tensor/logging_handlers.py @@ -9,6 +9,7 @@ import logging from typing import Dict, List + __all__: List[str] = [] _log_handlers: Dict[str, logging.Handler] = { diff --git a/torch/distributed/_shard/sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py index 8b3257240e383..e53ac25fa55d9 100644 --- a/torch/distributed/_shard/sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -6,14 +6,16 @@ import torch from torch.distributed._shard.metadata import ShardMetadata + class MEM_FORMAT_ENCODING(Enum): TORCH_CONTIGUOUS_FORMAT = 0 TORCH_CHANNELS_LAST = 1 TORCH_PRESERVE_FORMAT = 2 + @dataclass class TensorProperties: - """ Properties used to create :class:`Tensor` """ + """Properties used to create :class:`Tensor`""" # Regular tensor fields dtype: torch.dtype = field(default=torch.get_default_dtype()) @@ -32,7 +34,7 @@ def __getstate__(self): elif memory_format == torch.preserve_format: mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT else: - raise RuntimeError(f'Invalid torch.memory_format: {memory_format}') + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") return ( self.dtype, @@ -46,7 +48,13 @@ def __setstate__( self, state, ): - (self.dtype, self.layout, self.requires_grad, mem_format_encoding, self.pin_memory) = state + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: memory_format = torch.contiguous_format @@ -55,7 +63,9 @@ def __setstate__( elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: memory_format = torch.preserve_format else: - raise RuntimeError(f'Invalid torch.memory_format encoding: {mem_format_encoding}') + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) self.memory_format = memory_format @@ -66,8 +76,10 @@ def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned() + pin_memory=tensor.is_pinned(), ) + + @dataclass class ShardedTensorMetadata: """ diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index 549dde38cdf8a..9a82012d59cd3 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -4,19 +4,14 @@ import torch import torch.distributed as dist -from torch._C._distributed_c10d import ( - ProcessGroup, -) import torch.distributed._shard.sharding_spec as shard_spec +from torch._C._distributed_c10d import ProcessGroup +from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharding_spec._internals import ( - get_split_size, get_chunked_dim_size, + get_split_size, ) -from torch.distributed.nn.functional import ( - all_to_all, - all_to_all_single, -) -from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed.nn.functional import all_to_all, all_to_all_single from .shard import Shard @@ -42,7 +37,7 @@ def get_idx_from_placements(placements, current_rank) -> int: for idx, placement in enumerate(placements): # type: ignore[attr-defined] if current_rank == placement.rank(): # type: ignore[union-attr] return idx - raise RuntimeError('current_rank not in the placement.') + raise RuntimeError("current_rank not in the placement.") def build_reshard_metadata( @@ -138,7 +133,9 @@ def reshuffle_local_shard( local_shard = local_shard.transpose(0, reshard_dim).contiguous() gathered_input_size = list(local_shard.size()) gathered_input_size[0] = sharded_dim_size - gathered_input = torch.empty(gathered_input_size, device=local_shard.device, dtype=local_shard.dtype) + gathered_input = torch.empty( + gathered_input_size, device=local_shard.device, dtype=local_shard.dtype + ) # all2all. local_shard = all_to_all_single( gathered_input, diff --git a/torch/distributed/_shard/sharded_tensor/shard.py b/torch/distributed/_shard/sharded_tensor/shard.py index ac1e881370e81..dcb6b3b5d6267 100644 --- a/torch/distributed/_shard/sharded_tensor/shard.py +++ b/torch/distributed/_shard/sharded_tensor/shard.py @@ -18,7 +18,8 @@ class Shard: metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`): The metadata for the shard, including offsets, lengths and device placement. """ - __slots__ = ['tensor', 'metadata'] + + __slots__ = ["tensor", "metadata"] tensor: torch.Tensor metadata: ShardMetadata @@ -31,7 +32,10 @@ def __post_init__(self): f"metadata.shard_lengths: {self.metadata.shard_sizes}, " ) placement_device = self.metadata.placement - if placement_device is not None and placement_device.device() != self.tensor.device: + if ( + placement_device is not None + and placement_device.device() != self.tensor.device + ): raise ValueError( f"Local shard tensor device does not match with local Shard's placement! " f"Found local shard tensor device: {self.tensor.device}, " @@ -39,7 +43,9 @@ def __post_init__(self): ) @classmethod - def from_tensor_and_offsets(cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int): + def from_tensor_and_offsets( + cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int + ): """ Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. @@ -52,8 +58,6 @@ def from_tensor_and_offsets(cls, tensor: torch.Tensor, shard_offsets: List[int], shard_sizes = list(tensor.size()) placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") shard_meta = ShardMetadata( - shard_offsets=shard_offsets, - shard_sizes=shard_sizes, - placement=placement + shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement ) return Shard(tensor, shard_meta) diff --git a/torch/distributed/_shard/sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py index 782def0e4d4c2..a6954813f82b2 100644 --- a/torch/distributed/_shard/sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -1,22 +1,23 @@ # mypy: allow-untyped-defs import collections.abc import copy -from typing import Optional, List, Sequence, TYPE_CHECKING +from typing import List, Optional, Sequence, TYPE_CHECKING import torch -from torch.distributed import distributed_c10d as c10d -from torch.distributed import rpc +from torch.distributed import distributed_c10d as c10d, rpc from torch.distributed._shard.sharding_spec._internals import ( check_tensor, validate_non_overlapping_shards_metadata, ) -from .metadata import TensorProperties, ShardedTensorMetadata +from .metadata import ShardedTensorMetadata, TensorProperties from .shard import Shard + if TYPE_CHECKING: from torch.distributed._shard.metadata import ShardMetadata + def _parse_and_validate_remote_device(pg, remote_device): if remote_device is None: raise ValueError("remote device is None") @@ -48,6 +49,7 @@ def _parse_and_validate_remote_device(pg, remote_device): return rank, device + def _validate_output_tensor_for_gather( my_rank: int, dst_rank: int, @@ -66,10 +68,10 @@ def _validate_output_tensor_for_gather( ) elif dst_tensor: raise ValueError( - "Argument ``dst_tensor`` must NOT be specified " - "on non-destination ranks." + "Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks." ) + def _flatten_tensor_size(size) -> torch.Size: """ Checks if tensor size is valid, then flatten/return a torch.Size object. @@ -81,33 +83,37 @@ def _flatten_tensor_size(size) -> torch.Size: for dim in dims: if not isinstance(dim, int): - raise TypeError(f'size has to be a sequence of ints, found: {dims}') + raise TypeError(f"size has to be a sequence of ints, found: {dims}") return torch.Size(dims) + def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): if is_local: assert isinstance(ranks, int) if expected != actual: - raise ValueError(f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " - f"Found one local shard tensor {prop_name}={expected}, " - f"the other local shard tensor {prop_name}={actual}.") + raise ValueError( + f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " + f"Found one local shard tensor {prop_name}={expected}, " + f"the other local shard tensor {prop_name}={actual}." + ) else: # compare failure check across ranks, ranks list should have two rank assert len(ranks) == 2 if expected != actual: - raise ValueError(f"ShardedTensor {prop_name} property does not match from different ranks! " - f"Found {prop_name}={expected} on rank:{ranks[0]}, " - f"and {prop_name}={actual} on rank:{ranks[1]}.") + raise ValueError( + f"ShardedTensor {prop_name} property does not match from different ranks! " + f"Found {prop_name}={expected} on rank:{ranks[0]}, " + f"and {prop_name}={actual} on rank:{ranks[1]}." + ) def build_metadata_from_local_shards( local_shards: List[Shard], global_size: torch.Size, current_rank: int, - pg: c10d.ProcessGroup + pg: c10d.ProcessGroup, ) -> ShardedTensorMetadata: - assert len(local_shards) > 0, "must have local shards!" local_shard_metadatas: List[ShardMetadata] = [] @@ -121,21 +127,28 @@ def build_metadata_from_local_shards( local_shard_tensor = local_shard.tensor local_shard_meta = local_shard.metadata local_shard_metadatas.append(local_shard_meta) - rank, local_device = _parse_and_validate_remote_device(pg, local_shard_meta.placement) + rank, local_device = _parse_and_validate_remote_device( + pg, local_shard_meta.placement + ) - if local_shard_tensor.layout != torch.strided or local_shard_tensor.layout != first_shard_layout: + if ( + local_shard_tensor.layout != torch.strided + or local_shard_tensor.layout != first_shard_layout + ): raise ValueError( - f'Only torch.strided layout is currently supported, but found ' - f'{local_shard_tensor.layout} on rank:{current_rank}!' + f"Only torch.strided layout is currently supported, but found " + f"{local_shard_tensor.layout} on rank:{current_rank}!" ) if not local_shard_tensor.is_contiguous(): - raise ValueError('Only torch.contiguous_format memory_format is currently supported!') + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported!" + ) if rank != current_rank: raise ValueError( f"Local shard metadata's rank does not match with the rank in its process group! " - f'Found current rank in the process group: {current_rank}, ' + f"Found current rank in the process group: {current_rank}, " f"local ShardMetadata placement's rank: {rank}" ) if local_shard_tensor.device != local_device: @@ -145,10 +158,27 @@ def build_metadata_from_local_shards( f"local shard metadata placement device: {local_device}" ) - _raise_if_mismatch(local_shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) - _raise_if_mismatch(local_shard_tensor.is_pinned(), first_shard_is_pinned, "pin_memory", current_rank) - _raise_if_mismatch(local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank) - _raise_if_mismatch(local_shard_tensor.requires_grad, first_shard_requires_grad, "requires_grad", current_rank) + _raise_if_mismatch( + local_shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.is_pinned(), + first_shard_is_pinned, + "pin_memory", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank + ) + _raise_if_mismatch( + local_shard_tensor.requires_grad, + first_shard_requires_grad, + "requires_grad", + current_rank, + ) # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then # do all_gather to collect local_sharded_tensor_metadata from all ranks @@ -157,18 +187,21 @@ def build_metadata_from_local_shards( layout=first_shard_layout, requires_grad=first_shard_requires_grad, memory_format=torch.contiguous_format, - pin_memory=first_shard_is_pinned + pin_memory=first_shard_is_pinned, ) local_sharded_tensor_metadata = ShardedTensorMetadata( shards_metadata=local_shard_metadatas, size=global_size, - tensor_properties=local_tensor_properties) + tensor_properties=local_tensor_properties, + ) return local_sharded_tensor_metadata -def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]): +def build_global_metadata( + gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]], +): global_sharded_tensor_metadata = None global_metadata_rank = 0 @@ -180,39 +213,54 @@ def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMet global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) global_metadata_rank = rank else: - _raise_if_mismatch(global_sharded_tensor_metadata.size, - rank_metadata.size, - "global_size", - [global_metadata_rank, rank], - is_local=False) + _raise_if_mismatch( + global_sharded_tensor_metadata.size, + rank_metadata.size, + "global_size", + [global_metadata_rank, rank], + is_local=False, + ) # don't need to check layout and memory format as we already checked in local shards validation stage - _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.dtype, - rank_metadata.tensor_properties.dtype, - "dtype", - [global_metadata_rank, rank], - is_local=False) - - _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.requires_grad, - rank_metadata.tensor_properties.requires_grad, - "requires_grad", - [global_metadata_rank, rank], - is_local=False) - - _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.pin_memory, - rank_metadata.tensor_properties.pin_memory, - "pin_memory", - [global_metadata_rank, rank], - is_local=False) + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.dtype, + rank_metadata.tensor_properties.dtype, + "dtype", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.requires_grad, + rank_metadata.tensor_properties.requires_grad, + "requires_grad", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.pin_memory, + rank_metadata.tensor_properties.pin_memory, + "pin_memory", + [global_metadata_rank, rank], + is_local=False, + ) # pass all validations, extend shards metadata - global_sharded_tensor_metadata.shards_metadata.extend(rank_metadata.shards_metadata) + global_sharded_tensor_metadata.shards_metadata.extend( + rank_metadata.shards_metadata + ) if global_sharded_tensor_metadata is not None: # check if shards_metadata have overlap shards - validate_non_overlapping_shards_metadata(global_sharded_tensor_metadata.shards_metadata) + validate_non_overlapping_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) # check if the shards_metadata is compatible with global size of the sharded tensor. - check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size) + check_tensor( + global_sharded_tensor_metadata.shards_metadata, + global_sharded_tensor_metadata.size, + ) else: raise ValueError("ShardedTensor have no local shards on all ranks!") diff --git a/torch/distributed/_shard/sharder.py b/torch/distributed/_shard/sharder.py index bf3b3596d1bee..6fbf6a2e5ff9d 100644 --- a/torch/distributed/_shard/sharder.py +++ b/torch/distributed/_shard/sharder.py @@ -1,6 +1,8 @@ import abc + import torch.nn as nn + class Sharder(abc.ABC): """ This is an interface which allows user to create more advanced @@ -11,6 +13,7 @@ class Sharder(abc.ABC): take an object of the `Sharder` and call `shard` to shard the module, then replace the original module with sharded module returned. """ + @abc.abstractmethod def shard(self, module: nn.Module) -> nn.Module: """ diff --git a/torch/distributed/_shard/sharding_plan/__init__.py b/torch/distributed/_shard/sharding_plan/__init__.py index 269dfd8af7605..325f7d7eb47b9 100644 --- a/torch/distributed/_shard/sharding_plan/__init__.py +++ b/torch/distributed/_shard/sharding_plan/__init__.py @@ -1,4 +1 @@ -from .api import ( - ShardingPlan, - ShardingPlanner -) +from .api import ShardingPlan, ShardingPlanner diff --git a/torch/distributed/_shard/sharding_plan/api.py b/torch/distributed/_shard/sharding_plan/api.py index fa92bf7078887..a7552c5a68f88 100644 --- a/torch/distributed/_shard/sharding_plan/api.py +++ b/torch/distributed/_shard/sharding_plan/api.py @@ -1,12 +1,12 @@ import abc -import torch.nn as nn - from dataclasses import dataclass from typing import Dict, List, Optional, Union +import torch.nn as nn from torch.distributed._shard.sharder import Sharder from torch.distributed._shard.sharding_spec import ShardingSpec + @dataclass class ShardingPlan: """ @@ -61,6 +61,7 @@ class ShardingPlan: >>> return_local_tensor=["fc2"] >>> ) """ + plan: Dict[str, Union[ShardingSpec, Sharder]] output_plan: Optional[Dict[str, ShardingSpec]] = None return_local_tensor: Optional[List[str]] = None @@ -71,6 +72,7 @@ class ShardingPlanner(abc.ABC): Default ShardingPlanner interface, can be extended and implement advanced sharding strategies. """ + @abc.abstractmethod def build_plan(self, module: nn.Module) -> ShardingPlan: """ diff --git a/torch/distributed/_shard/sharding_spec/__init__.py b/torch/distributed/_shard/sharding_spec/__init__.py index 8dd38105c53ba..bfd3f0a7581e8 100644 --- a/torch/distributed/_shard/sharding_spec/__init__.py +++ b/torch/distributed/_shard/sharding_spec/__init__.py @@ -1,12 +1,10 @@ +from torch.distributed._shard.metadata import ShardMetadata + from .api import ( + _infer_sharding_spec_from_shards_metadata, DevicePlacementSpec, EnumerableShardingSpec, PlacementSpec, ShardingSpec, - _infer_sharding_spec_from_shards_metadata, ) -from .chunk_sharding_spec import ( - ChunkShardingSpec as ChunkShardingSpec, -) - -from torch.distributed._shard.metadata import ShardMetadata +from .chunk_sharding_spec import ChunkShardingSpec as ChunkShardingSpec diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index 07d3c2e19bc00..8a439c447eff0 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -86,8 +86,8 @@ def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): for dim in range(len(shards[0].shard_offsets)): for i in range(1, len(shards)): if ( - shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] or - shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] + shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] + or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] ): sharded_dims.append(dim) break @@ -108,7 +108,7 @@ def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): pair = _find_nd_overlapping_shards(shards, sharded_dims) if pair: - raise ValueError(f'Shards {shards[pair[0]]} and {shards[pair[1]]} overlap') + raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap") def check_tensor(shards_metadata, tensor_dims) -> None: @@ -130,7 +130,9 @@ def check_tensor(shards_metadata, tensor_dims) -> None: tensor_rank = len(tensor_dims) shards_rank = len(shards_metadata[0].shard_offsets) if tensor_rank != shards_rank: - raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}') + raise ValueError( + f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}" + ) total_shard_volume = 0 for shard in shards_metadata: @@ -139,8 +141,9 @@ def check_tensor(shards_metadata, tensor_dims) -> None: shard_volume *= shard_length if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: raise ValueError( - f'Shard offset {shard.shard_offsets[i]} and length ' - f'{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}') + f"Shard offset {shard.shard_offsets[i]} and length " + f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}" + ) total_shard_volume += shard_volume tensor_volume = 1 @@ -150,9 +153,11 @@ def check_tensor(shards_metadata, tensor_dims) -> None: if total_shard_volume != tensor_volume: # TODO: Can we improve this error message to point out the gaps? raise ValueError( - f'Total volume of shards: {total_shard_volume} ' - f'does not match tensor volume: {tensor_volume}, in other words ' - f'all the individual shards do not cover the entire tensor') + f"Total volume of shards: {total_shard_volume} " + f"does not match tensor volume: {tensor_volume}, in other words " + f"all the individual shards do not cover the entire tensor" + ) + def get_split_size(dim_size, chunks): """ @@ -167,6 +172,7 @@ def get_split_size(dim_size, chunks): """ return (dim_size + chunks - 1) // chunks + def get_chunked_dim_size(dim_size, split_size, idx): """ Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` @@ -182,6 +188,7 @@ def get_chunked_dim_size(dim_size, split_size, idx): """ return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) + def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): """ Generate the start pos and offset length for the current rank for diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py index 7493eccdf0158..e22e8b569e03c 100644 --- a/torch/distributed/_shard/sharding_spec/api.py +++ b/torch/distributed/_shard/sharding_spec/api.py @@ -1,34 +1,36 @@ # mypy: allow-untyped-defs +import functools +import operator from abc import ABC, abstractmethod from dataclasses import dataclass -import functools from typing import Callable, Dict, List, TYPE_CHECKING import torch +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.op_registry_utils import _decorator_func from ._internals import ( check_tensor, get_chunked_dim_size, get_split_size, - validate_non_overlapping_shards_metadata + validate_non_overlapping_shards_metadata, ) -from torch.distributed._shard.metadata import ShardMetadata -import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta -from torch.distributed._shard.op_registry_utils import _decorator_func -import operator if TYPE_CHECKING: # Only include ShardedTensor when do type checking, exclude it # from run-time to resolve circular dependency. from torch.distributed._shard.sharded_tensor import ShardedTensor + class PlacementSpec(ABC): # noqa: B024 """ Base class representing the placement of an entity. Subclasses of this class can be used to specify customized placements which might not be covered by existing APIs. """ + pass @@ -47,15 +49,18 @@ def __post_init__(self): if not isinstance(self.device, torch.distributed._remote_device): self.device = torch.distributed._remote_device(self.device) + class ShardingSpec(ABC): """ Base class representing sharding specifications. """ + @abstractmethod - def build_metadata(self, - tensor_sizes: torch.Size, - tensor_properties: sharded_tensor_meta.TensorProperties, - ) -> sharded_tensor_meta.ShardedTensorMetadata: + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: """ Given a global tensor size, define how to shard a tensor like this shape across ranks, return ShardedTensorMetadata @@ -71,7 +76,9 @@ def build_metadata(self, """ @abstractmethod - def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": """ Given a global tensor on src_rank, shard this tensor across ranks within the process group, return a ShardedTensor. @@ -88,26 +95,35 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> A :class:`ShardedTensor` sharded from the given tensor. """ + # Ops customized for a particular ShardingSpec. _CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {} + def _has_custom_op(sharding_spec, op): """ Returns whether or not the ShardingSpec has a custom op implementation. """ class_name = type(sharding_spec).__qualname__ - return class_name in _CUSTOM_SHARDING_SPEC_OPS and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + return ( + class_name in _CUSTOM_SHARDING_SPEC_OPS + and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + -def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs, process_group): +def _dispatch_custom_op( + sharding_spec, op: Callable, types, args, kwargs, process_group +): """ Calls the custom op for this ShardingSpec if it exists. """ class_name = type(sharding_spec).__qualname__ if not _has_custom_op(sharding_spec, op): - raise RuntimeError(f'Custom op: {op} not registered for {class_name}') + raise RuntimeError(f"Custom op: {op} not registered for {class_name}") func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] return func(types, args, kwargs, process_group) + def custom_sharding_spec_op(sharding_spec_class, func): """ Decorator to allow custom registration of ops. @@ -119,9 +135,7 @@ def custom_sharding_spec_op(sharding_spec_class, func): if class_name not in _CUSTOM_SHARDING_SPEC_OPS: _CUSTOM_SHARDING_SPEC_OPS[class_name] = {} return functools.partial( - _decorator_func, - op=func, - op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] + _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] ) @@ -140,30 +154,33 @@ class EnumerableShardingSpec(ShardingSpec): def __post_init__(self): if len(self.shards) == 0: - raise ValueError(f'Empty shard list provided: {self.shards}') + raise ValueError(f"Empty shard list provided: {self.shards}") # Validate each shard has same rank. rank = -1 for shard in self.shards: if rank != -1 and rank != len(shard.shard_offsets): - raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}') + raise ValueError( + f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}" + ) rank = len(shard.shard_offsets) validate_non_overlapping_shards_metadata(self.shards) - def build_metadata(self, - tensor_sizes: torch.Size, - tensor_properties: sharded_tensor_meta.TensorProperties, - ) -> sharded_tensor_meta.ShardedTensorMetadata: + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: # check if shards form a valid tensor check_tensor(self.shards, tensor_sizes) return sharded_tensor_meta.ShardedTensorMetadata( - self.shards, - tensor_sizes, - tensor_properties + self.shards, tensor_sizes, tensor_properties ) - def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") @@ -216,10 +233,14 @@ def _infer_sharding_spec_from_shards_metadata(shards_metadata): if chunk_sharding_dim is not None: # Ensure we infer the correct placement order from offsets placements = [ - x for _, x in sorted(zip(chunk_offset_list, placements), key=operator.itemgetter(0)) + x + for _, x in sorted( + zip(chunk_offset_list, placements), key=operator.itemgetter(0) + ) ] from .chunk_sharding_spec import ChunkShardingSpec + chunk_spec = ChunkShardingSpec( dim=chunk_sharding_dim, placements=placements, diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index bd2c960f7f60c..dd0e354dfc25c 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -1,28 +1,28 @@ # mypy: allow-untyped-defs from dataclasses import dataclass +from typing import cast, List, Optional, TYPE_CHECKING, Union + import torch +import torch.distributed as dist import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +import torch.distributed.distributed_c10d as distributed_c10d +from torch.distributed._shard._utils import narrow_tensor from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharded_tensor.utils import ( - _parse_and_validate_remote_device -) -from torch.distributed._shard._utils import narrow_tensor -import torch.distributed as dist -import torch.distributed.distributed_c10d as distributed_c10d -from typing import cast, List, Optional, Union, TYPE_CHECKING -from ._internals import ( - get_chunked_dim_size, - get_split_size, + _parse_and_validate_remote_device, ) +from ._internals import get_chunked_dim_size, get_split_size from .api import ShardingSpec + if TYPE_CHECKING: # Only include ShardedTensor when do type checking, exclude it # from run-time to resolve circular dependency. from torch.distributed._shard.sharded_tensor import ShardedTensor + @dataclass class ChunkShardingSpec(ShardingSpec): """ @@ -71,14 +71,13 @@ def _verify_dim(dim): ) if not isinstance(dim, int): - raise ValueError( - f"Sharding dim needs to be an integer, found: {dim}" - ) + raise ValueError(f"Sharding dim needs to be an integer, found: {dim}") - def build_metadata(self, - tensor_sizes: torch.Size, - tensor_properties: sharded_tensor_meta.TensorProperties, - ) -> sharded_tensor_meta.ShardedTensorMetadata: + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: tensor_num_dim = len(tensor_sizes) self._verify_dim(self.dim) @@ -105,13 +104,12 @@ def build_metadata(self, shards_metadata.append(shard_metadata) return sharded_tensor_meta.ShardedTensorMetadata( - shards_metadata, - tensor_sizes, - tensor_properties + shards_metadata, tensor_sizes, tensor_properties ) - - def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": """ Args: src_rank: group rank relative to ``process_group`` @@ -119,15 +117,14 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> N.B. If ``process_group`` is None, ``src_rank`` is a global rank. """ # relative imports to avoid circular dependency - from torch.distributed._shard.sharded_tensor import ( - ShardedTensor - ) + from torch.distributed._shard.sharded_tensor import ShardedTensor + tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned() + pin_memory=tensor.is_pinned(), ) current_rank = dist.get_rank(process_group) current_global_rank = dist.get_rank() @@ -147,7 +144,9 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> scatter_shape[self.dim] = split_size # type: ignore[index] for shard_meta in tensor_meta.shards_metadata: - remote_global_rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement) + remote_global_rank, device = _parse_and_validate_remote_device( + process_group, shard_meta.placement + ) if current_rank == src_rank: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a @@ -158,7 +157,9 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> # resize the narrowed tensor to the same size and use it for # the scatter collective as dist.scatter requires same size # inputs on every rank - tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape) + tensor_to_scatter = ( + narrowed_tensor.detach().clone().resize_(scatter_shape) + ) else: tensor_to_scatter = narrowed_tensor.detach().clone().contiguous() @@ -168,7 +169,11 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> if current_global_rank == remote_global_rank: local_tensor = torch.empty( - scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device) + scatter_shape, + dtype=tensor.dtype, + layout=tensor.layout, + device=device, + ) local_metadata = shard_meta # each rank should have local_tensor and local_metadata initialized if we build @@ -179,14 +184,19 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> # Scatter the shards to all ranks in the pg # scatter takes the global rank as ``src`` src_for_scatter = src_rank - if process_group is not None and process_group is not distributed_c10d._get_default_group(): - src_for_scatter = distributed_c10d.get_global_rank(process_group, src_for_scatter) + if ( + process_group is not None + and process_group is not distributed_c10d._get_default_group() + ): + src_for_scatter = distributed_c10d.get_global_rank( + process_group, src_for_scatter + ) dist.scatter( local_tensor, scatter_list=tensors_to_scatter if current_rank == src_rank else None, src=src_for_scatter, - group=process_group + group=process_group, ) if list(local_tensor.size()) != local_metadata.shard_sizes: @@ -199,9 +209,8 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) st = ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards, - tensor_meta, - process_group=process_group) + local_shards, tensor_meta, process_group=process_group + ) # Manually set sharding_spec st._sharding_spec = self diff --git a/torch/distributed/_spmd/api.py b/torch/distributed/_spmd/api.py index ce9984efac6e8..ab5136978f668 100644 --- a/torch/distributed/_spmd/api.py +++ b/torch/distributed/_spmd/api.py @@ -13,12 +13,9 @@ import torch.distributed._functional_collectives import torch.nn as nn import torch.utils._pytree as pytree - from functorch import make_fx - from torch import fx from torch._decomp.decompositions import native_layer_norm_backward - from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._spmd.data_parallel import gradients_tagging from torch.distributed._spmd.parallel_mode import ( diff --git a/torch/distributed/_spmd/batch_dim_utils.py b/torch/distributed/_spmd/batch_dim_utils.py index d3c39295c0e66..244cc26c55ed4 100644 --- a/torch/distributed/_spmd/batch_dim_utils.py +++ b/torch/distributed/_spmd/batch_dim_utils.py @@ -2,17 +2,14 @@ from typing import Callable, Dict, List, Set import torch - import torch.fx as fx - import torch.utils._pytree as pytree - from torch import Tensor - from torch.distributed._tensor import DeviceMesh, Replicate, Shard from torch.distributed._tensor.ops.view_ops import dim_maps, DimSpec, InputDim from torch.distributed._tensor.placement_types import _Partial, DTensorSpec + aten = torch.ops.aten diff --git a/torch/distributed/_spmd/config.py b/torch/distributed/_spmd/config.py index 73ee19e803dc8..3fc45bc27a3a1 100644 --- a/torch/distributed/_spmd/config.py +++ b/torch/distributed/_spmd/config.py @@ -4,6 +4,7 @@ from types import ModuleType from typing import Set + # log level (levels print what it says + all levels listed below it) # DEBUG print full traces <-- lowest level + print tracing of every instruction # INFO print compiler functions + distributed graphs diff --git a/torch/distributed/_spmd/data_parallel.py b/torch/distributed/_spmd/data_parallel.py index 8b18c6c86763a..835cdb9fa8efd 100644 --- a/torch/distributed/_spmd/data_parallel.py +++ b/torch/distributed/_spmd/data_parallel.py @@ -2,17 +2,13 @@ import operator from contextlib import contextmanager from enum import Enum - from typing import Any, cast, Dict, List, Optional, Tuple import torch - import torch.fx as fx import torch.library import torch.nn as nn - import torch.utils._pytree as pytree - from torch.distributed._spmd.batch_dim_utils import BatchDimAnalyzer from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed._tensor._op_schema import ( @@ -22,7 +18,6 @@ TupleStrategy, ) from torch.distributed._tensor._redistribute import redistribute_local_tensor - from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.placement_types import _Partial, DTensorSpec, Placement from torch.fx import GraphModule @@ -30,6 +25,7 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata from torch.nn.utils._named_member_accessor import NamedMemberAccessor + aten = torch.ops.aten # Dummy op used by data parallel to tag gradients. diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py index 5fb5ff766799a..839b58bf03e03 100644 --- a/torch/distributed/_spmd/distribute.py +++ b/torch/distributed/_spmd/distribute.py @@ -9,11 +9,9 @@ import torch import torch.distributed._spmd.experimental_ops import torch.fx as fx - from torch.distributed._spmd.comm_tensor import _get_tracer from torch.distributed._spmd.graph_utils import OP from torch.distributed._spmd.log_utils import get_logger - from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor._op_schema import OpSchema from torch.distributed._tensor._redistribute import redistribute_local_tensor diff --git a/torch/distributed/_spmd/experimental_ops.py b/torch/distributed/_spmd/experimental_ops.py index 94a0da8224496..7039822c41ccd 100644 --- a/torch/distributed/_spmd/experimental_ops.py +++ b/torch/distributed/_spmd/experimental_ops.py @@ -6,7 +6,6 @@ from torch.distributed._tensor._op_schema import OpSchema, OutputSharding from torch.distributed._tensor.ops.common_rules import pointwise_rule from torch.distributed._tensor.ops.utils import register_prop_rule - from torch.distributed._tensor.placement_types import ( _Partial, DTensorSpec, @@ -16,6 +15,7 @@ TensorMeta, ) + aten = torch.ops.aten # pyre-ignore diff --git a/torch/distributed/_spmd/graph_optimization.py b/torch/distributed/_spmd/graph_optimization.py index 4a5cad7917d88..a50e266eb1282 100644 --- a/torch/distributed/_spmd/graph_optimization.py +++ b/torch/distributed/_spmd/graph_optimization.py @@ -37,6 +37,7 @@ from torch.utils import _pytree as pytree from torch.utils._pytree import tree_flatten, tree_unflatten + logger: logging.Logger = logging.getLogger("graph_optimization") aten = torch.ops.aten fake_tensor_mode = FakeTensorMode() diff --git a/torch/distributed/_spmd/parallel_mode.py b/torch/distributed/_spmd/parallel_mode.py index 2e9c15258d092..65a55377ac82f 100644 --- a/torch/distributed/_spmd/parallel_mode.py +++ b/torch/distributed/_spmd/parallel_mode.py @@ -11,7 +11,6 @@ ) from torch.distributed._spmd.distribute import _convert_to_distributed, Schema from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard - from torch.fx import GraphModule diff --git a/torch/distributed/_spmd/partial_lower.py b/torch/distributed/_spmd/partial_lower.py index bb1f1e2e085fc..7899a8e143f62 100644 --- a/torch/distributed/_spmd/partial_lower.py +++ b/torch/distributed/_spmd/partial_lower.py @@ -7,12 +7,11 @@ from typing import Callable, List, Optional, Set, Tuple import torch - from functorch import make_fx - from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.decomposition import select_decomp_table + MIN_ATEN_OPS_TO_LOWER = 10 logger: logging.Logger = logging.getLogger(__name__) From 4817180601016f706ee0cce76b6d52b9cfc51ef5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 17 Jun 2024 23:06:28 +0000 Subject: [PATCH 37/63] make fallback for aten.argsort.stable (#128907) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128907 Approved by: https://github.com/lezcano ghstack dependencies: #128343 --- test/inductor/test_torchinductor_opinfo.py | 1 - torch/_inductor/lowering.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 5f97c2f0fd712..29be591dc006c 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -411,7 +411,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "_segment_reduce.lengths": {f16}, "_segment_reduce.offsets": {f16}, "addmv": {f16}, - "argsort": {b8, f16, f32, f64, i32, i64}, "as_strided.partial_views": {f16}, "corrcoef": {f16}, "diff": {f16}, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 449e512352fa1..44a6d05d1f44d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2183,6 +2183,7 @@ def is_aligned(x): # Sorting / Sorting-like make_fallback(aten.sort) make_fallback(aten.sort.stable) +make_fallback(aten.argsort.stable) make_fallback(aten.kthvalue) make_fallback(aten.topk) make_fallback(aten.mode) From 108318ad1038f4f3ad0da4f54f53effdd9ef365a Mon Sep 17 00:00:00 2001 From: David Berard Date: Tue, 18 Jun 2024 15:40:45 +0000 Subject: [PATCH 38/63] [BE][JIT] Handle case where codegen object can be unset (#128951) Summary: Unblocks a test that's failing. `codegen` can be unset until `compile` is called. If `codegen` is not set, then just use the kernel name directly. Test Plan: ``` buck2 run //caffe2/test:tensorexpr -- --regex test_simple_add ``` Differential Revision: D58727391 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128951 Approved by: https://github.com/aaronenyeshi --- torch/csrc/jit/tensorexpr/kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index d7c737d8f8f2c..e5ea5bb46e0e7 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -181,7 +181,7 @@ class TORCH_API TensorExprKernel { } const std::string& getKernelName() const { - return codegen_->kernel_func_name(); + return (codegen_ ? codegen_->kernel_func_name() : kernel_func_name_); } const std::vector& getSymbolicShapeInputs() const { From ec616da51848bcfa9d0bd9c693c62b50fbe84c0f Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 18 Jun 2024 16:16:38 +0000 Subject: [PATCH 39/63] RNN API cleanup for cuDNN 9.1 (#122011) Can potentially avoid a bit of boilerplate if we move directly to cuDNN 9.1's RNN API... Co-authored-by: Aaron Gokaslan Pull Request resolved: https://github.com/pytorch/pytorch/pull/122011 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cudnn/RNN.cpp | 32 +++++++++++------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 55c666eeca83c..c90a6fd7a6c9c 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -614,8 +614,6 @@ void add_projection_weights( /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer)); #else - void* unused_pointer; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( /*handle=*/handle, @@ -626,8 +624,8 @@ void add_projection_weights( /*linLayerID=*/linear_id, /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); #endif cudnnDataType_t data_type; @@ -735,8 +733,6 @@ get_parameters( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer = nullptr; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; for (int stateless = 0; stateless < 100; stateless++) { if (cudnn_method) { // matrix @@ -749,8 +745,8 @@ get_parameters( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, @@ -759,8 +755,8 @@ get_parameters( weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, - unused_desc.mut_desc(), - &unused_pointer, + nullptr, + nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } @@ -922,8 +918,6 @@ std::vector get_expected_data_ptrs( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer = nullptr; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; if (cudnn_method) { // matrix AT_CUDNN_CHECK(cudnnGetRNNWeightParams( @@ -935,8 +929,8 @@ std::vector get_expected_data_ptrs( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, @@ -945,8 +939,8 @@ std::vector get_expected_data_ptrs( weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, - unused_desc.mut_desc(), - &unused_pointer, + nullptr, + nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } @@ -972,8 +966,6 @@ std::vector get_expected_data_ptrs( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( @@ -985,8 +977,8 @@ std::vector get_expected_data_ptrs( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); #endif data_ptrs.push_back(matrix_pointer); } From 9818283da18de00047760ec4431870d3f8e620a6 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 14 Jun 2024 19:12:10 +0000 Subject: [PATCH 40/63] re-enable jacrev/jacfwd/hessian after #128028 landed (#128622) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128622 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 69 ------------------- ...ion_no_setup_context_transform_hessian_cpu | 0 ...tion_no_setup_context_transform_jacfwd_cpu | 0 ...essianCPU.test_jacfwd_different_levels_cpu | 0 test/functorch/test_eager_transforms.py | 4 +- torch/_functorch/eager_transforms.py | 4 -- torch/testing/_internal/common_utils.py | 1 - 7 files changed, 2 insertions(+), 76 deletions(-) create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu create mode 100644 test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index dca6d28d1912d..f2df33bdda67c 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2746,26 +2746,6 @@ def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): wrapped_gm = backend.graphs[graph_idx] return wrapped_gm - def test_hessian_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.hessian(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_hessian(self): counters.clear() @@ -2900,7 +2880,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_hessian_argnums(self): counters.clear() @@ -3046,7 +3025,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """ return (unflatten,)""", ) - @unittest.expectedFailure def test_hessian_disable_capture(self): counters.clear() @@ -3073,26 +3051,6 @@ def wrapper_fn(x): ) self.assertEqual(actual, expected) - def test_jacrev_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.jacrev(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_jacrev(self): counters.clear() @@ -3169,7 +3127,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_jacrev_two_tensors_argnums(self): counters.clear() @@ -3252,7 +3209,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacrev_has_aux(self): counters.clear() @@ -3337,7 +3293,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacrev_disable_capture(self): counters.clear() @@ -4284,26 +4239,6 @@ def wrapper_fn(x, y): self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) - def test_jacfwd_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.jacfwd(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_jacfwd(self): counters.clear() @@ -4387,7 +4322,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_jacfwd_two_tensors_argnums(self): counters.clear() @@ -4477,7 +4411,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_has_aux(self): counters.clear() @@ -4572,7 +4505,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_randomness(self): counters.clear() @@ -4676,7 +4608,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_disable_capture(self): counters.clear() diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu b/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 8107f865f7bc5..c767810beb85a 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -77,6 +77,7 @@ subtest, TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, ) from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -2341,8 +2342,7 @@ def f(x): self.assertEqual(actual, expected) # https://github.com/pytorch/pytorch/issues/127036 - # it won't fail as jacrev/jacfwd were not inlined (see #128255) - # @xfailIfTorchDynamo + @xfailIfTorchDynamo @parametrize("_preallocate_and_copy", (True, False)) def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): # With chunk_size=1, we shouldn't `vmap` and hence not be limited diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index fbea5164014bc..fff6bd67838f0 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -767,8 +767,6 @@ def compute_jacobian_preallocate_and_copy(): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) - else: - wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn @@ -1350,8 +1348,6 @@ def push_jvp(basis): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) - else: - wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2d5ea4a6c64ff..8daeefdee9d85 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -5008,7 +5008,6 @@ def repl_frame(m): return m.group(0) s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n .+\n( +[~^]+ *\n)?', repl_frame, s) - s = re.sub(r'( Date: Tue, 18 Jun 2024 17:15:05 +0000 Subject: [PATCH 41/63] [EZ] Fix typos in RELEASE.md (#128769) This PR fixes typo in `RELEASE.md` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128769 Approved by: https://github.com/yumium, https://github.com/mikaylagawarecki --- RELEASE.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 3c9d68f9a6cdc..7091052c85bd1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -290,7 +290,7 @@ After the final RC is created. The following tasks should be performed : * Create validation issue for the release, see for example [Validations for 2.1.2 release](https://github.com/pytorch/pytorch/issues/114904) and perform required validations. -* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no prerformance regressions. +* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no performance regressions. * Prepare and stage PyPI binaries for promotion. This is done with this script: [`pytorch/builder:release/pypi/promote_pypi_to_staging.sh`](https://github.com/pytorch/builder/blob/main/release/pypi/promote_pypi_to_staging.sh) @@ -429,12 +429,12 @@ need to support these particular versions of software. ## Operating Systems Supported OS flavors are summarized in the table below: -| Operating System family | Architectrue | Notes | +| Operating System family | Architecture | Notes | | --- | --- | --- | | Linux | aarch64, x86_64 | Wheels are manylinux2014 compatible, i.e. they should be runnable on any Linux system with glibc-2.17 or above. | | MacOS | arm64 | Builds should be compatible with MacOS 11 (Big Sur) or newer, but are actively tested against MacOS 14 (Sonoma). | | MacOS | x86_64 | Requires MacOS Catalina or above, not supported after 2.2, see https://github.com/pytorch/pytorch/issues/114602 | -| Windows | x86_64 | Buils are compatible with Windows-10 or newer. | +| Windows | x86_64 | Builds are compatible with Windows-10 or newer. | # Submitting Tutorials From 4e03263224af813fbf5e0e745e84c13268c48dc7 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 18 Jun 2024 17:26:23 +0000 Subject: [PATCH 42/63] [CUDA][Convolution] Add missing launch bounds to `vol2col_kernel` (#128740) Fix "too many resources requested" that can happen with recent toolkits on V100. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128740 Approved by: https://github.com/mikaylagawarecki --- aten/src/ATen/native/cuda/vol2col.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/cuda/vol2col.cuh b/aten/src/ATen/native/cuda/vol2col.cuh index 98ec2c3522d54..222270e862160 100644 --- a/aten/src/ATen/native/cuda/vol2col.cuh +++ b/aten/src/ATen/native/cuda/vol2col.cuh @@ -14,6 +14,7 @@ using namespace at::cuda::detail; // Kernel for fast unfold+copy on volumes template +C10_LAUNCH_BOUNDS_1(1024) __global__ void vol2col_kernel( const int64_t n, const T* data_vol, From 84c86e56bd8b86ae47c18b77141c1fe46188c5b7 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 18 Jun 2024 17:48:47 +0000 Subject: [PATCH 43/63] Update tracker issues after successfully cherry-picking a PR (#128924) This extends the capacity of the cherry-pick bot to automatically update the tracker issue with the information. For this to work, the tracker issue needs to be an open one with a `release tracker` label, i.e. https://github.com/pytorch/pytorch/issues/128436. The version from the release branch, i.e. `release/2.4`, will be match with the title of the tracker issue, i.e. `[v.2.4.0] Release Tracker` or `[v.2.4.1] Release Tracker` ### Testing `python cherry_pick.py --onto-branch release/2.4 --classification release --fixes "DEBUG DEBUG" --github-actor huydhn 128718` * On the PR https://github.com/pytorch/pytorch/pull/128718#issuecomment-2174846771 * On the tracker issue https://github.com/pytorch/pytorch/issues/128436#issuecomment-2174846757 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128924 Approved by: https://github.com/atalman --- .github/scripts/cherry_pick.py | 114 ++++++++++++++++++++++++++++---- .github/scripts/github_utils.py | 9 +++ 2 files changed, 111 insertions(+), 12 deletions(-) diff --git a/.github/scripts/cherry_pick.py b/.github/scripts/cherry_pick.py index 4c892de21da8a..2650a5060d0ff 100755 --- a/.github/scripts/cherry_pick.py +++ b/.github/scripts/cherry_pick.py @@ -3,11 +3,11 @@ import json import os import re -from typing import Any, Optional +from typing import Any, cast, Dict, List, Optional from urllib.error import HTTPError -from github_utils import gh_fetch_url, gh_post_pr_comment +from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo from trymerge import get_pr_commit_sha, GitHubPR @@ -19,6 +19,7 @@ "critical", "fixnewfeature", } +RELEASE_BRANCH_REGEX = re.compile(r"release/(?P.+)") def parse_args() -> Any: @@ -58,6 +59,33 @@ def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]: return commit_sha if pr.is_closed() else None +def get_release_version(onto_branch: str) -> Optional[str]: + """ + Return the release version if the target branch is a release branch + """ + m = re.match(RELEASE_BRANCH_REGEX, onto_branch) + return m.group("version") if m else "" + + +def get_tracker_issues( + org: str, project: str, onto_branch: str +) -> List[Dict[str, Any]]: + """ + Find the tracker issue from the repo. The tracker issue needs to have the title + like [VERSION] Release Tracker following the convention on PyTorch + """ + version = get_release_version(onto_branch) + if not version: + return [] + + tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"]) + if not tracker_issues: + return [] + + # Figure out the tracker issue from the list by looking at the title + return [issue for issue in tracker_issues if version in issue.get("title", "")] + + def cherry_pick( github_actor: str, repo: GitRepo, @@ -77,17 +105,49 @@ def cherry_pick( ) try: + org, project = repo.gh_owner_and_name() + + cherry_pick_pr = "" if not dry_run: - org, project = repo.gh_owner_and_name() cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch) - msg = f"The cherry pick PR is at {cherry_pick_pr}" - if fixes: - msg += f" and it is linked with issue {fixes}" - elif classification in REQUIRES_ISSUE: - msg += f" and it is recommended to link a {classification} cherry pick PR with an issue" + tracker_issues_comments = [] + tracker_issues = get_tracker_issues(org, project, onto_branch) + for issue in tracker_issues: + issue_number = int(str(issue.get("number", "0"))) + if not issue_number: + continue + + res = cast( + Dict[str, Any], + post_tracker_issue_comment( + org, + project, + issue_number, + pr.pr_num, + cherry_pick_pr, + classification, + fixes, + dry_run, + ), + ) + + comment_url = res.get("html_url", "") + if comment_url: + tracker_issues_comments.append(comment_url) - post_comment(org, project, pr.pr_num, msg) + msg = f"The cherry pick PR is at {cherry_pick_pr}" + if fixes: + msg += f" and it is linked with issue {fixes}." + elif classification in REQUIRES_ISSUE: + msg += f" and it is recommended to link a {classification} cherry pick PR with an issue." + + if tracker_issues_comments: + msg += " The following tracker issues are updated:\n" + for tracker_issues_comment in tracker_issues_comments: + msg += f"* {tracker_issues_comment}\n" + + post_pr_comment(org, project, pr.pr_num, msg, dry_run) finally: if current_branch: @@ -159,7 +219,9 @@ def submit_pr( raise RuntimeError(msg) from error -def post_comment(org: str, project: str, pr_num: int, msg: str) -> None: +def post_pr_comment( + org: str, project: str, pr_num: int, msg: str, dry_run: bool = False +) -> List[Dict[str, Any]]: """ Post a comment on the PR itself to point to the cherry picking PR when success or print the error when failure @@ -182,7 +244,35 @@ def post_comment(org: str, project: str, pr_num: int, msg: str) -> None: comment = "\n".join( (f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}") ) - gh_post_pr_comment(org, project, pr_num, comment) + return gh_post_pr_comment(org, project, pr_num, comment, dry_run) + + +def post_tracker_issue_comment( + org: str, + project: str, + issue_num: int, + pr_num: int, + cherry_pick_pr: str, + classification: str, + fixes: str, + dry_run: bool = False, +) -> List[Dict[str, Any]]: + """ + Post a comment on the tracker issue (if any) to record the cherry pick + """ + comment = "\n".join( + ( + "Link to landed trunk PR (if applicable):", + f"* https://github.com/{org}/{project}/pull/{pr_num}", + "", + "Link to release branch PR:", + f"* {cherry_pick_pr}", + "", + "Criteria Category:", + " - ".join((classification.capitalize(), fixes.capitalize())), + ) + ) + return gh_post_pr_comment(org, project, issue_num, comment, dry_run) def main() -> None: @@ -214,7 +304,7 @@ def main() -> None: except RuntimeError as error: if not args.dry_run: - post_comment(org, project, pr_num, str(error)) + post_pr_comment(org, project, pr_num, str(error)) else: raise error diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index d76d32f624d8a..f804c6e197dd4 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -202,3 +202,12 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") -> ) else: raise + + +def gh_query_issues_by_labels( + org: str, repo: str, labels: List[str], state: str = "open" +) -> List[Dict[str, Any]]: + url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues" + return gh_fetch_json( + url, method="GET", params={"labels": ",".join(labels), "state": state} + ) From 77830d509fcae41be37f5b3a2fa05faabc778e29 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 18:11:43 +0000 Subject: [PATCH 44/63] Revert "Introduce a prototype for SymmetricMemory (#128582)" This reverts commit 7a39755da28d5a109bf0c37f72b364d3a83137b1. Reverted https://github.com/pytorch/pytorch/pull/128582 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128582#issuecomment-2176685232)) --- .lintrunner.toml | 1 - BUILD.bazel | 1 - build_variables.bzl | 2 - c10/cuda/driver_api.h | 19 +- caffe2/CMakeLists.txt | 1 - test/distributed/test_symmetric_memory.py | 156 ----- torch/_C/_distributed_c10d.pyi | 30 - .../distributed/c10d/CUDASymmetricMemory.cu | 539 ------------------ .../distributed/c10d/CUDASymmetricMemory.cuh | 109 ---- .../distributed/c10d/ProcessGroupCudaP2P.hpp | 1 - .../csrc/distributed/c10d/SymmetricMemory.cpp | 189 ------ .../csrc/distributed/c10d/SymmetricMemory.hpp | 152 ----- torch/csrc/distributed/c10d/init.cpp | 39 -- .../csrc/distributed/c10d/intra_node_comm.cpp | 99 +++- .../csrc/distributed/c10d/intra_node_comm.cu | 18 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 9 +- 16 files changed, 111 insertions(+), 1254 deletions(-) delete mode 100644 test/distributed/test_symmetric_memory.py delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cu delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.cpp delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.hpp diff --git a/.lintrunner.toml b/.lintrunner.toml index dc9f9ddd46c7c..a7bbdc884415e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,7 +68,6 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', - 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', 'torch/csrc/**/*.cpp', diff --git a/BUILD.bazel b/BUILD.bazel index c563c52d861e6..10c065f5084c7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,7 +744,6 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index 793b611a0a6f0..ceb28707897e5 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -501,7 +501,6 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", - "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", @@ -685,7 +684,6 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index cbbdf16823ec7..43bcbd1d70bac 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -18,17 +18,14 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ - _(cuMemGetAllocationGranularity) \ - _(cuMemExportToShareableHandle) \ - _(cuMemImportFromShareableHandle) \ +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ _(cuGetErrorString) #define C10_NVML_DRIVER_API(_) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8426741609fe7..89c31fab11347 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,7 +560,6 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py deleted file mode 100644 index a768e059044f7..0000000000000 --- a/test/distributed/test_symmetric_memory.py +++ /dev/null @@ -1,156 +0,0 @@ -# Owner(s): ["module: c10d"] - -import torch - -import torch.distributed as dist -from torch._C._distributed_c10d import _SymmetricMemory -from torch.distributed.distributed_c10d import _get_process_group_store - -from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - run_tests, - skip_but_pass_in_sandcastle_if, - skipIfRocm, -) - - -def requires_cuda_p2p_access(): - cuda_p2p_access_available = ( - torch.cuda.is_available() and torch.cuda.device_count() >= 2 - ) - num_devices = torch.cuda.device_count() - for i in range(num_devices - 1): - for j in range(i + 1, num_devices): - if not torch.cuda.can_device_access_peer(i, j): - cuda_p2p_access_available = False - break - if not cuda_p2p_access_available: - break - - return skip_but_pass_in_sandcastle_if( - not cuda_p2p_access_available, - "cuda p2p access is not available", - ) - - -@instantiate_parametrized_tests -@requires_cuda_p2p_access() -class SymmetricMemoryTest(MultiProcessTestCase): - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - @property - def world_size(self) -> int: - return 2 - - @property - def device(self) -> torch.device: - return torch.device(f"cuda:{self.rank}") - - def _init_process(self): - torch.cuda.set_device(self.device) - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="nccl", - world_size=self.world_size, - rank=self.rank, - store=store, - ) - _SymmetricMemory.set_group_info( - "0", - self.rank, - self.world_size, - _get_process_group_store(dist.GroupMember.WORLD), - ) - - def _verify_symmetric_memory(self, symm_mem): - self.assertEqual(symm_mem.world_size, 2) - - buf = symm_mem.get_buffer(0, (64, 64), torch.float32) - if symm_mem.rank == 0: - symm_mem.wait_signal(src_rank=1) - self.assertTrue(buf.eq(42).all()) - else: - buf.fill_(42) - symm_mem.put_signal(dst_rank=0) - - symm_mem.barrier() - - if symm_mem.rank == 0: - symm_mem.barrier() - self.assertTrue(buf.eq(43).all()) - else: - buf.fill_(43) - symm_mem.barrier() - - symm_mem.barrier() - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name) - - t = torch.empty(shape, dtype=dtype, device=device) - with self.assertRaises(RuntimeError): - _SymmetricMemory.rendezvous(t) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) - - del t - self._verify_symmetric_memory(symm_mem) - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p_persistent(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - alloc_id = 42 # Persistent allocation - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name, alloc_id) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - data_ptr = t.data_ptr() - - # Verify that persistent allocation would fail if there's an active - # allocation with the same alloc_id. - with self.assertRaises(RuntimeError): - _SymmetricMemory.empty_strided_p2p(*alloc_args) - - # Verify that persistent allocation would succeed in lieu of activate - # allocations with the same alloc_id, and the returned tensor would - # have the same data pointer. - del t - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - self.assertEqual(t.data_ptr(), data_ptr) - - # Verify that get_symmetric_memory would fail if called before - # rendezvous. - with self.assertRaises(RuntimeError): - _SymmetricMemory.get_symmetric_memory(t) - - symm_mem_0 = _SymmetricMemory.rendezvous(t) - symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) - self.assertEqual(id(symm_mem_0), id(symm_mem_1)) - - self._verify_symmetric_memory(symm_mem_0) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 0095b5af434b5..cffbf22219c8e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -637,33 +637,3 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... - -class _SymmetricMemory: - @staticmethod - def set_group_info( - group_name: str, rank: int, world_size: int, store: Store - ) -> None: ... - @staticmethod - def empty_strided_p2p( - size: torch.types._size, - stride: torch.types._size, - dtype: torch.dtype, - device: torch.device, - group_name: str, - ) -> torch.Tensor: ... - @property - def rank(self) -> int: ... - @property - def world_size(self) -> int: ... - @staticmethod - def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... - def get_buffer( - self, - rank: int, - sizes: torch.Size, - dtype: torch.dtype, - storage_offset: Optional[int] = 0, - ) -> torch.Tensor: ... - def barrier(self, channel: int = 0) -> None: ... - def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... - def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu deleted file mode 100644 index f27db85f7ff85..0000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ /dev/null @@ -1,539 +0,0 @@ -#include - -#include -#include -#include -#include - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include -#endif - -#include -#include - -namespace { - -constexpr size_t signal_pad_size = 2048; -const std::string store_comm_prefix = "CUDASymmetricMemory"; - -static size_t store_comm_seq_id = 0; - -template -std::vector store_all_gather( - const c10::intrusive_ptr& store, - int rank, - int world_size, - T val) { - static_assert(std::is_trivially_copyable_v); - - std::vector peer_keys; - for (int r = 0; r < world_size; ++r) { - std::ostringstream oss; - oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r; - peer_keys.push_back(oss.str()); - } - ++store_comm_seq_id; - - { - std::vector payload( - reinterpret_cast(&val), - reinterpret_cast(&val) + sizeof(T)); - store->set(peer_keys[rank], payload); - } - - std::vector peer_vals; - for (int r = 0; r < world_size; ++r) { - if (r == rank) { - peer_vals.push_back(val); - continue; - } - store->wait({peer_keys[r]}); - auto payload = store->get(peer_keys[r]); - TORCH_CHECK(payload.size() == sizeof(T)); - T peer_val{}; - std::memcpy(&peer_val, payload.data(), sizeof(T)); - peer_vals.push_back(peer_val); - } - return peer_vals; -} - -void store_barrier( - const c10::intrusive_ptr& store, - int rank, - int world_size) { - store_all_gather(store, rank, world_size, 0); -} - -int import_remote_fd(int pid, int fd) { -#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd) - int pidfd = syscall(SYS_pidfd_open, pid, 0); - return syscall(SYS_pidfd_getfd, pidfd, fd, 0); -#else - TORCH_CHECK( - false, - "CUDASymmetricMemory requires pidfd_open ", - "and pidfd_getfd support"); -#endif -} - -void map_block( - void** ptr, - c10d::symmetric_memory::HandleType handle, - size_t size, - int device_idx) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - auto dev_ptr = reinterpret_cast(ptr); - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL)); - - CUmemAccessDesc desc; - desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - desc.location.id = static_cast(device_idx); - desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1)); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -CUDASymmetricMemory::CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size) - : handles_(std::move(handles)), - block_size_(block_size), - buffers_(std::move(buffers)), - signal_pads_(std::move(signal_pads)), - buffer_size_(buffer_size), - local_device_idx_(local_device_idx), - rank_(rank), - world_size_(world_size) { - const size_t arr_size = sizeof(void*) * world_size_; - buffers_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - signal_pads_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - - c10::cuda::CUDAGuard guard(local_device_idx); - AT_CUDA_CHECK(cudaMemcpy( - buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); - AT_CUDA_CHECK(cudaMemcpy( - signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); -} - -CUDASymmetricMemory::~CUDASymmetricMemory() { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - c10::cuda::CUDAGuard guard(local_device_idx_); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - auto driver_api = c10::cuda::DriverAPI::get(); - for (int r = 0; r < world_size_; ++r) { - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(buffers_[r]), block_size_)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); - } - c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); - c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -std::vector CUDASymmetricMemory::get_buffer_ptrs() { - return buffers_; -} - -std::vector CUDASymmetricMemory::get_signal_pad_ptrs() { - return signal_pads_; -} - -void** CUDASymmetricMemory::get_buffer_ptrs_dev() { - return buffers_dev_; -} - -void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() { - return signal_pads_dev_; -} - -size_t CUDASymmetricMemory::get_buffer_size() { - return buffer_size_; -} - -size_t CUDASymmetricMemory::get_signal_pad_size() { - return signal_pad_size; -} - -at::Tensor CUDASymmetricMemory::get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) { - const auto numel = - std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); - const auto element_size = c10::elementSize(dtype); - const auto req_size = (numel + storage_offset) * element_size; - TORCH_CHECK( - req_size <= buffer_size_, - "CUDASymmetricMemory::get_buffer: the requested size (", - req_size, - " bytes) exceeds the allocated size (", - buffer_size_, - " bytes)"); - auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storage_offset) - .options(options) - .target_device(device) - .make_tensor(); -} - -void check_channel(int channel, int world_size) { - TORCH_CHECK( - channel >= 0, - "channel for barrier(), put_signal() and wait_signal() ", - "must be greater than 0 (got ", - channel, - ")"); - const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; - TORCH_CHECK( - static_cast(channel) < num_channels, - "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", - num_channels - 1, - " (got ", - channel, - ")"); -} - -__device__ __forceinline__ void release_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); -#endif -} - -__device__ __forceinline__ void acquire_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); -#endif -} - -static __global__ void barrier_kernel( - uint32_t** signal_pads, - int channel, - int rank, - int world_size) { - if (threadIdx.x < world_size) { - auto target_rank = threadIdx.x; - release_signal(signal_pads[target_rank] + world_size * channel + rank); - acquire_signal(signal_pads[rank] + world_size * channel + target_rank); - } -} - -void CUDASymmetricMemory::barrier(int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void put_signal_kernel( - uint32_t** signal_pads, - int dst_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - release_signal(signal_pads[dst_rank] + world_size * channel + rank); - } -} - -void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - dst_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void wait_signal_kernel( - uint32_t** signal_pads, - int src_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - acquire_signal(signal_pads[rank] + world_size * channel + src_rank); - } - __threadfence_system(); -} - -void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - src_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -int CUDASymmetricMemory::get_rank() { - return rank_; -} - -int CUDASymmetricMemory::get_world_size() { - return world_size_; -} - -void* CUDASymmetricMemoryAllocator::alloc( - size_t size, - int device_idx, - const std::string& group_name) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - prop.location.id = device_idx; - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; - - size_t signal_pad_offset = at::round_up(size, 16UL); - size_t block_size = signal_pad_offset + signal_pad_size; - - size_t granularity; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( - &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); - block_size = at::round_up(block_size, granularity); - - HandleType handle; - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); - - void* ptr = nullptr; - map_block(&ptr, handle, block_size, device_idx); - - c10::cuda::CUDAGuard guard(device_idx); - AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); - - auto block = c10::make_intrusive( - handle, device_idx, block_size, size, signal_pad_offset, group_name); - { - std::unique_lock lock(mutex_); - ptr_to_block_.emplace(ptr, std::move(block)); - } - return ptr; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -void CUDASymmetricMemoryAllocator::free(void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - if (block == nullptr) { - return; - } - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. - if (block->symm_mem == nullptr) { - auto driver_api = c10::cuda::DriverAPI::get(); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(ptr), block->block_size)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); - } - { - std::unique_lock lock(mutex_); - ptr_to_block_.erase(ptr); - } -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->buffer_size; -} - -struct RendezvousRequest { - int device_idx; - int block_fd; - int pid; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; -}; - -void validate_rendezvous_requests( - const std::vector reqs, - int world_size) { - TORCH_CHECK(reqs.size() == (size_t)world_size); - - std::unordered_set device_indices; - device_indices.reserve(world_size); - for (auto req : reqs) { - device_indices.insert(req.device_idx); - } - if (device_indices.size() < (size_t)world_size) { - TORCH_CHECK( - false, - "CUDASymmetricMemoryAllocator::rendezvous: ", - "detected allocations from overlapping devices ", - "from different ranks."); - } - - for (int r = 1; r < world_size; ++r) { - TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); - TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); - TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); - } -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( - void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - - if (block->symm_mem != nullptr) { - return block->symm_mem; - } - - auto group_info = get_group_info(block->group_name); - auto driver_api = c10::cuda::DriverAPI::get(); - int block_fd; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); - - auto local_req = RendezvousRequest{ - .device_idx = block->device_idx, - .block_fd = block_fd, - .pid = getpid(), - .block_size = block->block_size, - .buffer_size = block->buffer_size, - .signal_pad_offset = block->signal_pad_offset}; - auto reqs = store_all_gather( - group_info.store, group_info.rank, group_info.world_size, local_req); - validate_rendezvous_requests(reqs, group_info.world_size); - - std::vector handles(group_info.world_size); - std::vector buffers(group_info.world_size, nullptr); - std::vector signal_pads(group_info.world_size, nullptr); - for (int r = 0; r < group_info.world_size; ++r) { - if (r == group_info.rank) { - handles[r] = block->handle; - buffers[r] = ptr; - signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); - continue; - } - int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( - &handles[r], - (void*)(uintptr_t)imported_fd, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - map_block(&buffers[r], handles[r], block->block_size, block->device_idx); - signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); - close(imported_fd); - } - store_barrier(group_info.store, group_info.rank, group_info.world_size); - close(block_fd); - - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. So that outstanding - // references to the CUDASymmetricMemory object can keep the allocation - // alive. - block->symm_mem = c10::make_intrusive( - std::move(handles), - block->block_size, - std::move(buffers), - std::move(signal_pads), - block->buffer_size, - block->device_idx, - group_info.rank, - group_info.world_size); - return block->symm_mem; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->symm_mem != nullptr; -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { - std::shared_lock lock(mutex_); - auto it = ptr_to_block_.find(ptr); - if (it == ptr_to_block_.end()) { - return nullptr; - } - return it->second; -} - -struct RegisterCUDASymmetricMemoryAllocator { - RegisterCUDASymmetricMemoryAllocator() { - register_allocator( - c10::DeviceType::CUDA, - c10::make_intrusive()); - } -}; - -static RegisterCUDASymmetricMemoryAllocator register_allocator_; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh deleted file mode 100644 index 0e0e40a6bd091..0000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh +++ /dev/null @@ -1,109 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace c10d { -namespace symmetric_memory { - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -using HandleType = CUmemGenericAllocationHandle; -#else -using HandleType = void*; -#endif - -class CUDASymmetricMemory : public SymmetricMemory { - public: - CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size); - - ~CUDASymmetricMemory() override; - - std::vector get_buffer_ptrs() override; - std::vector get_signal_pad_ptrs() override; - void** get_buffer_ptrs_dev() override; - void** get_signal_pad_ptrs_dev() override; - size_t get_buffer_size() override; - size_t get_signal_pad_size() override; - - at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) override; - - void barrier(int channel) override; - void put_signal(int dst_rank, int channel) override; - void wait_signal(int src_rank, int channel) override; - - int get_rank() override; - int get_world_size() override; - - private: - std::vector handles_; - size_t block_size_; - std::vector buffers_; - std::vector signal_pads_; - size_t buffer_size_; - int local_device_idx_; - int rank_; - int world_size_; - void** buffers_dev_; - void** signal_pads_dev_; - std::optional> finalizer_; -}; - -struct Block : public c10::intrusive_ptr_target { - HandleType handle; - int device_idx; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; - std::string group_name; - c10::intrusive_ptr symm_mem = nullptr; - - Block( - HandleType handle, - int device_idx, - size_t block_size, - size_t buffer_size, - size_t signal_pad_offset, - const std::string& group_name) - : handle(handle), - device_idx(device_idx), - block_size(block_size), - buffer_size(buffer_size), - signal_pad_offset(signal_pad_offset), - group_name(group_name), - symm_mem(nullptr) {} -}; - -class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { - public: - void* alloc( - size_t size, - int device_idx, - const std::string& group_name) override; - - void free(void *ptr) override; - size_t get_alloc_size(void* ptr) override; - c10::intrusive_ptr rendezvous(void* ptr) override; - bool is_rendezvous_completed(void* ptr) override; - - private: - c10::intrusive_ptr find_block(void* ptr); - - std::shared_mutex mutex_; - std::unordered_map> ptr_to_block_; -}; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index 7c41414c4e4e1..cff4ad09b7064 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -10,7 +10,6 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout = namespace c10d { -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API ProcessGroupCudaP2P : public Backend { public: struct Options : Backend::Options { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp deleted file mode 100644 index b3d9f31bb0342..0000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include - -namespace { - -using namespace c10d::symmetric_memory; - -class AllocatorMap { - public: - static AllocatorMap& get() { - static AllocatorMap instance; - return instance; - } - - void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - map_[device_type] = std::move(allocator); - } - - c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - auto it = map_.find(device_type); - TORCH_CHECK( - it != map_.end(), - "SymmetricMemory does not support device type ", - device_type); - return it->second; - } - - ~AllocatorMap() { - for (auto& it : map_) { - it.second.release(); - } - } - - private: - AllocatorMap() = default; - AllocatorMap(const AllocatorMap&) = delete; - AllocatorMap& operator=(const AllocatorMap&) = delete; - - std::unordered_map< - c10::DeviceType, - c10::intrusive_ptr> - map_; -}; - -static std::unordered_map group_info_map{}; - -// Data structures for tracking persistent allocations -static std::unordered_map alloc_id_to_dev_ptr{}; -static std::unordered_map> - alloc_id_to_storage{}; - -static at::Tensor empty_strided_p2p_persistent( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - uint64_t alloc_id) { - // Make the allocation fails if a previous allocation with the same alloc_id - // is still active. - auto storage = alloc_id_to_storage.find(alloc_id); - if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) { - TORCH_CHECK( - false, - "SymmetricMemory::empty_strided_p2p_persistent: ", - "can not allocate with alloc_id == ", - alloc_id, - " because a previous allocation with the same alloc_id " - "is still active."); - } - - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = nullptr; - if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) { - dev_ptr = alloc_id_to_dev_ptr[alloc_id]; - TORCH_CHECK( - alloc_size == allocator->get_alloc_size(dev_ptr), - "SymmetricMemory::empty_strided_p2p_persistent: ", - "requested allocation size (", - alloc_size, - ") is different from the size of a previous allocation ", - "with the same alloc_id ", - allocator->get_alloc_size(dev_ptr)); - } else { - dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - alloc_id_to_dev_ptr[alloc_id] = dev_ptr; - } - - auto options = at::TensorOptions().dtype(dtype).device(device); - auto allocated = at::from_blob(dev_ptr, size, stride, options); - - // Track the allocation's activeness - alloc_id_to_storage.erase(alloc_id); - alloc_id_to_storage.emplace( - alloc_id, allocated.storage().getWeakStorageImpl()); - return allocated; -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - return AllocatorMap::get().register_allocator( - device_type, std::move(allocator)); -} - -c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - return AllocatorMap::get().get_allocator(device_type); -} - -void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store) { - TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end()); - GroupInfo group_info; - group_info.rank = rank; - group_info.world_size = world_size; - group_info.store = std::move(store); - group_info_map.emplace(group_name, std::move(group_info)); -} - -const GroupInfo& get_group_info(const std::string& group_name) { - TORCH_CHECK( - group_info_map.find(group_name) != group_info_map.end(), - "get_group_info: no group info associated with the group name ", - group_name); - return group_info_map[group_name]; -} - -at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id) { - if (alloc_id.has_value()) { - return empty_strided_p2p_persistent( - size, stride, dtype, device, group_name, *alloc_id); - } - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::from_blob( - dev_ptr, - size, - stride, - [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); }, - options); -} - -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - return allocator->rendezvous(tensor.data_ptr()); -} - -c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - TORCH_CHECK( - allocator->is_rendezvous_completed(tensor.data_ptr()), - "SymmetricMemory: must invoke rendezvous on a tensor ", - "before calling get_symmetric_memory on it"); - return allocator->rendezvous(tensor.data_ptr()); -} - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp deleted file mode 100644 index 344b86ea5c7e3..0000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ /dev/null @@ -1,152 +0,0 @@ -#pragma once - -#include -#include - -namespace c10d { -namespace symmetric_memory { - -// SymmetricMemory represents symmetric allocations across a group of devices. -// The allocations represented by a SymmetricMemory object are accessible by -// all devices in the group. The class can be used for op-level custom -// communication patterns (via the get_buffer APIs and the synchronization -// primitives), as well as custom communication kernels (via the buffer and -// signal_pad device pointers). -// -// To acquire a SymmetricMemory object, each rank first allocates -// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes -// SymmetricMemoryAllocator::rendezvous() on the memory to establish the -// association across peer buffers. The rendezvous is a one-time process, and -// the mapping between a local memory memory and the associated SymmetricMemory -// object is unique. -// -// NOTE [symmetric memory signal pad] -// Signal pads are P2P-accessible memory regions designated for -// synchronization. SymmetricMemory offers built-in synchronization primitives -// such as barriers, put_signal, and wait_signal, which are all based on signal -// pads. Users may utilize signal pads for their own synchronization logic, -// provided that the signal pads remain zero-filled following successful -// synchronization. -// -// NOTE [symmetric memory synchronization channel] -// Synchronization channels allow users to use a single SymmetricMemory object -// to perform isolated synchronizations on different streams. For example, -// consider the case in which two barriers are issued on two streams for -// different purposes. Without the concept of channels, we cannot guarantee the -// correctness of the barriers since signals issued from barrier on stream A -// can be received by the barrier on stream B. By specifying different channels -// for these two barriers, they can operate correctly in parallel. -class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemory() {} - - virtual std::vector get_buffer_ptrs() = 0; - virtual std::vector get_signal_pad_ptrs() = 0; - - // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer - // to a device array of size world_size, containing buffer pointers and - // signal pad pointers, respectively. - virtual void** get_buffer_ptrs_dev() = 0; - virtual void** get_signal_pad_ptrs_dev() = 0; - virtual size_t get_buffer_size() = 0; - virtual size_t get_signal_pad_size() = 0; - - virtual at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) = 0; - - virtual void barrier(int channel) = 0; - virtual void put_signal(int dst_rank, int channel) = 0; - virtual void wait_signal(int src_rank, int channel) = 0; - - virtual int get_rank() = 0; - virtual int get_world_size() = 0; -}; - -class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemoryAllocator(){}; - - virtual void* alloc( - size_t size, - int device_idx, - const std::string& group_name) = 0; - - virtual void free(void* ptr) = 0; - virtual size_t get_alloc_size(void* ptr) = 0; - virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; - virtual bool is_rendezvous_completed(void* ptr) = 0; -}; - -C10_EXPORT void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator); - -C10_EXPORT c10::intrusive_ptr get_allocator( - c10::DeviceType device_type); - -// Set a store for rendezvousing symmetric allocations on a group of devices -// identified by `group_name`. The concept of groups is logical; users can -// utilize predefined groups (e.g., a group of device identified by a -// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator -// backends might employ a more efficient communication channel for the actual -// rendezvous process and only use the store for bootstrapping purposes. -TORCH_API void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store); - -struct GroupInfo { - int rank; - int world_size; - c10::intrusive_ptr store; -}; - -C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); - -// Identical to empty_strided, but allows symmetric memory access to be -// established for the allocated tensor via SymmetricMemory::rendezvous(). This -// function itself is not a collective operation. It invokes -// SymmetricMemoryAllocator::alloc() for the requested device under the hood. -// -// NOTE [symmetric memory persistent allocation] -// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent -// allocation. This makes the function cache allocated memory and ensure that -// invocations with the same `alloc_id` receive tensors backed by the same -// memory address. For safety, if a previous persistent allocation is still -// active (i.e., the storage of the returned tensor is still alive), persistent -// allocations with the same `alloc_id` will fail. This determinism coupled -// with memory planning of communication buffers (e.g., by Inductor) allows -// communication algorithms to reliably reuse previously established remote -// memory access. -TORCH_API at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id); - -// Establishes symmetric memory access on tensors allocated via -// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a -// one-time process, and the mapping between a local memory region and the -// associated SymmetricMemory object is unique. Subsequent calls to -// rendezvous() with the same tensor, or tensors allocated with -// empty_strided_p2p_persistent() using the same alloc_id, will receive the -// cached SymmetricMemory object. -// -// The function has a collective semantic and must be invoked simultaneously -// from all rendezvous participants. -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor); - -// Returns the SymmetricMemory object associated with the tensor. It can only -// be invoked after rendezvous() but does not need to be invoked collectively. -TORCH_API c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor); - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index db5778efcf354..6f1b28886b989 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -41,7 +41,6 @@ #include #include #include -#include #include #include @@ -976,44 +975,6 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); - using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; - py::class_>( - module, "_SymmetricMemory") - .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info) - .def_static( - "empty_strided_p2p", - ::c10d::symmetric_memory::empty_strided_p2p, - py::arg("size"), - py::arg("stride"), - py::arg("dtype"), - py::arg("device"), - py::arg("group_name"), - py::arg("alloc_id") = py::none()) - .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) - .def_static( - "get_symmetric_memory", - &::c10d::symmetric_memory::get_symmetric_memory) - .def_property_readonly("rank", &SymmetricMemory::get_rank) - .def_property_readonly("world_size", &SymmetricMemory::get_world_size) - .def( - "get_buffer", - &SymmetricMemory::get_buffer, - py::arg("rank"), - py::arg("sizes"), - py::arg("dtype"), - py::arg("storage_offset") = 0) - .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) - .def( - "put_signal", - &SymmetricMemory::put_signal, - py::arg("dst_rank"), - py::arg("channel") = 0) - .def( - "wait_signal", - &SymmetricMemory::wait_signal, - py::arg("src_rank"), - py::arg("channel") = 0); - auto store = py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 9d7ba5abf951d..85136a91e0256 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -218,8 +218,23 @@ IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { return; } - auto allocator = get_allocator(c10::DeviceType::CUDA); - allocator->free(symmetricMemoryPtr_); + // Intentionally releasing resources without synchronizing devices. The + // teardown logic is safe for propoerly sync'd user program. We don't want + // improperly sync'd user program to hang here. + for (size_t r = 0; r < worldSize_; ++r) { + if (r == rank_) { + continue; + } + AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); + AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); + } + AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); + AT_CUDA_CHECK(cudaFree(buffers_[rank_])); + if (topoInfo_ != nullptr) { + AT_CUDA_CHECK(cudaFree(topoInfo_)); + } + AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); + AT_CUDA_CHECK(cudaFree(buffersDev_)); } bool IntraNodeComm::isEnabled() { @@ -329,19 +344,83 @@ bool IntraNodeComm::rendezvous() { // Detect topology Topology topology = detectTopology(nvlMesh, worldSize_); - set_group_info("IntraNodeComm", rank_, worldSize_, store_); - auto allocator = get_allocator(c10::DeviceType::CUDA); - symmetricMemoryPtr_ = - allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm"); - symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); - TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); + // Initialize p2p state + auto p2pState = initP2pState(); + + // Allocate buffer + void* buffer = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_)); + + // Second handshake: exchange topology and CUDA IPC handles + struct IpcInfo { + NvlMesh nvlMesh; + Topology topology; + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + }; + + // Make p2p state and buffer available for IPC + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); + AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); + + IpcInfo ipcInfo{ + .nvlMesh = nvlMesh, + .topology = topology, + .p2pStateHandle = p2pStateHandle, + .bufferHandle = bufferHandle}; + + auto peerIpcInfos = + storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo); + + for (const auto& info : peerIpcInfos) { + if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || + info.topology != peerIpcInfos.front().topology) { + LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " + "participants are observing different topologies (" + << int(info.topology) << " and " << int(topology) << ")"; + AT_CUDA_CHECK(cudaFree(p2pState)); + AT_CUDA_CHECK(cudaFree(buffer)); + return false; + } + } + + std::array p2pStates = {}, buffers = {}; + for (size_t r = 0; r < peerIpcInfos.size(); ++r) { + if (r == rank_) { + p2pStates[r] = p2pState; + buffers[r] = buffer; + } else { + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &p2pStates[r], + peerIpcInfos[r].p2pStateHandle, + cudaIpcMemLazyEnablePeerAccess)); + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &buffers[r], + peerIpcInfos[r].bufferHandle, + cudaIpcMemLazyEnablePeerAccess)); + } + } + void* p2pStatesDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); + AT_CUDA_CHECK(cudaMemcpy( + p2pStatesDev, + p2pStates.data(), + sizeof(p2pStates), + cudaMemcpyHostToDevice)); + + void* buffersDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); + AT_CUDA_CHECK(cudaMemcpy( + buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); isInitialized_ = true; topology_ = topology; - p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); - buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); + std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin()); + std::copy(buffers.begin(), buffers.end(), buffers_.begin()); + p2pStatesDev_ = p2pStatesDev; + buffersDev_ = buffersDev; topoInfo_ = topoInfo; return true; #endif diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index ac751ff7be1e0..51fc6252d2235 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -132,8 +132,6 @@ struct P2pState { uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; }; -static_assert(sizeof(P2pState) <= kP2pStateSize); - template static __global__ void oneShotAllReduceKernel( at::BFloat16* input, @@ -524,7 +522,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks; if (!fuseInputCopy) { AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -584,7 +582,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -634,7 +632,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -757,7 +755,15 @@ at::Tensor IntraNodeComm::getBuffer( const std::vector& sizes, c10::ScalarType dtype, int64_t storageOffset) { - return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); + const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0); + const auto elementSize = c10::elementSize(dtype); + TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_); + auto options = at::TensorOptions().dtype(dtype).device( + at::kCUDA, at::cuda::current_device()); + return at::for_blob(buffers_[rank], sizes) + .storage_offset(storageOffset) + .options(options) + .make_tensor(); } } // namespace intra_node_comm diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index a67df5c34586a..5d7e2d426d30a 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -4,16 +4,12 @@ #include #include #include -#include #include namespace c10d::intra_node_comm { -using namespace c10d::symmetric_memory; - constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; -constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; using HybridCubeMesh = std::array, kMaxDevices>; @@ -31,7 +27,6 @@ enum class AllReduceAlgo : uint8_t { HCM = 3 }; -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { public: IntraNodeComm( @@ -102,8 +97,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool isInitialized_ = false; Topology topology_ = Topology::UNKNOWN; - void* symmetricMemoryPtr_ = nullptr; - c10::intrusive_ptr symmetricMemory_ = nullptr; + std::array p2pStates_{}; + std::array buffers_{}; void* p2pStatesDev_{}; void* buffersDev_{}; void* topoInfo_{}; From 1877b7896c237567285804ecc138bc86180a7ced Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 18 Jun 2024 07:49:05 -0700 Subject: [PATCH 45/63] [checkpoint] Clean up selective activation checkpoint and make public (#125795) ### bc-breaking for existing users of the private API: - Existing policy functions must now change their return value to be [CheckpointPolicy](https://github.com/pytorch/pytorch/blob/c0b40ab42e38a208351911496b7153511304f8da/torch/utils/checkpoint.py#L1204-L1230) Enum instead of bool. - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy. - Policy function now accepts a `ctx` object instead of `mode` for its first argument. - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`. - Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint). Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit Memory considerations: - As with the existing SAC, cached values are cleared upon first use. - We error if the user wishes to backward a second time on a region forwarded with SAC enabled. In-place: - We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed. - `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place) Randomness, views - Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors) Tensor object preservation - ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor. Policy function - Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error. - The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3). - The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE : metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute. - The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below). Tensors guaranteed to be the same tensor as-is - Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795 Approved by: https://github.com/Chillee, https://github.com/fmassa --- docs/source/checkpoint.rst | 3 + test/dynamo/test_activation_checkpointing.py | 27 +- test/test_autograd.py | 416 ++++++++++++++++++- torch/_higher_order_ops/wrap.py | 6 +- torch/utils/checkpoint.py | 316 +++++++++----- 5 files changed, 643 insertions(+), 125 deletions(-) diff --git a/docs/source/checkpoint.rst b/docs/source/checkpoint.rst index f7bc160fa98bd..8559d8bd73663 100644 --- a/docs/source/checkpoint.rst +++ b/docs/source/checkpoint.rst @@ -35,3 +35,6 @@ torch.utils.checkpoint .. autofunction:: checkpoint .. autofunction:: checkpoint_sequential .. autofunction:: set_checkpoint_debug_enabled +.. autoclass:: CheckpointPolicy +.. autoclass:: SelectiveCheckpointContext +.. autofunction:: create_selective_checkpoint_contexts diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 14851e51895b4..274e033028451 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -19,7 +19,11 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor -from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint +from torch.utils.checkpoint import ( + checkpoint, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( @@ -105,8 +109,11 @@ def op_count(gm): def _get_custom_policy(no_recompute_list=None): - def _custom_policy(mode, func, *args, **kwargs): - return func in no_recompute_list + def _custom_policy(ctx, func, *args, **kwargs): + if func in no_recompute_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE return _custom_policy @@ -530,7 +537,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -580,7 +587,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -650,7 +657,7 @@ def _custom_policy(mode, func, *args, **kwargs): def selective_checkpointing_context_fn(): meta = {} - return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) def gn(x, y): return torch.sigmoid( @@ -698,7 +705,7 @@ def fn(x, y): ) def test_compile_selective_checkpoint_partial_ctx_fn(self): def selective_checkpointing_context_fn(no_recompute_list): - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -751,7 +758,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list), ) @@ -803,7 +810,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -854,7 +861,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) diff --git a/test/test_autograd.py b/test/test_autograd.py index c133ae95b4b3d..e45f5d47c6925 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2,6 +2,7 @@ import collections import contextlib +import functools import gc import io import math @@ -79,8 +80,14 @@ ) from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch.utils.checkpoint import ( + checkpoint, + checkpoint_sequential, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) from torch.utils.cpp_extension import load_inline +from torch.utils.flop_counter import FlopCounterMode from torch.utils.hooks import RemovableHandle # noqa: TCH001 @@ -13215,6 +13222,413 @@ def fn2(x): self.assertEqual(counter[0], 1) +class TestSelectiveActivationCheckpoint(TestCase): + @unittest.skipIf(not TEST_CUDA, "requires CUDA") + def test_flops_and_mem(self): + # From https://github.com/pytorch/pytorch/pull/126320 + def get_act_mem(f): + out = f() + out.backward() + # Why do one forward and backward? + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + out = f() + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) + out.backward() + return act_mem + + def get_bw_flops(f): + # Normalized so that a 512 square matmul returns 1 + f().backward() + out = f() + # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so + # it will be able to observe whether an op is cached or not. + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + x = torch.randn(512, 512, requires_grad=True, device="cuda") + y = torch.randn(512, 512, requires_grad=True, device="cuda") + + def fn(x, y): + return torch.mm(x.cos(), y).sin().sum() + + def fn_ac(x, y): + return checkpoint(fn, x, y, use_reentrant=False) + + def fn_sac(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + [ + torch.ops.aten.mm.default, + ], + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn_sac2(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + act_mem_noac = get_act_mem(lambda: fn(x, y)) + bw_flops_noac = get_bw_flops(lambda: fn(x, y)) + + self.assertEqual(act_mem_noac, 2.0) + self.assertEqual(bw_flops_noac, 2.0) + + act_mem_ac = get_act_mem(lambda: fn_ac(x, y)) + bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y)) + + self.assertEqual(act_mem_ac, 0.0) + self.assertEqual(bw_flops_ac, 3.0) + + act_mem_sac = get_act_mem(lambda: fn_sac(x, y)) + bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y)) + + self.assertEqual(act_mem_sac, 1.0) + self.assertEqual(bw_flops_sac, 2.0) + + act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y)) + bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y)) + + self.assertEqual(act_mem_sac2, 1.0) + self.assertEqual(bw_flops_sac2, 2.0) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_output_already_has_autograd_meta(self): + # View of tensor of non-differentiable dtype still has AutogradMeta + def fn(x, y): + return x.view(-1), y.sin().cos() + + x = torch.tensor([1, 2, 3], dtype=torch.int64) + y = torch.randn(3, requires_grad=True) + + context_fn = functools.partial( + create_selective_checkpoint_contexts, + [ + torch.ops.aten.view.default, + ], + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + out[1].sum().backward() + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_subclass_dispatching_sizes(self): + # Test that we ignore ops that grab metadata like torch.ops.aten.sym_size.default + # Caching such metadata ops can be problematic when the following are satisfied: + # + # 1. size/strides are dispatched upon + # 2. our policy saves sizes + ta = torch.randn(6, 2) + + class CustomSizeDynamicShapesTensor(torch.Tensor): + @staticmethod + def __new__(cls, inner): + return torch.Tensor._make_wrapper_subclass( + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + inner.size(), + inner.stride(), + None, + None, + inner.dtype, + inner.layout, + inner.device, + False, + inner.requires_grad, + "sizes", + ) + + def __init__(self, inner): + self.inner = inner + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_inner = torch.utils._pytree.tree_map_only( + cls, lambda x: x.inner, args + ) + out_inner = func(*args_inner, **kwargs) + return torch.utils._pytree.tree_map_only( + torch.Tensor, lambda x: cls(x), out_inner + ) + + def policy_fn(ctx, op, *args, **kwargs): + if op is torch.ops.aten.sym_size.default: + # Silently ignored! + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + # We avoid the following case + # + # saved :[4, 3], [], [], [4, 3], [4, 3], [4, 3], [12] + # forward :sum ,sum,mul, mul , mul ,view , view + # recompute :sum ,sum,mul, view , view + # + # Views save the shape of their input, so we expect the second + # view to save 12, but because during AC packing during forward + # saves the shapes of the input for metadata checks later, + # we would save the wrong shape during the recompute. + view_out = (x * x.sum()).view(-1).view(4, 3) + self.assertEqual(view_out.grad_fn._saved_self_sym_sizes, [12]) + return view_out.exp() + + x = torch.randn(4, 3, requires_grad=True) + x_wrapper = CustomSizeDynamicShapesTensor(x) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x_wrapper, use_reentrant=False, context_fn=context_fn) + out.sum().backward() + + def test_bad_inputs(self): + bad_op_list1 = [2] + + with self.assertRaisesRegex( + ValueError, "Expected op in `op_list` to be an OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list1) + + bad_op_list2 = [torch.ops.aten.sin] + + with self.assertRaisesRegex( + ValueError, "update the OpOverloadPacket to a specific OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list2) + + with self.assertRaisesRegex(TypeError, "either a function or a list of ops."): + create_selective_checkpoint_contexts(2) + + # Dynamo fails for various reasons: + # - some tests using custom op that does not implement Fake + # - dynamo is trying to trace into saved variable hooks unpack hook for some reason + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_policy_with_state(self): + # If I have a stateful callable, state is shared between the original + # forward and the recompute. + counters = [] + + class Policy: + def __init__(self): + self.counter = [0] + self.recompute_counter = [0] + + def __call__(self, ctx, func, *args, **kwargs): + counter = self.recompute_counter if ctx.is_recompute else self.counter + counter[0] += 1 + counters.append(counter[0]) + if counter == 1 and func is torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().sin().sin() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + Policy(), + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward() + # 1. counter properly reset to 0 for the recompute + # 2. due to early-stop we do not recompute the final op + self.assertEqual(counters, [1, 2, 3, 1, 2]) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_storage_lifetime(self): + from torch.utils._python_dispatch import _get_current_dispatch_mode + from torch.utils.checkpoint import ( + _CachedTorchDispatchMode, + _CachingTorchDispatchMode, + ) + + def policy_fn(ctx, op, *args, **kwargs): + return CheckpointPolicy.MUST_SAVE + + ref = None + + def fn(x): + nonlocal ref + + self.assertIsInstance( + _get_current_dispatch_mode(), + (_CachingTorchDispatchMode, _CachedTorchDispatchMode), + ) + + out = x.cos().exp() + + if isinstance(_get_current_dispatch_mode(), _CachingTorchDispatchMode): + raw_val = ( + _get_current_dispatch_mode() + .storage[torch.ops.aten.exp.default][0] + .val + ) + # ref should've been detached + # to avoid graph -> the saved variable hooks -> recompute_context -> storage -> graph + self.assertFalse(raw_val.requires_grad) + ref = weakref.ref(raw_val) + + # Careful for early-stop + return out.sin() + + with disable_gc(): + # Case 1: If graph goes away without backward, make sure there's no reference cycle + # keeping storage alive. + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, policy_fn + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + self.assertIsNotNone(ref()) + del out + self.assertIsNone(ref()) + + # Case 2: After backward, even if retain_graph=True, the storage should go away + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, policy_fn + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + self.assertIsNotNone(ref()) + out.sum().backward(retain_graph=True) + # The dispatch mode's storage should still be alive, but the entries should've + # been cleared. + self.assertIsNone(ref()) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_version_counter(self): + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().mul_(2).cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 1) Error because the output of sin is saved and mutated by mul_ + with self.assertRaisesRegex(RuntimeError, "has been mutated"): + out.sum().backward() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 2) No longer should be an error because of allow_cache_entry_mutation + out.sum().backward() + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_more_than_one_output(self): + # maybe there is a more systematic way: + counter = [0] + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.var_mean.correction: + counter[0] += 1 + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + # var_mean has two outputs + def fn(x): + a, b = torch.var_mean(x) + return a * b + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + self.assertEqual(counter[0], 2) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_non_tensor_output(self): + # When SAC is enabled, the op is not computed a second time + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + counter = [0] + + @torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) + def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + counter[0] += 1 + return x.sin(), 2 + + def setup_context(ctx, inputs, output) -> torch.Tensor: + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad, _unused): + (x,) = ctx.saved_tensors + return grad * x.cos() + + torch.library.register_autograd( + "mylib::sin_with_extra", backward, setup_context=setup_context + ) + + x = torch.randn(3, requires_grad=True) + + def fn(x): + return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin() + + ops_list = [torch.ops.mylib.sin_with_extra.default] + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, ops_list + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + self.assertEqual(counter[0], 1) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_can_only_trigger_recompute_once(self): + # We don't support this to avoid adding extra complexity for now. + # If there's a need, we could probably do some kind of use_count tracking. + # TODO: have a nice error message here. + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward(retain_graph=True) + + with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"): + out.sum().backward(retain_graph=True) + + class TestAutogradMultipleDispatch(TestCase): def test_autograd_multiple_dispatch_registrations(self, device): t = torch.randn(3, 3, device=device, requires_grad=True) diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 6d83a44e752a0..e7fe553387d1c 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -1,15 +1,17 @@ # mypy: allow-untyped-defs import inspect +import itertools import logging import torch from torch._ops import HigherOrderOperator -from torch.utils.checkpoint import checkpoint, uid +from torch.utils.checkpoint import checkpoint + import torch._dynamo.config log = logging.getLogger(__name__) - +uid = itertools.count(1) # Used for testing the HigherOrderOperator mechanism class Wrap(HigherOrderOperator): diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 5cbfd1543cf42..dab7730d84397 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -5,18 +5,8 @@ import warnings import weakref from collections import defaultdict -from itertools import count -from typing import ( - Any, - Callable, - ContextManager, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Tuple, -) +from typing import * # noqa: F403 +import enum from weakref import ReferenceType import torch @@ -39,6 +29,10 @@ "set_checkpoint_early_stop", "DefaultDeviceType", "set_checkpoint_debug_enabled", + "CheckpointPolicy", + "SelectiveCheckpointContext", + "create_selective_checkpoint_contexts", + "SAC_IGNORED_OPS", ] _DEFAULT_DETERMINISM_MODE = "default" @@ -1153,149 +1147,247 @@ def _is_compiling(func, args, kwargs): return False -def _detach(x): - if isinstance(x, torch.Tensor): - return x.detach() +class _VersionWrapper: + # Check that cached tensors are not mutated. + def __init__(self, val): + self.val: Union[torch.Tensor, Any] = val + self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None + + def get_val(self, allow_cache_entry_mutation): + if self.version is not None and not allow_cache_entry_mutation: + if self.val._version != self.version: + # Can we give user a stack trace of where the mutation happened? + raise RuntimeError( + "Tensor cached during selective activation checkpoint has been mutated" + ) + return self.val + + +def _maybe_detach(x, any_ret_has_alias_info): + # We detach for two separate reasons: + # - For view ops, we need to ensure that when the tensor is returned from + # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr + # - Avoid reference cycles + # For case 1, it is not enough to check whether x has differentiable dtype + # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. + # when the tensor is a view. + if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): + with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): + # Ensure that view performed beneath autograd properly propagates + # version counter. TODO: Use reentrant_dispatch instead of + # manually manipulating dispatch keys. Using reentrant_dispatch + # would respect inference_mode, though that is not relevant for + # this case. + x = x.detach() return x -uid = count(1) +class SelectiveCheckpointContext: + """ + Context passed to policy function during selective checkpointing. + This class is used to pass relevant metadata to the policy function during + selective checkpointing. The metadata includes whether the current invocation + of the policy function is during recomputation or not. -# NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times -# (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call), -# so we ignore these ops and just always recompute them. -_ignored_ops = { - torch.ops.prim.device.default, + Example: + >>> # xdoctest: +SKIP(stub) + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> print(ctx.is_recompute) + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + def __init__(self, *, is_recompute): + self.is_recompute = is_recompute + + +class CheckpointPolicy(enum.Enum): + """ + Enum for specifying the policy for checkpointing during backpropagation. + + The following policies are supported: + + - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward + pass and will not be recomputed during the backward pass + - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the + forward pass and will be recomputed during the backward pass + + Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden + by other subsystems like `torch.compile`. + + .. note:: + A policy function that always returns ``PREFER_RECOMPUTE`` is + equivalent to vanilla checkpointing. + + A policy function that returns ``PREFER_SAVE`` every op is + NOT equivalent to not using checkpointing. Using such a policy would + save additional tensors not limited to ones that are actually needed for + gradient computation. + """ + MUST_SAVE = 0 + PREFER_SAVE = 1 + MUST_RECOMPUTE = 2 + PREFER_RECOMPUTE = 3 + + +SAC_IGNORED_OPS = { + # AC inserts different number of detach during forward and recompute. torch.ops.aten.detach.default, + # AC's determinism check invokes additional metadata ops during forward. + # With subclasses involved, these metadata ops become dispatchable, this + # can result in incorrectness if these ops are selected cached. + torch.ops.prim.device.default, } | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) class _CachingTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachedTorchDispatchMode. - """ + # Used together with _CachedTorchDispatchMode to implement SAC. def __init__(self, policy_fn, storage): self.policy_fn = policy_fn self.storage = storage - def push_into_storage(self, out, func, args, kwargs): - out_detached = tree_map(_detach, out) - self.storage[func].append(out_detached) + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) - def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs): - if should_not_recompute: + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) + + if is_compiling and policy == CheckpointPolicy.MUST_SAVE: fx_traceback.current_meta["recompute"] = 0 - # NOTE: Here we just store and reuse output of all ops, since in torch.compile mode - # we decide and handle recomputation in the partitioner. + out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - return out - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: - return func(*args, **kwargs) - should_not_recompute = self.policy_fn("forward", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs) - else: - if should_not_recompute: - out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) + return out class _CachedTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachingTorchDispatchMode. - """ - def __init__(self, policy_fn, storage): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage, allow_cache_entry_mutation): self.policy_fn = policy_fn self.storage = storage - - def pop_from_storage(self, func, args, kwargs): - assert func in self.storage - out = self.storage[func].pop(0) - return out - - def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs): - out = self.pop_from_storage(func, args, kwargs) - return out + self.allow_cache_entry_mutation = allow_cache_entry_mutation def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: + if func in SAC_IGNORED_OPS: return func(*args, **kwargs) - should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + storage = self.storage.get(func) + if storage is None: + raise RuntimeError(f"{func} encountered during backward, but not found in storage") + if len(storage) == 0: + raise RuntimeError( + "Trying to backward an extra time. You are only allowed to backward once " + "on any region computed under selective activation checkpoint." + ) + out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) else: - if should_not_recompute: - out = self.pop_from_storage(func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + out = func(*args, **kwargs) + return out -def _pt2_selective_checkpoint_context_fn_gen(policy_fn): + +def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): """ - A helper function that generates a pair of contexts to be later passed into - `torch.utils.checkpoint` API to implment selective checkpointing. + Helper to avoid recomputing certain ops during activation checkpointing. - .. warning:: - This is context_fn is intended for use with torch.compile only. + Use this with `torch.utils.checkpoint.checkpoint` to control which + operations are recomputed during the backward pass. Args: - policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function - to decide whether a particular op should be recomputed in backward pass or not. - In eager mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op is guaranteed to be recomputed. - In torch.compile mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op may or may not be recomputed - (it's up to the partitioner to decide). - + policy_fn_or_list (Callable or List): + - If a policy function is provided, it should accept a + :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and + kwargs to the op, and return a :class:`CheckpointPolicy` enum value + indicating whether the execution of the op should be recomputed or not. + - If a list of operations is provided, it is equivalent to a policy + returning `CheckpointPolicy.MUST_SAVE` for the specified + operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other + operations. + allow_cache_entry_mutation (bool, optional): By default, an error is + raised if any tensors cached by selective activation checkpoint are + mutated in order to ensure correctness. If set to `True`, this check + is disabled. Returns: - A pair of generated contexts. + A tuple of two context managers. Example: >>> # xdoctest: +REQUIRES(LINUX) + >>> import functools >>> - >>> def get_custom_policy(): - >>> no_recompute_list = [ - >>> torch.ops.aten.mm.default, - >>> ] - >>> def custom_policy(mode, func, *args, **kwargs): - >>> return func in no_recompute_list - >>> return custom_policy + >>> x = torch.rand(10, 10, requires_grad=True) + >>> y = torch.rand(10, 10, requires_grad=True) >>> - >>> def selective_checkpointing_context_fn(): - >>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy()) + >>> ops_to_save = [ + >>> torch.ops.aten.mm.default, + >>> ] >>> - >>> def gn(x, y): - >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> if op in ops_to_save: + >>> return CheckpointPolicy.MUST_SAVE + >>> else: + >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> - >>> def fn(x, y): - >>> return torch.utils.checkpoint.checkpoint( - >>> gn, x, y, - >>> use_reentrant=False, - >>> context_fn=selective_checkpointing_context_fn, - >>> ) + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> # or equivalently + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> - >>> x = torch.randn(4, 4, requires_grad=True) - >>> y = torch.randn(4, 4, requires_grad=True) + >>> def fn(x, y): + >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> - >>> compiled_fn = torch.compile(fn) + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) """ - storage: Dict[Any, List[Any]] = defaultdict(list) - return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage) + # NB: If grad_mode is disabled, checkpoint would not run forward under + # context_fn anyway, so proceed as usual. + if isinstance(policy_fn_or_list, list): + for op in policy_fn_or_list: + if not isinstance(op, torch._ops.OpOverload): + _extra_msg = ( + "Please update the OpOverloadPacket to a specific OpOverload." + "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." + ) if isinstance(op, torch._ops.OpOverloadPacket) else "" + raise ValueError( + f"Expected op in `op_list` to be an OpOverload but got: {op} " + f"of type {type(op)}. {_extra_msg}" + ) + def policy_fn(ctx, op, *args, **kwargs): + if op in policy_fn_or_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + elif callable(policy_fn_or_list): + policy_fn = policy_fn_or_list + else: + raise TypeError("policy_fn_or_list must be either a function or a list of ops.") + + storage: Dict[Any, List[Any]] = defaultdict(list) + return ( + _CachingTorchDispatchMode(policy_fn, storage), + _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), + ) # NB: this helper wraps fn before calling checkpoint_impl. kwargs and # saving/restoring of global state is handled here. From d77a1aaa8623ba5e70f4f147362d84769784cf43 Mon Sep 17 00:00:00 2001 From: loganthomas Date: Tue, 18 Jun 2024 18:26:07 +0000 Subject: [PATCH 46/63] DOC: add note about same sized tensors to dist.gather() (#128676) Fixes #103305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128676 Approved by: https://github.com/wconstab --- torch/distributed/distributed_c10d.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index bd81fd61b02f9..d44c3733a214e 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3041,11 +3041,12 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): """ Gathers tensors from the whole group in a list. - Complex tensors are supported. + Complex and uneven sized tensors are supported. Args: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. + Uneven sized tensors are supported. tensor (Tensor): Tensor to be broadcast from current process. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. @@ -3118,6 +3119,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal """ Gather tensors from all ranks and put them in a single output tensor. + This function requires all tensors to be the same size on each process. + Args: output_tensor (Tensor): Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the @@ -3341,11 +3344,13 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): """ Gathers a list of tensors in a single process. + This function requires all tensors to be the same size on each process. + Args: tensor (Tensor): Input tensor. - gather_list (list[Tensor], optional): List of appropriately-sized - tensors to use for gathered data (default is None, must be specified - on the destination rank) + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. From 1a527915a64b8e5f60951715b09fa294b1a8844f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 09:54:11 -0700 Subject: [PATCH 47/63] [DSD] Correctly handle shared parameters for optimizer state_dict (#128685) * Fixes https://github.com/pytorch/pytorch/issues/128011 See the discussion in https://github.com/pytorch/pytorch/pull/128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128685 Approved by: https://github.com/LucasLLC --- .../distributed/checkpoint/test_state_dict.py | 27 ++++++++++++ torch/distributed/checkpoint/state_dict.py | 42 ++++++++++++++++--- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 3da18ea5cc600..ac6263569af45 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -851,6 +851,33 @@ def test_deprecate_fsdp_api(self) -> None: ): get_model_state_dict(model) + @with_comms + @skip_if_lt_x_gpu(2) + def test_shared_weight(self): + class TiedEmbeddingModel(nn.Module): + def __init__(self, vocab_size, embedding_dim): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + self.decoder = nn.Linear(embedding_dim, vocab_size) + self.decoder.weight = self.embedding.weight # Tying weights + + def forward(self, input): + input = (input * 10).to(torch.int) + embedded = self.embedding(input) + output = self.decoder(embedded) + return output + + def init_model_optim(): + device_mesh = init_device_mesh("cuda", (self.world_size,)) + orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) + orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) + dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3) + return orig_model, orig_optim, copy_optim, dist_model, dist_optim + + self._test_save_load(init_model_optim) + class TestNoComm(MultiProcessTestCase): def setUp(self) -> None: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 16a1ddde21586..6bdeb389e8a0c 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -153,6 +153,9 @@ class _StateDictInfo(StateDictOptions): fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] ] = field(default_factory=dict) + shared_params_mapping: Dict[ + Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] + ] = field(default_factory=dict) submodule_prefixes: Set[str] = field(default_factory=set) handle_model: bool = True handle_optim: bool = True @@ -286,14 +289,29 @@ def _verify_options( fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[Set[str], torch.Tensor] ] = {} + shared_params_mapping: Dict[ + Union[str, torch.Tensor], Union[Set[str], torch.Tensor] + ] = {} for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + fqns = _get_fqns(model, name) - if not isinstance(param, _EXTRA_STATE): - fqn_param_mapping[param] = fqns + fqn = fqn_param_mapping.get(param, None) + if fqn is not None: + cast(Set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() for fqn in fqns: if not isinstance(param, _EXTRA_STATE): fqn_param_mapping[fqn] = param + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + submodule_prefixes: Set[str] = set() if submodules: submodules = set(submodules) @@ -361,6 +379,7 @@ def fsdp_state_dict_type_without_warning( return _StateDictInfo( **asdict(options), fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, submodule_prefixes=submodule_prefixes, fsdp_context=fsdp_context, fsdp_modules=cast(List[nn.Module], fsdp_modules), @@ -450,7 +469,7 @@ def _get_model_state_dict( for key in list(state_dict.keys()): fqns = _get_fqns(model, key) - assert len(fqns) == 1 + assert len(fqns) == 1, (key, fqns) fqn = next(iter(fqns)) if fqn != key: # As we only support FSDP, DDP, and TP, the only cases are @@ -797,6 +816,19 @@ def _split_optim_state_dict( pg_state.append({_PARAMS: []}) for param in param_group[_PARAMS]: for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(List[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + params = pg_state[-1][_PARAMS] assert isinstance(params, list) params.append(fqn) @@ -805,9 +837,7 @@ def _split_optim_state_dict( for loaded_param_group in cast( ListDictValueType, optim_state_dict[_PG] ): - params = loaded_param_group[_PARAMS] - assert isinstance(params, list) - if fqn in params: + if fqn in cast(List[str], loaded_param_group[_PARAMS]): pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 for param_group in cast(ListDictValueType, optim_state_dict[_PG]): From bdffd9f0c6f4564ee0cdd15d030215b5df58b2a9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 23:10:58 -0700 Subject: [PATCH 48/63] [export] Graph break on nn.Parameter construction (#128935) Fixes https://github.com/pytorch/pytorch/issues/126109 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128935 Approved by: https://github.com/angelayi --- torch/_dynamo/variables/torch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 74c2193646bc0..1cc4622dea529 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -877,6 +877,9 @@ def handle_ntuple(value): @classmethod def call_nn_parameter(cls, tx, data=None, requires_grad=True): """A call to torch.nn.Parameter() gets lifted to before the graph""" + if tx.export: + unimplemented("nn parameter construction not supported with export") + if isinstance(requires_grad, variables.VariableTracker): try: requires_grad = requires_grad.as_python_constant() From 44483972bdd3dcd0c047020694817210846b5d70 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Jun 2024 06:51:37 -0700 Subject: [PATCH 49/63] [EZ] Keep weight_norm var name aligned (#128955) To keep it aligned with https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/native_functions.yaml#L6484 I.e. `x`->`v`, `y`->`g` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128955 Approved by: https://github.com/albanD, https://github.com/Skylion007 --- torch/_decomp/decompositions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 7ebc69462fa1c..dca552137ca6d 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4770,11 +4770,11 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None): @register_decomposition(torch.ops.aten._weight_norm_interface) -def _weight_norm_interface(x, y, dim=0): +def _weight_norm_interface(v, g, dim=0): # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 - keep_dim = tuple(i for i in range(len(x.shape)) if i != dim) - norm = x.norm(2, keep_dim, keepdim=True) - return x * (y / norm), norm + keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) + norm = v.norm(2, keep_dim, keepdim=True) + return v * (g / norm), norm @register_decomposition(aten.isin) From 04a5d3228ecd5af790dabcfeb27c8c4f86742e11 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Jun 2024 19:11:04 +0000 Subject: [PATCH 50/63] [ts migration] Support prim::tolist and aten::len (#128894) Support prim::tolist and aten::len. Add unit tests for prim::min. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128894 Approved by: https://github.com/angelayi --- test/export/test_converter.py | 106 +++++++++++++++++++++++++++++++++- torch/_export/converter.py | 12 +++- 2 files changed, 116 insertions(+), 2 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 300f70223a26b..8ea6a8089ae8b 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -111,13 +111,102 @@ def forward(self, x): def test_aten_len(self): class Module(torch.nn.Module): - def forward(self, x): + def forward(self, x: torch.Tensor): length = len(x) return torch.ones(length) + # aten::len.Tensor inp = (torch.ones(2, 3),) self._check_equal_ts_ep_converter(Module(), inp) + class Module(torch.nn.Module): + def forward(self, x: List[int]): + length = len(x) + return torch.ones(length) + + # aten::len.t + inp = ([1, 2, 3],) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[int, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_int + inp = ({1: "a", 2: "b", 3: "c"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[bool, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_bool + inp = ({True: "a", False: "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[float, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_float + inp = ({1.2: "a", 3.4: "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[torch.Tensor, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_Tensor + inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + # aten::len.str and aten::len.Dict_str are not supported + # since torch._C._jit_flatten does not support str + # inp = ("abcdefg",) + # self._check_equal_ts_ep_converter(Module(), inp) + # inp = ({"a": 1, "b": 2},) + # self._check_equal_ts_ep_converter(Module(), inp) + + def test_prim_min(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x_len = len(x) + y_len = len(y) + + # prim::min.int + len_int = min(x_len, y_len) + + # prim::min.float + len_float = int(min(x_len * 2.0, y_len * 2.0)) + + # prim::min.self_int + len_self_int = min([x_len, y_len]) + + # prim::min.self_float + len_self_float = int(min([x_len * 2.0, y_len * 2.0])) + + # prim::min.float_int + len_float_int = int(min(x_len * 2.0, y_len)) + + # prim::min.int_float + len_int_float = int(min(x_len, y_len * 2.0)) + + return torch.ones( + len_int + + len_float + + len_self_int + + len_self_float + + len_float_int + + len_int_float + ) + + inp = (torch.randn(10, 2), torch.randn(5)) + self._check_equal_ts_ep_converter(Module(), inp) + def test_aten___getitem___list(self): class Module(torch.nn.Module): def forward(self, x): @@ -659,6 +748,21 @@ def forward(self, x): # inp = (torch.randn([2, 3, 4]),) # self._check_equal_ts_ep_converter(func6, inp) + def test_prim_tolist(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> List[int]: + return x.tolist() + + inp = (torch.tensor([1, 2, 3]),) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> List[List[int]]: + return x.tolist() + + inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 2c54db38dee8b..48f983b2917ef 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -91,6 +91,7 @@ def get_dtype_as_int(tensor): "aten::__not__": operator.not_, "aten::__contains__": operator.contains, "prim::dtype": get_dtype_as_int, + "aten::len": len, } @@ -187,7 +188,7 @@ def _map_blocks_to_lifted_attrs(entry): def get_op_overload(node: torch._C.Node): schema_str = node.schema() - schema = torch._C.parse_schema(schema_str) + schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str) ns, op_name = str(schema.name).split("::") override = schema.overload_name @@ -651,6 +652,15 @@ def convert_profiler__record_function_exit(self, node: torch._C.Node): args = tuple(self.get_fx_value(input) for input in node.inputs()) self.fx_graph.call_function(target, args) + def convert_prim_tolist(self, node: torch._C.Node): + # prim::tolist cannot be supported by `_convert_standard_operators` + # since it requires call_method instead of call_function. + target = "tolist" + args = (self.get_fx_value(next(node.inputs())),) + fx_node = self.fx_graph.call_method(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + def _convert_standard_operators(self, node: torch._C.Node): target = kind_to_standard_operators[node.kind()] args = tuple(self.get_fx_value(input) for input in node.inputs()) From abde6cab4c7f972672ae008223000c16fd3964cd Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 12 Jun 2024 19:33:15 -0700 Subject: [PATCH 51/63] Remove compile_threads=1 in test_inductor_collectives.py (#128580) Summary: I believe https://github.com/pytorch/pytorch/issues/125235 should be fixed after switching to subprocess-based parallel compile. Test Plan: Ran locally with python-3.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128580 Approved by: https://github.com/eellison --- test/distributed/test_inductor_collectives.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 35e44b19bedd5..ee4535fd5a73f 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -60,8 +60,6 @@ def world_size(self) -> int: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_broadcast_inductor(self): """ Testing if broadcast works correctly when using inductor @@ -94,8 +92,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_inductor(self): """ This is matmul/cat/allreduce is a pattern we aim to optimize. @@ -129,8 +125,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_inductor_cudagraph_trees(self): """ Tests whether cudagraph trees support all_reduce from nccl @@ -177,8 +171,6 @@ def test_c10d_functional_tagged_pt2_compliant(self): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_eager_allreduce_inductor_wait(self): def eager_func(a, b, c, d, *, tag, ranks, group_size): x = torch.matmul(a, b) @@ -218,8 +210,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_inductor_allreduce_eager_wait(self): def inductor_func(a, b, c, d, *, tag, ranks, group_size): x = torch.matmul(a, b) @@ -256,8 +246,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_input_buffer_reuse(self): def func(a, *, tag, ranks, group_size): ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) @@ -275,8 +263,6 @@ def func(a, *, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_permute_tensor(self): def func(tensor, src_dst_pairs, *, tag, ranks, group_size): return _functional_collectives.permute_tensor( @@ -304,8 +290,6 @@ def func(tensor, src_dst_pairs, *, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_output_buffer_reuse(self): class Model(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -329,8 +313,6 @@ def forward(self, x, world_size, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_contiguous_input(self): class Model(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -355,8 +337,6 @@ def forward(self, x, world_size, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_into_tensor_inductor(self): """ This is matmul/cat/allreduce is a pattern we aim to optimize. @@ -388,8 +368,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_reduce_scatter_tensor_inductor(self): def example(a, b, *, tag, ranks, group_size): c = torch.matmul(a, b) @@ -418,8 +396,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_all_to_all_single_inductor(self): def example( inp, @@ -488,8 +464,6 @@ def example( @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_all_to_all_single_inductor_split_sizes_none(self): def example(inp, *, tag, ranks, group_size): a2a = torch.ops.c10d_functional.all_to_all_single( From fe8558b7aa4ce55d06893c48d5cb00b7a7eb7dae Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 10:37:13 -0700 Subject: [PATCH 52/63] [DSD] Add unittest to verify HSDP1 + broadcast_from_rank0 (#128755) HSDP1 + broadcast_from_rank0 actually behaves differently from FSDP1 + broadcast_from_rank0. So we need an unittest to cover this use case. This test relies on the fix from https://github.com/pytorch/pytorch/pull/128446. Differential Revision: [D58621436](https://our.internmc.facebook.com/intern/diff/D58621436/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128755 Approved by: https://github.com/Skylion007, https://github.com/wz337 ghstack dependencies: #128685 --- .../distributed/checkpoint/test_state_dict.py | 157 ++++++++++-------- 1 file changed, 87 insertions(+), 70 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index ac6263569af45..7736350628802 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -33,7 +33,11 @@ set_optimizer_state_dict, StateDictOptions, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, + StateDictType, +) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.optim import _apply_optimizer_in_backward from torch.nn.parallel import DistributedDataParallel as DDP @@ -70,7 +74,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): @property def world_size(self) -> int: - return 2 + return min(4, torch.cuda.device_count()) def _test_save_load( self, @@ -567,55 +571,71 @@ def test_non_persistent_buffers(self) -> None: set_model_state_dict(ddp_model, get_model_state_dict(ddp_model)) self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model)) - @with_comms - @skip_if_lt_x_gpu(2) - def test_broadcast_from_rank0(self) -> None: - def inner_test(wrapper): - model = CompositeParamModel(device=torch.device("cuda")) - optim = torch.optim.Adam(model.parameters()) - fsdp_model = wrapper(copy.deepcopy(model)) - fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) + def _test_broadcast_from_rank0(self, wrapper) -> None: + model = CompositeParamModel(device=torch.device("cuda")) + optim = torch.optim.Adam(model.parameters()) + fsdp_model = wrapper(copy.deepcopy(model)) + fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) - batch = torch.rand(8, 100, device="cuda") - model(batch).sum().backward() - optim.step() - states, optim_states = get_state_dict(model, optim) + batch = torch.rand(8, 100, device="cuda") + model(batch).sum().backward() + optim.step() + states, optim_states = get_state_dict(model, optim) - fsdp_model(batch).sum().backward() - fsdp_optim.step() + fsdp_model(batch).sum().backward() + fsdp_optim.step() - def check(equal): - fsdp_states = get_model_state_dict( - fsdp_model, - options=StateDictOptions(full_state_dict=True), - ) - fsdp_optim_states = get_optimizer_state_dict( - fsdp_model, - fsdp_optim, - options=StateDictOptions(full_state_dict=True), - ) - if equal: - self.assertEqual(states, fsdp_states) - self.assertEqual(optim_states, fsdp_optim_states) - else: - self.assertNotEqual(states, fsdp_states) - self.assertNotEqual(optim_states, fsdp_optim_states) - - check(equal=True) - fsdp_model(batch).sum().backward() - fsdp_optim.step() - check(equal=False) - - # Drop the states to simulate loading from rank0 - if dist.get_rank() > 0: - load_states = {} - load_states2 = {} - load_optim_states = {} + def check(equal): + fsdp_states = get_model_state_dict( + fsdp_model, + options=StateDictOptions(full_state_dict=True), + ) + fsdp_optim_states = get_optimizer_state_dict( + fsdp_model, + fsdp_optim, + options=StateDictOptions(full_state_dict=True), + ) + if equal: + self.assertEqual(states, fsdp_states) + self.assertEqual(optim_states, fsdp_optim_states) else: - load_states = copy.deepcopy(states) - load_states2 = copy.deepcopy(states) - load_optim_states = copy.deepcopy(optim_states) + self.assertNotEqual(states, fsdp_states) + self.assertNotEqual(optim_states, fsdp_optim_states) + + check(equal=True) + fsdp_model(batch).sum().backward() + fsdp_optim.step() + check(equal=False) + + # Drop the states to simulate loading from rank0 + if dist.get_rank() > 0: + load_states = {} + load_states2 = {} + load_optim_states = {} + else: + load_states = copy.deepcopy(states) + load_states2 = copy.deepcopy(states) + load_optim_states = copy.deepcopy(optim_states) + set_model_state_dict( + fsdp_model, + model_state_dict=load_states, + options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), + ) + set_optimizer_state_dict( + fsdp_model, + fsdp_optim, + optim_state_dict=load_optim_states, + options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), + ) + + check(equal=True) + # Verify the `strict` flag. + load_states = load_states2 + if load_states: + key = next(iter(load_states.keys())) + load_states.pop(key) + with self.assertRaisesRegex(RuntimeError, "Missing key"): set_model_state_dict( fsdp_model, model_state_dict=load_states, @@ -623,30 +643,10 @@ def check(equal): broadcast_from_rank0=True, full_state_dict=True ), ) - set_optimizer_state_dict( - fsdp_model, - fsdp_optim, - optim_state_dict=load_optim_states, - options=StateDictOptions( - broadcast_from_rank0=True, full_state_dict=True - ), - ) - - check(equal=True) - # Verify the `strict` flag. - load_states = load_states2 - if load_states: - key = next(iter(load_states.keys())) - load_states.pop(key) - with self.assertRaisesRegex(RuntimeError, "Missing key"): - set_model_state_dict( - fsdp_model, - model_state_dict=load_states, - options=StateDictOptions( - broadcast_from_rank0=True, full_state_dict=True - ), - ) + @with_comms + @skip_if_lt_x_gpu(2) + def test_broadcast_from_rank0(self) -> None: device_mesh = init_device_mesh("cuda", (self.world_size,)) self.run_subtests( { @@ -655,7 +655,24 @@ def check(equal): functools.partial(FSDP, device_mesh=device_mesh), ] }, - inner_test, + self._test_broadcast_from_rank0, + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_broadcast_from_rank0_hsdp(self) -> None: + device_mesh = init_device_mesh("cuda", (2, self.world_size // 2)) + self.run_subtests( + { + "wrapper": [ + functools.partial( + FSDP, + device_mesh=device_mesh, + sharding_strategy=ShardingStrategy.HYBRID_SHARD, + ), + ] + }, + self._test_broadcast_from_rank0, ) @with_comms From 9a7e2519d3d15f8d469b71cab914fcdaf071ebd6 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 18 Jun 2024 19:59:50 +0000 Subject: [PATCH 53/63] [MPS] Fused Adam & AdamW (#127242) Summary: This PR adds fused Adam and AdamW implementations. Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory: **Fast math enabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300 Times are in milliseconds (ms). ``` **Fast math disabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200 Times are in milliseconds (ms). ``` ```python def profile_fused_adam(): from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device = "mps" results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)] if adamWflag: fn = adamw.adamw else: fn = adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label='Fused Adam', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242 Approved by: https://github.com/kulinseth, https://github.com/janeyx99 --- aten/src/ATen/native/mps/OperationUtils.h | 19 +- aten/src/ATen/native/mps/OperationUtils.mm | 31 +- .../operations/FusedAdamAmsgradKernelImpl.h | 24 ++ .../operations/FusedAdamAmsgradKernelImpl.mm | 37 +++ .../native/mps/operations/FusedAdamKernel.mm | 69 +++++ .../mps/operations/FusedAdamKernelImpl.h | 23 ++ .../mps/operations/FusedAdamKernelImpl.mm | 35 +++ .../operations/FusedAdamWAmsgradKernelImpl.h | 24 ++ .../operations/FusedAdamWAmsgradKernelImpl.mm | 37 +++ .../native/mps/operations/FusedAdamWKernel.mm | 68 +++++ .../mps/operations/FusedAdamWKernelImpl.h | 23 ++ .../mps/operations/FusedAdamWKernelImpl.mm | 35 +++ .../native/mps/operations/FusedOptimizerOps.h | 274 ++++++++++++++++++ .../native/mps/operations/MultiTensorApply.h | 190 ++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_mps.py | 34 +-- test/test_optim.py | 31 +- torch/optim/adam.py | 6 + torch/optim/adamw.py | 6 + torch/testing/_internal/common_optimizers.py | 4 +- torch/utils/_foreach_utils.py | 2 +- 21 files changed, 911 insertions(+), 63 deletions(-) create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernel.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedOptimizerOps.h create mode 100644 aten/src/ATen/native/mps/operations/MultiTensorApply.h diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 25e86e6d262f9..a9493cbce3ada 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -336,25 +336,34 @@ inline bool is_dense_in_storage(const at::Tensor& t) { class MetalShaderLibrary { public: - MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {} + MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {} MetalShaderLibrary(const MetalShaderLibrary&) = delete; inline id getPipelineStateForFunc(const std::string& fname) { - return getLibraryPipelineState(getLibrary(), fname); + return getLibraryPipelineState(getLibrary(), fname).first; } id getPipelineStateForFunc(const std::string& fname, const std::initializer_list& params) { - return getLibraryPipelineState(getLibrary(params), fname); + return getLibraryPipelineState(getLibrary(params), fname).first; + } + inline id getMTLFunction(const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).second; + } + id getMTLFunction(const std::string& fname, const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).second; } private: - id getLibraryPipelineState(id lib, const std::string& fname); + std::pair, id> getLibraryPipelineState(id lib, const std::string& fname); id getLibrary(); id getLibrary(const std::initializer_list& params); id compileLibrary(const std::string& src); std::string shaderSource; unsigned nparams; + MTLCompileOptions* compile_options; id library = nil; std::unordered_map> libMap; - std::unordered_map> cplMap; + std::unordered_map, id>> cplMap; }; static inline void mtl_setBuffer(id encoder, const Tensor& t, unsigned idx) { diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 82d1fe9d92f48..8dc90e497fe4e 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -656,31 +656,38 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} id MetalShaderLibrary::compileLibrary(const std::string& src) { NSError* error = nil; - MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 - : MTLLanguageVersion2_3]; - // [options setFastMathEnabled: NO]; - auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; + MTLCompileOptions* options = compile_options; + if (!options) { + options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 + : MTLLanguageVersion2_3]; + [options setFastMathEnabled:NO]; + } + + const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; auto device = MPSDevice::getInstance()->device(); library = [device newLibraryWithSource:str options:options error:&error]; TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]); return library; } -id MetalShaderLibrary::getLibraryPipelineState(id lib, const std::string& fname) { - auto key = fmt::format("{}:{}", reinterpret_cast(lib), fname); - auto cpl = cplMap[key]; - if (cpl) { - return cpl; +std::pair, id> MetalShaderLibrary::getLibraryPipelineState( + id lib, + const std::string& fname) { + const auto key = fmt::format("{}:{}", reinterpret_cast(lib), fname); + auto found_cpl = cplMap.find(key); + if (found_cpl != cplMap.end()) { + return found_cpl->second; } NSError* error = nil; id func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]]; TORCH_CHECK(func, "Failed to create function state object for: ", fname); - cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error]; + auto cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error]; TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); - return cplMap[key] = cpl; + cplMap[key] = std::make_pair(cpl, func); + return cplMap[key]; } } // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h new file mode 100644 index 0000000000000..8711cb228ee9f --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adam_amsgrad_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm new file mode 100644 index 0000000000000..be6069ad9694b --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adam_amsgrad_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{ + params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adam_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<5, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm b/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm new file mode 100644 index 0000000000000..2e4d89ff851c3 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm @@ -0,0 +1,69 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at::native { + +void _fused_adam_kernel_mps_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (amsgrad) { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adam_amsgrad_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } else { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), + "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adam_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } +} + +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h new file mode 100644 index 0000000000000..90d1ee1509323 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adam_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm new file mode 100644 index 0000000000000..e3c87ae9bc787 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm @@ -0,0 +1,35 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adam_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adam_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<4, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h new file mode 100644 index 0000000000000..f03fcdb574139 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_amsgrad_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm new file mode 100644 index 0000000000000..fd94e9686fbce --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_amsgrad_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{ + params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<5, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm b/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm new file mode 100644 index 0000000000000..ce08972ef9adf --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm @@ -0,0 +1,68 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at::native { + +void _fused_adamw_kernel_mps_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (amsgrad) { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adamw_amsgrad_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } else { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), + "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adamw_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } +} +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h new file mode 100644 index 0000000000000..284516e0b89ce --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm new file mode 100644 index 0000000000000..8899f6a5e9e13 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm @@ -0,0 +1,35 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adamw_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<4, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h b/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h new file mode 100644 index 0000000000000..00a75067b7f4b --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h @@ -0,0 +1,274 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +static const char* FUSED_ADAM_OPS = R"METAL( +#include + +#define kmaxThreadGroups 32 +#define kmaxTensors 32 +#define chunk_size 65536 + +constexpr constant uint kParamIdx = 0; +constexpr constant uint kGradIdx = kParamIdx + kmaxTensors; +constexpr constant uint kExpAvgIdx = kGradIdx + kmaxTensors; +constexpr constant uint kExpAvgSqIdx = kExpAvgIdx + kmaxTensors; +constexpr constant uint kMaxExpAvgSqIdx = kExpAvgSqIdx + kmaxTensors; +constexpr constant uint kStateStepsIdx = kExpAvgSqIdx + kmaxTensors; +constexpr constant uint kStateStepsIdxForAmsgrad = kMaxExpAvgSqIdx + kmaxTensors; + +template +struct AdamArguments { + metal::array params [[ id(kParamIdx) ]]; + metal::array grads [[ id(kGradIdx) ]]; + metal::array exp_avgs [[ id(kExpAvgIdx) ]]; + metal::array exp_avg_sqs [[ id(kExpAvgSqIdx) ]]; + metal::array state_steps [[ id(kStateStepsIdx) ]]; +}; + +template +struct AdamAmsgradArguments { + metal::array params [[ id(kParamIdx) ]]; + metal::array grads [[ id(kGradIdx) ]]; + metal::array exp_avgs [[ id(kExpAvgIdx) ]]; + metal::array exp_avg_sqs [[ id(kExpAvgSqIdx) ]]; + metal::array max_exp_avg_sqs [[ id(kMaxExpAvgSqIdx) ]]; + metal::array state_steps [[ id(kStateStepsIdxForAmsgrad) ]]; +}; + +struct MetadataArguments { + uint32_t numels[kmaxTensors]; + uint32_t threadgroup_to_tensor[kmaxThreadGroups]; + uint32_t threadgroup_to_chunk[kmaxThreadGroups]; +}; + +enum ADAM_MODE : uint8_t { + ORIGINAL = 0, + ADAMW = 1 +}; + +template +inline void adam_math_amsgrad( + device T & param, + device T & grad, + device T & exp_avg, + device T & exp_avg_sq, + device T & max_exp_avg_sq, + device state_steps_t & state_steps, + const float lr, + const float beta1, + const float beta2, + const float weight_decay, + const float eps, + const uint8_t maximize +) { + T grad_ = grad; + + if (maximize) { + grad = -grad; + } + + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + switch (adam_mode) { + case ADAM_MODE::ORIGINAL: + grad += param * weight_decay; + break; + case ADAM_MODE::ADAMW: + param -= lr * weight_decay * param; + break; + } + } + + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const float casted_state_steps = static_cast(state_steps); + const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); + const T step_size = lr / bias_correction1; + const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); + const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq); + + const T denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps; + param -= step_size * exp_avg / denom; + grad = grad_; +} + +template +inline void adam_math( + device T & param, + device T & grad, + device T & exp_avg, + device T & exp_avg_sq, + device state_steps_t & state_steps, + const float lr, + const float beta1, + const float beta2, + const float weight_decay, + const float eps, + const uint8_t maximize +) { + T grad_ = grad; + + if (maximize) { + grad = -grad; + } + + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + switch (adam_mode) { + case ADAM_MODE::ORIGINAL: + grad += param * weight_decay; + break; + case ADAM_MODE::ADAMW: + param -= lr * weight_decay * param; + break; + } + } + + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const float casted_state_steps = static_cast(state_steps); + const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); + const T step_size = lr / bias_correction1; + const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); + const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + const T denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps; + param -= step_size * exp_avg / denom; + grad = grad_; +} + +template +kernel void fused_adam_amsgrad( + device AdamAmsgradArguments & args [[buffer(0)]], + constant MetadataArguments & metadata_args [[buffer(1)]], + constant float & lr [[buffer(2)]], + constant float & beta1 [[buffer(3)]], + constant float & beta2 [[buffer(4)]], + constant float & weight_decay [[buffer(5)]], + constant float & eps [[buffer(6)]], + constant uint8_t & maximize [[buffer(7)]], + uint tid [[thread_position_in_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint tptg [[threads_per_threadgroup]]) { + + const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid]; + const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid]; + const uint32_t chunk_offset = chunk_idx * chunk_size; + const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset; + + const auto step_count = args.state_steps[tensor_loc]; + + // each chunk is a threadgroup + auto param = args.params[tensor_loc] + chunk_offset; + auto grad = args.grads[tensor_loc] + chunk_offset; + auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset; + auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset; + auto max_exp_avg_sq = args.max_exp_avg_sqs[tensor_loc] + chunk_offset; + + for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) { + adam_math_amsgrad( + *(param + i_start), + *(grad + i_start), + *(exp_avg + i_start), + *(exp_avg_sq + i_start), + *(max_exp_avg_sq + i_start), + *step_count, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize + ); + } +} + +template +kernel void fused_adam( + device AdamArguments & args [[buffer(0)]], + constant MetadataArguments & metadata_args [[buffer(1)]], + constant float & lr [[buffer(2)]], + constant float & beta1 [[buffer(3)]], + constant float & beta2 [[buffer(4)]], + constant float & weight_decay [[buffer(5)]], + constant float & eps [[buffer(6)]], + constant uint8_t & maximize [[buffer(7)]], + uint tid [[thread_position_in_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint tptg [[threads_per_threadgroup]]) { + + const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid]; + const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid]; + const uint32_t chunk_offset = chunk_idx * chunk_size; + const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset; + + const auto step_count = args.state_steps[tensor_loc]; + + // each chunk is a threadgroup + auto param = args.params[tensor_loc] + chunk_offset; + auto grad = args.grads[tensor_loc] + chunk_offset; + auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset; + auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset; + + for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) { + adam_math( + *(param + i_start), + *(grad + i_start), + *(exp_avg + i_start), + *(exp_avg_sq + i_start), + *step_count, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize + ); + } +} + +#define REGISTER_FUSED_ADAM_OP(DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE, HOST_NAME, KERNEL_NAME, ARGUMENTS_STRUCT) \ +template \ +[[host_name(#HOST_NAME "_" #DTYPE "_" #STATE_STEPS_DTYPE)]] \ +kernel void KERNEL_NAME( \ + device ARGUMENTS_STRUCT & args [[buffer(0)]],\ + constant MetadataArguments & metadata_args [[buffer(1)]],\ + constant float & lr [[buffer(2)]],\ + constant float & beta1 [[buffer(3)]],\ + constant float & beta2 [[buffer(4)]],\ + constant float & weight_decay [[buffer(5)]],\ + constant float & eps [[buffer(6)]],\ + constant uint8_t & maximize [[buffer(7)]],\ + uint tid [[thread_position_in_threadgroup]],\ + uint tgid [[threadgroup_position_in_grid]],\ + uint tptg [[threads_per_threadgroup]]) + +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); + +)METAL"; + +static std::pair, id> getCPLState(const std::string& fname) { + static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0); + return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname)); +} + +} //namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/MultiTensorApply.h b/aten/src/ATen/native/mps/operations/MultiTensorApply.h new file mode 100644 index 0000000000000..fe9296cc0db79 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/MultiTensorApply.h @@ -0,0 +1,190 @@ +#pragma once +#include +#include +#include + +namespace at::native { +namespace mps { + +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kmaxThreadGroups = 32; +static constexpr int64_t kmaxTensors = 32; + +struct MetadataArguments { // the size of this struct must be less than 4 bytes + uint numels[kmaxTensors]; + uint threadgroup_to_tensor[kmaxThreadGroups]; + uint threadgroup_to_chunk[kmaxThreadGroups]; +}; + +template +static void multi_tensor_apply_for_fused_adam( + const std::string& kernel_name, + std::vector>& tensor_lists, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize + ) { + const auto num_tensors = tensor_lists[0].size(); + + if (num_tensors == 0) { + return; + } + + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth"); + for (const auto& d : c10::irange(depth)) { + TORCH_CHECK( + tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported"); + } + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + float lr_lv = lr; + float beta1_lv = beta1; + float beta2_lv = beta2; + float weight_decay_lv = weight_decay; + float eps_lv = eps; + uint8_t maximize_lv = maximize; + + // Remove comment for debugging + /* + mpsStream->addCompletedHandler(^(id cb) { + [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) { + NSLog(@"MPSStream: %@", log); + } + ]; + }); + */ + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [fusedOptimizerPSO, fusedOptimizerFunc] = getCPLState(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]}); + + [computeEncoder setComputePipelineState:fusedOptimizerPSO]; + + // BufferIndex is the index in the kernel function + auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease]; + id tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int64_t tensor_loc = 0; + int64_t threadgroup_loc = 0; + MetadataArguments metadata_arguments; + + for (const auto tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + + for (const auto& d : c10::irange(depth)) { + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index]) + offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size() + atIndex:d * kmaxTensors + tensor_loc]; + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index]) + offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size() + atIndex:depth * kmaxTensors + tensor_loc]; + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel(); + + tensor_loc++; + + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + + for (const auto& chunk : c10::irange(chunks)) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + + threadgroup_loc++; + + const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1; + // Reach the maximum threadgroups per dispatch + const auto blocks_full = threadgroup_loc == kmaxThreadGroups; + + if (tensor_full || blocks_full){ + [computeEncoder setBuffer:tensorArgumentBuffer + offset:0 + atIndex:0]; + [computeEncoder setBytes:&metadata_arguments + length:sizeof(MetadataArguments) + atIndex:1]; + [computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2]; + [computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3]; + [computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4]; + [computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5]; + [computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6]; + [computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7]; + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // Reset + threadgroup_loc = 0; + if (chunk == chunks - 1) { + // last chunk + tensor_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + } else { + // reuse the current tensor since the current one isn't done. + metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1]; + + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + for (const auto& d : c10::irange(depth)) { + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index]) + offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size() + atIndex:d * kmaxTensors + 0]; + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead]; + } + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index]) + offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size() + atIndex:depth * kmaxTensors + 0]; + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + + tensor_loc = 1; + } + } + } + } + + if (threadgroup_loc != 0) { + + [computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0]; + [computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1]; + [computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2]; + [computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3]; + [computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4]; + [computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5]; + [computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6]; + [computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7]; + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + + getMPSProfiler().endProfileKernel(fusedOptimizerPSO); + + } + }); +} + +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7474e0bc55d8b..b030141882c86 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15575,6 +15575,7 @@ dispatch: CPU: _fused_adam_kernel_cpu_ CUDA: _fused_adam_kernel_cuda_ + MPS: _fused_adam_kernel_mps_ autogen: _fused_adam, _fused_adam.out - func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15593,6 +15594,7 @@ dispatch: CPU: _fused_adamw_kernel_cpu_ CUDA: _fused_adamw_kernel_cuda_ + MPS: _fused_adamw_kernel_mps_ autogen: _fused_adamw, _fused_adamw.out - func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () diff --git a/test/test_mps.py b/test/test_mps.py index 311cf8245c4f3..a97b8fb8d6b13 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -76,7 +76,6 @@ def mps_ops_grad_modifier(ops): XFAILLIST_GRAD = { # precision issues - 'digamma': [torch.float32], 'special.polygammaspecial_polygamma_n_0': [torch.float16], 'polygammapolygamma_n_0': [torch.float16], 'nn.functional.binary_cross_entropy': [torch.float16], @@ -95,7 +94,6 @@ def mps_ops_grad_modifier(ops): 'masked.scatter': [torch.float16, torch.float32], 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`. 'aminmax': [torch.float32, torch.float16], - 'polar': [torch.float32], # Correctness issues 'atanh': [torch.float32], @@ -569,7 +567,6 @@ def mps_ops_modifier(ops): 'special.ndtr': [torch.uint8], 'sqrt': [torch.uint8], 'sub': [torch.uint8], - 'tanh': [torch.uint8], 'trapezoid': [torch.uint8], 'trapz': [torch.uint8], 'true_divide': [torch.uint8], @@ -586,28 +583,13 @@ def mps_ops_modifier(ops): 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # cpu not giving nan for x/0.0 - 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # inconsistency errors between cpu and mps, max seen atol is 2 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_BEFORE_13_3_XFAILLIST = { - # Failure due to precision issues (still present on 13.3+) as well as non-standard behavior of - # cpu ops for the negative integers. - # Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)): - # - CPU output: tensor([102.668, 1.129e+15]) - # - MPS output: tensor([102.6681, inf]) - # In the latter case, inf is probably correct (this is what scipy does). - 'polygamma': [torch.float32, torch.uint8], - 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8], - 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8], - 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], - # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 'tan': [torch.float32], 'cdist': [torch.float32], @@ -656,20 +638,6 @@ def mps_ops_modifier(ops): # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], - - # Failure due to precision issues as well as non-standard behavior of cpu ops for the - # negative integers. Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)): - # - CPU output: tensor([102.668, 1.129e+15]) - # - MPS output: tensor([102.6681, inf]) - # In the latter case, inf is probably correct (this is what scipy does). - 'polygamma': [torch.float32, torch.uint8], - 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8], - 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8], - 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], } MACOS_BEFORE_14_4_XFAILLIST = { diff --git a/test/test_optim.py b/test/test_optim.py index d61c33e2adcea..fb655ce36a533 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -32,6 +32,7 @@ ) from torch.testing._internal.common_dtype import floating_types_and from torch.testing._internal.common_optimizers import ( + _get_device_type, _get_optim_inputs_including_global_cliquey_kwargs, optim_db, OptimizerErrorEnum, @@ -1004,7 +1005,6 @@ def test_peak_memory_foreach(self, device, dtype, optim_info): self.assertLessEqual(mt_max_mem, expected_max_mem) - @onlyNativeDeviceTypes @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=floating_types_and( @@ -1013,10 +1013,15 @@ def test_peak_memory_foreach(self, device, dtype, optim_info): ), ) def test_fused_matches_forloop(self, device, dtype, optim_info): - if device not in optim_info.supports_fused_on: + if _get_device_type(device) not in optim_info.supports_fused_on: self.skipTest( f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" ) + if _get_device_type(device) == "mps" and dtype not in ( + torch.float16, + torch.float32, + ): + self.skipTest("MPS supports only torch.float16 and torch.float32") self._test_derived_optimizers(device, dtype, optim_info, "fused") @onlyNativeDeviceTypes @@ -1076,7 +1081,6 @@ def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info): ) self.assertEqual(params, params_c) - @onlyCUDA @parametrize("impl", ["fused", "capturable"]) @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], @@ -1100,8 +1104,15 @@ def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): ): # Capturable SGD/Adagrad does not exist self.skipTest("SGD does not currently support capturable") - if impl == "fused" and device not in optim_info.supports_fused_on: + if _get_device_type(device) == "cpu": + self.skipTest("Test is only for non-cpu devices") + elif ( + impl == "fused" + and _get_device_type(device) not in optim_info.supports_fused_on + ): self.skipTest(f"{device} is not supported for fused on {opt_name}") + elif impl == "capturable" and _get_device_type(device) == "mps": + self.skipTest("MPS does not support capturable") cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu") for optim_input in cpu_optim_inputs: @@ -1114,12 +1125,12 @@ def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): # load optim_input.kwargs[impl] = True - param_cuda = param.clone().detach().to(device="cuda") - optimizer_cuda = optim_cls([param_cuda], **optim_input.kwargs) - optimizer_cuda.load_state_dict(optim_state_dict_cpu) - optimizer_cuda.zero_grad() - param_cuda.grad = torch.rand_like(param_cuda) - optimizer_cuda.step() + param_device = param.clone().detach().to(device=device) + optimizer_device = optim_cls([param_device], **optim_input.kwargs) + optimizer_device.load_state_dict(optim_state_dict_cpu) + optimizer_device.zero_grad() + param_device.grad = torch.rand_like(param_device) + optimizer_device.step() @optims(optim_db, dtypes=[torch.float32]) def test_param_groups_weight_decay(self, device, dtype, optim_info): diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 86785be4ed179..fa7397e02b424 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -309,6 +309,8 @@ def step(self, closure=None): {_capturable_doc} {_differentiable_doc} {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _On the Convergence of Adam and Beyond: @@ -660,6 +662,10 @@ def _fused_adam( ), _, ) in grouped_tensors.items(): + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + assert not isinstance(lr, Tensor) + device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 00931bed02272..20ab827552491 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -310,6 +310,8 @@ def step(self, closure=None): {_capturable_doc} {_differentiable_doc} {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: @@ -662,6 +664,10 @@ def _fused_adamw( ), _, ) in grouped_tensors.items(): + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + assert not isinstance(lr, Tensor) + device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 628bedad313dc..b7d06e7dc8083 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1232,7 +1232,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), - supports_fused_on=("cpu", "cuda"), + supports_fused_on=("cpu", "cuda", "mps"), decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1354,7 +1354,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamw, optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), - supports_fused_on=("cpu", "cuda"), + supports_fused_on=("cpu", "cuda", "mps"), decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index bcc274579ad01..c3100d41b6c0f 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -11,7 +11,7 @@ def _get_foreach_kernels_supported_devices() -> List[str]: def _get_fused_kernels_supported_devices() -> List[str]: r"""Return the device type list that supports fused kernels in optimizer.""" - return ["cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()] + return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()] TensorListList: TypeAlias = List[List[Optional[Tensor]]] Indices: TypeAlias = List[int] From 5bc9835d64eb5592cb606252ccf19212872cefc7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 54/63] Revert "[dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428)" This reverts commit c52eda896eb3ec7f8d04b6321861f4c5614a40bb. Reverted https://github.com/pytorch/pytorch/pull/128428 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- test/dynamo/test_repros.py | 2 +- torch/_dynamo/trace_rules.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 2329ab305e763..dbcb259241fcb 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1674,7 +1674,7 @@ def test_issue175(self): self.assertEqual(cnt.frame_count, 1) self.assertEqual( - 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count ) def test_exec_import(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index abbef02e63c68..b5b12435a931a 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2669,6 +2669,26 @@ "torch.nn._reduction.legacy_get_enum", "torch.nn._reduction.legacy_get_string", "torch.nn.factory_kwargs", + "torch.nn.functional._adaptive_max_pool1d", + "torch.nn.functional._adaptive_max_pool2d", + "torch.nn.functional._adaptive_max_pool3d", + "torch.nn.functional._canonical_mask", + "torch.nn.functional._fractional_max_pool2d", + "torch.nn.functional._fractional_max_pool3d", + "torch.nn.functional._get_softmax_dim", + "torch.nn.functional._in_projection_packed", + "torch.nn.functional._in_projection", + "torch.nn.functional._is_integer", + "torch.nn.functional._max_pool1d", + "torch.nn.functional._max_pool2d", + "torch.nn.functional._max_pool3d", + "torch.nn.functional._mha_shape_check", + "torch.nn.functional._no_grad_embedding_renorm_", + "torch.nn.functional._none_or_dtype", + "torch.nn.functional._threshold", + "torch.nn.functional._unpool_output_size", + "torch.nn.functional._verify_batch_size", + "torch.nn.functional._verify_spatial_size", "torch.nn.functional.adaptive_avg_pool2d", "torch.nn.functional.adaptive_avg_pool3d", "torch.nn.functional.adaptive_max_pool1d_with_indices", @@ -2766,7 +2786,15 @@ "torch.nn.grad.conv2d_weight", "torch.nn.grad.conv3d_input", "torch.nn.grad.conv3d_weight", + "torch.nn.modules.activation._arg_requires_grad", + "torch.nn.modules.activation._check_arg_device", "torch.nn.modules.activation._is_make_fx_tracing", + "torch.nn.modules.container._addindent", + "torch.nn.modules.transformer._detect_is_causal_mask", + "torch.nn.modules.transformer._generate_square_subsequent_mask", + "torch.nn.modules.transformer._get_activation_fn", + "torch.nn.modules.transformer._get_clones", + "torch.nn.modules.transformer._get_seq_len", "torch.nn.modules.utils._list_with_default", "torch.nn.modules.utils._ntuple", "torch.nn.modules.utils._quadruple", From 1babeddbbf3a44318d13cf3b8afaac2a6d657115 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 55/63] Revert "[inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484)" This reverts commit 1f6e84fa6852805e15ddc9583c5f36c3a7f93df8. Reverted https://github.com/pytorch/pytorch/pull/128484 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- test/inductor/test_mkldnn_pattern_matcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index a80d723987602..810c22d037c54 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -37,8 +37,7 @@ torch.nn.Tanh(): 2, torch.nn.Hardswish(): 6, torch.nn.LeakyReLU(0.1, inplace=False): 4, - # Use floats for min/max, otherwise they can get converted to symints - torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3, + torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3, torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, torch.nn.GELU(approximate="none"): 6, torch.nn.GELU(approximate="tanh"): 10, From 44722c6b1085611e0f20917a76fcf3f8f2776e13 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 56/63] Revert "[dynamo][fsdp] Dont take unspecializedNNModuleVariable path for FSDP modules (#128453)" This reverts commit 2b28b107dbafeec18d1095a2002e79511aa241df. Reverted https://github.com/pytorch/pytorch/pull/128453 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- torch/_dynamo/variables/builder.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index af91edb432c88..8a201410d6be3 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1164,11 +1164,7 @@ def wrap_module(self, value: torch.nn.Module): and not config.allow_rnn ): unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") - - # Dont take this path for FSDP - if not getattr( - value, "_is_fsdp_managed_module", None - ) and mutation_guard.is_dynamic_nn_module(value, self.tx.export): + if mutation_guard.is_dynamic_nn_module(value, self.tx.export): # created dynamically, don't specialize on it self.install_guards(GuardBuilder.TYPE_MATCH) if ( From 5dc4f652bc5c068ef15130c955e3f2ffe11f4b74 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 18 Jun 2024 13:35:49 -0400 Subject: [PATCH 57/63] Backward support for unbind() with NJT (#128032) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032 Approved by: https://github.com/soulitzer --- test/test_nestedtensor.py | 19 +++++++++++++++++++ tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 17 +++++++++++++++++ torch/csrc/autograd/FunctionsManual.h | 4 ++++ torch/nested/_internal/ops.py | 11 +++++++++++ 5 files changed, 52 insertions(+), 1 deletion(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 50d6deea92911..fa33a13ed495d 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5610,6 +5610,25 @@ def f(nt): for dynamic in [False, True, None]: self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + @dtypes(torch.float32, torch.double, torch.half) + def test_unbind_backward(self, device, dtype): + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 4, device=device), + torch.randn(5, 4, device=device), + torch.randn(3, 4, device=device), + ], + layout=torch.jagged, + requires_grad=True, + ) + + a, b, c = nt.unbind() + b.sum().backward() + + expected_grad = torch.zeros_like(nt) + expected_grad.unbind()[1].add_(1.0) + torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 76a7a0a1e42a4..02a3e6c518ad8 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2847,7 +2847,7 @@ self: unbind_backward(grads, dim) result: auto_linear AutogradNestedTensor: - self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options()) + self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" result: auto_linear - name: stack(Tensor[] tensors, int dim=0) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 9d897c667c906..f51c2f047f935 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1014,6 +1014,23 @@ Tensor unbind_backward_nested( return at::_nested_tensor_from_tensor_list(grads_tensors); } +Tensor unbind_backward_nested_jagged( + const variable_list& grads, + const Tensor& self, + int64_t dim) { + TORCH_INTERNAL_ASSERT( + dim == 0, "unbind_backward_nested_jagged() only supports dim=0") + auto grad_nt = at::zeros_like(self); + auto unbound_grads = grad_nt.unbind(); + for (int64_t i : c10::irange(static_cast(grads.size()))) { + if (grads[i].defined()) { + unbound_grads[i].copy_(static_cast(grads[i])); + } + } + + return grad_nt; +} + Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { auto result = self; diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index dedff70be1ba3..ecf99bd098057 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -244,6 +244,10 @@ at::Tensor unbind_backward_nested( const Tensor& nt_sizes, int64_t dim, const at::TensorOptions& options); +at::Tensor unbind_backward_nested_jagged( + const variable_list& grads, + const Tensor& self, + int64_t dim); at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes); at::Tensor unsqueeze_to( const at::Tensor& self, diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 6f1c47dd69471..8458f03717130 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -472,6 +472,17 @@ def to_copy_default(func, *args, **kwargs): )(jagged_unary_pointwise) +@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") +def zero__default(func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + func(inp._values) + return inp + + @register_jagged_func( torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" ) From 4cc3fb5ee2296e1178cec710a945c99aa303170d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:38:22 -0700 Subject: [PATCH 58/63] Bump urllib3 from 2.2.1 to 2.2.2 in /tools/build/bazel (#128908) Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.2.1 to 2.2.2. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.2.1...2.2.2) --- updated-dependencies: - dependency-name: urllib3 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tools/build/bazel/requirements.txt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index cd95aeeec5c6f..fea6221c9b7ca 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -145,7 +145,7 @@ numpy==1.26.4 \ --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f - # via -r tools/build/bazel/requirements.in + # via -r requirements.in pyyaml==6.0.1 \ --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ @@ -198,26 +198,26 @@ pyyaml==6.0.1 \ --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f - # via -r tools/build/bazel/requirements.in + # via -r requirements.in requests==2.32.2 \ --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c - # via -r tools/build/bazel/requirements.in + # via -r requirements.in sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in typing-extensions==4.11.0 \ --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a - # via -r tools/build/bazel/requirements.in -urllib3==2.2.1 \ - --hash=sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d \ - --hash=sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19 + # via -r requirements.in +urllib3==2.2.2 \ + --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ + --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 # via requests # The following packages are considered to be unsafe in a requirements file: setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in From 2227da44317f4ea836aaad96337b53533aed2770 Mon Sep 17 00:00:00 2001 From: Aaron Enye Shi Date: Tue, 18 Jun 2024 21:01:01 +0000 Subject: [PATCH 59/63] [Profiler] Clean up use_mtia to follow standard use_device instead (#126284) Summary: use_mtia should instead set use_device='mtia' similar to cuda, xpu, and privateuseone. Avoid an ever-growing list of use_* arguments. Since use_mtia is specific to FBCode, we don't need a deprecation warning. Test Plan: CI. Differential Revision: D57338005 Pulled By: aaronenyeshi Pull Request resolved: https://github.com/pytorch/pytorch/pull/126284 Approved by: https://github.com/fenypatel99 --- torch/autograd/profiler.py | 13 +++++++------ torch/profiler/profiler.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 0392a87698463..f847fc13ff8ad 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -118,7 +118,7 @@ class profile: use_device (str, optional): Enables timing of device events. Adds approximately 4us of overhead to each tensor operation when use cuda. - The valid devices options are 'cuda', 'xpu' and 'privateuseone'. + The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'. record_shapes (bool, optional): If shapes recording is set, information about input dimensions will be collected. This allows one to see which @@ -205,7 +205,6 @@ def __init__( with_modules=False, use_kineto=False, use_cpu=True, - use_mtia=False, experimental_config=None, ): self.enabled: bool = enabled @@ -231,7 +230,6 @@ def __init__( self.with_stack = with_stack self.with_modules = with_modules self.use_cpu = use_cpu - self.use_mtia = use_mtia if experimental_config is None: experimental_config = _ExperimentalConfig() self.experimental_config = experimental_config @@ -246,7 +244,7 @@ def __init__( ), "Device-only events supported only with Kineto (use_kineto=True)" if self.use_device is not None: - VALID_DEVICE_OPTIONS = ["cuda", "xpu"] + VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"] if _get_privateuse1_backend_name() != "privateuseone": VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name()) if self.use_device not in VALID_DEVICE_OPTIONS: @@ -265,8 +263,6 @@ def __init__( self.kineto_activities = set() if self.use_cpu: self.kineto_activities.add(ProfilerActivity.CPU) - if self.use_mtia: - self.kineto_activities.add(ProfilerActivity.MTIA) self.profiler_kind = ProfilerState.KINETO if self.use_device == "cuda": @@ -280,6 +276,11 @@ def __init__( use_kineto and ProfilerActivity.XPU in _supported_activities() ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." self.kineto_activities.add(ProfilerActivity.XPU) + elif self.use_device == "mtia": + assert ( + use_kineto and ProfilerActivity.MTIA in _supported_activities() + ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." + self.kineto_activities.add(ProfilerActivity.MTIA) elif self.use_device is not None and self.use_device != "privateuseone": if ( not use_kineto diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f43dcc06de209..2fd3ab9be6b80 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -132,6 +132,8 @@ def __init__( self.use_device = "cuda" elif ProfilerActivity.XPU in self.activities: self.use_device = "xpu" + elif ProfilerActivity.MTIA in self.activities: + self.use_device = "mtia" elif ProfilerActivity.PrivateUse1 in self.activities: self.use_device = _get_privateuse1_backend_name() @@ -149,7 +151,6 @@ def prepare_trace(self): if self.profiler is None: self.profiler = prof.profile( use_cpu=(ProfilerActivity.CPU in self.activities), - use_mtia=(ProfilerActivity.MTIA in self.activities), use_device=self.use_device, record_shapes=self.record_shapes, with_flops=self.with_flops, From e47603a5495b33d59be0b770ac9b243877c993ad Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Jun 2024 06:51:41 -0700 Subject: [PATCH 60/63] Fix weight_norm decomposition behavior (#128956) By upcasting norm to float32 to align with CUDA and CPU behaviors https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/WeightNorm.cpp#L56-L59 Discovered this when started running OpInfo tests, see https://github.com/pytorch/pytorch/actions/runs/9552858711/job/26332062502#step:20:1060 ``` File "/var/lib/jenkins/workspace/test/test_decomp.py", line 185, in op_assert_ref assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" AssertionError: 1 Operation: aten._weight_norm_interface.default ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128956 Approved by: https://github.com/albanD ghstack dependencies: #128955 --- torch/_decomp/decompositions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index dca552137ca6d..42d1cb9a15270 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4773,8 +4773,10 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None): def _weight_norm_interface(v, g, dim=0): # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) - norm = v.norm(2, keep_dim, keepdim=True) - return v * (g / norm), norm + # align with cuda behavior, keep norm in 'float' when g is 'bfloat16' + norm_dtype = torch.float if g.dtype == torch.bfloat16 else None + norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype) + return v * (g / norm.to(g.dtype)), norm @register_decomposition(aten.isin) From cec31050b4609a4bbdcd332c823139666ad57224 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:43 +0800 Subject: [PATCH 61/63] [BE][Easy] enable UFMT for `torch/distributed/{tensor,_tensor}/` (#128868) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128868 Approved by: https://github.com/fegin --- .lintrunner.toml | 9 - .../distributed/_tensor/_collective_utils.py | 2 +- torch/distributed/_tensor/_dispatch.py | 2 +- torch/distributed/_tensor/_op_schema.py | 1 + torch/distributed/_tensor/_sharding_prop.py | 1 + torch/distributed/_tensor/_tp_conv.py | 1 + torch/distributed/_tensor/api.py | 1 - torch/distributed/_tensor/debug/__init__.py | 1 - .../distributed/_tensor/debug/_op_coverage.py | 1 - torch/distributed/_tensor/debug/comm_mode.py | 3 +- .../_tensor/debug/visualize_sharding.py | 1 - .../_tensor/examples/checkpoint_example.py | 2 - .../examples/comm_mode_features_example.py | 3 - .../examples/torchrec_sharding_example.py | 2 +- .../examples/visualize_sharding_example.py | 1 + .../_tensor/experimental/__init__.py | 1 + .../_tensor/experimental/attention.py | 1 + .../_tensor/experimental/local_map.py | 1 + torch/distributed/_tensor/ops/__init__.py | 8 +- .../distributed/_tensor/ops/basic_strategy.py | 2 - torch/distributed/_tensor/ops/conv_ops.py | 1 + .../distributed/_tensor/ops/embedding_ops.py | 3 +- .../_tensor/ops/experimental_ops.py | 12 +- torch/distributed/_tensor/ops/math_ops.py | 1 - torch/distributed/_tensor/ops/matrix_ops.py | 2 +- .../distributed/_tensor/ops/pointwise_ops.py | 2 - torch/distributed/_tensor/ops/random_ops.py | 1 + torch/distributed/_tensor/ops/tensor_ops.py | 1 - torch/distributed/_tensor/ops/view_ops.py | 3 +- torch/distributed/_tensor/placement_types.py | 1 - torch/distributed/_tensor/random.py | 1 - torch/distributed/tensor/parallel/__init__.py | 4 +- torch/distributed/tensor/parallel/_utils.py | 10 +- torch/distributed/tensor/parallel/api.py | 21 +- torch/distributed/tensor/parallel/ddp.py | 1 + torch/distributed/tensor/parallel/fsdp.py | 8 +- .../tensor/parallel/input_reshard.py | 13 +- torch/distributed/tensor/parallel/loss.py | 1 + torch/distributed/tensor/parallel/style.py | 226 ++++++++++++------ 39 files changed, 213 insertions(+), 143 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index a7bbdc884415e..e3f1b58027c3e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1443,15 +1443,6 @@ exclude_patterns = [ 'torch/distributed/rpc/rref_proxy.py', 'torch/distributed/rpc/server_process_global_profiler.py', 'torch/distributed/run.py', - 'torch/distributed/tensor/__init__.py', - 'torch/distributed/tensor/parallel/__init__.py', - 'torch/distributed/tensor/parallel/_utils.py', - 'torch/distributed/tensor/parallel/_view_with_dim_change.py', - 'torch/distributed/tensor/parallel/api.py', - 'torch/distributed/tensor/parallel/fsdp.py', - 'torch/distributed/tensor/parallel/input_reshard.py', - 'torch/distributed/tensor/parallel/multihead_attention_tp.py', - 'torch/distributed/tensor/parallel/style.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py index 4c1d18403666f..15644ac798731 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/_tensor/_collective_utils.py @@ -3,7 +3,6 @@ import math from dataclasses import dataclass from functools import lru_cache - from typing import List, Optional import torch @@ -21,6 +20,7 @@ Work, ) + logger = logging.getLogger(__name__) diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py index 1739243a5d3ba..a659c54a3d932 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/_tensor/_dispatch.py @@ -6,7 +6,6 @@ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import torch - import torch.distributed as dist import torch.distributed._tensor.api as dtensor import torch.distributed._tensor.random as random @@ -27,6 +26,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta from torch.distributed._tensor.random import is_rng_supported_mesh + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py index 071c2ac4748f1..6e6884f47306a 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/_tensor/_op_schema.py @@ -8,6 +8,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.device_mesh import DeviceMesh + try: from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec except ImportError: diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index 449cf6c23775a..8f1cabeb0c43c 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -25,6 +25,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/_tp_conv.py b/torch/distributed/_tensor/_tp_conv.py index d480e9d7f79ec..cc6f1968e6ef9 100644 --- a/torch/distributed/_tensor/_tp_conv.py +++ b/torch/distributed/_tensor/_tp_conv.py @@ -7,6 +7,7 @@ import torch.distributed as dist import torch.distributed._tensor.api as dtensor + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 22f7e690022a9..e1c01040a9094 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -5,7 +5,6 @@ from typing import Any, Callable, cast, Optional, Sequence, Tuple import torch - import torch.distributed._tensor._dispatch as op_dispatch import torch.distributed._tensor.random as random import torch.nn as nn diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/_tensor/debug/__init__.py index b7bde685fd1e1..b70529f203e1d 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/_tensor/debug/__init__.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs from torch.distributed._tensor.api import DTensor - from torch.distributed._tensor.debug.comm_mode import CommDebugMode diff --git a/torch/distributed/_tensor/debug/_op_coverage.py b/torch/distributed/_tensor/debug/_op_coverage.py index 4f54246332351..214c4f003ff2d 100644 --- a/torch/distributed/_tensor/debug/_op_coverage.py +++ b/torch/distributed/_tensor/debug/_op_coverage.py @@ -5,7 +5,6 @@ import torch import torch.fx import torch.nn as nn - from functorch.compile import make_boxed_func from torch._functorch.compilers import aot_module from torch._inductor.decomposition import select_decomp_table diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 5b69454828f3c..0241c739fb701 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -5,16 +5,15 @@ import torch from torch.autograd.graph import register_multi_grad_hook from torch.distributed._tensor.api import DTensor - from torch.nn.modules.module import ( register_module_forward_hook, register_module_forward_pre_hook, ) from torch.utils._python_dispatch import TorchDispatchMode - from torch.utils._pytree import tree_flatten from torch.utils.module_tracker import ModuleTracker + funcol_native = torch.ops._c10d_functional funcol_py = torch.ops.c10d_functional funcol_autograd = torch.ops._c10d_functional_autograd diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/_tensor/debug/visualize_sharding.py index 76cd8f3e92088..8eae86e5c0ab5 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/_tensor/debug/visualize_sharding.py @@ -5,7 +5,6 @@ from torch._prims_common import ShapeType from torch.distributed._tensor import DeviceMesh - from torch.distributed._tensor.placement_types import Placement, Shard diff --git a/torch/distributed/_tensor/examples/checkpoint_example.py b/torch/distributed/_tensor/examples/checkpoint_example.py index 1cb292f12c414..1701e28ac2ca7 100644 --- a/torch/distributed/_tensor/examples/checkpoint_example.py +++ b/torch/distributed/_tensor/examples/checkpoint_example.py @@ -5,7 +5,6 @@ checkpoint save/load the model. """ import os - from typing import cast, List import torch @@ -13,7 +12,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F - from torch.distributed._tensor import ( DeviceMesh, distribute_module, diff --git a/torch/distributed/_tensor/examples/comm_mode_features_example.py b/torch/distributed/_tensor/examples/comm_mode_features_example.py index 106a5db735107..93155687cf920 100644 --- a/torch/distributed/_tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/_tensor/examples/comm_mode_features_example.py @@ -1,16 +1,13 @@ import os import torch - from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.debug import CommDebugMode - from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, RowwiseParallel, ) - from torch.testing._internal.distributed._tensor.common_dtensor import ( MLPModule, MLPStacked, diff --git a/torch/distributed/_tensor/examples/torchrec_sharding_example.py b/torch/distributed/_tensor/examples/torchrec_sharding_example.py index 3e6c63dd18eb9..33f8c7017f5be 100644 --- a/torch/distributed/_tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/_tensor/examples/torchrec_sharding_example.py @@ -9,7 +9,6 @@ from typing import List, TYPE_CHECKING import torch - from torch.distributed._tensor import ( DeviceMesh, DTensor, @@ -24,6 +23,7 @@ TensorStorageMetadata, ) + if TYPE_CHECKING: from torch.distributed._tensor.placement_types import Placement diff --git a/torch/distributed/_tensor/examples/visualize_sharding_example.py b/torch/distributed/_tensor/examples/visualize_sharding_example.py index 6e295e147b38b..0f83968891591 100644 --- a/torch/distributed/_tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/_tensor/examples/visualize_sharding_example.py @@ -4,6 +4,7 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding + world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) diff --git a/torch/distributed/_tensor/experimental/__init__.py b/torch/distributed/_tensor/experimental/__init__.py index 2dd21605ffcc5..bee73667e1eaf 100644 --- a/torch/distributed/_tensor/experimental/__init__.py +++ b/torch/distributed/_tensor/experimental/__init__.py @@ -5,6 +5,7 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.experimental.local_map import local_map + __all__ = ["local_map", "implicit_replication"] diff --git a/torch/distributed/_tensor/experimental/attention.py b/torch/distributed/_tensor/experimental/attention.py index eb7703a96ba5f..b7738cb2dee54 100644 --- a/torch/distributed/_tensor/experimental/attention.py +++ b/torch/distributed/_tensor/experimental/attention.py @@ -11,6 +11,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel.style import ParallelStyle + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/experimental/local_map.py b/torch/distributed/_tensor/experimental/local_map.py index 0fc6ce96e6e02..60d1796fdec4c 100644 --- a/torch/distributed/_tensor/experimental/local_map.py +++ b/torch/distributed/_tensor/experimental/local_map.py @@ -7,6 +7,7 @@ from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor.placement_types import Placement + try: from torch.utils import _cxx_pytree as pytree except ImportError: diff --git a/torch/distributed/_tensor/ops/__init__.py b/torch/distributed/_tensor/ops/__init__.py index d19fdfa50cb70..eaccc8aa8d3f6 100644 --- a/torch/distributed/_tensor/ops/__init__.py +++ b/torch/distributed/_tensor/ops/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +from .conv_ops import * # noqa: F403 from .embedding_ops import * # noqa: F403 -from .matrix_ops import * # noqa: F403 +from .experimental_ops import * # noqa: F403 from .math_ops import * # noqa: F403 -from .tensor_ops import * # noqa: F403 +from .matrix_ops import * # noqa: F403 from .pointwise_ops import * # noqa: F403 from .random_ops import * # noqa: F403 +from .tensor_ops import * # noqa: F403 from .view_ops import * # noqa: F403 -from .conv_ops import * # noqa: F403 -from .experimental_ops import * # noqa: F403 diff --git a/torch/distributed/_tensor/ops/basic_strategy.py b/torch/distributed/_tensor/ops/basic_strategy.py index cc28cc19d370a..97dd43b1524dc 100644 --- a/torch/distributed/_tensor/ops/basic_strategy.py +++ b/torch/distributed/_tensor/ops/basic_strategy.py @@ -1,6 +1,5 @@ import itertools from dataclasses import dataclass - from typing import List, Set, Tuple from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy @@ -11,7 +10,6 @@ Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh diff --git a/torch/distributed/_tensor/ops/conv_ops.py b/torch/distributed/_tensor/ops/conv_ops.py index f466a13aa4637..24e75593064ee 100644 --- a/torch/distributed/_tensor/ops/conv_ops.py +++ b/torch/distributed/_tensor/ops/conv_ops.py @@ -7,6 +7,7 @@ from torch.distributed._tensor.ops.utils import register_prop_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py index 6f8cc8c67851e..5af79562adcb2 100644 --- a/torch/distributed/_tensor/ops/embedding_ops.py +++ b/torch/distributed/_tensor/ops/embedding_ops.py @@ -11,16 +11,15 @@ expand_to_full_mesh_op_strategy, register_op_strategy, ) - from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/experimental_ops.py b/torch/distributed/_tensor/ops/experimental_ops.py index 546945acd6220..6d6967d4ea8d1 100644 --- a/torch/distributed/_tensor/ops/experimental_ops.py +++ b/torch/distributed/_tensor/ops/experimental_ops.py @@ -2,19 +2,21 @@ # implement matrix related ops for distributed tensor from typing import List -try: - import numpy as np -except ModuleNotFoundError: - np = None # type: ignore[assignment] - import torch from torch.distributed._tensor._op_schema import OpSchema, OutputSharding from torch.distributed._tensor.ops.utils import register_prop_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + aten = torch.ops.aten +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + + @register_prop_rule(aten.slice_backward.default) def slice_backward_rules(op_schema: OpSchema) -> OutputSharding: grad_output_spec, input_sizes, dim, start, end, step = op_schema.args_schema diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py index 377c50dffa13e..412c566253ab1 100644 --- a/torch/distributed/_tensor/ops/math_ops.py +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -6,7 +6,6 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch - from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, diff --git a/torch/distributed/_tensor/ops/matrix_ops.py b/torch/distributed/_tensor/ops/matrix_ops.py index 15f00af670d27..128a73a59ffec 100644 --- a/torch/distributed/_tensor/ops/matrix_ops.py +++ b/torch/distributed/_tensor/ops/matrix_ops.py @@ -19,9 +19,9 @@ Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index ab80f783cf5b3..96bfb808c1006 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -2,7 +2,6 @@ from typing import List, Sequence, Tuple import torch - from torch.distributed._tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, @@ -13,7 +12,6 @@ StrategyType, TupleStrategy, ) - from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, diff --git a/torch/distributed/_tensor/ops/random_ops.py b/torch/distributed/_tensor/ops/random_ops.py index 390dc419ecd78..d4b533aae09ac 100644 --- a/torch/distributed/_tensor/ops/random_ops.py +++ b/torch/distributed/_tensor/ops/random_ops.py @@ -9,6 +9,7 @@ from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index d2feb19ba2f95..a91d6261c51dc 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -3,7 +3,6 @@ from typing import cast, List, Optional, Sequence, Tuple import torch - from torch.distributed._tensor._op_schema import ( _is_inplace_op, OpSchema, diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 7161988adf25c..ea088b7377a9b 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -15,7 +15,6 @@ ) import torch - from torch import Tensor from torch.distributed._tensor._op_schema import ( OpSchema, @@ -32,10 +31,10 @@ prod, register_op_strategy, ) - from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten Shape = Tuple[int, ...] diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 31e280c2f5b8b..352e12640bd74 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -6,7 +6,6 @@ import torch import torch.distributed._functional_collectives as funcol - from torch.distributed._tensor._collective_utils import ( fill_empty_tensor_to_shards, mesh_broadcast, diff --git a/torch/distributed/_tensor/random.py b/torch/distributed/_tensor/random.py index ed331736c5ce4..3e43a9119ac20 100644 --- a/torch/distributed/_tensor/random.py +++ b/torch/distributed/_tensor/random.py @@ -6,7 +6,6 @@ import torch import torch.distributed as dist - from torch import Tensor from torch.distributed._tensor.placement_types import DTensorSpec, Shard from torch.distributed.device_mesh import _get_device_handle, DeviceMesh diff --git a/torch/distributed/tensor/parallel/__init__.py b/torch/distributed/tensor/parallel/__init__.py index 990550414ca47..9fe378c51b0d4 100644 --- a/torch/distributed/tensor/parallel/__init__.py +++ b/torch/distributed/tensor/parallel/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from torch.distributed.tensor.parallel.api import parallelize_module - from torch.distributed.tensor.parallel.loss import loss_parallel from torch.distributed.tensor.parallel.style import ( ColwiseParallel, @@ -11,6 +10,7 @@ SequenceParallel, ) + __all__ = [ "ColwiseParallel", "ParallelStyle", @@ -19,5 +19,5 @@ "RowwiseParallel", "SequenceParallel", "parallelize_module", - "loss_parallel" + "loss_parallel", ] diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 394fde457bb21..3f47ec6f1ef34 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -5,12 +5,16 @@ from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources + + try: from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling except Exception: + def is_torchdynamo_compiling(): # type: ignore[misc] return False + LayoutsType = Union[Placement, Tuple[Placement, ...]] @@ -46,8 +50,10 @@ def _validate_tp_mesh_dim( is valid, `False` otherwise. """ if device_mesh.ndim > 1: - raise ValueError(f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" - 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]') + raise ValueError( + f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" + 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' + ) parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) if parent_mesh: diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index f78e9712d304b..e0fc4d2ef2b72 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -1,21 +1,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, Union from fnmatch import fnmatch +from typing import Dict, Union import torch import torch.distributed._tensor.random as random import torch.nn as nn -from torch.distributed._tensor import ( - DeviceMesh, -) +from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.random import ( is_rng_supported_mesh, TensorParallelRNGTracker, ) from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim -from torch.distributed.tensor.parallel.style import ( - ParallelStyle, -) +from torch.distributed.tensor.parallel.style import ParallelStyle __all__ = [ @@ -98,14 +94,19 @@ def parallelize_module( # type: ignore[return] atom = path_splits.pop(0) matched_children = filter( # `t[0]` is child name - lambda t: fnmatch(t[0], atom), module.named_children() + lambda t: fnmatch(t[0], atom), + module.named_children(), ) # apply the plan to all matched submodules for _, submodule in matched_children: if path_splits: # we haven't reached the leaf, apply in dict style - leaf_path = ".".join(path_splits) # rest of the path after `atom` - parallelize_module(submodule, device_mesh, {leaf_path: parallelize_style}) + leaf_path = ".".join( + path_splits + ) # rest of the path after `atom` + parallelize_module( + submodule, device_mesh, {leaf_path: parallelize_style} + ) else: # otherwise, directly apply style to this submodule parallelize_module(submodule, device_mesh, parallelize_style) diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index baa9d638037d3..6c4d6f8016755 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -7,6 +7,7 @@ _unflatten_tensor, ) + __all__ = [] # type: ignore[var-annotated] diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index c38771ae86e2b..df51efaf87f54 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist - import torch.distributed._shard.sharding_spec as shard_spec import torch.distributed.distributed_c10d as c10d from torch.distributed._shard.sharded_tensor import ( @@ -13,12 +12,10 @@ ShardedTensorMetadata, TensorProperties, ) - from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.device_mesh import _mesh_resources - from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor @@ -28,6 +25,7 @@ _unflatten_tensor, ) + __all__ = ["DTensorExtensions"] @@ -245,7 +243,6 @@ def _chunk_dtensor( # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): - # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. # TP is the inner dimension and FSDP is the outer dimension. # Therefore, shard placements for tensor is (Shard(0), Replicate()). @@ -324,6 +321,7 @@ class DTensorExtensions(FSDPExtensions): This is the implementation for FSDPExtensions defined in https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py """ + def __init__(self, device_handle) -> None: super().__init__() self.compute_stream = None @@ -352,7 +350,7 @@ def post_unflatten_transform( tensor, param_extension, device_handle=self.device_handle, - compute_stream=self.compute_stream + compute_stream=self.compute_stream, ) _set_fsdp_flattened(result) return result diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 3ea97846e313a..4e7af55d32c35 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -5,6 +5,7 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard + __all__ = [ "input_reshard", ] @@ -49,7 +50,9 @@ def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> N nonlocal cx cx = saved_tensor_hooks # type: ignore[name-defined] - def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any: + def input_reshard_backward_hook( + _: torch.nn.Module, _i: Tuple[Any, ...], _o: Any + ) -> Any: nonlocal cx cx.__exit__() # type: ignore[name-defined, union-attr] @@ -60,7 +63,9 @@ def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any return module -def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401 +def _pack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor +) -> Any: # noqa: D401 """Hook function called after FWD to shard input.""" if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) @@ -78,7 +83,9 @@ def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> return x -def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401 +def _unpack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: Any +) -> torch.Tensor: # noqa: D401 """Hook function called before activation recomputing in BWD to restore input.""" if ( isinstance(x, DTensor) diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index f2776c5123b47..a51d14b0efbd5 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -18,6 +18,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index a4f4d4de0b985..42437a7084758 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -1,12 +1,20 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod -from typing import Optional, Union, Tuple, Dict, Any from functools import partial +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module +from torch.distributed._tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Placement, + Replicate, + Shard, +) __all__ = [ @@ -74,29 +82,35 @@ def __init__( *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, - use_local_output: bool = True + use_local_output: bool = True, ): super().__init__() - self.input_layouts = (input_layouts or Replicate(), ) - self.output_layouts = (output_layouts or Shard(-1), ) + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Shard(-1),) # colwise linear runtime sharding (desired sharding): # 1. requires replicate input # 2. shard output on last dim - self.desired_input_layouts = (Replicate(), ) + self.desired_input_layouts = (Replicate(),) self.use_local_output = use_local_output @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): # TODO: figure out dynamo support for instance method and switch this to instance method # annotate module input placements/sharding with input_layouts input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): @@ -104,17 +118,13 @@ def _partition_linear_fn(self, name, module, device_mesh): # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) module.register_parameter(name, dist_param) def _partition_embedding_fn(self, name, module, device_mesh): # colwise shard embedding.weight is straight forward as Shard(1) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(1)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(1)])) module.register_parameter(name, dist_param) @staticmethod @@ -131,14 +141,20 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn else: - raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!") + raise NotImplementedError( + "ColwiseParallel currently only support nn.Linear and nn.Embedding!" + ) return distribute_module( module, device_mesh, partition_fn, - partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), - partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), ) @@ -180,41 +196,49 @@ def __init__( *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, - use_local_output: bool = True + use_local_output: bool = True, ): super().__init__() - self.input_layouts = (input_layouts or Shard(-1), ) - self.output_layouts = (output_layouts or Replicate(), ) + self.input_layouts = (input_layouts or Shard(-1),) + self.output_layouts = (output_layouts or Replicate(),) self.use_local_output = use_local_output @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) # means Rowwise as nn.Linear is input * weight^T + bias, where # weight would become Shard(0) - module.register_parameter("weight", nn.Parameter( - distribute_tensor(module.weight, device_mesh, [Shard(1)]) - )) + module.register_parameter( + "weight", + nn.Parameter(distribute_tensor(module.weight, device_mesh, [Shard(1)])), + ) if module.bias is not None: - module.register_parameter("bias", nn.Parameter( - distribute_tensor(module.bias, device_mesh, [Replicate()]) - )) + module.register_parameter( + "bias", + nn.Parameter( + distribute_tensor(module.bias, device_mesh, [Replicate()]) + ), + ) def _partition_embedding_fn(self, name, module, device_mesh): # rowwise shard embedding.weight is Shard(0) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) module.register_parameter(name, dist_param) @staticmethod @@ -231,20 +255,26 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if isinstance(module, nn.Linear): partition_fn = self._partition_linear_fn # rowwise linear runtime sharding requires input tensor shard on last dim - self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), ) + self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),) elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn # rowwise embedding runtime sharding requires input tensor replicated - self.desired_input_layouts = (Replicate(), ) + self.desired_input_layouts = (Replicate(),) else: - raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!") + raise NotImplementedError( + "RowwiseParallel currently only support nn.Linear and nn.Embedding!" + ) return distribute_module( module, device_mesh, partition_fn, - partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), - partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), ) @@ -287,17 +317,15 @@ class SequenceParallel(ParallelStyle): inits for the weights on those modules, you need to broadcast the weights before/after parallelizing to ensure that they are replicated. """ - def __init__( - self, - *, - sequence_dim: int = 1, - use_local_output: bool = False - ): + + def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False): super().__init__() self.sequence_dim = sequence_dim self.use_local_output = use_local_output - def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh): + def _replicate_module_fn( + self, name: str, module: nn.Module, device_mesh: DeviceMesh + ): for p_name, param in module.named_parameters(): # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow # us to simply just use from_local @@ -312,9 +340,13 @@ def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh): if isinstance(input_tensor, DTensor): return inputs elif isinstance(input_tensor, torch.Tensor): - return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False) + return DTensor.from_local( + input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False + ) else: - raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}") + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) @staticmethod def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): @@ -380,32 +412,43 @@ def __init__( self, *, input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, - desired_input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, Tuple[Optional[Placement]]] + ] = None, input_kwarg_layouts: Optional[Dict[str, Placement]] = None, desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, - use_local_output: bool = False + use_local_output: bool = False, ): - self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts - self.desired_input_layouts = \ - (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts + self.input_layouts = ( + (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts + ) + self.desired_input_layouts = ( + (desired_input_layouts,) + if isinstance(desired_input_layouts, Placement) + else desired_input_layouts + ) self.use_local_output = use_local_output if self.input_layouts is not None: - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" - assert len(self.input_layouts) == len(self.desired_input_layouts), \ - "input_layouts and desired_input_layouts should have same length!" + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + assert len(self.input_layouts) == len( + self.desired_input_layouts + ), "input_layouts and desired_input_layouts should have same length!" self.with_kwargs = input_kwarg_layouts is not None self.input_kwarg_layouts = input_kwarg_layouts or {} self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} if self.with_kwargs: - assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \ - "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + assert len(self.input_kwarg_layouts) == len( + self.desired_input_kwarg_layouts + ), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" def _prepare_input_arg( self, input: Any, mesh: DeviceMesh, input_layout: Optional[Placement], - desired_layout: Optional[Placement] + desired_layout: Optional[Placement], ): if input_layout is not None: if isinstance(input, DTensor): @@ -413,8 +456,12 @@ def _prepare_input_arg( # assert inp.placements[0] == input_layout dt_inp = input else: - assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" - dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) + assert isinstance( + input, torch.Tensor + ), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) @@ -432,9 +479,15 @@ def _prepare_input_fn(self, inputs, device_mesh): if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" - for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): - prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)) + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + for inp, input_layout, desired_layout in zip( + inputs, self.input_layouts, self.desired_input_layouts + ): + prepared_inputs.append( + self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout) + ) return tuple(prepared_inputs) def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): @@ -445,15 +498,19 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): input_layout = self.input_kwarg_layouts.get(kwarg_key) desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) - prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout) + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( + kwarg_val, device_mesh, input_layout, desired_input_layout + ) return (prepared_arg_inputs, prepared_kwarg_inputs) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if self.with_kwargs: module.register_forward_pre_hook( - lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh), - with_kwargs=True + lambda _, inputs, kwargs: self._prepare_input_kwarg_fn( + inputs, kwargs, device_mesh + ), + with_kwargs=True, ) # type: ignore[misc] else: module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] @@ -497,38 +554,55 @@ class PrepareModuleOutput(ParallelStyle): >>> ) >>> ) """ + def __init__( self, *, output_layouts: Union[Placement, Tuple[Placement]], desired_output_layouts: Union[Placement, Tuple[Placement]], - use_local_output: bool = True + use_local_output: bool = True, ): - self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts - self.desired_output_layouts = \ - (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts + self.output_layouts = ( + (output_layouts,) + if isinstance(output_layouts, Placement) + else output_layouts + ) + self.desired_output_layouts = ( + (desired_output_layouts,) + if isinstance(desired_output_layouts, Placement) + else desired_output_layouts + ) self.use_local_output = use_local_output - assert len(self.output_layouts) == len(self.desired_output_layouts), \ - "output_layouts and desired_output_layouts should have same length!" + assert len(self.output_layouts) == len( + self.desired_output_layouts + ), "output_layouts and desired_output_layouts should have same length!" def _prepare_out_fn(self, outputs, device_mesh): prepared_outputs = [] if not isinstance(outputs, tuple): outputs = (outputs,) if len(outputs) != len(self.output_layouts): - raise ValueError("module outputs and output_layouts should have same length!") - for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts): + raise ValueError( + "module outputs and output_layouts should have same length!" + ) + for out, out_layout, desired_out_layout in zip( + outputs, self.output_layouts, self.desired_output_layouts + ): if out_layout is not None: if isinstance(out, DTensor): # TODO: re-enable the check once we fix the compile path # assert out.placements[0] == out_layout dt_out = out else: - dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False) + dt_out = DTensor.from_local( + out, device_mesh, (out_layout,), run_check=False + ) if out_layout != desired_out_layout: dt_out = dt_out.redistribute(placements=(desired_out_layout,)) - prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out) + prepared_outputs.append( + dt_out.to_local() if self.use_local_output else dt_out + ) else: prepared_outputs.append(out) if len(prepared_outputs) == 1: From 3b798df853444d66077ffa846f5682e621b07388 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:44 +0800 Subject: [PATCH 62/63] [BE][Easy] enable UFMT for `torch/distributed/{fsdp,optim,rpc}/` (#128869) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869 Approved by: https://github.com/fegin ghstack dependencies: #128868 --- .lintrunner.toml | 27 ---- torch/distributed/fsdp/__init__.py | 1 + torch/distributed/fsdp/_common_utils.py | 2 + torch/distributed/fsdp/_debug_utils.py | 1 + torch/distributed/fsdp/_flat_param.py | 1 + torch/distributed/fsdp/_init_utils.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 1 + torch/distributed/fsdp/_runtime_utils.py | 1 + torch/distributed/fsdp/_state_dict_utils.py | 3 - .../distributed/fsdp/_unshard_param_utils.py | 1 + torch/distributed/fsdp/_wrap_utils.py | 1 - torch/distributed/fsdp/api.py | 2 +- .../fsdp/fully_sharded_data_parallel.py | 2 +- torch/distributed/fsdp/sharded_grad_scaler.py | 1 + torch/distributed/fsdp/wrap.py | 1 + torch/distributed/optim/__init__.py | 10 +- .../optim/apply_optimizer_in_backward.py | 10 +- .../distributed/optim/functional_adadelta.py | 5 +- torch/distributed/optim/functional_adagrad.py | 3 +- torch/distributed/optim/functional_adam.py | 3 +- torch/distributed/optim/functional_adamax.py | 3 +- torch/distributed/optim/functional_adamw.py | 3 +- torch/distributed/optim/functional_rmsprop.py | 3 +- torch/distributed/optim/functional_rprop.py | 3 +- torch/distributed/optim/functional_sgd.py | 3 +- torch/distributed/optim/named_optimizer.py | 13 +- torch/distributed/optim/optimizer.py | 5 +- torch/distributed/optim/utils.py | 2 + .../optim/zero_redundancy_optimizer.py | 37 +++--- torch/distributed/rpc/__init__.py | 69 +++++----- torch/distributed/rpc/_testing/__init__.py | 5 +- .../_testing/faulty_agent_backend_registry.py | 11 +- torch/distributed/rpc/_utils.py | 19 ++- torch/distributed/rpc/api.py | 118 ++++++++++-------- torch/distributed/rpc/backend_registry.py | 99 ++++++++++----- torch/distributed/rpc/constants.py | 3 +- torch/distributed/rpc/functions.py | 2 + torch/distributed/rpc/internal.py | 5 +- torch/distributed/rpc/options.py | 2 + torch/distributed/rpc/rref_proxy.py | 17 ++- .../rpc/server_process_global_profiler.py | 13 +- 41 files changed, 300 insertions(+), 213 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e3f1b58027c3e..99c04cac4fbb3 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1413,35 +1413,8 @@ exclude_patterns = [ 'torch/distributed/nn/jit/instantiator.py', 'torch/distributed/nn/jit/templates/__init__.py', 'torch/distributed/nn/jit/templates/remote_module_template.py', - 'torch/distributed/optim/__init__.py', - 'torch/distributed/optim/apply_optimizer_in_backward.py', - 'torch/distributed/optim/functional_adadelta.py', - 'torch/distributed/optim/functional_adagrad.py', - 'torch/distributed/optim/functional_adam.py', - 'torch/distributed/optim/functional_adamax.py', - 'torch/distributed/optim/functional_adamw.py', - 'torch/distributed/optim/functional_rmsprop.py', - 'torch/distributed/optim/functional_rprop.py', - 'torch/distributed/optim/functional_sgd.py', - 'torch/distributed/optim/named_optimizer.py', - 'torch/distributed/optim/optimizer.py', - 'torch/distributed/optim/post_localSGD_optimizer.py', - 'torch/distributed/optim/utils.py', - 'torch/distributed/optim/zero_redundancy_optimizer.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', - 'torch/distributed/rpc/__init__.py', - 'torch/distributed/rpc/_testing/__init__.py', - 'torch/distributed/rpc/_testing/faulty_agent_backend_registry.py', - 'torch/distributed/rpc/_utils.py', - 'torch/distributed/rpc/api.py', - 'torch/distributed/rpc/backend_registry.py', - 'torch/distributed/rpc/constants.py', - 'torch/distributed/rpc/functions.py', - 'torch/distributed/rpc/internal.py', - 'torch/distributed/rpc/options.py', - 'torch/distributed/rpc/rref_proxy.py', - 'torch/distributed/rpc/server_process_global_profiler.py', 'torch/distributed/run.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index d887730f442f6..6180dbb3df299 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -18,6 +18,7 @@ StateDictType, ) + __all__ = [ "BackwardPrefetch", "CPUOffload", diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index aae2405d0bb50..10d0f82126511 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -44,9 +44,11 @@ StateDictType, ) + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions + from ._flat_param import FlatParamHandle FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 523330e5580df..163d9a045b68e 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -15,6 +15,7 @@ clean_tensor_name, ) + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 816b91433063a..8bc975dc72fd5 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -50,6 +50,7 @@ FSDPExtensions, ) + __all__ = [ "FlatParameter", "FlatParamHandle", diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index c8b58091bf89b..aaeedf22397a4 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -58,9 +58,9 @@ from torch.distributed.fsdp.wrap import _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions from torch.distributed.utils import _sync_params_and_buffers - from torch.utils._python_dispatch import is_traceable_wrapper_subclass + if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 54f800a168653..4cfe761769a3b 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -55,6 +55,7 @@ ) from torch.utils._pytree import tree_map_only + if TYPE_CHECKING: from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 833c1d45697ae..f84e7dd3e5055 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -39,6 +39,7 @@ ) from torch.utils import _pytree as pytree + logger = logging.getLogger(__name__) # Do not include "process_group" to enable hybrid shard and MoE cases diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 797a0116587bb..815cfb2dd4a1f 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -17,9 +17,7 @@ import torch import torch.distributed as dist - import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper - import torch.nn as nn import torch.nn.functional as F from torch.distributed._shard.sharded_tensor import ( @@ -29,7 +27,6 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import _mesh_resources - from torch.distributed.fsdp._common_utils import ( _FSDPState, _get_module_fsdp_state_if_fully_sharded_module, diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index 435193a88703a..4143d2928c8b8 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -26,6 +26,7 @@ from ._flat_param import FlatParamHandle + FLAT_PARAM = "_flat_param" diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 84cdf250d8ae1..895bcbd8e967b 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -11,7 +11,6 @@ _get_module_fsdp_state, _override_module_mixed_precision, ) - from torch.distributed.fsdp.wrap import ( _construct_wrap_fn, _or_policy, diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 0272ee0c57c9f..f2e4bdb7ea023 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -5,12 +5,12 @@ from dataclasses import dataclass from enum import auto, Enum - from typing import Optional, Sequence, Type import torch from torch.nn.modules.batchnorm import _BatchNorm + __all__ = [ "ShardingStrategy", "BackwardPrefetch", diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 9edd057a8f371..1567bb973b22a 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -85,8 +85,8 @@ StateDictType, ) from torch.distributed.utils import _p_assert -from ._flat_param import FlatParameter, FlatParamHandle +from ._flat_param import FlatParameter, FlatParamHandle from ._optim_utils import ( _flatten_optim_state_dict, _get_param_id_to_param_from_optim_input, diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 3487e01263c71..7c1b2f8352868 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -8,6 +8,7 @@ from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState from torch.distributed.distributed_c10d import ProcessGroup + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index acb5a6f1f642a..f8604bbb1bb04 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -24,6 +24,7 @@ import torch.nn as nn + __all__ = [ "always_wrap_policy", "lambda_auto_wrap_policy", diff --git a/torch/distributed/optim/__init__.py b/torch/distributed/optim/__init__.py index fe33265fd532f..924b993ec8414 100644 --- a/torch/distributed/optim/__init__.py +++ b/torch/distributed/optim/__init__.py @@ -15,7 +15,6 @@ _get_in_backward_optimizers, ) from .functional_adadelta import _FunctionalAdadelta - from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam from .functional_adamax import _FunctionalAdamax @@ -26,6 +25,7 @@ from .named_optimizer import _NamedOptimizer from .utils import as_functional_optim + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -44,4 +44,10 @@ from .post_localSGD_optimizer import PostLocalSGDOptimizer from .zero_redundancy_optimizer import ZeroRedundancyOptimizer -__all__ = ["as_functional_optim", "DistributedOptimizer", "PostLocalSGDOptimizer", "ZeroRedundancyOptimizer"] + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/torch/distributed/optim/apply_optimizer_in_backward.py b/torch/distributed/optim/apply_optimizer_in_backward.py index 6bd182cca5736..36f679f4eba49 100644 --- a/torch/distributed/optim/apply_optimizer_in_backward.py +++ b/torch/distributed/optim/apply_optimizer_in_backward.py @@ -2,6 +2,7 @@ import torch + __all__: List[str] = [] # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter @@ -11,6 +12,7 @@ param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() + @no_type_check def _apply_optimizer_in_backward( optimizer_class: Type[torch.optim.Optimizer], @@ -48,9 +50,7 @@ def _apply_optimizer_in_backward( # have their registered optimizer(s) applied. """ - torch._C._log_api_usage_once( - "torch.distributed.optim.apply_optimizer_in_backward" - ) + torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") @no_type_check def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: @@ -62,7 +62,9 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: # Don't create a new acc_grad if we already have one # i.e. for shared parameters or attaching multiple optimizers to a param. if param not in param_to_acc_grad_map: - param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0] + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] optimizer = optimizer_class([param], **optimizer_kwargs) diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index bc5f7c63dd175..3ad51348b6afa 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adadelta Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, @@ -102,5 +103,5 @@ def step(self, gradients: List[Optional[Tensor]]): weight_decay=weight_decay, foreach=self.foreach, maximize=self.maximize, - has_complex=has_complex + has_complex=has_complex, ) diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index 93a1fe2b2240d..67f7328489ed2 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adagrad Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index 34868d23d8a53..3ed271765170c 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adam Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index 32bce65dfe1f5..8f1fdc0ccc02b 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adamax Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 43addd0508221..d3f1f80e9209b 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional AdamW Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index 851119c8600c0..7a03e8e9f462f 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional RMSprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 60742bc68896f..615015a95a316 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Rprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 3a8176e877057..32381855db6b5 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional SGD Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 9e1e5377873d1..8e0b539b14826 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -1,9 +1,18 @@ # mypy: allow-untyped-defs import logging import warnings - from copy import deepcopy -from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Mapping, + Optional, + overload, + Union, +) import torch import torch.nn as nn diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index f2eca606c0261..65df14770c21c 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import logging - from collections import defaultdict from threading import Lock from typing import List, Optional @@ -12,8 +11,10 @@ import torch.nn as nn from torch import Tensor from torch.distributed.rpc import RRef + from .utils import functional_optim_map + __all__ = ["DistributedOptimizer"] logger = logging.getLogger(__name__) @@ -205,7 +206,7 @@ def __init__(self, optimizer_class, params_rref, *args, **kwargs): "(i.e. Distributed Model Parallel training on CPU) due to the Python's " "Global Interpreter Lock (GIL). Please file an issue if you need this " "optimizer in TorchScript. ", - optimizer_class + optimizer_class, ) optimizer_new_func = _new_local_optimizer diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index af2220ca55749..d2c75eee7e39b 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -2,6 +2,7 @@ from typing import Type from torch import optim + from .functional_adadelta import _FunctionalAdadelta from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam @@ -11,6 +12,7 @@ from .functional_rprop import _FunctionalRprop from .functional_sgd import _FunctionalSGD + # dict to map a user passed in optimizer_class to a functional # optimizer class if we have already defined inside the # distributed.optim package, this is so that we hide the diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 8a3be3b018153..f664d11afb79c 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -20,11 +20,12 @@ from torch.optim import Optimizer -logger = logging.getLogger(__name__) - __all__ = ["ZeroRedundancyOptimizer"] +logger = logging.getLogger(__name__) + + # Credits: classy_vision/generic/distributed_util.py def _recursive_copy_to_device( value: Any, @@ -925,9 +926,9 @@ def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: mapping bucket indices to :class:`_DDPBucketAssignment` s for each rank. """ - assert self._overlap_with_ddp, ( - "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" - ) + assert ( + self._overlap_with_ddp + ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" if len(self._bucket_assignments_per_rank_cache) > 0: return self._bucket_assignments_per_rank_cache @@ -1074,9 +1075,9 @@ def _local_step( "Specifying `gradients` should not " "be used when `overlap_with_ddp=False`" ) - assert closure is None, ( - "`closure` is not supported when using a local functional optimizer" - ) + assert ( + closure is None + ), "`closure` is not supported when using a local functional optimizer" loss = self.optim.step(gradients=gradients) # Sync any updated attributes in the local optimizer to the exposed @@ -1504,7 +1505,7 @@ def _init_local_optimizer(self) -> None: "%s does not support the argument " "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " "error due to an empty parameter list", - self._optim_constructor + self._optim_constructor, ) self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] @@ -1515,17 +1516,16 @@ def _init_local_optimizer(self) -> None: self._bucket_assignments_per_rank[self.global_rank] ) logger.info( - "rank %s with %s parameters " - "across %s buckets", - self.global_rank, local_numel, num_assigned_buckets + "rank %s with %s parameters " "across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, ) if self.global_rank == 0: logger.info( - "%s DDP " - "buckets and " - "%s bucket " - "assignments", - len(self._overlap_info.params_per_bucket), self._overlap_info.num_bucket_assignments + "%s DDP " "buckets and " "%s bucket " "assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, ) else: # NOTE: Passing `param_groups` into the local optimizer constructor @@ -1640,7 +1640,8 @@ def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: "Using the functional optimizer %s " "instead of %s since " "`overlap_with_ddp=True`", - optim_constructor, optimizer_class + optim_constructor, + optimizer_class, ) return optim_constructor else: diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index 581433d220c63..6c6608a2a773f 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,22 +1,25 @@ # mypy: allow-untyped-defs -from datetime import timedelta import logging import os import threading import warnings +from datetime import timedelta from typing import Generator, Tuple from urllib.parse import urlparse import torch import torch.distributed as dist + +__all__ = ["is_available"] + + logger = logging.getLogger(__name__) _init_counter = 0 _init_counter_lock = threading.Lock() -__all__ = ["is_available"] def is_available() -> bool: return hasattr(torch._C, "_rpc_init") @@ -27,54 +30,51 @@ def is_available() -> bool: if is_available(): + import numbers + + import torch.distributed.autograd as dist_autograd from torch._C._distributed_c10d import Store - from torch._C._distributed_rpc import ( + from torch._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, _disable_jit_rref_pickle, - _enable_jit_rref_pickle, _disable_server_process_global_profiler, + _enable_jit_rref_pickle, _enable_server_process_global_profiler, - _set_and_start_rpc_agent, - _reset_current_rpc_agent, - _delete_all_user_and_unforked_owner_rrefs, - _destroy_rref_context, - _set_profiler_node_id, - _is_current_rpc_agent_set, - _rref_context_get_debug_info, - _cleanup_python_rpc_handler, - _invoke_rpc_builtin, - _invoke_rpc_python_udf, - _invoke_rpc_torchscript, + _get_current_rpc_agent, _invoke_remote_builtin, _invoke_remote_python_udf, _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, _set_rpc_timeout, - _get_current_rpc_agent, - get_rpc_timeout, - enable_gil_profiling, - RpcBackendOptions, _TensorPipeRpcBackendOptionsBase, - RpcAgent, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, PyRRef, - TensorPipeAgent, RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + TensorPipeAgent, WorkerInfo, - _DEFAULT_INIT_METHOD, - _DEFAULT_NUM_WORKER_THREADS, - _UNSET_RPC_TIMEOUT, - _DEFAULT_RPC_TIMEOUT_SEC, - ) # noqa: F401 + ) from . import api, backend_registry, functions from .api import * # noqa: F401,F403 - import numbers - - import torch.distributed.autograd as dist_autograd - from .backend_registry import BackendType from .options import TensorPipeRpcBackendOptions # noqa: F401 - from .server_process_global_profiler import ( - _server_process_global_profile, - ) + from .server_process_global_profiler import _server_process_global_profile rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] @@ -153,7 +153,7 @@ def init_rpc( "corresponding to %(backend)s, hence that backend will be used " "instead of the default BackendType.TENSORPIPE. To silence this " "warning pass `backend=%(backend)s` explicitly.", - {'backend': backend} + {"backend": backend}, ) if backend is None: @@ -224,7 +224,6 @@ def _init_rpc_backend( world_size=None, rpc_backend_options=None, ): - _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) if _is_current_rpc_agent_set(): diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py index 640c4d09f0628..8ac1c02f4cee4 100644 --- a/torch/distributed/rpc/_testing/__init__.py +++ b/torch/distributed/rpc/_testing/__init__.py @@ -12,8 +12,9 @@ def is_available(): if is_available(): # Registers FAULTY_TENSORPIPE RPC backend. - from . import faulty_agent_backend_registry from torch._C._distributed_rpc_testing import ( - FaultyTensorPipeRpcBackendOptions, FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, ) + + from . import faulty_agent_backend_registry diff --git a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py index 9e8660989e5a7..d04882e16e79a 100644 --- a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py +++ b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -4,6 +4,7 @@ import torch.distributed as dist import torch.distributed.rpc as rpc + def _faulty_tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, @@ -11,7 +12,7 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( messages_to_fail, messages_to_delay, num_fail_sends, - **kwargs + **kwargs, ): from . import FaultyTensorPipeRpcBackendOptions @@ -28,16 +29,14 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( def _faulty_tensorpipe_init_backend_handler( store, name, rank, world_size, rpc_backend_options ): - from . import FaultyTensorPipeAgent - from . import FaultyTensorPipeRpcBackendOptions from torch.distributed.rpc import api + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, FaultyTensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" ) diff --git a/torch/distributed/rpc/_utils.py b/torch/distributed/rpc/_utils.py index 6499a80e0e172..8925bc662b5f9 100644 --- a/torch/distributed/rpc/_utils.py +++ b/torch/distributed/rpc/_utils.py @@ -1,12 +1,14 @@ # mypy: allow-untyped-defs +import logging from contextlib import contextmanager from typing import cast -import logging -from . import api -from . import TensorPipeAgent + +from . import api, TensorPipeAgent + logger = logging.getLogger(__name__) + @contextmanager def _group_membership_management(store, name, is_join): token_key = "RpcGroupManagementToken" @@ -29,10 +31,17 @@ def _group_membership_management(store, name, is_join): try: store.wait([returned]) except RuntimeError: - logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned) + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) raise + def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) - ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) return ret diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index a33358eb0dc67..5fc9e61aa5592 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -1,6 +1,4 @@ # mypy: allow-untyped-defs -__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", - "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] import collections import contextlib @@ -8,17 +6,10 @@ import inspect import logging import threading -from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING +from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar import torch -from torch.futures import Future - from torch._C._distributed_rpc import ( - PyRRef, - RemoteProfilerManager, - WorkerInfo, - TensorPipeAgent, - get_rpc_timeout, _cleanup_python_rpc_handler, _delete_all_user_and_unforked_owner_rrefs, _destroy_rref_context, @@ -32,18 +23,36 @@ _is_current_rpc_agent_set, _reset_current_rpc_agent, _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + TensorPipeAgent, + WorkerInfo, ) +from torch.futures import Future +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, PythonUDF, RPCExecMode, - _internal_rpc_pickler, - _build_rpc_profiling_key, ) -from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT -from ._utils import _group_membership_management, _update_group_membership +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + logger = logging.getLogger(__name__) @@ -59,6 +68,7 @@ _ignore_rref_leak = True _default_pickler = _internal_rpc_pickler + @contextlib.contextmanager def _use_rpc_pickler(rpc_pickler): r""" @@ -107,7 +117,9 @@ def __init__(self): _ALL_WORKER_NAMES: Set[Any] = set() _all_gather_dict_lock = threading.RLock() _all_gather_sequence_id: Dict[str, int] = {} -_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) def _init_rpc_states(agent): @@ -146,6 +158,7 @@ def _broadcast_to_followers(sequence_id, objects_map): states.gathered_objects = objects_map states.proceed_signal.set() + _thread_local_var = threading.local() @@ -245,7 +258,7 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): follower_name, _broadcast_to_followers, args=(sequence_id, states.gathered_objects), - timeout=rpc_timeout + timeout=rpc_timeout, ) worker_name_to_response_future_dict[follower_name] = fut @@ -283,9 +296,7 @@ def _barrier(worker_names): try: _all_gather(None, set(worker_names)) except RuntimeError as ex: - logger.error( - "Failed to complete barrier, got error %s", ex - ) + logger.error("Failed to complete barrier, got error %s", ex) @_require_initialized @@ -371,7 +382,11 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): all_worker_infos = agent.get_worker_infos() for worker in all_worker_infos: if worker.name != my_name: - rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False)) + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) agent.join(shutdown=True, timeout=timeout) finally: # In case of errors, continue to complete the local shutdown. @@ -445,13 +460,10 @@ def _rref_typeof_on_owner(rref, blocking: bool = True): return future -def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True): - fut = rpc_async( - rref.owner(), - _rref_typeof_on_owner, - args=(rref,), - timeout=timeout - ) +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) if blocking: return fut.wait() else: @@ -463,13 +475,16 @@ def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: boo if TYPE_CHECKING: + class RRef(PyRRef[T], Generic[T]): pass + else: try: # Combine the implementation class and the type class. class RRef(PyRRef, Generic[T]): pass + except TypeError: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases @@ -517,7 +532,9 @@ def method(self, *args, **kwargs): assert docstring is not None, "RRef user-facing methods should all have docstrings." # Do surgery on pybind11 generated docstrings. - docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef") + docstring = docstring.replace( + "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" + ) # Attach user-facing RRef method with modified docstring. new_method = method_factory(method_name, docstring) @@ -633,7 +650,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): dst_worker_info = _to_worker_info(to) should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -647,7 +666,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): func = wrapped if qualified_name is not None: - rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs) + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) elif isinstance(func, torch.jit.ScriptFunction): rref = _invoke_remote_torchscript( dst_worker_info.name, @@ -662,11 +683,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): PythonUDF(func, args, kwargs) ) rref = _invoke_remote_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec ) # attach profiling information if should_profile: @@ -678,7 +695,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): return rref -def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT): +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): if not callable(func): raise TypeError("function should be callable.") @@ -687,7 +706,9 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -702,11 +723,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = if qualified_name is not None: fut = _invoke_rpc_builtin( - dst_worker_info, - qualified_name, - rpc_timeout, - *args, - **kwargs + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs ) elif isinstance(func, torch.jit.ScriptFunction): fut = _invoke_rpc_torchscript( @@ -715,18 +732,14 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = args, kwargs, rpc_timeout, - is_async_exec + is_async_exec, ) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs) ) fut = _invoke_rpc_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - rpc_timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec ) if should_profile: assert torch.autograd._profiler_enabled() @@ -915,12 +928,15 @@ def _get_should_profile(): # Kineto profiler. ActiveProfilerType = torch._C._profiler.ActiveProfilerType return ( - torch.autograd._profiler_enabled() and - torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + torch.autograd._profiler_enabled() + and torch._C._autograd._profiler_type() + == ActiveProfilerType.LEGACY # type: ignore[attr-defined] ) -def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info): +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): ctx_manager = contextlib.nullcontext() if should_profile: diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 6290f9e8e2054..a06f0276ede95 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -__all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"] + import collections import enum @@ -7,13 +7,19 @@ import torch import torch.distributed as dist + +from . import api, constants as rpc_constants from ._utils import _group_membership_management, _update_group_membership -from . import api -from . import constants as rpc_constants -__all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend", - "BackendValue", "BackendType"] +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] BackendValue = collections.namedtuple( "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] @@ -41,6 +47,7 @@ def _backend_type_repr(self): if BackendType.__doc__: BackendType.__doc__ = _backend_type_doc + def backend_registered(backend_name): """ Checks if backend_name is registered as an RPC backend. @@ -80,7 +87,7 @@ def register_backend( init_backend_handler=init_backend_handler, ) }, - **existing_enum_dict + **existing_enum_dict, ) # Can't handle Function Enum API (mypy bug #9079) BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] @@ -90,20 +97,22 @@ def register_backend( BackendType.__doc__ = _backend_type_doc return BackendType[backend_name] + def construct_rpc_backend_options( backend, rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, init_method=rpc_constants.DEFAULT_INIT_METHOD, - **kwargs + **kwargs, ): - return backend.value.construct_rpc_backend_options_handler( rpc_timeout, init_method, **kwargs ) + def init_backend(backend, *args, **kwargs): return backend.value.init_backend_handler(*args, **kwargs) + def _init_process_group(store, rank, world_size): # Initialize ProcessGroup. process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT @@ -115,22 +124,21 @@ def _init_process_group(store, rank, world_size): assert group is not None, "Failed to initialize default ProcessGroup." if (rank != -1) and (rank != group.rank()): - raise RuntimeError( - f"rank argument {rank} doesn't match pg rank {group.rank()}" - ) + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") if (world_size != -1) and (world_size != group.size()): raise RuntimeError( f"world_size argument {world_size} doesn't match pg size {group.size()}" ) return group + def _tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, _transports=None, _channels=None, - **kwargs + **kwargs, ): from . import TensorPipeRpcBackendOptions @@ -155,9 +163,9 @@ def _tensorpipe_validate_devices(devices, device_count): def _tensorpipe_exchange_and_check_all_device_maps( my_name, my_device_count, my_device_maps, my_devices, group ): - gathered: List[Tuple[ - str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device] - ]] = [("", 0, {}, []) for _ in range(group.size())] + gathered: List[ + Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group ) @@ -173,13 +181,15 @@ def _tensorpipe_exchange_and_check_all_device_maps( my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) return reverse_device_maps, my_devices -def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True): + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): for node in all_names: devices = all_devices[node] if len(set(devices)) != len(devices): raise ValueError( - f"Node {node} has duplicated devices\n" - f"devices = {devices}" + f"Node {node} has duplicated devices\n" f"devices = {devices}" ) if not _tensorpipe_validate_devices(devices, all_device_counts[node]): raise ValueError( @@ -190,7 +200,9 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev for source_node in all_names: # For dynamic group (non-static) do not check the target node name since it may not have joined yet - if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names): + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): raise ValueError( f"Node {source_node} has invalid target node names in its device maps\n" f"device maps = {all_device_maps[source_node].keys()}\n" @@ -238,6 +250,7 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev f"device count = {all_device_counts[target_node]}" ) + def _create_device_list(my_devices, my_device_maps, reverse_device_maps): if not my_devices: devices_set: Set[torch.device] = set() @@ -250,6 +263,7 @@ def _create_device_list(my_devices, my_device_maps, reverse_device_maps): my_devices = sorted(my_devices, key=lambda d: d.index) return my_devices + def _create_reverse_mapping(my_name, all_names, all_device_maps): reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} for node in all_names: @@ -259,8 +273,10 @@ def _create_reverse_mapping(my_name, all_names, all_device_maps): } return reverse_device_maps + def _get_device_infos(): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) opts = agent._get_backend_options() device_count = torch.cuda.device_count() @@ -268,8 +284,10 @@ def _get_device_infos(): torch.cuda.init() return device_count, opts.device_maps, opts.devices + def _set_devices_and_reverse_device_map(agent): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, agent) # Group state is retrieved from local agent # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid @@ -282,34 +300,52 @@ def _set_devices_and_reverse_device_map(agent): worker_name = worker_info.name if worker_name != my_name: # TODO: make async? - device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos) + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) else: opts = agent._get_backend_options() - device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices + device_count, device_map, devices = ( + torch.cuda.device_count(), + opts.device_maps, + opts.devices, + ) all_device_counts[worker_name] = device_count all_device_maps[worker_name] = device_map all_devices[worker_name] = devices all_names.append(worker_name) - _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False) + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps for worker_name in all_names: # Set device list for each worker - all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps) - api.rpc_sync(worker_name, _update_group_membership, - args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True)) + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions -def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options): - from . import TensorPipeAgent - from . import TensorPipeRpcBackendOptions if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, TensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" ) @@ -389,6 +425,7 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_ raise return agent + register_backend( "TENSORPIPE", _tensorpipe_construct_rpc_backend_options_handler, diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py index 3bc525b70d9bb..56f6db4db259d 100644 --- a/torch/distributed/rpc/constants.py +++ b/torch/distributed/rpc/constants.py @@ -1,5 +1,6 @@ from datetime import timedelta from typing import List + from torch._C._distributed_rpc import ( _DEFAULT_INIT_METHOD, _DEFAULT_NUM_WORKER_THREADS, @@ -17,7 +18,7 @@ DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS # Ensure that we don't time out when there are long periods of time without # any operations against the underlying ProcessGroup. -DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1) +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) # Value indicating that timeout is not set for RPC call, and the default should be used. UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index c9e92980cf566..e48ea8cc534ab 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -159,9 +159,11 @@ def async_execution(fn): >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() >>> print(ret) # prints tensor([4., 4.]) """ + @functools.wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) + # Can't declare and use attributes of function objects (mypy#2087) wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] return wrapper diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 2fc647c414d96..5faf7d14d0da5 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -12,6 +12,7 @@ import torch.distributed as dist from torch._C._distributed_rpc import _get_current_rpc_agent + __all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] # Thread local tensor tables to store tensors while pickling torch.Tensor @@ -251,7 +252,9 @@ def _build_rpc_profiling_key( Returns: String representing profiling key """ - profile_key = f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) return profile_key diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 70328f3459695..53bf473ba5628 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -3,6 +3,7 @@ import torch from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase + from . import constants as rpc_contants @@ -10,6 +11,7 @@ __all__ = ["TensorPipeRpcBackendOptions"] + def _to_device(device: DeviceType) -> torch.device: device = torch.device(device) if device.type != "cuda": diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index cdb0a5d22b742..85927b68bacb9 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -1,20 +1,22 @@ # mypy: allow-untyped-defs from functools import partial -from . import functions -from . import rpc_async - import torch -from .constants import UNSET_RPC_TIMEOUT from torch.futures import Future +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + def _local_invoke(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + @functions.async_execution def _local_invoke_async_execution(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): def _rref_type_cont(rref_fut): rref_type = rref_fut.value() @@ -33,7 +35,7 @@ def _rref_type_cont(rref_fut): rref.owner(), _invoke_func, args=(rref, func_name, args, kwargs), - timeout=timeout + timeout=timeout, ) rref_fut = rref._get_type(timeout=timeout, blocking=False) @@ -63,6 +65,7 @@ def _complete_op(fut): rref_fut.then(_wrap_rref_type_cont) return result + # This class manages proxied RPC API calls for RRefs. It is entirely used from # C++ (see python_rpc_handler.cpp). class RRefProxy: @@ -72,4 +75,6 @@ def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): self.rpc_timeout = timeout def __getattr__(self, func_name): - return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout) + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 0543ab56a877f..b5d089d305253 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -2,18 +2,20 @@ # mypy: allow-untyped-defs import itertools +from typing import List import torch from torch.autograd.profiler_legacy import profile -from typing import List from . import ( _disable_server_process_global_profiler, _enable_server_process_global_profiler, ) + __all__: List[str] = [] + class _server_process_global_profile(profile): """ It has the same API as ``torch.autograd.profiler.profile`` class, @@ -123,7 +125,8 @@ def __enter__(self): False, False, False, - torch.profiler._ExperimentalConfig()) + torch.profiler._ExperimentalConfig(), + ) _enable_server_process_global_profiler(profiler_config) return self @@ -152,8 +155,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): process_global_function_events = [] for thread_local_events in process_global_events: # Parse from ``Event``s to ``FunctionEvent``s. - thread_local_function_events = torch.autograd.profiler_legacy._parse_legacy_records( - thread_local_events + thread_local_function_events = ( + torch.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) ) thread_local_function_events.sort( key=lambda function_event: [ From a0e1e20c4157bb3e537fc784a51d7aef1e754157 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:45 +0800 Subject: [PATCH 63/63] [BE][Easy] enable UFMT for `torch/distributed/` (#128870) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870 Approved by: https://github.com/fegin ghstack dependencies: #128868, #128869 --- .lintrunner.toml | 27 - torch/distributed/__init__.py | 64 +-- .../_composable/fsdp/_fsdp_collectives.py | 1 + .../_composable/fsdp/_fsdp_common.py | 1 - .../_composable/fsdp/_fsdp_init.py | 2 +- .../_composable/fsdp/_fsdp_param.py | 3 +- .../_composable/fsdp/_fsdp_param_group.py | 3 +- .../_composable/fsdp/_fsdp_state.py | 3 +- .../_composable/fsdp/fully_shard.py | 1 - torch/distributed/_composable/fully_shard.py | 1 - torch/distributed/_composable/replicate.py | 1 + torch/distributed/_cuda_p2p/__init__.py | 3 +- torch/distributed/_functional_collectives.py | 2 + .../_functional_collectives_impl.py | 1 + torch/distributed/_sharded_tensor/__init__.py | 7 +- torch/distributed/_sharding_spec/__init__.py | 7 +- torch/distributed/_state_dict_utils.py | 1 + torch/distributed/_tools/memory_tracker.py | 19 +- torch/distributed/c10d_logger.py | 12 +- torch/distributed/collective_utils.py | 14 +- torch/distributed/constants.py | 7 +- torch/distributed/device_mesh.py | 3 +- torch/distributed/distributed_c10d.py | 520 +++++++++++++----- .../examples/memory_tracker_example.py | 2 +- torch/distributed/launcher/__init__.py | 2 +- torch/distributed/launcher/api.py | 13 +- torch/distributed/logging_handlers.py | 1 + torch/distributed/nn/__init__.py | 5 +- torch/distributed/nn/api/remote_module.py | 27 +- torch/distributed/nn/functional.py | 21 +- torch/distributed/pipelining/_IR.py | 6 +- torch/distributed/pipelining/__init__.py | 1 + torch/distributed/remote_device.py | 17 +- torch/distributed/rendezvous.py | 33 +- torch/distributed/run.py | 49 +- torch/distributed/utils.py | 1 + 36 files changed, 583 insertions(+), 298 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 99c04cac4fbb3..2c3da39f80ccf 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1389,33 +1389,6 @@ exclude_patterns = [ 'torch/contrib/_tensorboard_vis.py', "torch/cuda/_gpu_trace.py", 'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable - 'torch/distributed/__init__.py', - 'torch/distributed/_composable_state.py', - 'torch/distributed/_sharded_tensor/__init__.py', - 'torch/distributed/_sharding_spec/__init__.py', - 'torch/distributed/_tools/__init__.py', - 'torch/distributed/_tools/memory_tracker.py', - 'torch/distributed/argparse_util.py', - 'torch/distributed/c10d_logger.py', - 'torch/distributed/collective_utils.py', - 'torch/distributed/constants.py', - 'torch/distributed/distributed_c10d.py', - 'torch/distributed/examples/memory_tracker_example.py', - 'torch/distributed/launch.py', - 'torch/distributed/launcher/__init__.py', - 'torch/distributed/launcher/api.py', - 'torch/distributed/logging_handlers.py', - 'torch/distributed/nn/__init__.py', - 'torch/distributed/nn/api/__init__.py', - 'torch/distributed/nn/api/remote_module.py', - 'torch/distributed/nn/functional.py', - 'torch/distributed/nn/jit/__init__.py', - 'torch/distributed/nn/jit/instantiator.py', - 'torch/distributed/nn/jit/templates/__init__.py', - 'torch/distributed/nn/jit/templates/remote_module_template.py', - 'torch/distributed/remote_device.py', - 'torch/distributed/rendezvous.py', - 'torch/distributed/run.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index eb339000e89e7..93b701732206f 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-defs -import sys import pdb +import sys import torch + def is_available() -> bool: """ Return ``True`` if the distributed package is available. @@ -29,31 +30,31 @@ def is_available() -> bool: if is_available(): from torch._C._distributed_c10d import ( - Store, - FileStore, - TCPStore, - ProcessGroup as ProcessGroup, - Backend as _Backend, - PrefixStore, - Reducer, - Logger, - BuiltinCommHookType, - GradBucket, - Work as _Work, - _DEFAULT_FIRST_BUCKET_BYTES, - _register_comm_hook, - _register_builtin_comm_hook, _broadcast_coalesced, _compute_bucket_assignment_by_size, - _verify_params_across_processes, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, DebugLevel, + FileStore, get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, set_debug_level, set_debug_level_from_env, - _make_nccl_premul_sum, - _ControlCollectives, - _StoreCollectives, + Store, + TCPStore, + Work as _Work, ) class _DistributedPdb(pdb.Pdb): @@ -63,10 +64,11 @@ class _DistributedPdb(pdb.Pdb): Usage: _DistributedPdb().set_trace() """ + def interaction(self, *args, **kwargs): _stdin = sys.stdin try: - sys.stdin = open('/dev/stdin') + sys.stdin = open("/dev/stdin") pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin @@ -98,37 +100,31 @@ def breakpoint(rank: int = 0): del guard if sys.platform != "win32": - from torch._C._distributed_c10d import ( - HashStore, - _round_robin_process_groups, - ) + from torch._C._distributed_c10d import _round_robin_process_groups, HashStore - from .distributed_c10d import * # noqa: F403 + from .device_mesh import DeviceMesh, init_device_mesh # Variables prefixed with underscore are not auto imported # See the comment in `distributed_c10d.py` above `_backend` on why we expose # this. - + from .distributed_c10d import * # noqa: F403 from .distributed_c10d import ( _all_gather_base, - _reduce_scatter_base, - _create_process_group_wrapper, - _rank_not_in_group, _coalescing_manager, _CoalescingManager, + _create_process_group_wrapper, _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, get_node_local_rank, ) - + from .remote_device import _remote_device from .rendezvous import ( - rendezvous, _create_store_from_options, register_rendezvous_handler, + rendezvous, ) - from .remote_device import _remote_device - from .device_mesh import init_device_mesh, DeviceMesh - set_debug_level_from_env() else: diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index 1423cfd600fc8..14f7f8a313faf 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -5,6 +5,7 @@ import torch.distributed as dist from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp + from ._fsdp_common import ( _get_dim0_padded_size, _raise_assert_with_print, diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 594ec483bd3bf..36b181250f28d 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import math import traceback - from dataclasses import dataclass from enum import auto, Enum from typing import Any, cast, List, Optional diff --git a/torch/distributed/_composable/fsdp/_fsdp_init.py b/torch/distributed/_composable/fsdp/_fsdp_init.py index 07fd45e9e3d71..141addc6b7191 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_init.py +++ b/torch/distributed/_composable/fsdp/_fsdp_init.py @@ -4,10 +4,10 @@ import torch import torch.distributed as dist import torch.nn as nn - from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh from torch.distributed.device_mesh import _get_device_handle from torch.utils._python_dispatch import is_traceable_wrapper_subclass + from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo from ._fsdp_state import _get_module_fsdp_state diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index c56dc79e266bb..6e0e815f7a537 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -7,12 +7,12 @@ import torch import torch._dynamo.compiled_autograd as ca import torch.nn as nn - from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed._tensor.device_mesh import _mesh_resources from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta + from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ( _chunk_with_empty, @@ -24,6 +24,7 @@ HSDPMeshInfo, ) + """ [Note: FSDP tensors] FSDP considers the following tensors: diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 06fa90e060e70..6592a815bacfa 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import contextlib - from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple import torch @@ -11,6 +10,7 @@ from torch.profiler import record_function from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle + from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy from ._fsdp_collectives import ( AllGatherResult, @@ -21,6 +21,7 @@ from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState + _ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 79a09342704ff..c6cdb2b29880b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools - from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import torch @@ -13,10 +12,12 @@ ) from torch.distributed.utils import _to_kwargs from torch.utils._pytree import tree_flatten, tree_map + from ._fsdp_api import MixedPrecisionPolicy from ._fsdp_common import _cast_fp_tensor, TrainingState from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup + if TYPE_CHECKING: from ._fsdp_param import FSDPParam diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 61b7878d467ff..e8ab3466118bc 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools - from typing import Any, cast, Iterable, List, NoReturn, Optional, Union import torch diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 950a034071a43..06b121aef80a8 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -8,7 +8,6 @@ from torch.distributed._composable_state import _get_module_state, _insert_module_state from torch.distributed.fsdp._common_utils import _FSDPState from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo - from torch.distributed.fsdp._init_utils import ( _init_buffer_state, _init_core_state, diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index 0cb4ea79bc7d1..6ba70cf7bfc93 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -9,6 +9,7 @@ from .contract import _get_registry, contract + _ROOT_MODULE_PREFIX = "" diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index 1d3f24c80f08a..a3998c8e1d3b4 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -1,15 +1,14 @@ # mypy: allow-untyped-defs from collections import defaultdict from contextlib import contextmanager - from functools import partial from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch import torch.distributed._functional_collectives as funcol - import torch.distributed.distributed_c10d as c10d + if TYPE_CHECKING: from torch._C._distributed_c10d import _DistributedBackendOptions, Backend diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 9ac89166b25fd..82ca3cb8b0738 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -11,6 +11,7 @@ from . import _functional_collectives_impl as fun_col_impl + try: from torch.utils._cxx_pytree import tree_map_only except ImportError: @@ -1134,6 +1135,7 @@ def all_gather_inplace( reduce_scatter_tensor as legacy_reducescatter, ) + # This dict should contain sets of functions that dynamo is allowed to remap. # Functions in this set should accept the same args/kwargs 1:1 as their mapping. traceable_collective_remaps = { diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index c39cb4a9d50d1..4bd193d662bd6 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -4,6 +4,7 @@ import torch import torch.distributed.distributed_c10d as c10d + """ This file contains the op impls for the legacy (c10d_functional) functional collectives. These impls simply call into the native (_c10d_functional) functional collectives. diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index 6c6694cfb0813..5e6f4d2a1a6ec 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -1,11 +1,12 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed._shard` package. import sys -import torch import warnings +import torch from torch.distributed._shard.sharded_tensor import * # noqa: F403 + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -15,4 +16,6 @@ stacklevel=2, ) -sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor +sys.modules[ + "torch.distributed._sharded_tensor" +] = torch.distributed._shard.sharded_tensor diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py index 21c56d5dc849e..c74dd3633e0f5 100644 --- a/torch/distributed/_sharding_spec/__init__.py +++ b/torch/distributed/_sharding_spec/__init__.py @@ -1,11 +1,12 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed._shard` package. import sys -import torch import warnings +import torch from torch.distributed._shard.sharding_spec import * # noqa: F403 + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -16,4 +17,6 @@ ) import torch.distributed._shard.sharding_spec as _sharding_spec -sys.modules['torch.distributed._sharding_spec'] = _sharding_spec + + +sys.modules["torch.distributed._sharding_spec"] = _sharding_spec diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 2f9f0555be641..cb9def721686c 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch.distributed._functional_collectives import AsyncCollectiveTensor + if dist.is_available() or TYPE_CHECKING: from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index 10f70c9ce18e7..e4d8aa6e762b8 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -1,24 +1,14 @@ # mypy: allow-untyped-defs +import operator +import pickle from collections import defaultdict - from itertools import chain - -import pickle - -from typing import ( - Any, - Callable, - Dict, - List, - no_type_check, - Sequence, - TYPE_CHECKING, -) +from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING import torch import torch.nn as nn from torch.utils._python_dispatch import TorchDispatchMode -import operator + if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle @@ -234,6 +224,7 @@ def load(self, path: str) -> None: def _create_pre_forward_hook(self, name: str) -> Callable: """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" + def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: self._cur_module_name = f"{name}.forward" if ( diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index c1cc67b40681b..2c92176c53eb2 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -15,9 +15,9 @@ import torch import torch.distributed as dist - from torch.distributed.logging_handlers import _log_handlers + __all__: List[str] = [] _DEFAULT_DESTINATION = "default" @@ -36,7 +36,9 @@ def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Lo return logger -def _get_logging_handler(destination: str = _DEFAULT_DESTINATION) -> Tuple[logging.Handler, str]: +def _get_logging_handler( + destination: str = _DEFAULT_DESTINATION, +) -> Tuple[logging.Handler, str]: log_handler = _log_handlers[destination] log_handler_name = type(log_handler).__name__ return (log_handler, log_handler_name) @@ -69,8 +71,10 @@ def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: } return msg_dict -_T = TypeVar('_T') -_P = ParamSpec('_P') + +_T = TypeVar("_T") +_P = ParamSpec("_P") + def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: @functools.wraps(func) diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index ed6c93078299a..78199e7a26f22 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -14,8 +14,10 @@ import torch.distributed as dist + T = TypeVar("T") + @dataclass class SyncPayload(Generic[T]): stage_name: Optional[str] @@ -23,6 +25,7 @@ class SyncPayload(Generic[T]): payload: T exception: Optional[Exception] = None + def broadcast( data_or_fn: Union[T, Callable[[], T]], *, @@ -55,10 +58,12 @@ def broadcast( """ if not success and data_or_fn is not None: - raise AssertionError("Data or Function is expected to be None if not successful") + raise AssertionError( + "Data or Function is expected to be None if not successful" + ) payload: Optional[T] = None - exception : Optional[Exception] = None + exception: Optional[Exception] = None # if no pg is passed then execute if rank is 0 if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): # determine if it is an executable function or data payload only @@ -119,7 +124,7 @@ def all_gather( >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) """ payload: Optional[T] = None - exception : Optional[Exception] = None + exception: Optional[Exception] = None success = True # determine if it is an executable function or data payload only if callable(data_or_fn): @@ -161,7 +166,8 @@ def all_gather( if len(exception_list) > 0: raise RuntimeError( # type: ignore[misc] - error_msg, exception_list) from exception_list[0] + error_msg, exception_list + ) from exception_list[0] return ret_list else: if not sync_obj.success: diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index 47b1f90e406c5..b3754043644b8 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -1,8 +1,10 @@ -from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT from datetime import timedelta from typing import Optional -__all__ = ['default_pg_timeout', 'default_pg_nccl_timeout'] +from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT + + +__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] # Default process group wide timeout, if applicable. # This only applies to the non-nccl backends @@ -16,6 +18,7 @@ try: from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT + default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT except ImportError: # if C++ NCCL support is not compiled, we don't have access to the default nccl value. diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index e46356a368942..a1fee846d2545 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -6,10 +6,9 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch - from torch.distributed import is_available +from torch.utils._typing_utils import not_none -from ..utils._typing_utils import not_none __all__ = ["init_device_mesh", "DeviceMesh"] diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index d44c3733a214e..91e4cf9f540c8 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,11 +1,11 @@ # mypy: allow-untyped-defs """Distributed Collective Communication (c10d).""" -import itertools import collections.abc import contextlib import hashlib import io +import itertools import logging import os import pickle @@ -14,19 +14,26 @@ import warnings from collections import namedtuple from datetime import timedelta -from typing import Any, Callable, Dict, Optional, Tuple, Union, List, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import deprecated import torch +from torch._C import _DistStoreError as DistStoreError from torch._C._distributed_c10d import ( + _DistributedBackendOptions, + _register_process_group, + _resolve_process_group, + _unregister_all_process_groups, + _unregister_process_group, AllgatherOptions, AllreduceCoalescedOptions, AllreduceOptions, AllToAllOptions, - _DistributedBackendOptions, BarrierOptions, BroadcastOptions, + DebugLevel, GatherOptions, + get_debug_level, PrefixStore, ProcessGroup, ReduceOp, @@ -34,41 +41,88 @@ ReduceScatterOptions, ScatterOptions, Store, - DebugLevel, - get_debug_level, Work, - _register_process_group, - _resolve_process_group, - _unregister_all_process_groups, - _unregister_process_group, ) from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs -from .constants import default_pg_timeout, default_pg_nccl_timeout +from torch.utils._typing_utils import not_none + from .c10d_logger import _exception_logger, _time_logger +from .constants import default_pg_nccl_timeout, default_pg_timeout from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 -from ..utils._typing_utils import not_none -DistStoreError = torch._C._DistStoreError + __all__ = [ - 'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced', - 'all_gather_object', 'all_reduce', - 'all_reduce_coalesced', 'all_to_all', - 'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast', 'send_object_list', - 'recv_object_list', 'broadcast_object_list', 'destroy_process_group', - 'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank', - 'get_world_size', 'get_pg_count', 'group', 'init_process_group', 'irecv', - 'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available', - 'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available', - 'isend', 'monitored_barrier', 'new_group', 'new_subgroups', - 'new_subgroups_by_enumeration', 'recv', 'reduce', - 'reduce_scatter', 'scatter', - 'scatter_object_list', 'send', 'supports_complex', - 'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions', - 'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore', - 'ProcessGroup', 'ReduceOp', 'ReduceOptions', 'ReduceScatterOptions', - 'ScatterOptions', 'Store', 'DebugLevel', 'get_debug_level', 'Work', - 'default_pg_timeout', 'get_group_rank', 'get_global_rank', 'get_process_group_ranks', - 'reduce_op', 'all_gather_into_tensor', 'reduce_scatter_tensor', 'get_node_local_rank', + "Backend", + "BackendConfig", + "GroupMember", + "P2POp", + "all_gather", + "all_gather_coalesced", + "all_gather_object", + "all_reduce", + "all_reduce_coalesced", + "all_to_all", + "all_to_all_single", + "barrier", + "batch_isend_irecv", + "broadcast", + "send_object_list", + "recv_object_list", + "broadcast_object_list", + "destroy_process_group", + "gather", + "gather_object", + "get_backend_config", + "get_backend", + "get_rank", + "get_world_size", + "get_pg_count", + "group", + "init_process_group", + "irecv", + "is_gloo_available", + "is_initialized", + "is_mpi_available", + "is_backend_available", + "is_nccl_available", + "is_torchelastic_launched", + "is_ucc_available", + "isend", + "monitored_barrier", + "new_group", + "new_subgroups", + "new_subgroups_by_enumeration", + "recv", + "reduce", + "reduce_scatter", + "scatter", + "scatter_object_list", + "send", + "supports_complex", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "PrefixStore", + "ProcessGroup", + "ReduceOp", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + "Store", + "DebugLevel", + "get_debug_level", + "Work", + "default_pg_timeout", + "get_group_rank", + "get_global_rank", + "get_process_group_ranks", + "reduce_op", + "all_gather_into_tensor", + "reduce_scatter_tensor", + "get_node_local_rank", ] _MPI_AVAILABLE = True @@ -79,6 +133,7 @@ _pickler = pickle.Pickler _unpickler = pickle.Unpickler + # Change __module__ of all imported types from torch._C._distributed_c10d that are public def _export_c_types() -> None: _public_types_to_change_module = [ @@ -97,22 +152,25 @@ def _export_c_types() -> None: Store, DebugLevel, get_debug_level, - Work + Work, ] for type in _public_types_to_change_module: type.__module__ = "torch.distributed.distributed_c10d" + + _export_c_types() try: from torch._C._distributed_c10d import ProcessGroupMPI + ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupMPI"] except ImportError: _MPI_AVAILABLE = False try: - from torch._C._distributed_c10d import ProcessGroupNCCL - from torch._C._distributed_c10d import ProcessGroupCudaP2P + from torch._C._distributed_c10d import ProcessGroupCudaP2P, ProcessGroupNCCL + ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" ProcessGroupCudaP2P.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupNCCL", "ProcessGroupCudaP2P"] @@ -120,8 +178,8 @@ def _export_c_types() -> None: _NCCL_AVAILABLE = False try: - from torch._C._distributed_c10d import ProcessGroupGloo - from torch._C._distributed_c10d import _ProcessGroupWrapper + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupGloo"] except ImportError: @@ -129,6 +187,7 @@ def _export_c_types() -> None: try: from torch._C._distributed_c10d import ProcessGroupUCC + ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupUCC"] except ImportError: @@ -191,20 +250,20 @@ class Backend(str): backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { - 'cpu' : GLOO, - 'cuda' : NCCL, + "cpu": GLOO, + "cuda": NCCL, } backend_capability: Dict[str, List[str]] = { - GLOO : ["cpu", "cuda"], - NCCL : ["cuda"], - UCC : ["cpu", "cuda"], - MPI : ["cpu", "cuda"], + GLOO: ["cpu", "cuda"], + NCCL: ["cuda"], + UCC: ["cpu", "cuda"], + MPI: ["cpu", "cuda"], } backend_type_map: Dict[str, ProcessGroup.BackendType] = { UNDEFINED: ProcessGroup.BackendType.UNDEFINED, - GLOO : ProcessGroup.BackendType.GLOO, + GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, UCC: ProcessGroup.BackendType.UCC, } @@ -220,7 +279,13 @@ def __new__(cls, name: str): return value @classmethod - def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None) -> None: + def register_backend( + cls, + name, + func, + extended_api=False, + devices: Optional[Union[str, List[str]]] = None, + ) -> None: """ Register a new backend with the given name and instantiating function. @@ -247,19 +312,19 @@ def register_backend(cls, name, func, extended_api=False, devices: Optional[Unio """ # Allow UCC plugin if Pytorch is not built with native support. # TODO: remove this exception once UCC plugin is fully deprecated. - if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())): - assert not hasattr(Backend, name.upper()), ( - f"{name.upper()} c10d backend already exist" - ) - assert name.upper() not in Backend._plugins, ( - f"{name.upper()} c10d backend creator function already exist" - ) + if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): + assert not hasattr( + Backend, name.upper() + ), f"{name.upper()} c10d backend already exist" + assert ( + name.upper() not in Backend._plugins + ), f"{name.upper()} c10d backend creator function already exist" setattr(Backend, name.upper(), name.lower()) Backend.backend_list.append(name.lower()) if devices is not None: for device in devices: - if device != 'cpu' and device != 'cuda': + if device != "cpu" and device != "cuda": Backend.default_device_backend_map[device] = name.lower() Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM @@ -281,6 +346,7 @@ def register_backend(cls, name, func, extended_api=False, devices: Optional[Unio Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) + class BackendConfig: """Backend configuration class.""" @@ -294,7 +360,10 @@ def __init__(self, backend: Backend): # supported since PyTorch 2.0 for device, default_backend in Backend.default_device_backend_map.items(): if is_backend_available(default_backend): - if default_backend == Backend.NCCL and not torch.cuda.is_available(): + if ( + default_backend == Backend.NCCL + and not torch.cuda.is_available() + ): continue self.device_backend_map[device] = Backend(default_backend) elif backend.lower() in Backend.backend_list: @@ -316,12 +385,16 @@ def __init__(self, backend: Backend): for device_backend_pair_str in backend.lower().split(","): device_backend_pair = device_backend_pair_str.split(":") if len(device_backend_pair) != 2: - raise ValueError(f"Invalid device:backend pairing: \ - {device_backend_pair_str}. {backend_str_error_message}") + raise ValueError( + f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}" + ) device, backend = device_backend_pair if device in self.device_backend_map: - raise ValueError(f"Duplicate device type {device} \ - in backend string: {backend}. {backend_str_error_message}") + raise ValueError( + f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}" + ) self.device_backend_map[device] = Backend(backend) else: # User specified a single backend name whose device capability is @@ -334,23 +407,24 @@ def __init__(self, backend: Backend): ) backend_val = Backend(backend) self.device_backend_map = { - "cpu" : backend_val, - "cuda" : backend_val, - "xpu" : backend_val, + "cpu": backend_val, + "cuda": backend_val, + "xpu": backend_val, } - logger.info( - "Using backend config: %s", self.device_backend_map - ) + logger.info("Using backend config: %s", self.device_backend_map) def __repr__(self): """Return all the device:backend pairs separated by commas.""" - return ",".join(f"{device}:{backend}" for device, backend in self.device_backend_map.items()) + return ",".join( + f"{device}:{backend}" for device, backend in self.device_backend_map.items() + ) def get_device_backend_map(self) -> Dict[str, Backend]: """Return backend map of the device.""" return self.device_backend_map + class _reduce_op: r""" Deprecated enum-like class. @@ -397,8 +471,14 @@ class P2POp: tag (int, optional): Tag to match send with recv. """ - def __init__(self, op: Callable, tensor: torch.Tensor, peer: int, - group: Optional[ProcessGroup] = None, tag: int = 0): + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): """Init.""" self.op = op self.tensor = tensor @@ -406,8 +486,14 @@ def __init__(self, op: Callable, tensor: torch.Tensor, peer: int, self.group = group self.tag = tag - def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, - group: Optional[ProcessGroup] = None, tag: int = 0): + def __new__( + cls, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): """Create and return a new instance of the class.""" _check_op(op) _check_single_tensor(tensor, "tensor") @@ -415,7 +501,9 @@ def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, def __repr__(self): my_group_rank = get_rank(self.group) - peer_group_rank = get_group_rank(self.group, self.peer) if self.group else self.peer + peer_group_rank = ( + get_group_rank(self.group, self.peer) if self.group else self.peer + ) op_name = self.op.__name__ group_name = self.group.group_name if self.group else "default_pg" if "send" in op_name: @@ -429,6 +517,7 @@ def __repr__(self): return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})" + class _CollOp: """ A class to capture collective operations. @@ -441,8 +530,14 @@ class _CollOp: root (int, optional): root of broadcast or reduce. """ - def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torch.Tensor] = None, - redop: Optional[ReduceOp] = None, root: Optional[int] = None): + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + dst_tensor: Optional[torch.Tensor] = None, + redop: Optional[ReduceOp] = None, + root: Optional[int] = None, + ): self.op = op self.tensor = tensor self.dst_tensor = dst_tensor @@ -462,6 +557,7 @@ def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torc _pg_to_tag: Dict[ProcessGroup, str] = {} _backend: Optional[str] = None + class _World: """ Container class for c10d process group state. @@ -597,6 +693,7 @@ def pg_config_info(self) -> List[Dict[str, Any]]: _world = _World() """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" + class _WorldMeta(type): """ Meta class of ``group`` and ``GroupMember``. @@ -613,11 +710,13 @@ def WORLD(cls) -> Optional[ProcessGroup]: def WORLD(cls, pg: Optional[ProcessGroup]): _world.default_pg = pg + class group(metaclass=_WorldMeta): """Group class. Placeholder.""" pass + class GroupMember(metaclass=_WorldMeta): """Group member class.""" @@ -630,23 +729,28 @@ def _get_default_timeout(backend: Backend) -> timedelta: if not isinstance(default_pg_nccl_timeout, timedelta): # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was # changed to be a warning. We should fix the moco model. - warnings.warn("Attempted to get default timeout for nccl backend, but NCCL support is not compiled") + warnings.warn( + "Attempted to get default timeout for nccl backend, but NCCL support is not compiled" + ) return default_pg_timeout return default_pg_nccl_timeout else: return default_pg_timeout + def _check_valid_timeout(timeout: Any) -> None: if not isinstance(timeout, timedelta): raise TypeError( f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" ) + # Default process group state _default_pg_init_method: Optional[str] = None STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device: """ Return the device to use with ``group`` for control flow usage (object collectives, barrier). @@ -711,14 +815,20 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device _world.pg_default_device[group] = devices[0] logger.info( - "Using device %s for object " - "collectives.", _world.pg_default_device[group] + "Using device %s for object " "collectives.", _world.pg_default_device[group] ) return _world.pg_default_device[group] @_time_logger -def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, logging_interval=timedelta(seconds=10)) -> None: +def _store_based_barrier( + rank, + store, + group_name, + rendezvous_count, + timeout, + logging_interval=timedelta(seconds=10), +) -> None: """ Store based barrier for synchronizing processes. @@ -755,7 +865,12 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log logger.debug( "Waiting in store based barrier to initialize process group for " "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", - rank, store_key, world_size, worker_count, timeout, e + rank, + store_key, + world_size, + worker_count, + timeout, + e, ) if timedelta(seconds=(time.time() - start)) > timeout: @@ -766,7 +881,10 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log ) logger.info( - "Rank %s: Completed store-based barrier for key:%s with %s nodes.", rank, store_key, world_size + "Rank %s: Completed store-based barrier for key:%s with %s nodes.", + rank, + store_key, + world_size, ) @@ -803,13 +921,16 @@ def get_group_rank(group: ProcessGroup, global_rank: int) -> int: if group is GroupMember.WORLD: return global_rank if group not in _world.pg_group_ranks: - raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API") + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) group_ranks = _world.pg_group_ranks[group] if global_rank not in group_ranks: raise ValueError(f"Global rank {global_rank} is not part of group {group}") return group_ranks[global_rank] + def get_global_rank(group: ProcessGroup, group_rank: int) -> int: """ Translate a group rank into a global rank. @@ -828,7 +949,9 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int: if group is GroupMember.WORLD: return group_rank if group not in _world.pg_group_ranks: - raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API") + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) for rank, grp_rank in _world.pg_group_ranks[group].items(): if grp_rank == group_rank: return rank @@ -858,6 +981,7 @@ def get_process_group_ranks(group: ProcessGroup) -> List[int]: """ return list(_world.pg_group_ranks[group].keys()) + def _get_group_size(group) -> int: """Get a given group's world size.""" if group is GroupMember.WORLD or group is None: @@ -906,13 +1030,16 @@ def _check_tensor_list(param, param_name) -> None: def _as_iterable(obj) -> collections.abc.Iterable: return obj if isinstance(obj, list) else (obj,) + def _ensure_all_tensors_same_dtype(*tensors) -> None: last_dtype = None for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): tensor_dtype = tensor.dtype # Mixing complex and its element type is allowed if tensor_dtype.is_complex: - tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + tensor_dtype = ( + torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + ) if last_dtype is None: last_dtype = tensor_dtype @@ -1049,6 +1176,7 @@ def _update_default_pg(pg) -> None: rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 torch._C._distributed_c10d._set_global_rank(rank) + def get_backend_config(group: Optional[ProcessGroup] = None) -> str: """ Return the backend configuration of the given process group. @@ -1071,6 +1199,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str: backend_config = _world.pg_backend_config.get(pg) return str(not_none(backend_config)) + def get_backend(group: Optional[ProcessGroup] = None) -> Backend: """ Return the backend of the given process group. @@ -1093,6 +1222,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend: pg_store = _world.pg_map[pg] if pg in _world.pg_map else None return Backend(not_none(pg_store)[0]) + def _get_process_group_uid(pg: ProcessGroup) -> int: backend = None try: @@ -1103,6 +1233,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int: return backend.uid return -1 + def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: """ Return the pg configuration of the given process group. @@ -1120,6 +1251,7 @@ def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: "ranks": get_process_group_ranks(pg), } + def _get_all_pg_configs() -> List[Dict[str, Any]]: """ Return the pg configuration of all the process groups. @@ -1130,6 +1262,7 @@ def _get_all_pg_configs() -> List[Dict[str, Any]]: config_info.append(_get_pg_config(pg)) return config_info + def get_pg_count() -> int: """ Return the number of process groups. @@ -1137,6 +1270,7 @@ def get_pg_count() -> int: """ return _world.group_count + def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: """ Return the local rank of the current process relative to the node. @@ -1162,6 +1296,7 @@ def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: "assuming you are not running in a multi-device context and want the code to run locally instead." ) + def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: """ Set the timeout for the given process group when users want to use a different timeout instead of @@ -1349,7 +1484,14 @@ def init_process_group( ) default_pg, _ = _new_process_group_helper( - -1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg" + -1, + -1, + [], + backend, + None, + group_name, + timeout=timeout, + group_desc="default_pg", ) _update_default_pg(default_pg) else: @@ -1375,7 +1517,7 @@ def init_process_group( pg_options=pg_options, timeout=timeout, device_id=device_id, - group_desc="default_pg" + group_desc="default_pg", ) _update_default_pg(default_pg) @@ -1394,7 +1536,9 @@ def _distributed_excepthook(*args): finally: sys.stderr = old_stderr msg = buf.getvalue() - msg = "\n".join(f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n")) + msg = "\n".join( + f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") + ) sys.stderr.write(msg) sys.stderr.flush() @@ -1421,6 +1565,7 @@ def _distributed_excepthook(*args): # default devices and messes up NCCL internal state. _store_based_barrier(rank, store, group_name, world_size, timeout) + def _get_split_source(pg): split_from = None if pg.bound_device_id: @@ -1442,6 +1587,7 @@ def _get_split_source(pg): return split_from + def _shutdown_backend(pg): """ Try to shut down the backend of a process group. @@ -1453,10 +1599,13 @@ def _shutdown_backend(pg): backend = pg._get_backend(torch.device("cuda")) except RuntimeError: pass - if is_nccl_available() and isinstance(backend, (ProcessGroupNCCL, ProcessGroupCudaP2P)): + if is_nccl_available() and isinstance( + backend, (ProcessGroupNCCL, ProcessGroupCudaP2P) + ): # explictly call shutdown to ensure that NCCL resources are released backend._shutdown() + def _new_process_group_helper( group_size, group_rank, @@ -1487,9 +1636,11 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != 'cuda'): - raise ValueError("init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu") + if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + raise ValueError( + "init_process_group device_id parameter must be a cuda device with an " + "id, e.g. cuda:0, not just cuda or cpu" + ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value _check_valid_timeout(timeout) @@ -1514,8 +1665,10 @@ def _new_process_group_helper( # ranks_. We can only know this if the group we are making is the # entire world or if we have bound a device id to the world (which # causes early connection initialization). - if (is_initialized() and - (len(global_ranks_in_group) == _get_default_group().size() or _get_default_group().bound_device_id)): + if is_initialized() and ( + len(global_ranks_in_group) == _get_default_group().size() + or _get_default_group().bound_device_id + ): split_from = _get_split_source(_get_default_group()) else: split_from = None @@ -1538,7 +1691,9 @@ def _new_process_group_helper( prefix_store = PrefixStore(f"{group_name}/", store) base_pg_options = ProcessGroup.Options(backend=str(backend)) base_pg_options._timeout = timeout - pg: ProcessGroup = ProcessGroup(prefix_store, group_rank, group_size, base_pg_options) + pg: ProcessGroup = ProcessGroup( + prefix_store, group_rank, group_size, base_pg_options + ) if device_id: pg.bound_device_id = device_id backend_config = BackendConfig(backend) @@ -1561,12 +1716,19 @@ def _new_process_group_helper( return GroupMember.NON_GROUP_MEMBER, None # create new process group with accurate rank and size if pg.rank() == -1 and pg.size() == -1: - pg = ProcessGroup(backend_prefix_store, backend_class.rank(), backend_class.size(), base_pg_options) + pg = ProcessGroup( + backend_prefix_store, + backend_class.rank(), + backend_class.size(), + base_pg_options, + ) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: # raise RuntimeError("GLOO options not supported") - backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout) + backend_class = ProcessGroupGloo( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.GLOO elif backend_str == Backend.NCCL: if not is_nccl_available(): @@ -1592,19 +1754,22 @@ def _new_process_group_helper( pg_options.global_ranks_in_group = global_ranks_in_group pg_options.group_name = group_name backend_class = ProcessGroupNCCL( - backend_prefix_store, group_rank, group_size, pg_options) + backend_prefix_store, group_rank, group_size, pg_options + ) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): # TODO: once UCC plugin is fully deprecated, remove # is_ucc_available() from above elif-condition and raise # RuntimeError if is_ucc_available() returns false. - backend_class = ProcessGroupUCC(backend_prefix_store, group_rank, group_size, timeout=timeout) + backend_class = ProcessGroupUCC( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.UCC else: - assert backend_str.upper() in Backend._plugins, ( - f"Unknown c10d backend type {backend_str.upper()}" - ) + assert ( + backend_str.upper() in Backend._plugins + ), f"Unknown c10d backend type {backend_str.upper()}" backend_plugin = Backend._plugins[backend_str.upper()] creator_fn = backend_plugin.creator_fn @@ -1612,7 +1777,9 @@ def _new_process_group_helper( backend_type = ProcessGroup.BackendType.CUSTOM if not extended_api: - backend_class = creator_fn(backend_prefix_store, group_rank, group_size, timeout) + backend_class = creator_fn( + backend_prefix_store, group_rank, group_size, timeout + ) else: dist_backend_opts = _DistributedBackendOptions() dist_backend_opts.store = backend_prefix_store @@ -1640,7 +1807,10 @@ def _new_process_group_helper( break # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set - if backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] or backend_str.upper() in Backend._plugins: + if ( + backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] + or backend_str.upper() in Backend._plugins + ): # In debug mode and if GLOO is available, wrap in a wrapper PG that # enables enhanced collective checking for debuggability. if get_debug_level() == DebugLevel.DETAIL: @@ -1698,6 +1868,7 @@ def _new_process_group_helper( _world.pg_to_tag[pg] = pg_tag return pg, prefix_store + def destroy_process_group(group: Optional[ProcessGroup] = None): """ Destroy a given process group, and deinitialize the distributed package. @@ -1736,7 +1907,9 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): if group is None or group == GroupMember.WORLD: # shutdown all backends in the order of pg names. shutting down in order because # ncclCommAbort() was a 'collective' call in some versions of NCCL. - for pg_to_shutdown in sorted(_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True): + for pg_to_shutdown in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): _shutdown_backend(pg_to_shutdown) _update_default_pg(None) @@ -1832,7 +2005,9 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int: return _get_group_size(group) -def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]: +def isend( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> Optional[Work]: """ Send a tensor asynchronously. @@ -1871,7 +2046,13 @@ def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, return pg.send([tensor], dst, tag) -def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]: + +def irecv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> Optional[Work]: """ Receives a tensor asynchronously. @@ -1913,8 +2094,11 @@ def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proce group_src_rank = get_group_rank(pg, src) return pg.recv([tensor], group_src_rank, tag) + @_exception_logger -def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> None: +def send( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> None: """ Send a tensor synchronously. @@ -1951,8 +2135,14 @@ def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, t group_dst_rank = get_group_rank(group, dst) group.send([tensor], group_dst_rank, tag).wait() + @_exception_logger -def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> int: +def recv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> int: """ Receives a tensor synchronously. @@ -2004,7 +2194,15 @@ def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proces class _IllegalWork(Work): def __getattribute__(self, name): - if name in ["is_success", "exception", "wait", "source_rank", "_source_rank", "result", "synchronize"]: + if name in [ + "is_success", + "exception", + "wait", + "source_rank", + "_source_rank", + "result", + "synchronize", + ]: raise ValueError(f"Illegal to call {name} on IllegalWork object") @@ -2057,7 +2255,9 @@ def _coalescing_manager( group = group or _get_default_group() op_list = _world.pg_coalesce_state.setdefault(group, []) if op_list: - raise ValueError("ProcessGroup has non-empty op list at the start of coalescing") + raise ValueError( + "ProcessGroup has non-empty op list at the start of coalescing" + ) if device: group._start_coalescing(device) cm = _CoalescingManager() @@ -2212,6 +2412,7 @@ def broadcast(tensor, src, group=None, async_op=False): else: work.wait() + @_exception_logger def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2292,6 +2493,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + @_exception_logger @deprecated( "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " @@ -2359,6 +2561,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + @_exception_logger def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2404,6 +2607,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + def _object_to_tensor(obj, device, group): f = io.BytesIO() _pickler(f).dump(obj) @@ -2416,7 +2620,9 @@ def _object_to_tensor(obj, device, group): backend = get_backend(group) if backend == Backend.NCCL: hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) - logger.warning("_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash) + logger.warning( + "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash + ) local_size = torch.LongTensor([byte_tensor.numel()]).to(device) return byte_tensor, local_size @@ -2426,7 +2632,9 @@ def _tensor_to_object(tensor, tensor_size, group): backend = get_backend(group) if backend == Backend.NCCL: hash = torch._C._distributed_c10d._hash_tensors([tensor]) - logger.warning("_tensor_to_object size: %s hash value: %s", tensor.numel(), hash) + logger.warning( + "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + ) tensor = tensor.cpu() buf = tensor.numpy().tobytes()[:tensor_size] return _unpickler(io.BytesIO(buf)).load() @@ -2709,7 +2917,9 @@ def send_object_list(object_list, dst, group=None, device=None): # sent to this device. current_device = device or _get_pg_default_device(group) # Serialize object_list elements to tensors on src rank. - tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) object_sizes_tensor = torch.cat(size_list) # Send object sizes @@ -2793,7 +3003,9 @@ def recv_object_list(object_list, src=None, group=None, device=None): # case it is not ``None`` we move the size and object tensors to be # received to this device. current_device = device or _get_pg_default_device(group) - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) # Receive object sizes rank_sizes = recv(object_sizes_tensor, src=src, group=group) @@ -2802,11 +3014,13 @@ def recv_object_list(object_list, src=None, group=None, device=None): object_tensor = torch.empty( # type: ignore[call-overload] torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, - device=current_device + device=current_device, ) rank_objects = recv(object_tensor, src=src, group=group) - assert rank_sizes == rank_objects, "Mismatch in return ranks for object sizes and objects." + assert ( + rank_sizes == rank_objects + ), "Mismatch in return ranks for object sizes and objects." # Deserialize objects using their stored sizes. offset = 0 for i, obj_size in enumerate(object_sizes_tensor): @@ -2816,6 +3030,7 @@ def recv_object_list(object_list, src=None, group=None, device=None): object_list[i] = _tensor_to_object(obj_view, obj_size, group) return rank_objects + @_exception_logger def broadcast_object_list(object_list, src=0, group=None, device=None): """ @@ -2892,10 +3107,14 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): my_rank = get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) object_sizes_tensor = torch.cat(size_list) else: - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) # Broadcast object sizes broadcast(object_sizes_tensor, src=src, group=group) @@ -2912,7 +3131,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): object_tensor = torch.empty( # type: ignore[call-overload] torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, - device=current_device + device=current_device, ) broadcast(object_tensor, src=src, group=group) @@ -3000,7 +3219,10 @@ def scatter_object_list( pg_device = _get_pg_default_device(group) if my_rank == src: tensor_list, tensor_sizes = zip( - *[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list] + *[ + _object_to_tensor(obj, pg_device, group) + for obj in scatter_object_input_list + ] ) tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) @@ -3015,7 +3237,9 @@ def scatter_object_list( broadcast(max_tensor_size, src=src, group=group) # Scatter actual serialized objects - output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device) + output_tensor = torch.empty( + max_tensor_size.item(), dtype=torch.uint8, device=pg_device + ) scatter( output_tensor, scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined] @@ -3033,7 +3257,9 @@ def scatter_object_list( ) # Deserialize back to object - scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group) + scatter_object_output_list[0] = _tensor_to_object( + output_tensor, obj_tensor_size, group + ) @_exception_logger @@ -3900,6 +4126,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False else: work.wait() + @_exception_logger def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None): """ @@ -4041,15 +4268,18 @@ def _create_process_group_wrapper( wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) return wrapped_pg + # helper function for deterministically hashing a list of ranks def _hash_ranks(ranks: List[int]): return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() + # Takes a list of ranks and computes an integer color def _process_group_color(ranks: List[int]) -> int: # Convert our hash to an int, but avoid negative numbers by shifting a bit. return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1) + def _process_group_name(ranks, use_hashed_name): global _world if use_hashed_name: @@ -4061,6 +4291,7 @@ def _process_group_name(ranks, use_hashed_name): _world.group_count += 1 return pg_name + def _get_backend_from_str(backend: Optional[str] = None) -> Backend: # Default to the same backend as the global process group # if backend is not specified. @@ -4070,7 +4301,14 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend: @_time_logger -def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None): +def new_group( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, +): """ Create a new distributed group. @@ -4137,6 +4375,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local group_desc=group_desc, ) + def _new_group_with_tag( ranks=None, timeout=None, @@ -4144,7 +4383,7 @@ def _new_group_with_tag( pg_options=None, pg_tag=None, use_local_synchronization=False, - group_desc=None + group_desc=None, ): """ Variant of ``new_group`` that exposes tag creation. @@ -4159,7 +4398,6 @@ def _new_group_with_tag( global_rank = default_pg.rank() global_world_size = default_pg.size() - # Default to the same backend as the global process group # if the backend is not specified. if not backend: @@ -4175,7 +4413,9 @@ def _new_group_with_tag( if use_local_synchronization: # MPI backend doesn't have have a way for us to perform a partial sync if backend == Backend.MPI: - raise ValueError("MPI backend doesn't support use_local_synchronization=True") + raise ValueError( + "MPI backend doesn't support use_local_synchronization=True" + ) if ranks is not None and get_rank() not in ranks: return None @@ -4217,7 +4457,7 @@ def _new_group_with_tag( pg_options=pg_options, timeout=timeout, pg_tag=pg_tag, - group_desc=group_desc + group_desc=group_desc, ) # Create the global rank to group rank mapping @@ -4246,7 +4486,9 @@ def _new_group_with_tag( world_size = len(ranks) if use_local_synchronization else get_world_size() # Use store based barrier here since barrier() used a bunch of # default devices and messes up NCCL internal state. - _store_based_barrier(global_rank, barrier_store, group_name, world_size, timeout) + _store_based_barrier( + global_rank, barrier_store, group_name, world_size, timeout + ) return pg @@ -4332,16 +4574,20 @@ def new_subgroups( """ if group_size is None: if not torch.cuda.is_available(): - raise ValueError("Default group size only takes effect when CUDA is available." - "If your subgroup using a backend that does not depend on CUDA," - "please pass in 'group_size' correctly.") + raise ValueError( + "Default group size only takes effect when CUDA is available." + "If your subgroup using a backend that does not depend on CUDA," + "please pass in 'group_size' correctly." + ) group_size = torch.cuda.device_count() if group_size <= 0: raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") world_size = get_world_size() if world_size < group_size: - raise ValueError(f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})") + raise ValueError( + f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" + ) if world_size % group_size != 0: raise ValueError("The world size must be divisible by 'group_size'") @@ -4364,10 +4610,7 @@ def new_subgroups( rank = get_rank() if rank in ranks_in_subgroup: cur_subgroup = subgroup - logger.info( - "Rank %s is assigned to subgroup %s", - rank, ranks_in_subgroup - ) + logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup) return cur_subgroup, subgroups @@ -4479,8 +4722,13 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGro return group return None -def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup: - assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + +def _find_or_create_pg_by_ranks_and_tag( + tag: str, ranks: List[int], stride: int +) -> ProcessGroup: + assert ( + len(ranks) % stride == 0 + ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" my_rank = get_rank() my_ranks = None @@ -4505,6 +4753,7 @@ def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) # TODO copy settings and timeout from default PG return _new_group_with_tag(my_ranks, pg_tag=tag) + def _get_group_tag(pg: ProcessGroup) -> str: """Return the tag associated with ``pg``.""" tag = _world.pg_to_tag[pg] @@ -4512,12 +4761,15 @@ def _get_group_tag(pg: ProcessGroup) -> str: tag = tag[5:] return tag + def _get_process_group_name(pg: ProcessGroup) -> str: return _world.pg_names.get(pg, "None") + def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] + # This ops are not friendly to TorchDynamo. So, we decide to disallow these ops # in FX graph, allowing them to run them on eager, with torch.compile. dynamo_unsupported_distributed_c10d_ops = [ diff --git a/torch/distributed/examples/memory_tracker_example.py b/torch/distributed/examples/memory_tracker_example.py index cb2ba03777d8f..e40cfb8b3f594 100644 --- a/torch/distributed/examples/memory_tracker_example.py +++ b/torch/distributed/examples/memory_tracker_example.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs -import torch import torchvision +import torch from torch.distributed._tools import MemoryTracker diff --git a/torch/distributed/launcher/__init__.py b/torch/distributed/launcher/__init__.py index f0d25f8080c26..fb744a2b93615 100644 --- a/torch/distributed/launcher/__init__.py +++ b/torch/distributed/launcher/__init__.py @@ -8,7 +8,7 @@ from torch.distributed.launcher.api import ( # noqa: F401 - LaunchConfig, elastic_launch, launch_agent, + LaunchConfig, ) diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 937647f77828f..a3bcd4073c9ba 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -15,13 +15,18 @@ from torch.distributed.elastic import events, metrics from torch.distributed.elastic.agent.server.api import WorkerSpec from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent -from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs, + LogsSpecs, + SignalException, +) from torch.distributed.elastic.multiprocessing.errors import ChildFailedError from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger -__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] + +__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] logger = get_logger(__name__) @@ -212,8 +217,8 @@ def launch_agent( "max_restarts": config.max_restarts, "monitor_interval": config.monitor_interval, "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] - "metrics_cfg": config.metrics_cfg - } + "metrics_cfg": config.metrics_cfg, + }, ) rdzv_parameters = RendezvousParameters( diff --git a/torch/distributed/logging_handlers.py b/torch/distributed/logging_handlers.py index 3c607fe45da77..021ad100f06a8 100644 --- a/torch/distributed/logging_handlers.py +++ b/torch/distributed/logging_handlers.py @@ -9,6 +9,7 @@ import logging from typing import Dict, List + __all__: List[str] = [] _log_handlers: Dict[str, logging.Handler] = { diff --git a/torch/distributed/nn/__init__.py b/torch/distributed/nn/__init__.py index 3ed1b42cbe158..e15fb517052e4 100644 --- a/torch/distributed/nn/__init__.py +++ b/torch/distributed/nn/__init__.py @@ -1,4 +1,7 @@ import torch + +from .functional import * # noqa: F403 + + if torch.distributed.rpc.is_available(): from .api.remote_module import RemoteModule -from .functional import * # noqa: F403 diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index de8a15dd65da5..5583da8c3e8d4 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -21,14 +21,15 @@ import torch import torch.distributed.rpc as rpc -from torch import Tensor, device, dtype, nn -from torch.distributed.nn.jit import instantiator +from torch import device, dtype, nn, Tensor from torch.distributed import _remote_device +from torch.distributed.nn.jit import instantiator from torch.distributed.rpc.internal import _internal_rpc_pickler from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle + __all__ = ["RemoteModule"] _grad_t = Union[Tuple[Tensor, ...], Tensor] @@ -120,7 +121,6 @@ def _raise_not_supported(name: str) -> None: class _RemoteModule(nn.Module): - def __new__(cls, *args, **kwargs): # Use __new__ for logging purposes. torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module") @@ -370,7 +370,10 @@ def register_forward_pre_hook( # type: ignore[return] self, hook: Union[ Callable[[T, Tuple[Any, ...]], Optional[Any]], - Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], + Callable[ + [T, Tuple[Any, ...], Dict[str, Any]], + Optional[Tuple[Any, Dict[str, Any]]], + ], ], prepend: bool = False, with_kwargs: bool = False, @@ -405,10 +408,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: ) def named_parameters( # type: ignore[return] - self, - prefix: str = "", - recurse: bool = True, - remove_duplicate: bool = True + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Parameter]]: _raise_not_supported(self.named_parameters.__name__) @@ -416,10 +416,7 @@ def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[ret _raise_not_supported(self.buffers.__name__) def named_buffers( # type: ignore[return] - self, - prefix: str = "", - recurse: bool = True, - remove_duplicate: bool = True + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Tensor]]: _raise_not_supported(self.named_buffers.__name__) @@ -464,7 +461,11 @@ def _prepare_init(self, remote_device_str: str) -> bool: assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." remote_device = _remote_device(remote_device_str) - self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank() + self.on = ( + remote_device.worker_name() + if remote_device.worker_name() is not None + else remote_device.rank() + ) self.device = str(remote_device.device()) agent = rpc._get_current_rpc_agent() # If the device map of the remote worker is set, diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index e90a78a69324b..110df578552a5 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -2,11 +2,13 @@ import torch import torch.distributed as dist from torch.autograd import Function + # The two imports below are not always available depending on the # USE_DISTRIBUTED compile flag. Make sure they raise import error # if we're trying to use them. from torch.distributed import group, ReduceOp + def broadcast(tensor, src, group=group.WORLD): """ Broadcasts the tensor to the whole group. @@ -116,6 +118,7 @@ def all_gather(tensor, group=group.WORLD): """ return _AllGather.apply(group, tensor) + def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): """ Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. @@ -340,6 +343,7 @@ def backward(ctx, *grad_outputs): gx = torch.sum(torch.stack(gxs), dim=0) return (None, gx) + class _AllGatherBase(Function): @staticmethod def forward(ctx, output_tensor, input_tensor, group): @@ -354,16 +358,19 @@ def backward(ctx, grad_output): out_size = list(grad_output.size()) if out_size[0] % world_size != 0: raise RuntimeError( - f'Tensor with dimensions: {out_size} does ' - f'not have first dimension divisible by world_size: {world_size}' + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" ) out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) - gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype) + gx = torch.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) else: raise RuntimeError("Backend not supported!") return (None, gx, None) + class _AlltoAll(Function): @staticmethod def forward(ctx, group, out_tensor_list, *tensors): @@ -391,7 +398,9 @@ def forward(ctx, group, out_tensor_list, *tensors): @staticmethod def backward(ctx, *grad_outputs): tensor_list = [ - torch.empty(size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype) + torch.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) for size in ctx.input_tensor_size_list ] return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) @@ -415,7 +424,9 @@ def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): @staticmethod def backward(ctx, grad_output): - tensor = torch.empty(ctx.input_size, device=grad_output.device, dtype=grad_output.dtype) + tensor = torch.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) return (None, None, None, None) + ( _AlltoAllSingle.apply( ctx.group, diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 7d0aede8943eb..81ddeb8bfe0ad 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -5,7 +5,7 @@ import operator from collections import defaultdict from enum import Enum -from inspect import Parameter, signature, Signature +from inspect import Parameter, Signature, signature from types import MethodType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -21,6 +21,7 @@ ) from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module + from ._backward import _null_coalesce_accumulate, stage_backward from ._unflatten import _outline_submodules from ._utils import PipeInfo @@ -1176,7 +1177,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): predecessor_module = getattr(predecessor_module, atom) except AttributeError as e: raise AttributeError( - f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}' + f"Specified target {qualname} referenced " + f'nonexistent module {".".join(atoms[: i + 1])}' ) from e mod_to_wrap = getattr(predecessor_module, atoms[-1]) diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 18b3191add5b6..5b1843a33f6fd 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -8,6 +8,7 @@ ) from .stage import build_stage, PipelineStage + __all__ = [ "Pipe", "pipe_split", diff --git a/torch/distributed/remote_device.py b/torch/distributed/remote_device.py index da664f7408bb2..bdb1974b1b37c 100644 --- a/torch/distributed/remote_device.py +++ b/torch/distributed/remote_device.py @@ -47,7 +47,7 @@ def __init__(self, remote_device: Union[str, torch.device]): else: raise ValueError(PARSE_ERROR) else: - raise TypeError(f'Invalid type for remote_device: {type(remote_device)}') + raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") # Do some basic sanity check (no empty string) if self._worker_name is not None and not self._worker_name: @@ -96,18 +96,18 @@ def device(self) -> torch.device: def __repr__(self): if self._device is not None: if self._worker_name is not None: - return f'{self._worker_name}/{self._device}' + return f"{self._worker_name}/{self._device}" elif self._rank is not None: - return f'rank:{self._rank}/{self._device}' + return f"rank:{self._rank}/{self._device}" else: return str(self._device) else: if self._worker_name is not None: - return f'{self._worker_name}' + return f"{self._worker_name}" elif self._rank is not None: - return f'{self._rank}' + return f"{self._rank}" else: - raise RuntimeError('Invalid state!') + raise RuntimeError("Invalid state!") def __eq__(self, other): if not isinstance(other, _remote_device): @@ -122,8 +122,5 @@ def __eq__(self, other): return False - def __hash__(self): - return hash(self._worker_name) ^ \ - hash(self._device) ^ \ - hash(self._rank) + return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index e3266cb238aca..a944a75271b0d 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -10,7 +10,7 @@ import os import sys from datetime import timedelta -from typing import Dict, Optional, Callable, Iterator, Tuple +from typing import Callable, Dict, Iterator, Optional, Tuple from torch.distributed import FileStore, PrefixStore, Store, TCPStore @@ -21,6 +21,7 @@ __all__ = ["register_rendezvous_handler", "rendezvous"] + def register_rendezvous_handler(scheme, handler): """ Register a new rendezvous handler. @@ -47,16 +48,17 @@ def register_rendezvous_handler(scheme, handler): """ global _rendezvous_handlers if scheme in _rendezvous_handlers: - raise RuntimeError( - f"Rendezvous handler for {scheme}:// already registered" - ) + raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") _rendezvous_handlers[scheme] = handler # Query will have format "rank=0&world_size=1" and is # converted into {"rank": 0, "world_size": 1} def _query_to_dict(query: str) -> Dict[str, str]: - return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))} + return { + pair[0]: pair[1] + for pair in (pair.split("=") for pair in filter(None, query.split("&"))) + } def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: @@ -152,7 +154,9 @@ def _torchelastic_use_agent_store() -> bool: return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) -def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store: +def _create_c10d_store( + hostname, port, rank, world_size, timeout, use_libuv=True +) -> Store: """ Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. @@ -183,7 +187,13 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True else: start_daemon = rank == 0 return TCPStore( - hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv + hostname, + port, + world_size, + start_daemon, + timeout, + multi_tenant=True, + use_libuv=use_libuv, ) @@ -208,7 +218,9 @@ def _error(msg): assert result.hostname is not None - store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv) + store = _create_c10d_store( + result.hostname, result.port, rank, world_size, timeout, use_libuv + ) yield (store, rank, world_size) @@ -250,12 +262,13 @@ def _get_env_or_raise(env_var: str) -> str: else: world_size = int(_get_env_or_raise("WORLD_SIZE")) - master_addr = _get_env_or_raise("MASTER_ADDR") master_port = int(_get_env_or_raise("MASTER_PORT")) use_libuv = _get_use_libuv_from_query_dict(query_dict) - store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv) + store = _create_c10d_store( + master_addr, master_port, rank, world_size, timeout, use_libuv + ) yield (store, rank, world_size) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 5654693f3dfca..aa34891d1ecd2 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -397,9 +397,9 @@ def main(): import os import sys import uuid -import importlib.metadata as metadata -from argparse import REMAINDER, ArgumentParser -from typing import Callable, List, Tuple, Type, Union, Optional, Set +from argparse import ArgumentParser, REMAINDER +from importlib import metadata +from typing import Callable, List, Optional, Set, Tuple, Type, Union import torch from torch.distributed.argparse_util import check_env, env @@ -408,9 +408,9 @@ def main(): from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.utils.backend_registration import _get_custom_mod_func -import torch.multiprocessing + logger = get_logger(__name__) @@ -693,21 +693,26 @@ def determine_local_world_size(nproc_per_node: str): if torch.cuda.is_available(): num_proc = torch.cuda.device_count() device_type = "gpu" - elif hasattr(torch, torch._C._get_privateuse1_backend_name()) and \ - _get_custom_mod_func("is_available")(): + elif ( + hasattr(torch, torch._C._get_privateuse1_backend_name()) + and _get_custom_mod_func("is_available")() + ): num_proc = _get_custom_mod_func("device_count")() device_type = torch._C._get_privateuse1_backend_name() else: num_proc = os.cpu_count() device_type = "cpu" else: - raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e + raise ValueError( + f"Unsupported nproc_per_node value: {nproc_per_node}" + ) from e logger.info( - "Using nproc_per_node=%s," - " setting to %s since the instance " - "has %s %s", - nproc_per_node, num_proc, os.cpu_count(), device_type + "Using nproc_per_node=%s," " setting to %s since the instance " "has %s %s", + nproc_per_node, + num_proc, + os.cpu_count(), + device_type, ) return num_proc @@ -753,9 +758,13 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: logs_specs_cls = entrypoint_list[0].load() if logs_specs_cls is None: - raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key") + raise ValueError( + f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" + ) - logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)) + logging.info( + "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) + ) else: logs_specs_cls = DefaultLogsSpecs @@ -768,7 +777,11 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str assert 0 < min_nodes <= max_nodes assert args.max_restarts >= 0 - if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint: + if ( + hasattr(args, "master_addr") + and args.rdzv_backend != "static" + and not args.rdzv_endpoint + ): logger.warning( "master_addr is only used for static rdzv_backend and when rdzv_endpoint " "is not specified." @@ -784,7 +797,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str "please further tune the variable for optimal performance in " "your application as needed. \n" "*****************************************", - omp_num_threads + omp_num_threads, ) # This env variable will be passed down to the subprocesses os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) @@ -888,7 +901,9 @@ def run(args): "--rdzv-endpoint=%s " "--rdzv-id=%s\n" "**************************************\n", - args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id + args.rdzv_backend, + args.rdzv_endpoint, + args.rdzv_id, ) config, cmd, cmd_args = config_from_args(args) diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index f13d066415015..1a0b849f955d1 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -21,6 +21,7 @@ from torch.nn.parallel.scatter_gather import _is_namedtuple from torch.nn.utils.rnn import PackedSequence + __all__ = [] # type: ignore[var-annotated]