From d9c294c6726ec833406dfaf1a2cdee77c4a5785d Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 18 Jun 2024 22:06:53 +0000 Subject: [PATCH 01/18] [Inductor] Fix arguments passed to triton kernel launch hooks (#128732) `binary.launch_enter_hook` is treated as an instance method and will add a `self` argument to the hooks. `CompiledKernel.launch_enter_hook` is a static method, which matches the hook calling convention of profilers (i.e., a single `LazyDict` argument only). Pull Request resolved: https://github.com/pytorch/pytorch/pull/128732 Approved by: https://github.com/shunting314, https://github.com/bertmaher --- test/inductor/test_profiler.py | 4 ++-- torch/_inductor/runtime/triton_heuristics.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index d2ff71dd73bb6..9d0270a9aae8d 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -158,10 +158,10 @@ def test_inductor_profiling_triton_hooks(self): hooks_called = {"enter": False, "exit": False} - def launch_enter_hook(*args): + def launch_enter_hook(lazy_dict): hooks_called["enter"] = True - def launch_exit_hook(*args): + def launch_exit_hook(lazy_dict): hooks_called["exit"] = True CompiledKernel.launch_enter_hook = launch_enter_hook diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5396ccf3e70d5..82a25392b5e95 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -50,6 +50,7 @@ if triton is not None: from triton import Config + from triton.compiler import CompiledKernel from triton.runtime.autotuner import OutOfResources from triton.runtime.jit import KernelInterface @@ -453,8 +454,8 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): scope = { "grid_meta": cfg.kwargs, "bin": binary, - "launch_enter_hook": binary.launch_enter_hook, - "launch_exit_hook": binary.launch_exit_hook, + "launch_enter_hook": CompiledKernel.launch_enter_hook, + "launch_exit_hook": CompiledKernel.launch_exit_hook, "metadata": binary.packed_metadata if hasattr(binary, "packed_metadata") else binary.metadata, From ac5f565fa7010bd77b9e779415e8709d347234b6 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 18 Jun 2024 11:41:03 -0700 Subject: [PATCH 02/18] [FSDP2] Added `set_post_optim_event` (#128975) This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on. ``` def set_post_optim_event(self, event: torch.cuda.Event) -> None: ``` By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models. To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on. We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`). --- One possible way to use the API is to register a post-step hook on the optimizer. For example: https://github.com/pytorch/pytorch/blob/12e8d1399b979b45d16f0934017f742d01ab2b8d/test/distributed/_composable/fsdp/test_fully_shard_training.py#L546-L552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128975 Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy ghstack dependencies: #128884 --- .../fsdp/test_fully_shard_overlap.py | 82 ++++++++++++++++--- .../fsdp/test_fully_shard_training.py | 41 ++++++++++ .../_composable/fsdp/_fsdp_state.py | 14 +++- .../_composable/fsdp/fully_shard.py | 19 +++++ 4 files changed, 142 insertions(+), 14 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index 99823883abfbb..1fca6c3f3c5a0 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import functools from typing import Callable import torch @@ -7,6 +8,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor.experimental import implicit_replication from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( FSDPTest, @@ -23,15 +25,6 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fully_shard_training_overlap(self): - class LinearWithSleep(nn.Module): - def __init__(self, dim: int, sleep_ms: int): - super().__init__() - self.weight = nn.Parameter(torch.randn((dim, dim))) - self.sleep_ms = sleep_ms - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) - torch.manual_seed(42) # Use non-trivial comm. time but still shorter than compute time @@ -44,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: fully_shard(model, reshard_after_forward=True) orig_all_gather_into_tensor = dist.all_gather_into_tensor - orig_reduce_scatter = dist.reduce_scatter_tensor + orig_reduce_scatter_tensor = dist.reduce_scatter_tensor comm_stream = torch.cuda.Stream() def delay_collective(): @@ -61,7 +54,7 @@ def delayed_all_gather(*args, **kwargs): def delayed_reduce_scatter(*args, **kwargs): delay_collective() - return orig_reduce_scatter(*args, **kwargs) + return orig_reduce_scatter_tensor(*args, **kwargs) inp = torch.randn((2, dim), device="cuda") loss = model(inp).sum() # warmup CUDA and allocator @@ -92,6 +85,63 @@ def fwd_bwd(): ) self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time) + @skip_if_lt_x_gpu(2) + def test_fully_shard_post_optim_event_overlap(self): + torch.manual_seed(42) + + # Use non-trivial comm. time but still shorter than compute time + dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10) + # Define the model to have a high-compute linear followed by a + # low-compute linear, where only the low-compute linear uses FSDP + model = nn.Sequential( + LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim) + ).cuda() + fully_shard(model[1], reshard_after_forward=False) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + orig_all_gather_into_tensor = dist.all_gather_into_tensor + + def delayed_all_gather(*args, **kwargs): + torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms())) + return orig_all_gather_into_tensor(*args, **kwargs) + + inp = torch.randn((2, dim), device="cuda") + + def run_train_steps(num_iters: int, use_post_optim_event: bool): + for _ in range(num_iters): + optim.zero_grad() + with patch_all_gather(delayed_all_gather): + loss = model(inp).sum() + loss.backward() + with implicit_replication(): + optim.step() + if use_post_optim_event: + post_optim_event = torch.cuda.current_stream().record_event() + model[1].set_post_optim_event(post_optim_event) + + run_train_steps(1, False) # warmup CUDA and allocator + num_iters = 5 + baseline_time = self._time_fn( + functools.partial(run_train_steps, num_iters, False) + ) + test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True)) + + buffer_ms = 4 # CPU delays and copies + # Baseline: FSDP all-gather is exposed since the FSDP module waits for + # the current stream and hence the high-compute linear + self.assertLessEqual( + baseline_time, + num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms), + ) + # Test: FSDP all-gather is overlapped with the high-compute linear + # since the FSDP module only waits for the post-optim event (except on + # the 1st iteration when no event has been recorded) + expected_test_time = ( + num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms + ) + self.assertLessEqual(test_time, expected_test_time) + self.assertGreater(baseline_time, expected_test_time) + def _time_fn(self, fn: Callable): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -123,5 +173,15 @@ def backward(ctx, grad_output: torch.Tensor): return grad_input, grad_weight, None +class LinearWithSleep(nn.Module): + def __init__(self, dim: int, sleep_ms: int): + super().__init__() + self.weight = nn.Parameter(torch.randn((dim, dim))) + self.sleep_ms = sleep_ms + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 3dbaa65243794..abc579b40d624 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -532,6 +532,47 @@ def test_explicit_prefetching(self): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_post_optim_event(self): + torch.manual_seed(42) + model_args = ModelArgs(dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + def step_post_hook( + fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs + ) -> None: + post_optim_event = torch.cuda.current_stream().record_event() + fsdp_module.set_post_optim_event(post_optim_event) + + optim.register_step_post_hook(functools.partial(step_post_hook, model)) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + # Track all losses and check for equality at the end to avoid a CPU + # sync point after each iteration + ref_losses: List[torch.Tensor] = [] + losses: List[torch.Tensor] = [] + for iter_idx in range(10): + ref_optim.zero_grad() + ref_losses.append(ref_model(inp).sum()) + ref_losses[-1].backward() + ref_optim.step() + for iter_idx in range(10): + optim.zero_grad() + losses.append(model(inp).sum()) + losses[-1].backward() + optim.step() + # Sleep after the optimizer step to allow CPU to run ahead into the + # next iteration's forward, exercising the post-optim stream sync + torch.cuda._sleep(int(25 * get_cycles_per_ms())) + for ref_loss, loss in zip(ref_losses, losses): + self.assertEqual(ref_loss, loss) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index c6cdb2b29880b..f04e6f6d09292 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -36,6 +36,9 @@ def __init__(self): self.post_backward_final_callback_queued: bool = False # Whether to finalize backward in this backward's final callback self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[torch.cuda.Event] = None def disable_if_config_true(func): @@ -84,9 +87,14 @@ def _root_pre_forward( self._state_ctx.iter_forward_root = self with torch.profiler.record_function("FSDP::root_pre_forward"): # Wait for optimizer before implicitly prefetched all-gathers - current_stream = torch.cuda.current_stream() - self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) - self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = torch.cuda.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) if self._device.type == "cuda": with torch.profiler.record_function("FSDP::inputs_to_device"): args_tuple, kwargs_tuple = _to_kwargs( diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index e8ab3466118bc..88180f40f792c 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -309,6 +309,25 @@ def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: module._get_fsdp_state() for module in modules ] + def set_post_optim_event(self, event: torch.cuda.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (torch.cuda.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") From cb5e9183c6056a7f929a12f574372e87e879d29e Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 19 Jun 2024 00:05:50 +0000 Subject: [PATCH 03/18] [Caffe2] [2/N] Remove Caffe2 from tests (#128911) Follows #128675 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128911 Approved by: https://github.com/titaiwangms, https://github.com/r-barnes --- test/jit/test_tracer.py | 45 ----------- test/onnx/pytorch_test_common.py | 4 +- test/onnx/test_operators.py | 27 ------- test/quantization/core/test_quantized_op.py | 47 ------------ test/test_determination.py | 7 -- test/test_public_bindings.py | 1 - test/test_tensorboard.py | 83 +-------------------- test/test_torch.py | 17 +---- 8 files changed, 4 insertions(+), 227 deletions(-) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 5da8ab61c5b3c..d5ef39ba0c8b4 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -911,51 +911,6 @@ def forward(self, x): self.assertEqual(len(list(g.inputs())), 2) FileCheck().check("mul").check("add").run(str(g)) - def test_trace_c10_ops(self): - try: - _ = torch.ops._caffe2.GenerateProposals - except AttributeError: - self.skipTest("Skip the test since c2 ops are not registered.") - - class MyModel(torch.nn.Module): - def forward(self, scores, bbox_deltas, im_info, anchors): - a, b = torch.ops._caffe2.GenerateProposals( - (scores), - (bbox_deltas), - (im_info), - (anchors), - 2.0, - 6000, - 300, - 0.7, - 16, - True, - -90, - 90, - 1.0, - True, - ) - return a, b - - model = MyModel() - A = 4 - H = 10 - W = 8 - img_count = 3 - scores = torch.ones(img_count, A, H, W, dtype=torch.float32) - bbox_deltas = torch.linspace( - 0, 10, steps=img_count * 4 * A * H * W, dtype=torch.float32 - ) - bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) - im_info = torch.ones(img_count, 3, dtype=torch.float32) - anchors = torch.ones(A, 4, dtype=torch.float32) - inputs = (scores, bbox_deltas, im_info, anchors) - traced_model = torch.jit.trace(model, inputs) - self.assertEqual(traced_model(*inputs), model(*inputs)) - self.assertExportImportModule( - traced_model, (scores, bbox_deltas, im_info, anchors) - ) - def run_ge_tests(self, optimize, use_cuda): with enable_profiling_mode_for_profiling_tests(): with torch.jit.optimized_execution(optimize): diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 6fdbf4e92839c..3b66750f45d8d 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -340,8 +340,8 @@ def inner(self, *args, **kwargs): # skips tests for opset_versions listed in unsupported_opset_versions. -# if the caffe2 test cannot be run for a specific version, add this wrapper -# (for example, an op was modified but the change is not supported in caffe2) +# if the PyTorch test cannot be run for a specific version, add this wrapper +# (for example, an op was modified but the change is not supported in PyTorch) def skipIfUnsupportedOpsetVersion(unsupported_opset_versions): def skip_dec(func): @functools.wraps(func) diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 87ec424cf65d5..b3c75486450a5 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -873,33 +873,6 @@ def test_cumsum(self): x = torch.randn(2, 3, 4, requires_grad=True) self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11) - # Github Issue: https://github.com/pytorch/pytorch/issues/71095 - # def test_c2_op(self): - # class MyModel(torch.nn.Module): - # def __init__(self): - # super().__init__() - # - # def forward(self, scores, bbox_deltas, im_info, anchors): - # a, b = torch.ops._caffe2.GenerateProposals( - # (scores), (bbox_deltas), (im_info), (anchors), - # 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True, - # ) - # return a, b - # - # model = MyModel() - # A = 4 - # H = 10 - # W = 8 - # img_count = 3 - # scores = torch.ones(img_count, A, H, W, dtype=torch.float32) - # bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, - # dtype=torch.float32) - # bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) - # im_info = torch.ones(img_count, 3, dtype=torch.float32) - # anchors = torch.ones(A, 4, dtype=torch.float32) - # inputs = (scores, bbox_deltas, im_info, anchors) - # self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0}) - def test_dict(self): class MyModel(torch.nn.Module): def forward(self, x_in): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 2e606938192dd..25b062a7ab13f 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4457,54 +4457,7 @@ def _test_embedding_bag_unpack_impl(self, pack_fn, unpack_fn, bit_rate, optimize self.assertEqual(unpacked_weight.q_per_channel_scales(), qweight.q_per_channel_scales()) self.assertEqual(unpacked_weight.q_per_channel_zero_points(), qweight.q_per_channel_zero_points()) - # compare against C2 to ensure numerical equivalency. - from caffe2.python import core, workspace - conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == torch.float32 else "HalfFloatToFused8BitRowwiseQuantized" - reverse_conversion_op = None - if bit_rate == 4: - conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused4BitRowwiseQuantized" - reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat" - elif bit_rate == 2: - conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused2BitRowwiseQuantized" - reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat" - - def get_c2_weights(weights, engine_str): - workspace.ResetWorkspace() - - workspace.FeedBlob("weights", weights) - workspace.RunOperatorOnce( - core.CreateOperator( - conversion_op, ["weights"], ["quantized_weights"], engine=engine_str - ) - ) - emb_q = workspace.FetchBlob("quantized_weights") - if bit_rate == 4 or bit_rate == 2: - workspace.RunOperatorOnce( - core.CreateOperator( - reverse_conversion_op, ["quantized_weights"], ["dequantized_weights"] - ) - ) - dequantized_data = torch.from_numpy(workspace.FetchBlob("dequantized_weights")) - else: - dequantized_data = torch.ops._caffe2.Fused8BitRowwiseQuantizedToFloat( - torch.tensor(emb_q) - ) - return torch.from_numpy(emb_q), dequantized_data - - if optimized_qparams: - engine = "GREEDY" - else: - engine = "" - - # C2 quantization needs the memory format of Tensor to be `continuous`, otherwise it will - # throw exceptions. torch.clone() will make the memory format to be `continuous` - c2_copy = torch.clone(weights) - w_packed_c2, w_unpacked_c2 = get_c2_weights(c2_copy, engine) - # Compare packed weights against C2. - np.testing.assert_allclose(w_packed.numpy(), w_packed_c2.numpy(), atol=1e-6, rtol=1e-6) - # Compare unpacked weights against C2 - np.testing.assert_allclose(w_unpacked.numpy(), w_unpacked_c2.numpy(), atol=1e-6, rtol=1e-6) def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, diff --git a/test/test_determination.py b/test/test_determination.py index 50cc2fa9975da..09a67de45dc69 100644 --- a/test/test_determination.py +++ b/test/test_determination.py @@ -121,13 +121,6 @@ def test_torch_file(self): ], ) - def test_caffe2_file(self): - """Caffe2 files trigger dependent tests""" - self.assertEqual(self.determined_tests(["caffe2/python/brew_test.py"]), []) - self.assertEqual( - self.determined_tests(["caffe2/python/context.py"]), self.TESTS - ) - def test_new_folder(self): """New top-level Python folder triggers all tests""" self.assertEqual(self.determined_tests(["new_module/file.py"]), self.TESTS) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 8ab2ac1f511f0..65a5bf90b9f93 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -342,7 +342,6 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.rpc.rpc_test", "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture", "torch.testing._internal.distributed.rpc_utils", - "torch.utils.tensorboard._caffe2_graph", "torch._inductor.codegen.cuda.cuda_template", "torch._inductor.codegen.cuda.gemm_template", "torch._inductor.runtime.triton_helpers", diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index 3ce2ab2a172c8..1e79a2bf910ce 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -23,15 +23,6 @@ HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -TEST_CAFFE2 = True -try: - import caffe2.python.caffe2_pybind11_state as _caffe2_pybind11_state # noqa: F401 - from caffe2.python import brew, cnn, core, workspace - from caffe2.python.model_helper import ModelHelper -except ImportError: - TEST_CAFFE2 = False -skipIfNoCaffe2 = unittest.skipIf(not TEST_CAFFE2, "no caffe2") - TEST_MATPLOTLIB = True try: import matplotlib @@ -48,7 +39,6 @@ parametrize, TestCase, run_tests, - TEST_WITH_ASAN, TEST_WITH_CROSSREF, IS_WINDOWS, IS_MACOS, @@ -94,8 +84,6 @@ def tearDown(self): from torch.utils.tensorboard._pytorch_graph import graph from google.protobuf import text_format from PIL import Image -if TEST_TENSORBOARD and TEST_CAFFE2: - from torch.utils.tensorboard import _caffe2_graph as c2_graph class TestTensorBoardPyTorchNumpy(BaseTestCase): def test_pytorch_np(self): @@ -754,80 +742,11 @@ def test_scalar(self): res = make_np(np.int64(100000000000)) self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - @skipIfNoCaffe2 - def test_caffe2_np(self): - workspace.FeedBlob("testBlob", tensor_N(shape=(1, 3, 64, 64))) - self.assertIsInstance(make_np('testBlob'), np.ndarray) - - @skipIfNoCaffe2 - def test_caffe2_np_expect_fail(self): - with self.assertRaises(RuntimeError): - res = make_np('This_blob_does_not_exist') - def test_pytorch_np_expect_fail(self): with self.assertRaises(NotImplementedError): res = make_np({'pytorch': 1.0}) - @skipIfNoCaffe2 - @unittest.skipIf(TEST_WITH_ASAN, "Caffe2 failure with ASAN") - def test_caffe2_simple_model(self): - model = ModelHelper(name="mnist") - # how come those inputs don't break the forward pass =.=a - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(int)) - - with core.NameScope("conv1"): - conv1 = brew.conv(model, "data", 'conv1', dim_in=1, dim_out=20, kernel=5) - # Image size: 24 x 24 -> 12 x 12 - pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2) - # Image size: 12 x 12 -> 8 x 8 - conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=100, kernel=5) - # Image size: 8 x 8 -> 4 x 4 - pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2) - with core.NameScope("classifier"): - # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size - fc3 = brew.fc(model, pool2, 'fc3', dim_in=100 * 4 * 4, dim_out=500) - relu = brew.relu(model, fc3, fc3) - pred = brew.fc(model, relu, 'pred', 500, 10) - softmax = brew.softmax(model, pred, 'softmax') - xent = model.LabelCrossEntropy([softmax, "label"], 'xent') - # compute the expected loss - loss = model.AveragedLoss(xent, "loss") - model.net.RunAllOnMKL() - model.param_init_net.RunAllOnMKL() - model.AddGradientOperators([loss], skip=1) - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) - - @skipIfNoCaffe2 - def test_caffe2_simple_cnnmodel(self): - model = cnn.CNNModelHelper("NCHW", name="overfeat") - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(int)) - with core.NameScope("conv1"): - conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4) - relu1 = model.Relu(conv1, conv1) - pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) - with core.NameScope("classifier"): - fc = model.FC(pool1, "fc", 4096, 1000) - pred = model.Softmax(fc, "pred") - xent = model.LabelCrossEntropy([pred, "label"], "xent") - loss = model.AveragedLoss(xent, "loss") - - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) + class TestTensorProtoSummary(BaseTestCase): @parametrize( diff --git a/test/test_torch.py b/test/test_torch.py index f252ddf4a5745..86844c77faf4a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -41,7 +41,7 @@ skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, - skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, + bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( @@ -8632,21 +8632,6 @@ def test_allow_tensor_metadata_change(self): a = torch.ones(2, 3) # Metadata changes are allowed on view tensors that are created from detach(). - @skipIfNotRegistered("LayerNorm", "Skipping as LayerNorm is not registered") - def test_c10_layer_norm(self): - # test that we can call c10 ops and they return a reasonable result - X = torch.rand(5, 5, dtype=torch.float) - weight = torch.rand(*X.size()[1:], dtype=torch.float) - bias = torch.rand(*X.size()[1:], dtype=torch.float) - epsilon = 1e-4 - - expected_norm = torch.nn.functional.layer_norm( - X, X.size()[1:], weight=weight, bias=bias, eps=epsilon) - actual_norm, actual_mean, actual_stdev = \ - torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor( - weight), torch.tensor(bias), 1, epsilon, True) - torch.testing.assert_close(expected_norm, actual_norm) - def test_memory_format(self): def test_helper(x, memory_format): y = x.contiguous(memory_format=memory_format) From c5e0b844847c5c34ee824b0de2adeda85ce64133 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Jun 2024 13:14:24 -0700 Subject: [PATCH 04/18] [dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428) Co-authored-by: Laith Sakka Pull Request resolved: https://github.com/pytorch/pytorch/pull/128428 Approved by: https://github.com/yanboliang, https://github.com/mlazos --- test/dynamo/test_repros.py | 2 +- torch/_dynamo/trace_rules.py | 28 ---------------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index dbcb259241fcb..2329ab305e763 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1674,7 +1674,7 @@ def test_issue175(self): self.assertEqual(cnt.frame_count, 1) self.assertEqual( - 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count ) def test_exec_import(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index b5b12435a931a..abbef02e63c68 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2669,26 +2669,6 @@ "torch.nn._reduction.legacy_get_enum", "torch.nn._reduction.legacy_get_string", "torch.nn.factory_kwargs", - "torch.nn.functional._adaptive_max_pool1d", - "torch.nn.functional._adaptive_max_pool2d", - "torch.nn.functional._adaptive_max_pool3d", - "torch.nn.functional._canonical_mask", - "torch.nn.functional._fractional_max_pool2d", - "torch.nn.functional._fractional_max_pool3d", - "torch.nn.functional._get_softmax_dim", - "torch.nn.functional._in_projection_packed", - "torch.nn.functional._in_projection", - "torch.nn.functional._is_integer", - "torch.nn.functional._max_pool1d", - "torch.nn.functional._max_pool2d", - "torch.nn.functional._max_pool3d", - "torch.nn.functional._mha_shape_check", - "torch.nn.functional._no_grad_embedding_renorm_", - "torch.nn.functional._none_or_dtype", - "torch.nn.functional._threshold", - "torch.nn.functional._unpool_output_size", - "torch.nn.functional._verify_batch_size", - "torch.nn.functional._verify_spatial_size", "torch.nn.functional.adaptive_avg_pool2d", "torch.nn.functional.adaptive_avg_pool3d", "torch.nn.functional.adaptive_max_pool1d_with_indices", @@ -2786,15 +2766,7 @@ "torch.nn.grad.conv2d_weight", "torch.nn.grad.conv3d_input", "torch.nn.grad.conv3d_weight", - "torch.nn.modules.activation._arg_requires_grad", - "torch.nn.modules.activation._check_arg_device", "torch.nn.modules.activation._is_make_fx_tracing", - "torch.nn.modules.container._addindent", - "torch.nn.modules.transformer._detect_is_causal_mask", - "torch.nn.modules.transformer._generate_square_subsequent_mask", - "torch.nn.modules.transformer._get_activation_fn", - "torch.nn.modules.transformer._get_clones", - "torch.nn.modules.transformer._get_seq_len", "torch.nn.modules.utils._list_with_default", "torch.nn.modules.utils._ntuple", "torch.nn.modules.utils._quadruple", From 670b94c9c826756495b9e1ca34be1d43756d5296 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Jun 2024 13:14:25 -0700 Subject: [PATCH 05/18] [inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128484 Approved by: https://github.com/mlazos ghstack dependencies: #128428 --- test/inductor/test_mkldnn_pattern_matcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 810c22d037c54..a80d723987602 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -37,7 +37,8 @@ torch.nn.Tanh(): 2, torch.nn.Hardswish(): 6, torch.nn.LeakyReLU(0.1, inplace=False): 4, - torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3, + # Use floats for min/max, otherwise they can get converted to symints + torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3, torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, torch.nn.GELU(approximate="none"): 6, torch.nn.GELU(approximate="tanh"): 10, From 99f042d336b53844b509406f1ecf78cb6f5e5714 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:21:21 +0000 Subject: [PATCH 06/18] Revert "Forward fix to skip ROCm tests for #122836 (#128891)" This reverts commit 4061b3b8225f522ae0ed6db00111441e7d3cc3d5. Reverted https://github.com/pytorch/pytorch/pull/128891 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128891#issuecomment-2177291249)) --- test/test_nestedtensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index fa33a13ed495d..6b9b8f3be45d5 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5470,7 +5470,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_preserves_metadata_cache(self, device, dtype): # shape (B, *, D) nt = random_nt_from_dims( @@ -5501,7 +5500,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 @@ -5538,7 +5536,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_dynamic_min_seq_len(self, device, dtype): # shape (B, *, D) # min seq len: 7 @@ -5575,7 +5572,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 From 35c78668b408046e032a1e025b01250875959cc6 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 18 Jun 2024 13:37:50 -0700 Subject: [PATCH 07/18] Improve the debugging message for when foreach mta_called (#128991) The hope that lives in this PR: I am currently trying to debug why the foreach tests are so flaky. It looks like every flaky test falls under this pattern: - a test is flaky due to the mta_called assertion, which gathers data from the profiler regarding whether the multi_tensor_apply_kernel has been called. - then, a later test fails deterministically, usually failing to compare two results. ``` ================== 1 failed, 241 deselected, 2 rerun in 1.76s ================== Got exit code 1 Stopping at first consistent failure The following tests failed and then succeeded when run in a new process ['test/test_foreach.py::TestForeachCUDA::test_binary_op_float_inf_nan__foreach_add_cuda_bfloat16'] The following tests failed consistently: ['test/test_foreach.py::TestForeachCUDA::test_binary_op_list_error_cases__foreach_add_cuda_bfloat16'] ``` So my suspicion is that the first causes the second, but what causes the first? Idk! So it would be nice to have the error message tell us what the profiler actually saw in case it's getting muddled. This change would help mostly because I have not been able to repro this flakiness locally. Also undo the useless changes in #128220 which are actually redundant as Joel and I realized that we set the seed during the setUp of every test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128991 Approved by: https://github.com/clee2000 --- test/test_foreach.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 567d09cff02d7..99d4cbe5ec003 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -90,7 +90,7 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): mta_called = any("multi_tensor_apply_kernel" in k for k in keys) assert mta_called == ( expect_fastpath and (not zero_size) - ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}" + ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" else: actual = self.func(*inputs, **kwargs) if self.is_inplace: @@ -205,7 +205,6 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_parity(self, device, dtype, op, noncontiguous, inplace): - torch.manual_seed(2024) if inplace: _, _, func, ref = self._get_funcs(op) else: @@ -585,7 +584,6 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op): "failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125035", ) def test_binary_op_list_error_cases(self, device, dtype, op): - torch.manual_seed(202406) foreach_op, foreach_op_, ref, ref_ = ( op.method_variant, op.inplace_variant, @@ -680,7 +678,6 @@ def test_binary_op_list_error_cases(self, device, dtype, op): "failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125775", ) def test_binary_op_list_slow_path(self, device, dtype, op): - torch.manual_seed(20240607) foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op) # 0-strides tensor1 = make_tensor((10, 10), dtype=dtype, device=device) @@ -799,7 +796,6 @@ def test_binary_op_list_slow_path(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_binary_op_float_inf_nan(self, device, dtype, op): - torch.manual_seed(2024) inputs = ( [ torch.tensor([float("inf")], device=device, dtype=dtype), @@ -869,9 +865,6 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_binary_op_tensors_on_different_devices(self, device, dtype, op): - torch.manual_seed(202406) - # `tensors1`: ['cuda', 'cpu'] - # `tensors2`: ['cuda', 'cpu'] _cuda_tensors = next( iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True)) ).input From 5ffb032be682a34b959c82ce289b457ea6c6e504 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:26:38 +0000 Subject: [PATCH 08/18] Revert "Backward support for unbind() with NJT (#128032)" This reverts commit 5dc4f652bc5c068ef15130c955e3f2ffe11f4b74. Reverted https://github.com/pytorch/pytorch/pull/128032 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128032#issuecomment-2177296325)) --- test/test_nestedtensor.py | 19 ------------------- tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 17 ----------------- torch/csrc/autograd/FunctionsManual.h | 4 ---- torch/nested/_internal/ops.py | 11 ----------- 5 files changed, 1 insertion(+), 52 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 6b9b8f3be45d5..86f58b5a0de3a 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5606,25 +5606,6 @@ def f(nt): for dynamic in [False, True, None]: self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - @dtypes(torch.float32, torch.double, torch.half) - def test_unbind_backward(self, device, dtype): - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 4, device=device), - torch.randn(5, 4, device=device), - torch.randn(3, 4, device=device), - ], - layout=torch.jagged, - requires_grad=True, - ) - - a, b, c = nt.unbind() - b.sum().backward() - - expected_grad = torch.zeros_like(nt) - expected_grad.unbind()[1].add_(1.0) - torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) - instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 02a3e6c518ad8..76a7a0a1e42a4 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2847,7 +2847,7 @@ self: unbind_backward(grads, dim) result: auto_linear AutogradNestedTensor: - self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" + self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options()) result: auto_linear - name: stack(Tensor[] tensors, int dim=0) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index f51c2f047f935..9d897c667c906 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1014,23 +1014,6 @@ Tensor unbind_backward_nested( return at::_nested_tensor_from_tensor_list(grads_tensors); } -Tensor unbind_backward_nested_jagged( - const variable_list& grads, - const Tensor& self, - int64_t dim) { - TORCH_INTERNAL_ASSERT( - dim == 0, "unbind_backward_nested_jagged() only supports dim=0") - auto grad_nt = at::zeros_like(self); - auto unbound_grads = grad_nt.unbind(); - for (int64_t i : c10::irange(static_cast(grads.size()))) { - if (grads[i].defined()) { - unbound_grads[i].copy_(static_cast(grads[i])); - } - } - - return grad_nt; -} - Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { auto result = self; diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index ecf99bd098057..dedff70be1ba3 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -244,10 +244,6 @@ at::Tensor unbind_backward_nested( const Tensor& nt_sizes, int64_t dim, const at::TensorOptions& options); -at::Tensor unbind_backward_nested_jagged( - const variable_list& grads, - const Tensor& self, - int64_t dim); at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes); at::Tensor unsqueeze_to( const at::Tensor& self, diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 8458f03717130..6f1c47dd69471 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -472,17 +472,6 @@ def to_copy_default(func, *args, **kwargs): )(jagged_unary_pointwise) -@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") -def zero__default(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - func(inp._values) - return inp - - @register_jagged_func( torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" ) From b0d2fe6299c4462d28b23ef73d872eb608d73d96 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:28:53 +0000 Subject: [PATCH 09/18] Revert "Short-term fix to preserve NJT metadata cache in torch.compile (#122836)" This reverts commit 2a41fc03903de63270d325bd1886a50faf32d7e4. Reverted https://github.com/pytorch/pytorch/pull/122836 on behalf of https://github.com/jbschlosser due to internal test failures with DEBUG=1 asserts ([comment](https://github.com/pytorch/pytorch/pull/122836#issuecomment-2177298245)) --- aten/src/ATen/FunctionalInverses.cpp | 9 +- aten/src/ATen/native/native_functions.yaml | 14 +- test/dynamo/test_subclasses.py | 6 +- ...asDecompTest.test_has_decomposition.expect | 2 - test/test_nestedtensor.py | 173 +---------------- tools/autograd/derivatives.yaml | 4 +- torch/nested/_internal/nested_tensor.py | 174 ++++-------------- torch/nested/_internal/ops.py | 37 +--- torch/nested/_internal/sdpa.py | 62 ++----- 9 files changed, 69 insertions(+), 412 deletions(-) diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index a1cf449cde7c7..16b59333f918f 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -303,7 +303,7 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, return Tensor(); } -Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx, const c10::optional& min_seqlen, const c10::optional& max_seqlen) { +Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx) { auto values = at::_nested_get_values(mutated_view); if (inverse_return_mode != InverseReturnMode::NeverView) { return values; @@ -317,12 +317,7 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const auto lengths = at::_nested_get_lengths(base); auto ragged_idx = at::_nested_get_ragged_idx(base); auto dummy = at::_nested_get_jagged_dummy(base); - auto min_seqlen = at::_nested_get_min_seqlen(base); - auto max_seqlen = at::_nested_get_max_seqlen(base); - auto nt = at::_nested_view_from_jagged( - mutated_view, offsets, dummy, lengths, ragged_idx, - (min_seqlen.defined() ? c10::optional(min_seqlen) : c10::nullopt), - (max_seqlen.defined() ? c10::optional(max_seqlen) : c10::nullopt)); + auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx); if (inverse_return_mode != InverseReturnMode::NeverView) { return nt; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b030141882c86..a2d9095d56a38 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6185,12 +6185,12 @@ CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy autogen: _nested_view_from_buffer_copy.out -- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) +- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) variants: function device_check: NoCheck dispatch: {} -- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor +- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor variants: function device_check: NoCheck tags: view_copy @@ -6227,16 +6227,6 @@ device_check: NoCheck dispatch: {} -- func: _nested_get_min_seqlen(Tensor self) -> Tensor - variants: function - device_check: NoCheck - dispatch: {} - -- func: _nested_get_max_seqlen(Tensor self) -> Tensor - variants: function - device_check: NoCheck - dispatch: {} - - func: _nested_get_jagged_dummy(Tensor any) -> Tensor category_override: dummy dispatch: {} diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index f16ef15990fd8..302b07e4ddb78 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1616,15 +1616,15 @@ def backend(gm, args): guard_str, """\ Eq(s3 - 1, s0) -Eq(zf1, zf6)""", +Eq(zf1, zf4)""", ) else: self.assertExpectedInline( guard_str, """\ Eq(s4 - 1, s1) -Eq(s12 - 1, s7) -Eq(s11, s9)""", +Eq(s10 - 1, s5) +Eq(s9, s7)""", ) return gm diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 132d25a8b12f9..1179142e15d9e 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -446,8 +446,6 @@ aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out aten::_nested_get_jagged_dummy aten::_nested_get_lengths -aten::_nested_get_max_seqlen -aten::_nested_get_min_seqlen aten::_nested_get_offsets aten::_nested_get_ragged_idx aten::_nested_get_values diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 86f58b5a0de3a..78d082702aecb 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -67,21 +67,6 @@ def _iter_constructors(): yield torch.nested.nested_tensor -# Returns True if the function recompiles between inputs1 and inputs2 with the -# specified dynamic setting. -def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): - compile_count = [0] - - def counter(gm, example_inputs): - compile_count[0] += 1 - return gm - - compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) - compiled_f(*inputs1) - compiled_f(*inputs2) - return compile_count[0] > 1 - - # Helper function to generate a pair of random nested tensors # one is contiguous, the other is not, but they appear to have same entries # an output nested tensor consists of @@ -4833,18 +4818,19 @@ def fn(values, same_size): check_results(fn, compiled_fn, generate_inp(20)) self.assertEqual(compile_counter.frame_count, frame_count_2) + # Doesn't work until we have real views + @xfailIfTorchDynamo # Note 1: Math fallback doesn't work with bfloat16 on CUDA # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT @unittest.skipIf( TEST_WITH_ROCM, "ROCm doesn't support flash attention or mem_efficient attention for NT", ) - @dtypes( - *( - [torch.float16, torch.bfloat16, torch.float32] - if SM80OrLater - else [torch.float16, torch.float32] - ) + @parametrize( + "dtype", + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32], ) def test_sdpa(self, device, dtype): batch_size = 1 @@ -5187,6 +5173,8 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): ) self.assertEqual(output._values, output_dense) + # Doesn't work until we have real views + @xfailIfTorchDynamo @onlyCUDA @unittest.skipIf( not PLATFORM_SUPPORTS_FUSED_ATTENTION, @@ -5463,149 +5451,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): padded, [offsets_wrong], total_L ) - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_preserves_metadata_cache(self, device, dtype): - # shape (B, *, D) - nt = random_nt_from_dims( - [4, None, 3, 16], - device=device, - dtype=dtype, - layout=torch.jagged, - requires_grad=True, - ) - - # expect min / max seqlen to be stored here - cache = dict(nt._metadata_cache) - - @torch.compile - def f(nt): - q = nt.transpose(-3, -2) - output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2) - return output - - output = f(nt) - output.backward(torch.ones_like(output)) - self.assertEqual(output._metadata_cache, cache) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_dynamic_max_seq_len(self, device, dtype): - # shape (B, *, D) - # max seq len: 18 - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(18, 5), - ], - layout=torch.jagged, - ) - - # max seq len: 19 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(19, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt) * nt._get_max_seqlen() - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_dynamic_min_seq_len(self, device, dtype): - # shape (B, *, D) - # min seq len: 7 - nt = torch.nested.nested_tensor( - [ - torch.randn(7, 5), - torch.randn(8, 5), - torch.randn(9, 5), - ], - layout=torch.jagged, - ) - - # min seq len: 8 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(8, 5), - torch.randn(9, 5), - torch.randn(10, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt) * nt._get_min_seqlen() - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): - # shape (B, *, D) - # max seq len: 18 - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(18, 5), - ], - layout=torch.jagged, - ) - - # max seq len: 19 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(19, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - nt2 = nt.sin() + 1 - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt2) * nt2._get_max_seqlen() - - ref = f(nt) - output = torch.compile(f, fullgraph=True, dynamic=False)(nt) - self.assertEqual(ref, output) - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 76a7a0a1e42a4..1e9b9091a20e9 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2794,14 +2794,14 @@ nested_size: non_differentiable nested_strides: non_differentiable -- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) +- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) self: grad.values() offsets: non_differentiable lengths: non_differentiable dummy: non_differentiable - name: _nested_get_values(Tensor(a) self) -> Tensor(a) - self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional(at::_nested_get_min_seqlen(self)) : c10::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional(at::_nested_get_max_seqlen(self)) : c10::nullopt)" + self: _nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self)) # Transformers - name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 92423cf32b2fe..66d25eacc7ad4 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -27,15 +27,6 @@ def _get_sdpa_extreme_seqlen(func, tensor): return int(func(tensor).item()) -def _store_val_in_tensor(val) -> torch.Tensor: - # hack to get dynamic shapes support: store in a (val, 0) shaped tensor - return torch.zeros(val, 0) - - -def _load_val_from_tensor(t: torch.Tensor): - return t.shape[0] - - class NestedTensor(torch.Tensor): _values: torch.Tensor # type: ignore[assignment] _offsets: torch.Tensor @@ -131,14 +122,6 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) - # min / max sequence length should be dynamic if present - max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None) - if max_seqlen_tensor is not None: - torch._dynamo.mark_dynamic(max_seqlen_tensor, 0) - min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None) - if min_seqlen_tensor is not None: - torch._dynamo.mark_dynamic(min_seqlen_tensor, 0) - def values(self): # dispatch to get proper view relationship return torch._nested_get_values(self) # type: ignore[attr-defined] @@ -149,56 +132,25 @@ def offsets(self): def lengths(self): return self._lengths - # Private accessor functions for min / max sequence length. They're - # purposefully not @properties because those don't work with PT2 (yet). - # These compute / cache if not present. - # TODO: Revisit this when @properties are better supported by PT2. I think the ideal - # state would be to have public @properties for min / max sequence length that compile - # (including setters). - def _get_max_seqlen(self): - max_seqlen_tensor = self._max_seqlen_tensor - if max_seqlen_tensor is None: + @property + def _max_seqlen(self): + if "max_seqlen" not in self._metadata_cache: # compute & cache - max_val = _get_sdpa_extreme_seqlen( + self._metadata_cache["max_seqlen"] = _get_sdpa_extreme_seqlen( torch.max, self._offsets.diff() if self._lengths is None else self._lengths, ) - max_seqlen_tensor = _store_val_in_tensor(max_val) - self._metadata_cache["max_seqlen"] = max_seqlen_tensor - return _load_val_from_tensor(max_seqlen_tensor) + return self._metadata_cache["max_seqlen"] - def _get_min_seqlen(self): - min_seqlen_tensor = self._min_seqlen_tensor - if min_seqlen_tensor is None: + @property + def _min_seqlen(self): + if "min_seqlen" not in self._metadata_cache: # compute & cache - min_val = _get_sdpa_extreme_seqlen( + self._metadata_cache["min_seqlen"] = _get_sdpa_extreme_seqlen( torch.min, self._offsets.diff() if self._lengths is None else self._lengths, ) - min_seqlen_tensor = _store_val_in_tensor(min_val) - self._metadata_cache["min_seqlen"] = min_seqlen_tensor - return _load_val_from_tensor(min_seqlen_tensor) - - # Private accessors used for treating min / max seqlen as inner tensors for - # flatten / unflatten. These must be properties to work with the traceable wrapper - # subclass logic. These do not compute / cache if not present. - @property - def _max_seqlen_tensor(self) -> Optional[torch.Tensor]: - return self._metadata_cache.get("max_seqlen", None) - - @property - def _min_seqlen_tensor(self) -> Optional[torch.Tensor]: - return self._metadata_cache.get("min_seqlen", None) - - # These are old private @property accessors that are kept around for internal BC - # reasons. TODO: Remove these! - @property - def _max_seqlen(self): - return self._get_max_seqlen() - - @property - def _min_seqlen(self): - return self._get_min_seqlen() + return self._metadata_cache["min_seqlen"] def __repr__(self): # We should implement this in torch/_tensor_str.py instead @@ -218,7 +170,6 @@ def __reduce_ex__(self, proto): del state["_size"] del state["_strides"] - # TODO: Update this to handle the other inner tensors func = NestedTensor args = (self._values, self._offsets) return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state)) @@ -226,33 +177,22 @@ def __reduce_ex__(self, proto): def __tensor_flatten__(self): ctx = { "requires_grad": self.requires_grad, + # TODO: Don't guard on this! + "metadata_cache": self._metadata_cache, "ragged_idx": self._ragged_idx, } inner_tensors = ["_values", "_offsets"] if self._lengths is not None: inner_tensors.append("_lengths") - if self._min_seqlen_tensor is not None: - inner_tensors.append("_min_seqlen_tensor") - if self._max_seqlen_tensor is not None: - inner_tensors.append("_max_seqlen_tensor") return inner_tensors, ctx @staticmethod def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): - # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] - assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5 + # inner tensors: _values, _offsets, [_lengths] + assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3 values = inner_tensors["_values"] offsets = inner_tensors["_offsets"] lengths = inner_tensors.get("_lengths", None) - min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None) - max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None) - - metadata_cache = {} - if min_seqlen_tensor is not None: - metadata_cache["min_seqlen"] = min_seqlen_tensor - if max_seqlen_tensor is not None: - metadata_cache["max_seqlen"] = max_seqlen_tensor - ragged_idx = meta["ragged_idx"] # Note that we cannot simply check if is_fake(values) because @@ -271,7 +211,7 @@ def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): lengths=lengths, requires_grad=meta["requires_grad"], _ragged_idx=ragged_idx, - _metadata_cache=metadata_cache, + _metadata_cache=meta["metadata_cache"], ) @classmethod @@ -336,15 +276,6 @@ def forward( offsets: torch.Tensor, metadata_cache: Optional[Dict[str, Any]] = None, ): # type: ignore[override] - # maintain BC with this usages of this where the seqlens are stuffed - # directly into the metadata cache as non-Tensors / ints - if metadata_cache is not None: - min_seqlen = metadata_cache.get("min_seqlen", None) - max_seqlen = metadata_cache.get("max_seqlen", None) - if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor): - metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen) - if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor): - metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen) return NestedTensor( values.detach(), offsets=offsets, @@ -412,12 +343,12 @@ def jagged_from_list( ] ) - # compute this now since it's easy - min_seqlen = min([t.shape[0] for t in tensors]) - max_seqlen = max([t.shape[0] for t in tensors]) - ret_nt = nested_view_from_values_offsets( - values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen - ) + ret_nt = nested_view_from_values_offsets(values, offsets) + ret_nt._metadata_cache = { + # compute this now since it's easy + "max_seqlen": max(t.shape[0] for t in tensors), + "min_seqlen": min(t.shape[0] for t in tensors), + } return (ret_nt, offsets) # type: ignore[return-value] @@ -474,19 +405,16 @@ def jagged_from_tensor_and_lengths( if is_contiguous: ret_nt = nested_view_from_values_offsets( - values[offsets[0] : offsets[-1]], - offsets - offsets[0], - min_seqlen=min_seqlen, - max_seqlen=actual_max_seqlen, + values[offsets[0] : offsets[-1]], offsets - offsets[0] ) else: - ret_nt = nested_view_from_values_offsets_lengths( - values, - offsets, - length_list, - min_seqlen=min_seqlen, - max_seqlen=actual_max_seqlen, - ) + ret_nt = nested_view_from_values_offsets_lengths(values, offsets, length_list) + + # populate metadata cache with computed seqlen extremes + ret_nt._metadata_cache = { + "max_seqlen": actual_max_seqlen, + "min_seqlen": min_seqlen, + } return (ret_nt, offsets, None if is_contiguous else length_list) @@ -508,45 +436,13 @@ def _nt_view_dummy() -> torch.Tensor: return _dummy_instance -def nested_view_from_values_offsets( - values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None -): - min_seqlen_tensor = None - if min_seqlen is not None: - min_seqlen_tensor = _store_val_in_tensor(min_seqlen) - - max_seqlen_tensor = None - if max_seqlen is not None: - max_seqlen_tensor = _store_val_in_tensor(max_seqlen) - +def nested_view_from_values_offsets(values, offsets, ragged_idx=1): return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, - offsets, - _nt_view_dummy(), - None, - ragged_idx, - min_seqlen_tensor, - max_seqlen_tensor, - ) # type: ignore[return-value] - - -def nested_view_from_values_offsets_lengths( - values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None -): - min_seqlen_tensor = None - if min_seqlen is not None: - min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + values, offsets, _nt_view_dummy(), None, ragged_idx + ) - max_seqlen_tensor = None - if max_seqlen is not None: - max_seqlen_tensor = _store_val_in_tensor(max_seqlen) +def nested_view_from_values_offsets_lengths(values, offsets, lengths, ragged_idx=1): return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, - offsets, - _nt_view_dummy(), - lengths, - ragged_idx, - min_seqlen_tensor, - max_seqlen_tensor, - ) # type: ignore[return-value] + values, offsets, _nt_view_dummy(), lengths, ragged_idx + ) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 6f1c47dd69471..6ec3ba538f977 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1088,7 +1088,7 @@ def values_default(func, *args, **kwargs): @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, - "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", + "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?", ) def _nested_view_from_jagged_default(func, *args, **kwargs): _, new_kwargs = normalize_function( @@ -1101,21 +1101,8 @@ def _nested_view_from_jagged_default(func, *args, **kwargs): new_kwargs["lengths"], ) ragged_idx = new_kwargs["ragged_idx"] - min_seqlen = new_kwargs["min_seqlen"] - max_seqlen = new_kwargs["max_seqlen"] - metadata_cache = {} - if min_seqlen is not None: - metadata_cache["min_seqlen"] = min_seqlen - if max_seqlen is not None: - metadata_cache["max_seqlen"] = max_seqlen - return NestedTensor( - values, - offsets, - lengths=lengths, - _ragged_idx=ragged_idx, - _metadata_cache=metadata_cache, - ) + return NestedTensor(values, offsets, lengths=lengths, _ragged_idx=ragged_idx) @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") @@ -1148,26 +1135,6 @@ def _nested_get_ragged_idx(func, *args, **kwargs): return inp._ragged_idx -@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all") -def _nested_get_min_seqlen(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - return inp._metadata_cache.get("min_seqlen", None) - - -@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all") -def _nested_get_max_seqlen(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - return inp._metadata_cache.get("max_seqlen", None) - - # Make the dummy available on the C++ side. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") def _nested_get_jagged_dummy(func, *args, **kwargs): diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 8f2eba4db3e46..b7c69c905e9a8 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -15,7 +15,7 @@ ) from torch.nn.attention import SDPBackend -from .nested_tensor import NestedTensor +from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer log = logging.getLogger(__name__) @@ -125,7 +125,7 @@ def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( return False # This is being called inside sdp with shape [batch, heads, {seq_len}, dim] - if param._get_min_seqlen() == 0: + if param._min_seqlen == 0: if debug: log.warning( "Fused kernels do not support seq_len == 0, %s has a seq len of 0.", @@ -315,7 +315,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in if qkv.lengths() is None: # TODO: Explore performance impact of copying cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) - max_seqlen = qkv._get_max_seqlen() + max_seqlen = qkv._max_seqlen n_elem = qkv.values().shape[0] else: # TODO: Explore performance impact of copying @@ -323,7 +323,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) ) batch_size = qkv.size(0) - max_seqlen = qkv._get_max_seqlen() + max_seqlen = qkv._max_seqlen # TODO: Explore performance impact when compiling n_elem = int(cumulative_seqlen[-1].item()) return cumulative_seqlen, max_seqlen, n_elem @@ -364,7 +364,7 @@ def _view_as_dense( tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int ) -> torch.Tensor: if tensor.is_nested: - return tensor.values() + return buffer_from_jagged(tensor) return tensor.view(Nnz, num_heads, head_dim) @@ -567,8 +567,8 @@ def _sdpa_nested_preprocessing(query, key, value): output_nt_info = { "offsets": q_t.offsets(), - "_max_seqlen": q_t._get_max_seqlen(), - "_min_seqlen": q_t._get_min_seqlen(), + "_max_seqlen": q_t._max_seqlen, + "_min_seqlen": q_t._min_seqlen, } return ( @@ -694,14 +694,9 @@ def jagged_scaled_dot_product_attention( False, scale=og_scale, ) - from torch.nested._internal.nested_tensor import nested_view_from_values_offsets - # Reshape output to convert nnz to batch_size and seq_len - attention = nested_view_from_values_offsets( - attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + attention = ViewNestedFromBuffer.apply( + attention.squeeze(0), output_nt_info["offsets"] ).transpose(1, 2) return _post_process_flash_output(attention, og_size) elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: @@ -737,14 +732,9 @@ def jagged_scaled_dot_product_attention( scale=scale, ) - from torch.nested._internal.nested_tensor import nested_view_from_values_offsets - # Reshape output to convert nnz to batch_size and seq_len - return nested_view_from_values_offsets( - attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + return ViewNestedFromBuffer.apply( + attention.squeeze(0), output_nt_info["offsets"] ).transpose(1, 2) elif backend_choice == SDPBackend.MATH: # save the offsets and shape of the inputs, so we can reshape the final output @@ -754,19 +744,12 @@ def jagged_scaled_dot_product_attention( d1 = query._size[1] d2 = value._size[-1] - min_seqlen_tensor = query._metadata_cache.get( - "min_seqlen", None - ) # type: ignore[attr-defined] - max_seqlen_tensor = query._metadata_cache.get( - "max_seqlen", None - ) # type: ignore[attr-defined] - # convert jagged layout Nested Tensor to strided layout Nested Tensor # which support the math implementation of SDPA def get_strided_layout_nested_tensor(jagged_layout_nt): lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] transpose = torch.transpose(jagged_layout_nt, 1, 2) - tensor_list = transpose.values().split(list(lengths), dim=0) + tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0) strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) strided_nt = strided_nt.transpose(1, 2).contiguous() return strided_nt @@ -779,28 +762,11 @@ def get_strided_layout_nested_tensor(jagged_layout_nt): query, key, value, attn_mask, dropout_p, is_causal, scale=scale )[0] - from torch.nested._internal.nested_tensor import ( - _load_val_from_tensor, - nested_view_from_values_offsets, - ) - # convert strided layout Nested Tensor back to jagged layout Nested Tensor attn_out = attn_out.transpose(1, 2).contiguous().values() attn_out = attn_out.view(-1, d1, d2) - attn_out = nested_view_from_values_offsets( - attn_out, - offsets, - min_seqlen=( - None - if min_seqlen_tensor is None - else _load_val_from_tensor(min_seqlen_tensor) - ), - max_seqlen=( - None - if max_seqlen_tensor is None - else _load_val_from_tensor(max_seqlen_tensor) - ), - ).transpose(1, 2) + attn_out = ViewNestedFromBuffer.apply(attn_out, offsets) + attn_out = attn_out.transpose(1, 2) return attn_out else: From 2458f79f83e865a0469f844e87a64edfcecc7065 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Mon, 17 Jun 2024 12:40:38 -0700 Subject: [PATCH 10/18] [Inductor UT][Intel GPU] Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU (#128881) Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU because it have not implemented reduction kernel split. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128881 Approved by: https://github.com/blaine-rister, https://github.com/EikanWang, https://github.com/malfet --- test/inductor/test_torchinductor_strided_blocks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index bd859802892df..bf96ad8d486d8 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfXpu, ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, @@ -214,6 +215,7 @@ def get_input(view_size: Tuple[int]) -> torch.Tensor: # Expect 3 block pointers: 2 inputs one output self.run_and_compare(foo, x, y, expected_num_block_pointers=3) + @skipIfXpu @parametrize( "view_size,num_block_pointers,num_triton_kernels", [ From eda375a49078f5fecc90f28ca8ff949e8e5811e9 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 17 Jun 2024 19:54:34 -0700 Subject: [PATCH 11/18] [Inductor] Remove min/max from inductor opinfo test (#128925) **Summary** Remove `max.binary, min.binary, maximum, minimum` from `inductor_one_sample` op list as we fix the bool vectorization issue in https://github.com/pytorch/pytorch/pull/126841. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_maximum python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_minimum python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_min_binary python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_max_binary ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128925 Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_torchinductor_opinfo.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 29be591dc006c..c7153b5b6d849 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -425,11 +425,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "logspace": {f16}, "logspace.tensor_overload": {f16, f32, f64, i32, i64}, "masked_logsumexp": {i64}, - "max.binary": {b8}, "max_pool2d_with_indices_backward": {f16, f32, f64}, - "maximum": {b8}, - "min.binary": {b8}, - "minimum": {b8}, "new_empty_strided": {f16}, "nn.functional.adaptive_avg_pool3d": {f16}, "nn.functional.adaptive_max_pool1d": {f16, f32}, From 4bc90185fb77438717d59b2d9bb63096ae682935 Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Wed, 19 Jun 2024 01:17:05 +0000 Subject: [PATCH 12/18] fix: Print statements causing parse error (#128969) The print statements for the get_workflow_type script is problematic because the shell script calling this script is expecting the output to only be JSON. This PR resolves this by removing all print statements to covert them to a message field in the JSON return output so that the output can continue to expect to be JSON while giving us the debug data we are looking for. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128969 Approved by: https://github.com/tylertitsworth, https://github.com/ZainRizvi --- .github/scripts/get_workflow_type.py | 47 ++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/.github/scripts/get_workflow_type.py b/.github/scripts/get_workflow_type.py index 4a5303ae9212f..5384ef92c12f2 100644 --- a/.github/scripts/get_workflow_type.py +++ b/.github/scripts/get_workflow_type.py @@ -1,6 +1,6 @@ import json from argparse import ArgumentParser -from typing import Any +from typing import Any, Tuple from github import Auth, Github from github.Issue import Issue @@ -9,6 +9,8 @@ WORKFLOW_LABEL_META = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation LABEL_TYPE_KEY = "label_type" +MESSAGE_KEY = "message" +MESSAGE = "" # Debug message to return to the caller def parse_args() -> Any: @@ -48,45 +50,50 @@ def is_exception_branch(branch: str) -> bool: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def get_workflow_type(issue: Issue, username: str) -> str: +def get_workflow_type(issue: Issue, username: str) -> Tuple[str, str]: try: user_list = issue.get_comments()[0].body.split() if user_list[0] == "!": - print("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META + MESSAGE = "LF Workflows are disabled for everyone. Using meta runners." + return WORKFLOW_LABEL_META, MESSAGE elif user_list[0] == "*": - print("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF + MESSAGE = "LF Workflows are enabled for everyone. Using LF runners." + return WORKFLOW_LABEL_LF, MESSAGE elif username in user_list: - print(f"LF Workflows are enabled for {username}. Using LF runners.") - return WORKFLOW_LABEL_LF + MESSAGE = f"LF Workflows are enabled for {username}. Using LF runners." + return WORKFLOW_LABEL_LF, MESSAGE else: - print(f"LF Workflows are disabled for {username}. Using meta runners.") - return WORKFLOW_LABEL_META + MESSAGE = f"LF Workflows are disabled for {username}. Using meta runners." + return WORKFLOW_LABEL_META, MESSAGE except Exception as e: - print( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META + MESSAGE = f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" + return WORKFLOW_LABEL_META, MESSAGE def main() -> None: args = parse_args() if is_exception_branch(args.github_branch): - print(f"Exception branch: '{args.github_branch}', using meta runners") - output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META} + output = { + LABEL_TYPE_KEY: WORKFLOW_LABEL_META, + MESSAGE_KEY: f"Exception branch: '{args.github_branch}', using meta runners", + } else: try: gh = get_gh_client(args.github_token) # The default issue we use - https://github.com/pytorch/test-infra/issues/5132 issue = get_issue(gh, args.github_repo, args.github_issue) - - output = {LABEL_TYPE_KEY: get_workflow_type(issue, args.github_user)} + label_type, message = get_workflow_type(issue, args.github_user) + output = { + LABEL_TYPE_KEY: label_type, + MESSAGE_KEY: message, + } except Exception as e: - print(f"Failed to get issue. Falling back to meta runners. Exception: {e}") - output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META} + output = { + LABEL_TYPE_KEY: WORKFLOW_LABEL_META, + MESSAGE_KEY: f"Failed to get issue. Falling back to meta runners. Exception: {e}", + } json_output = json.dumps(output) print(json_output) From df85f34a14dd30f784418624b05bd52b12ab8b0b Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 14 Jun 2024 01:51:17 -0700 Subject: [PATCH 13/18] Add test to xfail_list only for abi_compatible (#128506) https://github.com/pytorch/pytorch/pull/126717 will skip the tests in both ABI compatible and non-ABI compatible mode. It's not expected to skip them in non-ABI compatible mode since they can actually run successfully in such mode but only have issues in ABI compatible mode. We leverage the existing `xfail_list` for those that will only fail in ABI compatible mode. - `test_qlinear_add` is already in the `xfail_list`. - `test_linear_packed` doesn't fail either in my local run (running with `TORCHINDUCTOR_ABI_COMPATIBLE=1`) or in the CI of this PR so I didn't add it into `xfail_list`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128506 Approved by: https://github.com/jgong5, https://github.com/desertfire --- test/inductor/test_cpu_cpp_wrapper.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 8bf9b1e6a61f8..0a2b75ddb5544 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -95,7 +95,9 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): "test_qconv2d_relu_cpu", "test_qlinear_cpu", "test_qlinear_add_cpu", + "test_qlinear_add_relu_cpu", "test_qlinear_dequant_promotion_cpu", + "test_qlinear_gelu_cpu", "test_qlinear_relu_cpu", ] for test_name in xfail_list: @@ -125,7 +127,6 @@ def make_test_case( slow=False, func_inputs=None, code_string_count=None, - skip=None, ): test_name = f"{name}_{device}" if device else name if code_string_count is None: @@ -134,8 +135,6 @@ def make_test_case( func = getattr(tests, test_name) assert callable(func), "not a callable" func = slowTest(func) if slow else func - if skip: - func = unittest.skip(skip)(func) @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): @@ -183,7 +182,6 @@ class BaseTest(NamedTuple): slow: bool = False func_inputs: list = None code_string_count: dict = {} - skip: str = None for item in [ BaseTest("test_add_complex"), @@ -242,9 +240,7 @@ class BaseTest(NamedTuple): torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), - BaseTest( - "test_linear_packed", "", test_cpu_repro.CPUReproTests(), skip="Failing" - ), + BaseTest("test_linear_packed", "", test_cpu_repro.CPUReproTests()), BaseTest( "test_lstm_packed_change_input_sizes", "cpu", @@ -318,21 +314,18 @@ class BaseTest(NamedTuple): "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_add", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_add_relu", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_dequant_promotion", @@ -388,7 +381,6 @@ class BaseTest(NamedTuple): item.slow, item.func_inputs, item.code_string_count, - skip=item.skip, ) test_torchinductor.copy_tests( From ed5b8432cdf8451520c064b16d9b0e971c5a5211 Mon Sep 17 00:00:00 2001 From: Alnis Murtovi Date: Wed, 19 Jun 2024 03:12:15 +0000 Subject: [PATCH 14/18] Enable mixed_mm only if casting from lower-bitwidth type to a higher one (#128899) This PR changes the behavior of `cuda_and_enabled_mixed_mm` such that mixed_mm is only enabled if we are casting from a lower-bitwidth type to a higher one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128899 Approved by: https://github.com/eellison --- test/inductor/test_pattern_matcher.py | 26 ++++++++++++++++++-------- torch/_inductor/fx_passes/post_grad.py | 9 +++++++-- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 58f1ff88499f7..d4570a8a2dbc0 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -442,7 +442,15 @@ def fn(a, b): .sub(8), ) - args_list = [ + def check_uint4x2_mixed_mm(args, expect_mixed_mm): + torch._dynamo.reset() + counters.clear() + ref = fn(*args) + test, (code,) = run_and_get_code(torch.compile(fn), *args) + torch.testing.assert_close(ref, test) + self.assertEqual("uint4x2_mixed_mm" in code, expect_mixed_mm) + + args_expect_mixed_mm = [ ( torch.randn(8, 8, device="cuda"), torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), @@ -454,6 +462,13 @@ def fn(a, b): .contiguous() .t(), ), + ] + + for args in args_expect_mixed_mm: + check_uint4x2_mixed_mm(args, True) + + # mixed mm is only enabled when casting from a lower-bitwidth dtype to a higher one + args_expect_no_mixed_mm = [ ( torch.randn(8, 8, device="cuda"), torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"), @@ -464,13 +479,8 @@ def fn(a, b): ), ] - for args in args_list: - torch._dynamo.reset() - counters.clear() - ref = fn(*args) - test, (code,) = run_and_get_code(torch.compile(fn), *args) - torch.testing.assert_close(ref, test) - self.assertTrue("uint4x2_mixed_mm" in code) + for args in args_expect_no_mixed_mm: + check_uint4x2_mixed_mm(args, False) @unittest.skipIf(not SM80OrLater, "need sm_80") @inductor_config.patch(use_mixed_mm=True) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 4bb0244b97f38..c67471c55ab7c 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -229,8 +229,13 @@ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): def cuda_and_enabled_mixed_mm(match): - return (config.use_mixed_mm or config.mixed_mm_choice != "default") and getattr( - match.kwargs["mat1"].meta.get("val"), "is_cuda", False + return ( + (config.use_mixed_mm or config.mixed_mm_choice != "default") + and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) + and ( + match.kwargs["mat2_dtype"].itemsize + > match.kwargs["mat2"].meta.get("val").dtype.itemsize + ) ) From 8771e3429c3d7327f08c48d547ad73546d5603b3 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Tue, 18 Jun 2024 14:24:22 -0700 Subject: [PATCH 15/18] Introduce a prototype for SymmetricMemory (#128582) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them. ### SymmetricMemory `SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers). ### Python API Example ```python from torch._C.distributed_c10d import _SymmetricMemory # Set a store for rendezvousing symmetric allocations on a group of devices # identified by group_name. The concept of groups is logical; users can # utilize predefined groups (e.g., a group of device identified by a # ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator # backends might employ a more efficient communication channel for the actual # rendezvous process and only use the store for bootstrapping purposes. _SymmetricMemory.set_group_info(group_name, rank, world_size, store) # Identical to empty_strided, but allows symmetric memory access to be # established for the allocated tensor via _SymmetricMemory.rendezvous(). # This function itself is not a collective operation. t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name) # Users can write Python custom ops that leverages the symmetric memory access. # Below are examples of things users can do (assuming the group's world_size is 2). # Establishes symmetric memory access on tensors allocated via # _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process, # and the mapping between a local memory region and the associated SymmetricMemory # object is unique. Subsequent calls to rendezvous() with the same tensor will receive # the cached SymmetricMemory object. # # The function has a collective semantic and must be invoked simultaneously # from all rendezvous participants. symm_mem = _SymmetricMemory.rendezvous(t) # This represents the allocation on rank 0 and is accessible from all devices. buf = symm_mem.get_buffer(0, (64, 64), torch.float32) if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) assert buf.eq(42).all() else: # The remote buffer can be used as a regular tensor buf.fill_(42) symm_mem.put_signal(dst_rank=0) symm_mem.barrier() if symm_mem.rank == 0: symm_mem.barrier() assert buf.eq(43).all() else: new_val = torch.empty_like(buf) new_val.fill_(43) # Contiguous copies to/from a remote buffer utilize copy engines # which bypasses SMs (i.e. no need to load the data into registers) buf.copy_(new_val) symm_mem.barrier() ``` ### Custom CUDA Comm Kernels Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels. ```cpp TORCH_API c10::intrusive_ptr get_symmetric_memory( const at::Tensor& tensor); class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { public: ... virtual std::vector get_buffer_ptrs() = 0; virtual std::vector get_signal_pad_ptrs() = 0; virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; ... }; ``` ### Limitations of IntraNodeComm and ProcessGroupCudaP2p Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach: - Leads to awkward UX in which the required workspace needs to be specified upfront. - Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather). - Prevents torch.compile from eliminating all copies. In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels. * __->__ #128582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582 Approved by: https://github.com/wanchaol --- .lintrunner.toml | 1 + BUILD.bazel | 1 + build_variables.bzl | 2 + c10/cuda/driver_api.h | 19 +- caffe2/CMakeLists.txt | 1 + test/distributed/test_symmetric_memory.py | 156 +++++ torch/_C/_distributed_c10d.pyi | 30 + .../distributed/c10d/CUDASymmetricMemory.cu | 539 ++++++++++++++++++ .../distributed/c10d/CUDASymmetricMemory.hpp | 107 ++++ .../distributed/c10d/ProcessGroupCudaP2P.hpp | 1 + .../csrc/distributed/c10d/SymmetricMemory.cpp | 189 ++++++ .../csrc/distributed/c10d/SymmetricMemory.hpp | 152 +++++ torch/csrc/distributed/c10d/init.cpp | 39 ++ .../csrc/distributed/c10d/intra_node_comm.cpp | 99 +--- .../csrc/distributed/c10d/intra_node_comm.cu | 18 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 9 +- 16 files changed, 1252 insertions(+), 111 deletions(-) create mode 100644 test/distributed/test_symmetric_memory.py create mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cu create mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp create mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.cpp create mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.hpp diff --git a/.lintrunner.toml b/.lintrunner.toml index 2c3da39f80ccf..76dedf9ea0bdb 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,6 +68,7 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', + 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', 'torch/csrc/**/*.cpp', diff --git a/BUILD.bazel b/BUILD.bazel index 10c065f5084c7..c563c52d861e6 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,6 +744,7 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index ceb28707897e5..793b611a0a6f0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -501,6 +501,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", + "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", @@ -684,6 +685,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 43bcbd1d70bac..cbbdf16823ec7 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -18,14 +18,17 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ + _(cuMemGetAllocationGranularity) \ + _(cuMemExportToShareableHandle) \ + _(cuMemImportFromShareableHandle) \ _(cuGetErrorString) #define C10_NVML_DRIVER_API(_) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 89c31fab11347..8426741609fe7 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,6 +560,7 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py new file mode 100644 index 0000000000000..a768e059044f7 --- /dev/null +++ b/test/distributed/test_symmetric_memory.py @@ -0,0 +1,156 @@ +# Owner(s): ["module: c10d"] + +import torch + +import torch.distributed as dist +from torch._C._distributed_c10d import _SymmetricMemory +from torch.distributed.distributed_c10d import _get_process_group_store + +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, + skip_but_pass_in_sandcastle_if, + skipIfRocm, +) + + +def requires_cuda_p2p_access(): + cuda_p2p_access_available = ( + torch.cuda.is_available() and torch.cuda.device_count() >= 2 + ) + num_devices = torch.cuda.device_count() + for i in range(num_devices - 1): + for j in range(i + 1, num_devices): + if not torch.cuda.can_device_access_peer(i, j): + cuda_p2p_access_available = False + break + if not cuda_p2p_access_available: + break + + return skip_but_pass_in_sandcastle_if( + not cuda_p2p_access_available, + "cuda p2p access is not available", + ) + + +@instantiate_parametrized_tests +@requires_cuda_p2p_access() +class SymmetricMemoryTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 2 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + _SymmetricMemory.set_group_info( + "0", + self.rank, + self.world_size, + _get_process_group_store(dist.GroupMember.WORLD), + ) + + def _verify_symmetric_memory(self, symm_mem): + self.assertEqual(symm_mem.world_size, 2) + + buf = symm_mem.get_buffer(0, (64, 64), torch.float32) + if symm_mem.rank == 0: + symm_mem.wait_signal(src_rank=1) + self.assertTrue(buf.eq(42).all()) + else: + buf.fill_(42) + symm_mem.put_signal(dst_rank=0) + + symm_mem.barrier() + + if symm_mem.rank == 0: + symm_mem.barrier() + self.assertTrue(buf.eq(43).all()) + else: + buf.fill_(43) + symm_mem.barrier() + + symm_mem.barrier() + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_empty_strided_p2p(self) -> None: + self._init_process() + + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + group_name = "0" + alloc_args = (shape, stride, dtype, device, group_name) + + t = torch.empty(shape, dtype=dtype, device=device) + with self.assertRaises(RuntimeError): + _SymmetricMemory.rendezvous(t) + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + symm_mem = _SymmetricMemory.rendezvous(t) + + del t + self._verify_symmetric_memory(symm_mem) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_empty_strided_p2p_persistent(self) -> None: + self._init_process() + + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + alloc_id = 42 # Persistent allocation + group_name = "0" + alloc_args = (shape, stride, dtype, device, group_name, alloc_id) + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + data_ptr = t.data_ptr() + + # Verify that persistent allocation would fail if there's an active + # allocation with the same alloc_id. + with self.assertRaises(RuntimeError): + _SymmetricMemory.empty_strided_p2p(*alloc_args) + + # Verify that persistent allocation would succeed in lieu of activate + # allocations with the same alloc_id, and the returned tensor would + # have the same data pointer. + del t + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + self.assertEqual(t.data_ptr(), data_ptr) + + # Verify that get_symmetric_memory would fail if called before + # rendezvous. + with self.assertRaises(RuntimeError): + _SymmetricMemory.get_symmetric_memory(t) + + symm_mem_0 = _SymmetricMemory.rendezvous(t) + symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) + self.assertEqual(id(symm_mem_0), id(symm_mem_1)) + + self._verify_symmetric_memory(symm_mem_0) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index cffbf22219c8e..0095b5af434b5 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -637,3 +637,33 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... + +class _SymmetricMemory: + @staticmethod + def set_group_info( + group_name: str, rank: int, world_size: int, store: Store + ) -> None: ... + @staticmethod + def empty_strided_p2p( + size: torch.types._size, + stride: torch.types._size, + dtype: torch.dtype, + device: torch.device, + group_name: str, + ) -> torch.Tensor: ... + @property + def rank(self) -> int: ... + @property + def world_size(self) -> int: ... + @staticmethod + def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... + def get_buffer( + self, + rank: int, + sizes: torch.Size, + dtype: torch.dtype, + storage_offset: Optional[int] = 0, + ) -> torch.Tensor: ... + def barrier(self, channel: int = 0) -> None: ... + def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... + def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu new file mode 100644 index 0000000000000..d923fb6044f2b --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -0,0 +1,539 @@ +#include + +#include +#include +#include +#include + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#endif + +#include +#include + +namespace { + +constexpr size_t signal_pad_size = 2048; +const std::string store_comm_prefix = "CUDASymmetricMemory"; + +static size_t store_comm_seq_id = 0; + +template +std::vector store_all_gather( + const c10::intrusive_ptr& store, + int rank, + int world_size, + T val) { + static_assert(std::is_trivially_copyable_v); + + std::vector peer_keys; + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r; + peer_keys.push_back(oss.str()); + } + ++store_comm_seq_id; + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peer_keys[rank], payload); + } + + std::vector peer_vals; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + peer_vals.push_back(val); + continue; + } + store->wait({peer_keys[r]}); + auto payload = store->get(peer_keys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peer_val{}; + std::memcpy(&peer_val, payload.data(), sizeof(T)); + peer_vals.push_back(peer_val); + } + return peer_vals; +} + +void store_barrier( + const c10::intrusive_ptr& store, + int rank, + int world_size) { + store_all_gather(store, rank, world_size, 0); +} + +int import_remote_fd(int pid, int fd) { +#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd) + int pidfd = syscall(SYS_pidfd_open, pid, 0); + return syscall(SYS_pidfd_getfd, pidfd, fd, 0); +#else + TORCH_CHECK( + false, + "CUDASymmetricMemory requires pidfd_open ", + "and pidfd_getfd support"); +#endif +} + +void map_block( + void** ptr, + c10d::symmetric_memory::HandleType handle, + size_t size, + int device_idx) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + auto dev_ptr = reinterpret_cast(ptr); + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL)); + + CUmemAccessDesc desc; + desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + desc.location.id = static_cast(device_idx); + desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1)); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +} // namespace + +namespace c10d { +namespace symmetric_memory { + +CUDASymmetricMemory::CUDASymmetricMemory( + std::vector handles, + size_t block_size, + std::vector buffers, + std::vector signal_pads, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size) + : handles_(std::move(handles)), + block_size_(block_size), + buffers_(std::move(buffers)), + signal_pads_(std::move(signal_pads)), + buffer_size_(buffer_size), + local_device_idx_(local_device_idx), + rank_(rank), + world_size_(world_size) { + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + + c10::cuda::CUDAGuard guard(local_device_idx); + AT_CUDA_CHECK(cudaMemcpy( + buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); + AT_CUDA_CHECK(cudaMemcpy( + signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); +} + +CUDASymmetricMemory::~CUDASymmetricMemory() { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + c10::cuda::CUDAGuard guard(local_device_idx_); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + auto driver_api = c10::cuda::DriverAPI::get(); + for (int r = 0; r < world_size_; ++r) { + C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( + reinterpret_cast(buffers_[r]), block_size_)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); + } + c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); + c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +std::vector CUDASymmetricMemory::get_buffer_ptrs() { + return buffers_; +} + +std::vector CUDASymmetricMemory::get_signal_pad_ptrs() { + return signal_pads_; +} + +void** CUDASymmetricMemory::get_buffer_ptrs_dev() { + return buffers_dev_; +} + +void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() { + return signal_pads_dev_; +} + +size_t CUDASymmetricMemory::get_buffer_size() { + return buffer_size_; +} + +size_t CUDASymmetricMemory::get_signal_pad_size() { + return signal_pad_size; +} + +at::Tensor CUDASymmetricMemory::get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + const auto numel = + std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "CUDASymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(buffers_[rank], sizes) + .storage_offset(storage_offset) + .options(options) + .target_device(device) + .make_tensor(); +} + +void check_channel(int channel, int world_size) { + TORCH_CHECK( + channel >= 0, + "channel for barrier(), put_signal() and wait_signal() ", + "must be greater than 0 (got ", + channel, + ")"); + const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + TORCH_CHECK( + static_cast(channel) < num_channels, + "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", + num_channels - 1, + " (got ", + channel, + ")"); +} + +__device__ __forceinline__ void release_signal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + volatile uint32_t* signal = addr; + uint32_t val; + do { + val = *signal; + } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); +#endif +} + +__device__ __forceinline__ void acquire_signal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + volatile uint32_t* signal = addr; + uint32_t val; + do { + val = *signal; + } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); +#endif +} + +static __global__ void barrier_kernel( + uint32_t** signal_pads, + int channel, + int rank, + int world_size) { + if (threadIdx.x < world_size) { + auto target_rank = threadIdx.x; + release_signal(signal_pads[target_rank] + world_size * channel + rank); + acquire_signal(signal_pads[rank] + world_size * channel + target_rank); + } +} + +void CUDASymmetricMemory::barrier(int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +static __global__ void put_signal_kernel( + uint32_t** signal_pads, + int dst_rank, + int channel, + int rank, + int world_size) { + if (threadIdx.x == 0) { + release_signal(signal_pads[dst_rank] + world_size * channel + rank); + } +} + +void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + dst_rank, + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +static __global__ void wait_signal_kernel( + uint32_t** signal_pads, + int src_rank, + int channel, + int rank, + int world_size) { + if (threadIdx.x == 0) { + acquire_signal(signal_pads[rank] + world_size * channel + src_rank); + } + __threadfence_system(); +} + +void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + src_rank, + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +int CUDASymmetricMemory::get_rank() { + return rank_; +} + +int CUDASymmetricMemory::get_world_size() { + return world_size_; +} + +void* CUDASymmetricMemoryAllocator::alloc( + size_t size, + int device_idx, + const std::string& group_name) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + prop.location.id = device_idx; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + size_t signal_pad_offset = at::round_up(size, 16UL); + size_t block_size = signal_pad_offset + signal_pad_size; + + size_t granularity; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + block_size = at::round_up(block_size, granularity); + + HandleType handle; + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); + + void* ptr = nullptr; + map_block(&ptr, handle, block_size, device_idx); + + c10::cuda::CUDAGuard guard(device_idx); + AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); + + auto block = c10::make_intrusive( + handle, device_idx, block_size, size, signal_pad_offset, group_name); + { + std::unique_lock lock(mutex_); + ptr_to_block_.emplace(ptr, std::move(block)); + } + return ptr; +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +void CUDASymmetricMemoryAllocator::free(void* ptr) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto block = find_block(ptr); + if (block == nullptr) { + return; + } + // Initializing CUDASymmetricMemory with an allocation transfers its + // ownership to the CUDASymmetricMemory object. + if (block->symm_mem == nullptr) { + auto driver_api = c10::cuda::DriverAPI::get(); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( + reinterpret_cast(ptr), block->block_size)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); + } + { + std::unique_lock lock(mutex_); + ptr_to_block_.erase(ptr); + } +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + return block->buffer_size; +} + +struct RendezvousRequest { + int device_idx; + int block_fd; + int pid; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; +}; + +void validate_rendezvous_requests( + const std::vector reqs, + int world_size) { + TORCH_CHECK(reqs.size() == (size_t)world_size); + + std::unordered_set device_indices; + device_indices.reserve(world_size); + for (auto req : reqs) { + device_indices.insert(req.device_idx); + } + if (device_indices.size() < (size_t)world_size) { + TORCH_CHECK( + false, + "CUDASymmetricMemoryAllocator::rendezvous: ", + "detected allocations from overlapping devices ", + "from different ranks."); + } + + for (int r = 1; r < world_size; ++r) { + TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); + TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); + TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); + } +} + +c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( + void* ptr) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + + if (block->symm_mem != nullptr) { + return block->symm_mem; + } + + auto group_info = get_group_info(block->group_name); + auto driver_api = c10::cuda::DriverAPI::get(); + int block_fd; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); + + auto local_req = RendezvousRequest{ + .device_idx = block->device_idx, + .block_fd = block_fd, + .pid = getpid(), + .block_size = block->block_size, + .buffer_size = block->buffer_size, + .signal_pad_offset = block->signal_pad_offset}; + auto reqs = store_all_gather( + group_info.store, group_info.rank, group_info.world_size, local_req); + validate_rendezvous_requests(reqs, group_info.world_size); + + std::vector handles(group_info.world_size); + std::vector buffers(group_info.world_size, nullptr); + std::vector signal_pads(group_info.world_size, nullptr); + for (int r = 0; r < group_info.world_size; ++r) { + if (r == group_info.rank) { + handles[r] = block->handle; + buffers[r] = ptr; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + continue; + } + int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &handles[r], + (void*)(uintptr_t)imported_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + map_block(&buffers[r], handles[r], block->block_size, block->device_idx); + signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); + close(imported_fd); + } + store_barrier(group_info.store, group_info.rank, group_info.world_size); + close(block_fd); + + // Initializing CUDASymmetricMemory with an allocation transfers its + // ownership to the CUDASymmetricMemory object. So that outstanding + // references to the CUDASymmetricMemory object can keep the allocation + // alive. + block->symm_mem = c10::make_intrusive( + std::move(handles), + block->block_size, + std::move(buffers), + std::move(signal_pads), + block->buffer_size, + block->device_idx, + group_info.rank, + group_info.world_size); + return block->symm_mem; +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + return block->symm_mem != nullptr; +} + +c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { + std::shared_lock lock(mutex_); + auto it = ptr_to_block_.find(ptr); + if (it == ptr_to_block_.end()) { + return nullptr; + } + return it->second; +} + +struct RegisterCUDASymmetricMemoryAllocator { + RegisterCUDASymmetricMemoryAllocator() { + register_allocator( + c10::DeviceType::CUDA, + c10::make_intrusive()); + } +}; + +static RegisterCUDASymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp new file mode 100644 index 0000000000000..82e75d22c84f6 --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include + +namespace c10d { +namespace symmetric_memory { + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +using HandleType = CUmemGenericAllocationHandle; +#else +using HandleType = void*; +#endif + +class CUDASymmetricMemory : public SymmetricMemory { + public: + CUDASymmetricMemory( + std::vector handles, + size_t block_size, + std::vector buffers, + std::vector signal_pads, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size); + + ~CUDASymmetricMemory() override; + + std::vector get_buffer_ptrs() override; + std::vector get_signal_pad_ptrs() override; + void** get_buffer_ptrs_dev() override; + void** get_signal_pad_ptrs_dev() override; + size_t get_buffer_size() override; + size_t get_signal_pad_size() override; + + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) override; + + void barrier(int channel) override; + void put_signal(int dst_rank, int channel) override; + void wait_signal(int src_rank, int channel) override; + + int get_rank() override; + int get_world_size() override; + + private: + std::vector handles_; + size_t block_size_; + std::vector buffers_; + std::vector signal_pads_; + size_t buffer_size_; + int local_device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::optional> finalizer_; +}; + +struct Block : public c10::intrusive_ptr_target { + HandleType handle; + int device_idx; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + std::string group_name; + c10::intrusive_ptr symm_mem = nullptr; + + Block( + HandleType handle, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::string& group_name) + : handle(handle), + device_idx(device_idx), + block_size(block_size), + buffer_size(buffer_size), + signal_pad_offset(signal_pad_offset), + group_name(group_name), + symm_mem(nullptr) {} +}; + +class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc(size_t size, int device_idx, const std::string& group_name) + override; + + void free(void* ptr) override; + size_t get_alloc_size(void* ptr) override; + c10::intrusive_ptr rendezvous(void* ptr) override; + bool is_rendezvous_completed(void* ptr) override; + + private: + c10::intrusive_ptr find_block(void* ptr); + + std::shared_mutex mutex_; + std::unordered_map> ptr_to_block_; +}; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index cff4ad09b7064..7c41414c4e4e1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -10,6 +10,7 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout = namespace c10d { +// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API ProcessGroupCudaP2P : public Backend { public: struct Options : Backend::Options { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp new file mode 100644 index 0000000000000..b3d9f31bb0342 --- /dev/null +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -0,0 +1,189 @@ +#include + +namespace { + +using namespace c10d::symmetric_memory; + +class AllocatorMap { + public: + static AllocatorMap& get() { + static AllocatorMap instance; + return instance; + } + + void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator) { + map_[device_type] = std::move(allocator); + } + + c10::intrusive_ptr get_allocator( + c10::DeviceType device_type) { + auto it = map_.find(device_type); + TORCH_CHECK( + it != map_.end(), + "SymmetricMemory does not support device type ", + device_type); + return it->second; + } + + ~AllocatorMap() { + for (auto& it : map_) { + it.second.release(); + } + } + + private: + AllocatorMap() = default; + AllocatorMap(const AllocatorMap&) = delete; + AllocatorMap& operator=(const AllocatorMap&) = delete; + + std::unordered_map< + c10::DeviceType, + c10::intrusive_ptr> + map_; +}; + +static std::unordered_map group_info_map{}; + +// Data structures for tracking persistent allocations +static std::unordered_map alloc_id_to_dev_ptr{}; +static std::unordered_map> + alloc_id_to_storage{}; + +static at::Tensor empty_strided_p2p_persistent( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + uint64_t alloc_id) { + // Make the allocation fails if a previous allocation with the same alloc_id + // is still active. + auto storage = alloc_id_to_storage.find(alloc_id); + if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) { + TORCH_CHECK( + false, + "SymmetricMemory::empty_strided_p2p_persistent: ", + "can not allocate with alloc_id == ", + alloc_id, + " because a previous allocation with the same alloc_id " + "is still active."); + } + + const size_t numel = + std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t element_size = c10::elementSize(dtype); + const size_t alloc_size = numel * element_size; + + auto allocator = get_allocator(device.type()); + void* dev_ptr = nullptr; + if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) { + dev_ptr = alloc_id_to_dev_ptr[alloc_id]; + TORCH_CHECK( + alloc_size == allocator->get_alloc_size(dev_ptr), + "SymmetricMemory::empty_strided_p2p_persistent: ", + "requested allocation size (", + alloc_size, + ") is different from the size of a previous allocation ", + "with the same alloc_id ", + allocator->get_alloc_size(dev_ptr)); + } else { + dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); + alloc_id_to_dev_ptr[alloc_id] = dev_ptr; + } + + auto options = at::TensorOptions().dtype(dtype).device(device); + auto allocated = at::from_blob(dev_ptr, size, stride, options); + + // Track the allocation's activeness + alloc_id_to_storage.erase(alloc_id); + alloc_id_to_storage.emplace( + alloc_id, allocated.storage().getWeakStorageImpl()); + return allocated; +} + +} // namespace + +namespace c10d { +namespace symmetric_memory { + +void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator) { + return AllocatorMap::get().register_allocator( + device_type, std::move(allocator)); +} + +c10::intrusive_ptr get_allocator( + c10::DeviceType device_type) { + return AllocatorMap::get().get_allocator(device_type); +} + +void set_group_info( + const std::string& group_name, + int rank, + int world_size, + c10::intrusive_ptr store) { + TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end()); + GroupInfo group_info; + group_info.rank = rank; + group_info.world_size = world_size; + group_info.store = std::move(store); + group_info_map.emplace(group_name, std::move(group_info)); +} + +const GroupInfo& get_group_info(const std::string& group_name) { + TORCH_CHECK( + group_info_map.find(group_name) != group_info_map.end(), + "get_group_info: no group info associated with the group name ", + group_name); + return group_info_map[group_name]; +} + +at::Tensor empty_strided_p2p( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + std::optional alloc_id) { + if (alloc_id.has_value()) { + return empty_strided_p2p_persistent( + size, stride, dtype, device, group_name, *alloc_id); + } + const size_t numel = + std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t element_size = c10::elementSize(dtype); + const size_t alloc_size = numel * element_size; + + auto allocator = get_allocator(device.type()); + void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); + + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::from_blob( + dev_ptr, + size, + stride, + [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); }, + options); +} + +TORCH_API c10::intrusive_ptr rendezvous( + const at::Tensor& tensor) { + auto allocator = get_allocator(tensor.device().type()); + return allocator->rendezvous(tensor.data_ptr()); +} + +c10::intrusive_ptr get_symmetric_memory( + const at::Tensor& tensor) { + auto allocator = get_allocator(tensor.device().type()); + TORCH_CHECK( + allocator->is_rendezvous_completed(tensor.data_ptr()), + "SymmetricMemory: must invoke rendezvous on a tensor ", + "before calling get_symmetric_memory on it"); + return allocator->rendezvous(tensor.data_ptr()); +} + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp new file mode 100644 index 0000000000000..344b86ea5c7e3 --- /dev/null +++ b/torch/csrc/distributed/c10d/SymmetricMemory.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include +#include + +namespace c10d { +namespace symmetric_memory { + +// SymmetricMemory represents symmetric allocations across a group of devices. +// The allocations represented by a SymmetricMemory object are accessible by +// all devices in the group. The class can be used for op-level custom +// communication patterns (via the get_buffer APIs and the synchronization +// primitives), as well as custom communication kernels (via the buffer and +// signal_pad device pointers). +// +// To acquire a SymmetricMemory object, each rank first allocates +// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes +// SymmetricMemoryAllocator::rendezvous() on the memory to establish the +// association across peer buffers. The rendezvous is a one-time process, and +// the mapping between a local memory memory and the associated SymmetricMemory +// object is unique. +// +// NOTE [symmetric memory signal pad] +// Signal pads are P2P-accessible memory regions designated for +// synchronization. SymmetricMemory offers built-in synchronization primitives +// such as barriers, put_signal, and wait_signal, which are all based on signal +// pads. Users may utilize signal pads for their own synchronization logic, +// provided that the signal pads remain zero-filled following successful +// synchronization. +// +// NOTE [symmetric memory synchronization channel] +// Synchronization channels allow users to use a single SymmetricMemory object +// to perform isolated synchronizations on different streams. For example, +// consider the case in which two barriers are issued on two streams for +// different purposes. Without the concept of channels, we cannot guarantee the +// correctness of the barriers since signals issued from barrier on stream A +// can be received by the barrier on stream B. By specifying different channels +// for these two barriers, they can operate correctly in parallel. +class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { + public: + virtual ~SymmetricMemory() {} + + virtual std::vector get_buffer_ptrs() = 0; + virtual std::vector get_signal_pad_ptrs() = 0; + + // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer + // to a device array of size world_size, containing buffer pointers and + // signal pad pointers, respectively. + virtual void** get_buffer_ptrs_dev() = 0; + virtual void** get_signal_pad_ptrs_dev() = 0; + virtual size_t get_buffer_size() = 0; + virtual size_t get_signal_pad_size() = 0; + + virtual at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) = 0; + + virtual void barrier(int channel) = 0; + virtual void put_signal(int dst_rank, int channel) = 0; + virtual void wait_signal(int src_rank, int channel) = 0; + + virtual int get_rank() = 0; + virtual int get_world_size() = 0; +}; + +class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { + public: + virtual ~SymmetricMemoryAllocator(){}; + + virtual void* alloc( + size_t size, + int device_idx, + const std::string& group_name) = 0; + + virtual void free(void* ptr) = 0; + virtual size_t get_alloc_size(void* ptr) = 0; + virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; + virtual bool is_rendezvous_completed(void* ptr) = 0; +}; + +C10_EXPORT void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator); + +C10_EXPORT c10::intrusive_ptr get_allocator( + c10::DeviceType device_type); + +// Set a store for rendezvousing symmetric allocations on a group of devices +// identified by `group_name`. The concept of groups is logical; users can +// utilize predefined groups (e.g., a group of device identified by a +// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator +// backends might employ a more efficient communication channel for the actual +// rendezvous process and only use the store for bootstrapping purposes. +TORCH_API void set_group_info( + const std::string& group_name, + int rank, + int world_size, + c10::intrusive_ptr store); + +struct GroupInfo { + int rank; + int world_size; + c10::intrusive_ptr store; +}; + +C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); + +// Identical to empty_strided, but allows symmetric memory access to be +// established for the allocated tensor via SymmetricMemory::rendezvous(). This +// function itself is not a collective operation. It invokes +// SymmetricMemoryAllocator::alloc() for the requested device under the hood. +// +// NOTE [symmetric memory persistent allocation] +// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent +// allocation. This makes the function cache allocated memory and ensure that +// invocations with the same `alloc_id` receive tensors backed by the same +// memory address. For safety, if a previous persistent allocation is still +// active (i.e., the storage of the returned tensor is still alive), persistent +// allocations with the same `alloc_id` will fail. This determinism coupled +// with memory planning of communication buffers (e.g., by Inductor) allows +// communication algorithms to reliably reuse previously established remote +// memory access. +TORCH_API at::Tensor empty_strided_p2p( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + std::optional alloc_id); + +// Establishes symmetric memory access on tensors allocated via +// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a +// one-time process, and the mapping between a local memory region and the +// associated SymmetricMemory object is unique. Subsequent calls to +// rendezvous() with the same tensor, or tensors allocated with +// empty_strided_p2p_persistent() using the same alloc_id, will receive the +// cached SymmetricMemory object. +// +// The function has a collective semantic and must be invoked simultaneously +// from all rendezvous participants. +TORCH_API c10::intrusive_ptr rendezvous( + const at::Tensor& tensor); + +// Returns the SymmetricMemory object associated with the tensor. It can only +// be invoked after rendezvous() but does not need to be invoked collectively. +TORCH_API c10::intrusive_ptr get_symmetric_memory( + const at::Tensor& tensor); + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f1b28886b989..db5778efcf354 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include @@ -975,6 +976,44 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); + using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; + py::class_>( + module, "_SymmetricMemory") + .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info) + .def_static( + "empty_strided_p2p", + ::c10d::symmetric_memory::empty_strided_p2p, + py::arg("size"), + py::arg("stride"), + py::arg("dtype"), + py::arg("device"), + py::arg("group_name"), + py::arg("alloc_id") = py::none()) + .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) + .def_static( + "get_symmetric_memory", + &::c10d::symmetric_memory::get_symmetric_memory) + .def_property_readonly("rank", &SymmetricMemory::get_rank) + .def_property_readonly("world_size", &SymmetricMemory::get_world_size) + .def( + "get_buffer", + &SymmetricMemory::get_buffer, + py::arg("rank"), + py::arg("sizes"), + py::arg("dtype"), + py::arg("storage_offset") = 0) + .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) + .def( + "put_signal", + &SymmetricMemory::put_signal, + py::arg("dst_rank"), + py::arg("channel") = 0) + .def( + "wait_signal", + &SymmetricMemory::wait_signal, + py::arg("src_rank"), + py::arg("channel") = 0); + auto store = py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 85136a91e0256..9d7ba5abf951d 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -218,23 +218,8 @@ IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { return; } - // Intentionally releasing resources without synchronizing devices. The - // teardown logic is safe for propoerly sync'd user program. We don't want - // improperly sync'd user program to hang here. - for (size_t r = 0; r < worldSize_; ++r) { - if (r == rank_) { - continue; - } - AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); - AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); - } - AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); - AT_CUDA_CHECK(cudaFree(buffers_[rank_])); - if (topoInfo_ != nullptr) { - AT_CUDA_CHECK(cudaFree(topoInfo_)); - } - AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); - AT_CUDA_CHECK(cudaFree(buffersDev_)); + auto allocator = get_allocator(c10::DeviceType::CUDA); + allocator->free(symmetricMemoryPtr_); } bool IntraNodeComm::isEnabled() { @@ -344,83 +329,19 @@ bool IntraNodeComm::rendezvous() { // Detect topology Topology topology = detectTopology(nvlMesh, worldSize_); - // Initialize p2p state - auto p2pState = initP2pState(); - - // Allocate buffer - void* buffer = nullptr; - AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_)); - - // Second handshake: exchange topology and CUDA IPC handles - struct IpcInfo { - NvlMesh nvlMesh; - Topology topology; - cudaIpcMemHandle_t p2pStateHandle, bufferHandle; - }; - - // Make p2p state and buffer available for IPC - cudaIpcMemHandle_t p2pStateHandle, bufferHandle; - AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); - AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); - - IpcInfo ipcInfo{ - .nvlMesh = nvlMesh, - .topology = topology, - .p2pStateHandle = p2pStateHandle, - .bufferHandle = bufferHandle}; - - auto peerIpcInfos = - storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo); - - for (const auto& info : peerIpcInfos) { - if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || - info.topology != peerIpcInfos.front().topology) { - LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " - "participants are observing different topologies (" - << int(info.topology) << " and " << int(topology) << ")"; - AT_CUDA_CHECK(cudaFree(p2pState)); - AT_CUDA_CHECK(cudaFree(buffer)); - return false; - } - } - - std::array p2pStates = {}, buffers = {}; - for (size_t r = 0; r < peerIpcInfos.size(); ++r) { - if (r == rank_) { - p2pStates[r] = p2pState; - buffers[r] = buffer; - } else { - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - &p2pStates[r], - peerIpcInfos[r].p2pStateHandle, - cudaIpcMemLazyEnablePeerAccess)); - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - &buffers[r], - peerIpcInfos[r].bufferHandle, - cudaIpcMemLazyEnablePeerAccess)); - } - } - void* p2pStatesDev = nullptr; - AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); - AT_CUDA_CHECK(cudaMemcpy( - p2pStatesDev, - p2pStates.data(), - sizeof(p2pStates), - cudaMemcpyHostToDevice)); - - void* buffersDev = nullptr; - AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); - AT_CUDA_CHECK(cudaMemcpy( - buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); + set_group_info("IntraNodeComm", rank_, worldSize_, store_); + auto allocator = get_allocator(c10::DeviceType::CUDA); + symmetricMemoryPtr_ = + allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm"); + symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); + TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); isInitialized_ = true; topology_ = topology; - std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin()); - std::copy(buffers.begin(), buffers.end(), buffers_.begin()); - p2pStatesDev_ = p2pStatesDev; - buffersDev_ = buffersDev; + p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); + buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); topoInfo_ = topoInfo; return true; #endif diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index 51fc6252d2235..ac751ff7be1e0 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -132,6 +132,8 @@ struct P2pState { uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; }; +static_assert(sizeof(P2pState) <= kP2pStateSize); + template static __global__ void oneShotAllReduceKernel( at::BFloat16* input, @@ -522,7 +524,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks; if (!fuseInputCopy) { AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -582,7 +584,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -632,7 +634,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -755,15 +757,7 @@ at::Tensor IntraNodeComm::getBuffer( const std::vector& sizes, c10::ScalarType dtype, int64_t storageOffset) { - const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0); - const auto elementSize = c10::elementSize(dtype); - TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_); - auto options = at::TensorOptions().dtype(dtype).device( - at::kCUDA, at::cuda::current_device()); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storageOffset) - .options(options) - .make_tensor(); + return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); } } // namespace intra_node_comm diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index 5d7e2d426d30a..a67df5c34586a 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -4,12 +4,16 @@ #include #include #include +#include #include namespace c10d::intra_node_comm { +using namespace c10d::symmetric_memory; + constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; +constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; using HybridCubeMesh = std::array, kMaxDevices>; @@ -27,6 +31,7 @@ enum class AllReduceAlgo : uint8_t { HCM = 3 }; +// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { public: IntraNodeComm( @@ -97,8 +102,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool isInitialized_ = false; Topology topology_ = Topology::UNKNOWN; - std::array p2pStates_{}; - std::array buffers_{}; + void* symmetricMemoryPtr_ = nullptr; + c10::intrusive_ptr symmetricMemory_ = nullptr; void* p2pStatesDev_{}; void* buffersDev_{}; void* topoInfo_{}; From eb9f4da11e86882b5c628cea539112de9638760a Mon Sep 17 00:00:00 2001 From: chilli Date: Tue, 18 Jun 2024 11:23:49 -0700 Subject: [PATCH 16/18] Modified template indexing to broadcast indices to out instead of mask and some other flexattention micro-opts (#128938) For headdim=64 and headdim=128 Old: image New: image Note, this does regress headdim=256. We can unregress it by special casing `headdim=256`, but ehh.... we can do it later Pull Request resolved: https://github.com/pytorch/pytorch/pull/128938 Approved by: https://github.com/drisspg --- benchmarks/transformer/score_mod.py | 4 +-- torch/_inductor/kernel/flex_attention.py | 32 +++++++++--------- torch/_inductor/select_algorithm.py | 42 +++++++++++++++++++----- 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index 29d951bc1dee8..135f26b0df2d9 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -264,7 +264,7 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]: batch_sizes = [2, 8, 16] num_heads = [16] q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)] - head_dims = [64, 128, 256] + head_dims = [64, 128] dtypes = [ torch.bfloat16, ] @@ -302,8 +302,6 @@ def main(dynamic: bool, calculate_bwd: bool): results.append( Experiment(config, run_single_experiment(config, dynamic=dynamic)) ) - for config in tqdm(generate_experiment_configs(calculate_bwd)): - results.append(Experiment(config, run_single_experiment(config))) print_results(results) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 987dc6d89328b..edb69068f0cd0 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -242,10 +242,8 @@ def build_subgraph_buffer( start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk) + qk = tl.dot(q, k) # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ m = offs_m[:, None] n = start_n + offs_n[None, :] @@ -265,24 +263,26 @@ def build_subgraph_buffer( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -- compute scaling constant --- - row_max = tl.max(post_mod_scores, 1) - m_i_new = tl.maximum(m_i, row_max) + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(post_mod_scores - m_i_new[:, None]) + alpha = tl.math.exp2(m_i - m_ij) + p = tl.math.exp2(post_mod_scores - m_ij[:, None]) if not ROWS_GUARANTEED_SAFE: - masked_out_rows = (m_i_new == float("-inf")) + masked_out_rows = (m_ij == float("-inf")) alpha = tl.where(masked_out_rows, 0, alpha) p = tl.where(masked_out_rows[:, None], 0, p) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc) - - # -- update m_i and l_i -- + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new + # # -- scale and update acc -- + acc = acc * alpha[:, None] + v = tl.load(V_block_ptr) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc) + + # -- update m_i + m_i = m_ij # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -294,8 +294,8 @@ def build_subgraph_buffer( idx_m = offs_m[:, None] idx_d = tl.arange(0, BLOCK_DMODEL)[None, :] + mask = idx_m < Q_LEN # TODO generalize and add proper mask support - mask = (idx_m != -1) & (idx_d != -1) {{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}} # TODO dont want to write this if we dont require grad diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index dd78d2869ce24..fb43e7da1d139 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -13,6 +13,7 @@ import sys import textwrap import time +from collections import namedtuple from concurrent.futures import ThreadPoolExecutor from io import StringIO @@ -102,6 +103,16 @@ def finalize_all(self) -> str: return self.code +SubgraphInfo = namedtuple( + "SubgraphInfo", + [ + "body", + "template_mask", + "template_out", + ], +) + + class TritonTemplateKernel(TritonKernel): def __init__( self, @@ -132,7 +143,6 @@ def __init__( self.named_input_nodes = {} # type: ignore[var-annotated] self.defines = defines self.kernel_name = kernel_name - self.template_mask = None self.use_jit = use_jit self.num_stages = num_stages self.num_warps = num_warps @@ -147,21 +157,34 @@ def __init__( self.triton_meta: Optional[Dict[str, object]] = None # For Templated Attention this can be a list of ir.Subgraph self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs + + # The following attributes (body, template_mask, output_val) are all + # used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + self.subgraph_bodies: Dict[str, SubgraphInfo] = {} + self.body: IndentedBuffer = FakeIndentedBuffer() - self.subgraph_bodies: Dict[str, IndentedBuffer] = {} + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None @contextlib.contextmanager def set_subgraph_body(self, body_name: str): - old_body = self.body + old_body, old_mask, old_out = self.body, self.template_mask, self.template_out assert body_name in self.subgraph_bodies, body_name - self.body = self.subgraph_bodies[body_name] + self.body, self.template_mask, self.template_out = self.subgraph_bodies[ + body_name + ] yield - self.body = old_body + self.subgraph_bodies[body_name] = SubgraphInfo( + self.body, self.template_mask, self.template_out + ) + self.body, self.template_mask, self.template_out = old_body, old_mask, old_out @contextlib.contextmanager def create_subgraph_body(self, body_name: str): assert body_name not in self.subgraph_bodies - self.subgraph_bodies[body_name] = IndentedBuffer() + self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None) with self.set_subgraph_body(body_name): yield @@ -406,7 +429,8 @@ def store_output( self.range_trees[0].lookup( sympy.Integer(1), sympy_product(lengths) ).set_name("xindex") - self.template_mask = mask # type: ignore[assignment] + self.template_mask = mask + self.template_out = val self.template_indices = indices output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) @@ -492,7 +516,9 @@ def indexing( return super().indexing( index, dense_indexing=False, - copy_shape=self.template_mask, + # We pass template_out as the shape to broadcast the indexing to as + # the mask might be broadcast to the output shape + copy_shape=self.template_out, override_mask=self.template_mask, block_ptr=block_ptr, ) From acefc5c0160d8e37858b3c28fff07e6513b78e10 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 19 Jun 2024 03:45:41 +0000 Subject: [PATCH 17/18] [torch.compile] Enable bwd compilation metrics (#128973) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128973 Approved by: https://github.com/dshi7 --- torch/_dynamo/utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 9fa70e0c98d52..e283308aa37dc 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -714,18 +714,15 @@ def record_compilation_metrics( name = "compilation_metrics" else: name = "bwd_compilation_metrics" - # Currently only record fwd compilation metrics, will add bwd compilation metrics - # after the internal Scuba logging changes finish. - if isinstance(compilation_metrics, CompilationMetrics): - torch._logging.trace_structured( - name, - lambda: { - k: list(v) if isinstance(v, set) else v - for k, v in dataclasses.asdict(compilation_metrics).items() - }, - ) - if config.log_compilation_metrics: - log_compilation_event(compilation_metrics) + torch._logging.trace_structured( + name, + lambda: { + k: list(v) if isinstance(v, set) else v + for k, v in dataclasses.asdict(compilation_metrics).items() + }, + ) + if config.log_compilation_metrics: + log_compilation_event(compilation_metrics) def set_compilation_metrics_limit(new_size: int) -> None: From 1f0a68b57290afff9691d823829fda6ba4f73cbb Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Wed, 19 Jun 2024 03:56:20 +0000 Subject: [PATCH 18/18] [ROCm] Fix fp32 atomicAdd for non-MI100 GPUs (#128750) Current implementation is very specific to MI100. This is causing performance degradation for other GPUs. Fixes #128631 Benchmarking on MI300X: ``` Before: 1918.5126953125 ms After: 0.8285150527954102 ms ``` Co-authored-by: Jeff Daily Pull Request resolved: https://github.com/pytorch/pytorch/pull/128750 Approved by: https://github.com/xw285cornell --- aten/src/ATen/cuda/Atomic.cuh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index 56ee8f87e2530..c8f5e91d3ff7e 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -334,7 +334,13 @@ static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) /* Special case fp32 atomic. */ #if defined(USE_ROCM) -static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { +#if defined(__gfx908__) + atomicAddNoRet(address, val); +#else + (void)unsafeAtomicAdd(address, val); +#endif +} #else static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); } #endif