From 108318ad1038f4f3ad0da4f54f53effdd9ef365a Mon Sep 17 00:00:00 2001 From: David Berard Date: Tue, 18 Jun 2024 15:40:45 +0000 Subject: [PATCH 01/64] [BE][JIT] Handle case where codegen object can be unset (#128951) Summary: Unblocks a test that's failing. `codegen` can be unset until `compile` is called. If `codegen` is not set, then just use the kernel name directly. Test Plan: ``` buck2 run //caffe2/test:tensorexpr -- --regex test_simple_add ``` Differential Revision: D58727391 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128951 Approved by: https://github.com/aaronenyeshi --- torch/csrc/jit/tensorexpr/kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index d7c737d8f8f2..e5ea5bb46e0e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -181,7 +181,7 @@ class TORCH_API TensorExprKernel { } const std::string& getKernelName() const { - return codegen_->kernel_func_name(); + return (codegen_ ? codegen_->kernel_func_name() : kernel_func_name_); } const std::vector& getSymbolicShapeInputs() const { From ec616da51848bcfa9d0bd9c693c62b50fbe84c0f Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 18 Jun 2024 16:16:38 +0000 Subject: [PATCH 02/64] RNN API cleanup for cuDNN 9.1 (#122011) Can potentially avoid a bit of boilerplate if we move directly to cuDNN 9.1's RNN API... Co-authored-by: Aaron Gokaslan Pull Request resolved: https://github.com/pytorch/pytorch/pull/122011 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cudnn/RNN.cpp | 32 +++++++++++------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 55c666eeca83..c90a6fd7a6c9 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -614,8 +614,6 @@ void add_projection_weights( /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer)); #else - void* unused_pointer; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( /*handle=*/handle, @@ -626,8 +624,8 @@ void add_projection_weights( /*linLayerID=*/linear_id, /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); #endif cudnnDataType_t data_type; @@ -735,8 +733,6 @@ get_parameters( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer = nullptr; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; for (int stateless = 0; stateless < 100; stateless++) { if (cudnn_method) { // matrix @@ -749,8 +745,8 @@ get_parameters( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, @@ -759,8 +755,8 @@ get_parameters( weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, - unused_desc.mut_desc(), - &unused_pointer, + nullptr, + nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } @@ -922,8 +918,6 @@ std::vector get_expected_data_ptrs( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer = nullptr; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; if (cudnn_method) { // matrix AT_CUDNN_CHECK(cudnnGetRNNWeightParams( @@ -935,8 +929,8 @@ std::vector get_expected_data_ptrs( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, @@ -945,8 +939,8 @@ std::vector get_expected_data_ptrs( weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, - unused_desc.mut_desc(), - &unused_pointer, + nullptr, + nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } @@ -972,8 +966,6 @@ std::vector get_expected_data_ptrs( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( @@ -985,8 +977,8 @@ std::vector get_expected_data_ptrs( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); #endif data_ptrs.push_back(matrix_pointer); } From 9818283da18de00047760ec4431870d3f8e620a6 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 14 Jun 2024 19:12:10 +0000 Subject: [PATCH 03/64] re-enable jacrev/jacfwd/hessian after #128028 landed (#128622) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128622 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 69 ------------------- ...ion_no_setup_context_transform_hessian_cpu | 0 ...tion_no_setup_context_transform_jacfwd_cpu | 0 ...essianCPU.test_jacfwd_different_levels_cpu | 0 test/functorch/test_eager_transforms.py | 4 +- torch/_functorch/eager_transforms.py | 4 -- torch/testing/_internal/common_utils.py | 1 - 7 files changed, 2 insertions(+), 76 deletions(-) create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu create mode 100644 test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index dca6d28d1912..f2df33bdda67 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2746,26 +2746,6 @@ def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): wrapped_gm = backend.graphs[graph_idx] return wrapped_gm - def test_hessian_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.hessian(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_hessian(self): counters.clear() @@ -2900,7 +2880,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_hessian_argnums(self): counters.clear() @@ -3046,7 +3025,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """ return (unflatten,)""", ) - @unittest.expectedFailure def test_hessian_disable_capture(self): counters.clear() @@ -3073,26 +3051,6 @@ def wrapper_fn(x): ) self.assertEqual(actual, expected) - def test_jacrev_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.jacrev(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_jacrev(self): counters.clear() @@ -3169,7 +3127,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_jacrev_two_tensors_argnums(self): counters.clear() @@ -3252,7 +3209,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacrev_has_aux(self): counters.clear() @@ -3337,7 +3293,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacrev_disable_capture(self): counters.clear() @@ -4284,26 +4239,6 @@ def wrapper_fn(x, y): self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) - def test_jacfwd_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.jacfwd(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_jacfwd(self): counters.clear() @@ -4387,7 +4322,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_jacfwd_two_tensors_argnums(self): counters.clear() @@ -4477,7 +4411,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_has_aux(self): counters.clear() @@ -4572,7 +4505,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_randomness(self): counters.clear() @@ -4676,7 +4608,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_disable_capture(self): counters.clear() diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu b/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 8107f865f7bc..c767810beb85 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -77,6 +77,7 @@ subtest, TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, ) from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -2341,8 +2342,7 @@ def f(x): self.assertEqual(actual, expected) # https://github.com/pytorch/pytorch/issues/127036 - # it won't fail as jacrev/jacfwd were not inlined (see #128255) - # @xfailIfTorchDynamo + @xfailIfTorchDynamo @parametrize("_preallocate_and_copy", (True, False)) def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): # With chunk_size=1, we shouldn't `vmap` and hence not be limited diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index fbea5164014b..fff6bd67838f 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -767,8 +767,6 @@ def compute_jacobian_preallocate_and_copy(): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) - else: - wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn @@ -1350,8 +1348,6 @@ def push_jvp(basis): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) - else: - wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2d5ea4a6c64f..8daeefdee9d8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -5008,7 +5008,6 @@ def repl_frame(m): return m.group(0) s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n .+\n( +[~^]+ *\n)?', repl_frame, s) - s = re.sub(r'( Date: Tue, 18 Jun 2024 17:15:05 +0000 Subject: [PATCH 04/64] [EZ] Fix typos in RELEASE.md (#128769) This PR fixes typo in `RELEASE.md` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128769 Approved by: https://github.com/yumium, https://github.com/mikaylagawarecki --- RELEASE.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 3c9d68f9a6cd..7091052c85bd 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -290,7 +290,7 @@ After the final RC is created. The following tasks should be performed : * Create validation issue for the release, see for example [Validations for 2.1.2 release](https://github.com/pytorch/pytorch/issues/114904) and perform required validations. -* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no prerformance regressions. +* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no performance regressions. * Prepare and stage PyPI binaries for promotion. This is done with this script: [`pytorch/builder:release/pypi/promote_pypi_to_staging.sh`](https://github.com/pytorch/builder/blob/main/release/pypi/promote_pypi_to_staging.sh) @@ -429,12 +429,12 @@ need to support these particular versions of software. ## Operating Systems Supported OS flavors are summarized in the table below: -| Operating System family | Architectrue | Notes | +| Operating System family | Architecture | Notes | | --- | --- | --- | | Linux | aarch64, x86_64 | Wheels are manylinux2014 compatible, i.e. they should be runnable on any Linux system with glibc-2.17 or above. | | MacOS | arm64 | Builds should be compatible with MacOS 11 (Big Sur) or newer, but are actively tested against MacOS 14 (Sonoma). | | MacOS | x86_64 | Requires MacOS Catalina or above, not supported after 2.2, see https://github.com/pytorch/pytorch/issues/114602 | -| Windows | x86_64 | Buils are compatible with Windows-10 or newer. | +| Windows | x86_64 | Builds are compatible with Windows-10 or newer. | # Submitting Tutorials From 4e03263224af813fbf5e0e745e84c13268c48dc7 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 18 Jun 2024 17:26:23 +0000 Subject: [PATCH 05/64] [CUDA][Convolution] Add missing launch bounds to `vol2col_kernel` (#128740) Fix "too many resources requested" that can happen with recent toolkits on V100. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128740 Approved by: https://github.com/mikaylagawarecki --- aten/src/ATen/native/cuda/vol2col.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/cuda/vol2col.cuh b/aten/src/ATen/native/cuda/vol2col.cuh index 98ec2c3522d5..222270e86216 100644 --- a/aten/src/ATen/native/cuda/vol2col.cuh +++ b/aten/src/ATen/native/cuda/vol2col.cuh @@ -14,6 +14,7 @@ using namespace at::cuda::detail; // Kernel for fast unfold+copy on volumes template +C10_LAUNCH_BOUNDS_1(1024) __global__ void vol2col_kernel( const int64_t n, const T* data_vol, From 84c86e56bd8b86ae47c18b77141c1fe46188c5b7 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 18 Jun 2024 17:48:47 +0000 Subject: [PATCH 06/64] Update tracker issues after successfully cherry-picking a PR (#128924) This extends the capacity of the cherry-pick bot to automatically update the tracker issue with the information. For this to work, the tracker issue needs to be an open one with a `release tracker` label, i.e. https://github.com/pytorch/pytorch/issues/128436. The version from the release branch, i.e. `release/2.4`, will be match with the title of the tracker issue, i.e. `[v.2.4.0] Release Tracker` or `[v.2.4.1] Release Tracker` ### Testing `python cherry_pick.py --onto-branch release/2.4 --classification release --fixes "DEBUG DEBUG" --github-actor huydhn 128718` * On the PR https://github.com/pytorch/pytorch/pull/128718#issuecomment-2174846771 * On the tracker issue https://github.com/pytorch/pytorch/issues/128436#issuecomment-2174846757 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128924 Approved by: https://github.com/atalman --- .github/scripts/cherry_pick.py | 114 ++++++++++++++++++++++++++++---- .github/scripts/github_utils.py | 9 +++ 2 files changed, 111 insertions(+), 12 deletions(-) diff --git a/.github/scripts/cherry_pick.py b/.github/scripts/cherry_pick.py index 4c892de21da8..2650a5060d0f 100755 --- a/.github/scripts/cherry_pick.py +++ b/.github/scripts/cherry_pick.py @@ -3,11 +3,11 @@ import json import os import re -from typing import Any, Optional +from typing import Any, cast, Dict, List, Optional from urllib.error import HTTPError -from github_utils import gh_fetch_url, gh_post_pr_comment +from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo from trymerge import get_pr_commit_sha, GitHubPR @@ -19,6 +19,7 @@ "critical", "fixnewfeature", } +RELEASE_BRANCH_REGEX = re.compile(r"release/(?P.+)") def parse_args() -> Any: @@ -58,6 +59,33 @@ def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]: return commit_sha if pr.is_closed() else None +def get_release_version(onto_branch: str) -> Optional[str]: + """ + Return the release version if the target branch is a release branch + """ + m = re.match(RELEASE_BRANCH_REGEX, onto_branch) + return m.group("version") if m else "" + + +def get_tracker_issues( + org: str, project: str, onto_branch: str +) -> List[Dict[str, Any]]: + """ + Find the tracker issue from the repo. The tracker issue needs to have the title + like [VERSION] Release Tracker following the convention on PyTorch + """ + version = get_release_version(onto_branch) + if not version: + return [] + + tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"]) + if not tracker_issues: + return [] + + # Figure out the tracker issue from the list by looking at the title + return [issue for issue in tracker_issues if version in issue.get("title", "")] + + def cherry_pick( github_actor: str, repo: GitRepo, @@ -77,17 +105,49 @@ def cherry_pick( ) try: + org, project = repo.gh_owner_and_name() + + cherry_pick_pr = "" if not dry_run: - org, project = repo.gh_owner_and_name() cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch) - msg = f"The cherry pick PR is at {cherry_pick_pr}" - if fixes: - msg += f" and it is linked with issue {fixes}" - elif classification in REQUIRES_ISSUE: - msg += f" and it is recommended to link a {classification} cherry pick PR with an issue" + tracker_issues_comments = [] + tracker_issues = get_tracker_issues(org, project, onto_branch) + for issue in tracker_issues: + issue_number = int(str(issue.get("number", "0"))) + if not issue_number: + continue + + res = cast( + Dict[str, Any], + post_tracker_issue_comment( + org, + project, + issue_number, + pr.pr_num, + cherry_pick_pr, + classification, + fixes, + dry_run, + ), + ) + + comment_url = res.get("html_url", "") + if comment_url: + tracker_issues_comments.append(comment_url) - post_comment(org, project, pr.pr_num, msg) + msg = f"The cherry pick PR is at {cherry_pick_pr}" + if fixes: + msg += f" and it is linked with issue {fixes}." + elif classification in REQUIRES_ISSUE: + msg += f" and it is recommended to link a {classification} cherry pick PR with an issue." + + if tracker_issues_comments: + msg += " The following tracker issues are updated:\n" + for tracker_issues_comment in tracker_issues_comments: + msg += f"* {tracker_issues_comment}\n" + + post_pr_comment(org, project, pr.pr_num, msg, dry_run) finally: if current_branch: @@ -159,7 +219,9 @@ def submit_pr( raise RuntimeError(msg) from error -def post_comment(org: str, project: str, pr_num: int, msg: str) -> None: +def post_pr_comment( + org: str, project: str, pr_num: int, msg: str, dry_run: bool = False +) -> List[Dict[str, Any]]: """ Post a comment on the PR itself to point to the cherry picking PR when success or print the error when failure @@ -182,7 +244,35 @@ def post_comment(org: str, project: str, pr_num: int, msg: str) -> None: comment = "\n".join( (f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}") ) - gh_post_pr_comment(org, project, pr_num, comment) + return gh_post_pr_comment(org, project, pr_num, comment, dry_run) + + +def post_tracker_issue_comment( + org: str, + project: str, + issue_num: int, + pr_num: int, + cherry_pick_pr: str, + classification: str, + fixes: str, + dry_run: bool = False, +) -> List[Dict[str, Any]]: + """ + Post a comment on the tracker issue (if any) to record the cherry pick + """ + comment = "\n".join( + ( + "Link to landed trunk PR (if applicable):", + f"* https://github.com/{org}/{project}/pull/{pr_num}", + "", + "Link to release branch PR:", + f"* {cherry_pick_pr}", + "", + "Criteria Category:", + " - ".join((classification.capitalize(), fixes.capitalize())), + ) + ) + return gh_post_pr_comment(org, project, issue_num, comment, dry_run) def main() -> None: @@ -214,7 +304,7 @@ def main() -> None: except RuntimeError as error: if not args.dry_run: - post_comment(org, project, pr_num, str(error)) + post_pr_comment(org, project, pr_num, str(error)) else: raise error diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index d76d32f624d8..f804c6e197dd 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -202,3 +202,12 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") -> ) else: raise + + +def gh_query_issues_by_labels( + org: str, repo: str, labels: List[str], state: str = "open" +) -> List[Dict[str, Any]]: + url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues" + return gh_fetch_json( + url, method="GET", params={"labels": ",".join(labels), "state": state} + ) From 77830d509fcae41be37f5b3a2fa05faabc778e29 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 18:11:43 +0000 Subject: [PATCH 07/64] Revert "Introduce a prototype for SymmetricMemory (#128582)" This reverts commit 7a39755da28d5a109bf0c37f72b364d3a83137b1. Reverted https://github.com/pytorch/pytorch/pull/128582 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128582#issuecomment-2176685232)) --- .lintrunner.toml | 1 - BUILD.bazel | 1 - build_variables.bzl | 2 - c10/cuda/driver_api.h | 19 +- caffe2/CMakeLists.txt | 1 - test/distributed/test_symmetric_memory.py | 156 ----- torch/_C/_distributed_c10d.pyi | 30 - .../distributed/c10d/CUDASymmetricMemory.cu | 539 ------------------ .../distributed/c10d/CUDASymmetricMemory.cuh | 109 ---- .../distributed/c10d/ProcessGroupCudaP2P.hpp | 1 - .../csrc/distributed/c10d/SymmetricMemory.cpp | 189 ------ .../csrc/distributed/c10d/SymmetricMemory.hpp | 152 ----- torch/csrc/distributed/c10d/init.cpp | 39 -- .../csrc/distributed/c10d/intra_node_comm.cpp | 99 +++- .../csrc/distributed/c10d/intra_node_comm.cu | 18 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 9 +- 16 files changed, 111 insertions(+), 1254 deletions(-) delete mode 100644 test/distributed/test_symmetric_memory.py delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cu delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.cpp delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.hpp diff --git a/.lintrunner.toml b/.lintrunner.toml index dc9f9ddd46c7..a7bbdc884415 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,7 +68,6 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', - 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', 'torch/csrc/**/*.cpp', diff --git a/BUILD.bazel b/BUILD.bazel index c563c52d861e..10c065f5084c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,7 +744,6 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index 793b611a0a6f..ceb28707897e 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -501,7 +501,6 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", - "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", @@ -685,7 +684,6 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index cbbdf16823ec..43bcbd1d70ba 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -18,17 +18,14 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ - _(cuMemGetAllocationGranularity) \ - _(cuMemExportToShareableHandle) \ - _(cuMemImportFromShareableHandle) \ +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ _(cuGetErrorString) #define C10_NVML_DRIVER_API(_) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8426741609fe..89c31fab1134 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,7 +560,6 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py deleted file mode 100644 index a768e059044f..000000000000 --- a/test/distributed/test_symmetric_memory.py +++ /dev/null @@ -1,156 +0,0 @@ -# Owner(s): ["module: c10d"] - -import torch - -import torch.distributed as dist -from torch._C._distributed_c10d import _SymmetricMemory -from torch.distributed.distributed_c10d import _get_process_group_store - -from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - run_tests, - skip_but_pass_in_sandcastle_if, - skipIfRocm, -) - - -def requires_cuda_p2p_access(): - cuda_p2p_access_available = ( - torch.cuda.is_available() and torch.cuda.device_count() >= 2 - ) - num_devices = torch.cuda.device_count() - for i in range(num_devices - 1): - for j in range(i + 1, num_devices): - if not torch.cuda.can_device_access_peer(i, j): - cuda_p2p_access_available = False - break - if not cuda_p2p_access_available: - break - - return skip_but_pass_in_sandcastle_if( - not cuda_p2p_access_available, - "cuda p2p access is not available", - ) - - -@instantiate_parametrized_tests -@requires_cuda_p2p_access() -class SymmetricMemoryTest(MultiProcessTestCase): - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - @property - def world_size(self) -> int: - return 2 - - @property - def device(self) -> torch.device: - return torch.device(f"cuda:{self.rank}") - - def _init_process(self): - torch.cuda.set_device(self.device) - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="nccl", - world_size=self.world_size, - rank=self.rank, - store=store, - ) - _SymmetricMemory.set_group_info( - "0", - self.rank, - self.world_size, - _get_process_group_store(dist.GroupMember.WORLD), - ) - - def _verify_symmetric_memory(self, symm_mem): - self.assertEqual(symm_mem.world_size, 2) - - buf = symm_mem.get_buffer(0, (64, 64), torch.float32) - if symm_mem.rank == 0: - symm_mem.wait_signal(src_rank=1) - self.assertTrue(buf.eq(42).all()) - else: - buf.fill_(42) - symm_mem.put_signal(dst_rank=0) - - symm_mem.barrier() - - if symm_mem.rank == 0: - symm_mem.barrier() - self.assertTrue(buf.eq(43).all()) - else: - buf.fill_(43) - symm_mem.barrier() - - symm_mem.barrier() - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name) - - t = torch.empty(shape, dtype=dtype, device=device) - with self.assertRaises(RuntimeError): - _SymmetricMemory.rendezvous(t) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) - - del t - self._verify_symmetric_memory(symm_mem) - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p_persistent(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - alloc_id = 42 # Persistent allocation - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name, alloc_id) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - data_ptr = t.data_ptr() - - # Verify that persistent allocation would fail if there's an active - # allocation with the same alloc_id. - with self.assertRaises(RuntimeError): - _SymmetricMemory.empty_strided_p2p(*alloc_args) - - # Verify that persistent allocation would succeed in lieu of activate - # allocations with the same alloc_id, and the returned tensor would - # have the same data pointer. - del t - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - self.assertEqual(t.data_ptr(), data_ptr) - - # Verify that get_symmetric_memory would fail if called before - # rendezvous. - with self.assertRaises(RuntimeError): - _SymmetricMemory.get_symmetric_memory(t) - - symm_mem_0 = _SymmetricMemory.rendezvous(t) - symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) - self.assertEqual(id(symm_mem_0), id(symm_mem_1)) - - self._verify_symmetric_memory(symm_mem_0) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 0095b5af434b..cffbf22219c8 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -637,33 +637,3 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... - -class _SymmetricMemory: - @staticmethod - def set_group_info( - group_name: str, rank: int, world_size: int, store: Store - ) -> None: ... - @staticmethod - def empty_strided_p2p( - size: torch.types._size, - stride: torch.types._size, - dtype: torch.dtype, - device: torch.device, - group_name: str, - ) -> torch.Tensor: ... - @property - def rank(self) -> int: ... - @property - def world_size(self) -> int: ... - @staticmethod - def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... - def get_buffer( - self, - rank: int, - sizes: torch.Size, - dtype: torch.dtype, - storage_offset: Optional[int] = 0, - ) -> torch.Tensor: ... - def barrier(self, channel: int = 0) -> None: ... - def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... - def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu deleted file mode 100644 index f27db85f7ff8..000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ /dev/null @@ -1,539 +0,0 @@ -#include - -#include -#include -#include -#include - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include -#endif - -#include -#include - -namespace { - -constexpr size_t signal_pad_size = 2048; -const std::string store_comm_prefix = "CUDASymmetricMemory"; - -static size_t store_comm_seq_id = 0; - -template -std::vector store_all_gather( - const c10::intrusive_ptr& store, - int rank, - int world_size, - T val) { - static_assert(std::is_trivially_copyable_v); - - std::vector peer_keys; - for (int r = 0; r < world_size; ++r) { - std::ostringstream oss; - oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r; - peer_keys.push_back(oss.str()); - } - ++store_comm_seq_id; - - { - std::vector payload( - reinterpret_cast(&val), - reinterpret_cast(&val) + sizeof(T)); - store->set(peer_keys[rank], payload); - } - - std::vector peer_vals; - for (int r = 0; r < world_size; ++r) { - if (r == rank) { - peer_vals.push_back(val); - continue; - } - store->wait({peer_keys[r]}); - auto payload = store->get(peer_keys[r]); - TORCH_CHECK(payload.size() == sizeof(T)); - T peer_val{}; - std::memcpy(&peer_val, payload.data(), sizeof(T)); - peer_vals.push_back(peer_val); - } - return peer_vals; -} - -void store_barrier( - const c10::intrusive_ptr& store, - int rank, - int world_size) { - store_all_gather(store, rank, world_size, 0); -} - -int import_remote_fd(int pid, int fd) { -#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd) - int pidfd = syscall(SYS_pidfd_open, pid, 0); - return syscall(SYS_pidfd_getfd, pidfd, fd, 0); -#else - TORCH_CHECK( - false, - "CUDASymmetricMemory requires pidfd_open ", - "and pidfd_getfd support"); -#endif -} - -void map_block( - void** ptr, - c10d::symmetric_memory::HandleType handle, - size_t size, - int device_idx) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - auto dev_ptr = reinterpret_cast(ptr); - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL)); - - CUmemAccessDesc desc; - desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - desc.location.id = static_cast(device_idx); - desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1)); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -CUDASymmetricMemory::CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size) - : handles_(std::move(handles)), - block_size_(block_size), - buffers_(std::move(buffers)), - signal_pads_(std::move(signal_pads)), - buffer_size_(buffer_size), - local_device_idx_(local_device_idx), - rank_(rank), - world_size_(world_size) { - const size_t arr_size = sizeof(void*) * world_size_; - buffers_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - signal_pads_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - - c10::cuda::CUDAGuard guard(local_device_idx); - AT_CUDA_CHECK(cudaMemcpy( - buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); - AT_CUDA_CHECK(cudaMemcpy( - signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); -} - -CUDASymmetricMemory::~CUDASymmetricMemory() { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - c10::cuda::CUDAGuard guard(local_device_idx_); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - auto driver_api = c10::cuda::DriverAPI::get(); - for (int r = 0; r < world_size_; ++r) { - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(buffers_[r]), block_size_)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); - } - c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); - c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -std::vector CUDASymmetricMemory::get_buffer_ptrs() { - return buffers_; -} - -std::vector CUDASymmetricMemory::get_signal_pad_ptrs() { - return signal_pads_; -} - -void** CUDASymmetricMemory::get_buffer_ptrs_dev() { - return buffers_dev_; -} - -void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() { - return signal_pads_dev_; -} - -size_t CUDASymmetricMemory::get_buffer_size() { - return buffer_size_; -} - -size_t CUDASymmetricMemory::get_signal_pad_size() { - return signal_pad_size; -} - -at::Tensor CUDASymmetricMemory::get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) { - const auto numel = - std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); - const auto element_size = c10::elementSize(dtype); - const auto req_size = (numel + storage_offset) * element_size; - TORCH_CHECK( - req_size <= buffer_size_, - "CUDASymmetricMemory::get_buffer: the requested size (", - req_size, - " bytes) exceeds the allocated size (", - buffer_size_, - " bytes)"); - auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storage_offset) - .options(options) - .target_device(device) - .make_tensor(); -} - -void check_channel(int channel, int world_size) { - TORCH_CHECK( - channel >= 0, - "channel for barrier(), put_signal() and wait_signal() ", - "must be greater than 0 (got ", - channel, - ")"); - const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; - TORCH_CHECK( - static_cast(channel) < num_channels, - "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", - num_channels - 1, - " (got ", - channel, - ")"); -} - -__device__ __forceinline__ void release_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); -#endif -} - -__device__ __forceinline__ void acquire_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); -#endif -} - -static __global__ void barrier_kernel( - uint32_t** signal_pads, - int channel, - int rank, - int world_size) { - if (threadIdx.x < world_size) { - auto target_rank = threadIdx.x; - release_signal(signal_pads[target_rank] + world_size * channel + rank); - acquire_signal(signal_pads[rank] + world_size * channel + target_rank); - } -} - -void CUDASymmetricMemory::barrier(int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void put_signal_kernel( - uint32_t** signal_pads, - int dst_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - release_signal(signal_pads[dst_rank] + world_size * channel + rank); - } -} - -void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - dst_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void wait_signal_kernel( - uint32_t** signal_pads, - int src_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - acquire_signal(signal_pads[rank] + world_size * channel + src_rank); - } - __threadfence_system(); -} - -void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - src_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -int CUDASymmetricMemory::get_rank() { - return rank_; -} - -int CUDASymmetricMemory::get_world_size() { - return world_size_; -} - -void* CUDASymmetricMemoryAllocator::alloc( - size_t size, - int device_idx, - const std::string& group_name) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - prop.location.id = device_idx; - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; - - size_t signal_pad_offset = at::round_up(size, 16UL); - size_t block_size = signal_pad_offset + signal_pad_size; - - size_t granularity; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( - &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); - block_size = at::round_up(block_size, granularity); - - HandleType handle; - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); - - void* ptr = nullptr; - map_block(&ptr, handle, block_size, device_idx); - - c10::cuda::CUDAGuard guard(device_idx); - AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); - - auto block = c10::make_intrusive( - handle, device_idx, block_size, size, signal_pad_offset, group_name); - { - std::unique_lock lock(mutex_); - ptr_to_block_.emplace(ptr, std::move(block)); - } - return ptr; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -void CUDASymmetricMemoryAllocator::free(void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - if (block == nullptr) { - return; - } - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. - if (block->symm_mem == nullptr) { - auto driver_api = c10::cuda::DriverAPI::get(); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(ptr), block->block_size)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); - } - { - std::unique_lock lock(mutex_); - ptr_to_block_.erase(ptr); - } -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->buffer_size; -} - -struct RendezvousRequest { - int device_idx; - int block_fd; - int pid; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; -}; - -void validate_rendezvous_requests( - const std::vector reqs, - int world_size) { - TORCH_CHECK(reqs.size() == (size_t)world_size); - - std::unordered_set device_indices; - device_indices.reserve(world_size); - for (auto req : reqs) { - device_indices.insert(req.device_idx); - } - if (device_indices.size() < (size_t)world_size) { - TORCH_CHECK( - false, - "CUDASymmetricMemoryAllocator::rendezvous: ", - "detected allocations from overlapping devices ", - "from different ranks."); - } - - for (int r = 1; r < world_size; ++r) { - TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); - TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); - TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); - } -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( - void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - - if (block->symm_mem != nullptr) { - return block->symm_mem; - } - - auto group_info = get_group_info(block->group_name); - auto driver_api = c10::cuda::DriverAPI::get(); - int block_fd; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); - - auto local_req = RendezvousRequest{ - .device_idx = block->device_idx, - .block_fd = block_fd, - .pid = getpid(), - .block_size = block->block_size, - .buffer_size = block->buffer_size, - .signal_pad_offset = block->signal_pad_offset}; - auto reqs = store_all_gather( - group_info.store, group_info.rank, group_info.world_size, local_req); - validate_rendezvous_requests(reqs, group_info.world_size); - - std::vector handles(group_info.world_size); - std::vector buffers(group_info.world_size, nullptr); - std::vector signal_pads(group_info.world_size, nullptr); - for (int r = 0; r < group_info.world_size; ++r) { - if (r == group_info.rank) { - handles[r] = block->handle; - buffers[r] = ptr; - signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); - continue; - } - int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( - &handles[r], - (void*)(uintptr_t)imported_fd, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - map_block(&buffers[r], handles[r], block->block_size, block->device_idx); - signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); - close(imported_fd); - } - store_barrier(group_info.store, group_info.rank, group_info.world_size); - close(block_fd); - - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. So that outstanding - // references to the CUDASymmetricMemory object can keep the allocation - // alive. - block->symm_mem = c10::make_intrusive( - std::move(handles), - block->block_size, - std::move(buffers), - std::move(signal_pads), - block->buffer_size, - block->device_idx, - group_info.rank, - group_info.world_size); - return block->symm_mem; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->symm_mem != nullptr; -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { - std::shared_lock lock(mutex_); - auto it = ptr_to_block_.find(ptr); - if (it == ptr_to_block_.end()) { - return nullptr; - } - return it->second; -} - -struct RegisterCUDASymmetricMemoryAllocator { - RegisterCUDASymmetricMemoryAllocator() { - register_allocator( - c10::DeviceType::CUDA, - c10::make_intrusive()); - } -}; - -static RegisterCUDASymmetricMemoryAllocator register_allocator_; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh deleted file mode 100644 index 0e0e40a6bd09..000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh +++ /dev/null @@ -1,109 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace c10d { -namespace symmetric_memory { - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -using HandleType = CUmemGenericAllocationHandle; -#else -using HandleType = void*; -#endif - -class CUDASymmetricMemory : public SymmetricMemory { - public: - CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size); - - ~CUDASymmetricMemory() override; - - std::vector get_buffer_ptrs() override; - std::vector get_signal_pad_ptrs() override; - void** get_buffer_ptrs_dev() override; - void** get_signal_pad_ptrs_dev() override; - size_t get_buffer_size() override; - size_t get_signal_pad_size() override; - - at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) override; - - void barrier(int channel) override; - void put_signal(int dst_rank, int channel) override; - void wait_signal(int src_rank, int channel) override; - - int get_rank() override; - int get_world_size() override; - - private: - std::vector handles_; - size_t block_size_; - std::vector buffers_; - std::vector signal_pads_; - size_t buffer_size_; - int local_device_idx_; - int rank_; - int world_size_; - void** buffers_dev_; - void** signal_pads_dev_; - std::optional> finalizer_; -}; - -struct Block : public c10::intrusive_ptr_target { - HandleType handle; - int device_idx; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; - std::string group_name; - c10::intrusive_ptr symm_mem = nullptr; - - Block( - HandleType handle, - int device_idx, - size_t block_size, - size_t buffer_size, - size_t signal_pad_offset, - const std::string& group_name) - : handle(handle), - device_idx(device_idx), - block_size(block_size), - buffer_size(buffer_size), - signal_pad_offset(signal_pad_offset), - group_name(group_name), - symm_mem(nullptr) {} -}; - -class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { - public: - void* alloc( - size_t size, - int device_idx, - const std::string& group_name) override; - - void free(void *ptr) override; - size_t get_alloc_size(void* ptr) override; - c10::intrusive_ptr rendezvous(void* ptr) override; - bool is_rendezvous_completed(void* ptr) override; - - private: - c10::intrusive_ptr find_block(void* ptr); - - std::shared_mutex mutex_; - std::unordered_map> ptr_to_block_; -}; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index 7c41414c4e4e..cff4ad09b706 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -10,7 +10,6 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout = namespace c10d { -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API ProcessGroupCudaP2P : public Backend { public: struct Options : Backend::Options { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp deleted file mode 100644 index b3d9f31bb034..000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include - -namespace { - -using namespace c10d::symmetric_memory; - -class AllocatorMap { - public: - static AllocatorMap& get() { - static AllocatorMap instance; - return instance; - } - - void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - map_[device_type] = std::move(allocator); - } - - c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - auto it = map_.find(device_type); - TORCH_CHECK( - it != map_.end(), - "SymmetricMemory does not support device type ", - device_type); - return it->second; - } - - ~AllocatorMap() { - for (auto& it : map_) { - it.second.release(); - } - } - - private: - AllocatorMap() = default; - AllocatorMap(const AllocatorMap&) = delete; - AllocatorMap& operator=(const AllocatorMap&) = delete; - - std::unordered_map< - c10::DeviceType, - c10::intrusive_ptr> - map_; -}; - -static std::unordered_map group_info_map{}; - -// Data structures for tracking persistent allocations -static std::unordered_map alloc_id_to_dev_ptr{}; -static std::unordered_map> - alloc_id_to_storage{}; - -static at::Tensor empty_strided_p2p_persistent( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - uint64_t alloc_id) { - // Make the allocation fails if a previous allocation with the same alloc_id - // is still active. - auto storage = alloc_id_to_storage.find(alloc_id); - if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) { - TORCH_CHECK( - false, - "SymmetricMemory::empty_strided_p2p_persistent: ", - "can not allocate with alloc_id == ", - alloc_id, - " because a previous allocation with the same alloc_id " - "is still active."); - } - - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = nullptr; - if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) { - dev_ptr = alloc_id_to_dev_ptr[alloc_id]; - TORCH_CHECK( - alloc_size == allocator->get_alloc_size(dev_ptr), - "SymmetricMemory::empty_strided_p2p_persistent: ", - "requested allocation size (", - alloc_size, - ") is different from the size of a previous allocation ", - "with the same alloc_id ", - allocator->get_alloc_size(dev_ptr)); - } else { - dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - alloc_id_to_dev_ptr[alloc_id] = dev_ptr; - } - - auto options = at::TensorOptions().dtype(dtype).device(device); - auto allocated = at::from_blob(dev_ptr, size, stride, options); - - // Track the allocation's activeness - alloc_id_to_storage.erase(alloc_id); - alloc_id_to_storage.emplace( - alloc_id, allocated.storage().getWeakStorageImpl()); - return allocated; -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - return AllocatorMap::get().register_allocator( - device_type, std::move(allocator)); -} - -c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - return AllocatorMap::get().get_allocator(device_type); -} - -void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store) { - TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end()); - GroupInfo group_info; - group_info.rank = rank; - group_info.world_size = world_size; - group_info.store = std::move(store); - group_info_map.emplace(group_name, std::move(group_info)); -} - -const GroupInfo& get_group_info(const std::string& group_name) { - TORCH_CHECK( - group_info_map.find(group_name) != group_info_map.end(), - "get_group_info: no group info associated with the group name ", - group_name); - return group_info_map[group_name]; -} - -at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id) { - if (alloc_id.has_value()) { - return empty_strided_p2p_persistent( - size, stride, dtype, device, group_name, *alloc_id); - } - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::from_blob( - dev_ptr, - size, - stride, - [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); }, - options); -} - -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - return allocator->rendezvous(tensor.data_ptr()); -} - -c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - TORCH_CHECK( - allocator->is_rendezvous_completed(tensor.data_ptr()), - "SymmetricMemory: must invoke rendezvous on a tensor ", - "before calling get_symmetric_memory on it"); - return allocator->rendezvous(tensor.data_ptr()); -} - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp deleted file mode 100644 index 344b86ea5c7e..000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ /dev/null @@ -1,152 +0,0 @@ -#pragma once - -#include -#include - -namespace c10d { -namespace symmetric_memory { - -// SymmetricMemory represents symmetric allocations across a group of devices. -// The allocations represented by a SymmetricMemory object are accessible by -// all devices in the group. The class can be used for op-level custom -// communication patterns (via the get_buffer APIs and the synchronization -// primitives), as well as custom communication kernels (via the buffer and -// signal_pad device pointers). -// -// To acquire a SymmetricMemory object, each rank first allocates -// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes -// SymmetricMemoryAllocator::rendezvous() on the memory to establish the -// association across peer buffers. The rendezvous is a one-time process, and -// the mapping between a local memory memory and the associated SymmetricMemory -// object is unique. -// -// NOTE [symmetric memory signal pad] -// Signal pads are P2P-accessible memory regions designated for -// synchronization. SymmetricMemory offers built-in synchronization primitives -// such as barriers, put_signal, and wait_signal, which are all based on signal -// pads. Users may utilize signal pads for their own synchronization logic, -// provided that the signal pads remain zero-filled following successful -// synchronization. -// -// NOTE [symmetric memory synchronization channel] -// Synchronization channels allow users to use a single SymmetricMemory object -// to perform isolated synchronizations on different streams. For example, -// consider the case in which two barriers are issued on two streams for -// different purposes. Without the concept of channels, we cannot guarantee the -// correctness of the barriers since signals issued from barrier on stream A -// can be received by the barrier on stream B. By specifying different channels -// for these two barriers, they can operate correctly in parallel. -class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemory() {} - - virtual std::vector get_buffer_ptrs() = 0; - virtual std::vector get_signal_pad_ptrs() = 0; - - // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer - // to a device array of size world_size, containing buffer pointers and - // signal pad pointers, respectively. - virtual void** get_buffer_ptrs_dev() = 0; - virtual void** get_signal_pad_ptrs_dev() = 0; - virtual size_t get_buffer_size() = 0; - virtual size_t get_signal_pad_size() = 0; - - virtual at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) = 0; - - virtual void barrier(int channel) = 0; - virtual void put_signal(int dst_rank, int channel) = 0; - virtual void wait_signal(int src_rank, int channel) = 0; - - virtual int get_rank() = 0; - virtual int get_world_size() = 0; -}; - -class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemoryAllocator(){}; - - virtual void* alloc( - size_t size, - int device_idx, - const std::string& group_name) = 0; - - virtual void free(void* ptr) = 0; - virtual size_t get_alloc_size(void* ptr) = 0; - virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; - virtual bool is_rendezvous_completed(void* ptr) = 0; -}; - -C10_EXPORT void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator); - -C10_EXPORT c10::intrusive_ptr get_allocator( - c10::DeviceType device_type); - -// Set a store for rendezvousing symmetric allocations on a group of devices -// identified by `group_name`. The concept of groups is logical; users can -// utilize predefined groups (e.g., a group of device identified by a -// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator -// backends might employ a more efficient communication channel for the actual -// rendezvous process and only use the store for bootstrapping purposes. -TORCH_API void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store); - -struct GroupInfo { - int rank; - int world_size; - c10::intrusive_ptr store; -}; - -C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); - -// Identical to empty_strided, but allows symmetric memory access to be -// established for the allocated tensor via SymmetricMemory::rendezvous(). This -// function itself is not a collective operation. It invokes -// SymmetricMemoryAllocator::alloc() for the requested device under the hood. -// -// NOTE [symmetric memory persistent allocation] -// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent -// allocation. This makes the function cache allocated memory and ensure that -// invocations with the same `alloc_id` receive tensors backed by the same -// memory address. For safety, if a previous persistent allocation is still -// active (i.e., the storage of the returned tensor is still alive), persistent -// allocations with the same `alloc_id` will fail. This determinism coupled -// with memory planning of communication buffers (e.g., by Inductor) allows -// communication algorithms to reliably reuse previously established remote -// memory access. -TORCH_API at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id); - -// Establishes symmetric memory access on tensors allocated via -// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a -// one-time process, and the mapping between a local memory region and the -// associated SymmetricMemory object is unique. Subsequent calls to -// rendezvous() with the same tensor, or tensors allocated with -// empty_strided_p2p_persistent() using the same alloc_id, will receive the -// cached SymmetricMemory object. -// -// The function has a collective semantic and must be invoked simultaneously -// from all rendezvous participants. -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor); - -// Returns the SymmetricMemory object associated with the tensor. It can only -// be invoked after rendezvous() but does not need to be invoked collectively. -TORCH_API c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor); - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index db5778efcf35..6f1b28886b98 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -41,7 +41,6 @@ #include #include #include -#include #include #include @@ -976,44 +975,6 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); - using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; - py::class_>( - module, "_SymmetricMemory") - .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info) - .def_static( - "empty_strided_p2p", - ::c10d::symmetric_memory::empty_strided_p2p, - py::arg("size"), - py::arg("stride"), - py::arg("dtype"), - py::arg("device"), - py::arg("group_name"), - py::arg("alloc_id") = py::none()) - .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) - .def_static( - "get_symmetric_memory", - &::c10d::symmetric_memory::get_symmetric_memory) - .def_property_readonly("rank", &SymmetricMemory::get_rank) - .def_property_readonly("world_size", &SymmetricMemory::get_world_size) - .def( - "get_buffer", - &SymmetricMemory::get_buffer, - py::arg("rank"), - py::arg("sizes"), - py::arg("dtype"), - py::arg("storage_offset") = 0) - .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) - .def( - "put_signal", - &SymmetricMemory::put_signal, - py::arg("dst_rank"), - py::arg("channel") = 0) - .def( - "wait_signal", - &SymmetricMemory::wait_signal, - py::arg("src_rank"), - py::arg("channel") = 0); - auto store = py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 9d7ba5abf951..85136a91e025 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -218,8 +218,23 @@ IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { return; } - auto allocator = get_allocator(c10::DeviceType::CUDA); - allocator->free(symmetricMemoryPtr_); + // Intentionally releasing resources without synchronizing devices. The + // teardown logic is safe for propoerly sync'd user program. We don't want + // improperly sync'd user program to hang here. + for (size_t r = 0; r < worldSize_; ++r) { + if (r == rank_) { + continue; + } + AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); + AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); + } + AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); + AT_CUDA_CHECK(cudaFree(buffers_[rank_])); + if (topoInfo_ != nullptr) { + AT_CUDA_CHECK(cudaFree(topoInfo_)); + } + AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); + AT_CUDA_CHECK(cudaFree(buffersDev_)); } bool IntraNodeComm::isEnabled() { @@ -329,19 +344,83 @@ bool IntraNodeComm::rendezvous() { // Detect topology Topology topology = detectTopology(nvlMesh, worldSize_); - set_group_info("IntraNodeComm", rank_, worldSize_, store_); - auto allocator = get_allocator(c10::DeviceType::CUDA); - symmetricMemoryPtr_ = - allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm"); - symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); - TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); + // Initialize p2p state + auto p2pState = initP2pState(); + + // Allocate buffer + void* buffer = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_)); + + // Second handshake: exchange topology and CUDA IPC handles + struct IpcInfo { + NvlMesh nvlMesh; + Topology topology; + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + }; + + // Make p2p state and buffer available for IPC + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); + AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); + + IpcInfo ipcInfo{ + .nvlMesh = nvlMesh, + .topology = topology, + .p2pStateHandle = p2pStateHandle, + .bufferHandle = bufferHandle}; + + auto peerIpcInfos = + storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo); + + for (const auto& info : peerIpcInfos) { + if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || + info.topology != peerIpcInfos.front().topology) { + LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " + "participants are observing different topologies (" + << int(info.topology) << " and " << int(topology) << ")"; + AT_CUDA_CHECK(cudaFree(p2pState)); + AT_CUDA_CHECK(cudaFree(buffer)); + return false; + } + } + + std::array p2pStates = {}, buffers = {}; + for (size_t r = 0; r < peerIpcInfos.size(); ++r) { + if (r == rank_) { + p2pStates[r] = p2pState; + buffers[r] = buffer; + } else { + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &p2pStates[r], + peerIpcInfos[r].p2pStateHandle, + cudaIpcMemLazyEnablePeerAccess)); + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &buffers[r], + peerIpcInfos[r].bufferHandle, + cudaIpcMemLazyEnablePeerAccess)); + } + } + void* p2pStatesDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); + AT_CUDA_CHECK(cudaMemcpy( + p2pStatesDev, + p2pStates.data(), + sizeof(p2pStates), + cudaMemcpyHostToDevice)); + + void* buffersDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); + AT_CUDA_CHECK(cudaMemcpy( + buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); isInitialized_ = true; topology_ = topology; - p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); - buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); + std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin()); + std::copy(buffers.begin(), buffers.end(), buffers_.begin()); + p2pStatesDev_ = p2pStatesDev; + buffersDev_ = buffersDev; topoInfo_ = topoInfo; return true; #endif diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index ac751ff7be1e..51fc6252d223 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -132,8 +132,6 @@ struct P2pState { uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; }; -static_assert(sizeof(P2pState) <= kP2pStateSize); - template static __global__ void oneShotAllReduceKernel( at::BFloat16* input, @@ -524,7 +522,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks; if (!fuseInputCopy) { AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -584,7 +582,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -634,7 +632,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -757,7 +755,15 @@ at::Tensor IntraNodeComm::getBuffer( const std::vector& sizes, c10::ScalarType dtype, int64_t storageOffset) { - return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); + const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0); + const auto elementSize = c10::elementSize(dtype); + TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_); + auto options = at::TensorOptions().dtype(dtype).device( + at::kCUDA, at::cuda::current_device()); + return at::for_blob(buffers_[rank], sizes) + .storage_offset(storageOffset) + .options(options) + .make_tensor(); } } // namespace intra_node_comm diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index a67df5c34586..5d7e2d426d30 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -4,16 +4,12 @@ #include #include #include -#include #include namespace c10d::intra_node_comm { -using namespace c10d::symmetric_memory; - constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; -constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; using HybridCubeMesh = std::array, kMaxDevices>; @@ -31,7 +27,6 @@ enum class AllReduceAlgo : uint8_t { HCM = 3 }; -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { public: IntraNodeComm( @@ -102,8 +97,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool isInitialized_ = false; Topology topology_ = Topology::UNKNOWN; - void* symmetricMemoryPtr_ = nullptr; - c10::intrusive_ptr symmetricMemory_ = nullptr; + std::array p2pStates_{}; + std::array buffers_{}; void* p2pStatesDev_{}; void* buffersDev_{}; void* topoInfo_{}; From 1877b7896c237567285804ecc138bc86180a7ced Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 18 Jun 2024 07:49:05 -0700 Subject: [PATCH 08/64] [checkpoint] Clean up selective activation checkpoint and make public (#125795) ### bc-breaking for existing users of the private API: - Existing policy functions must now change their return value to be [CheckpointPolicy](https://github.com/pytorch/pytorch/blob/c0b40ab42e38a208351911496b7153511304f8da/torch/utils/checkpoint.py#L1204-L1230) Enum instead of bool. - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy. - Policy function now accepts a `ctx` object instead of `mode` for its first argument. - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`. - Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint). Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit Memory considerations: - As with the existing SAC, cached values are cleared upon first use. - We error if the user wishes to backward a second time on a region forwarded with SAC enabled. In-place: - We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed. - `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place) Randomness, views - Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors) Tensor object preservation - ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor. Policy function - Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error. - The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3). - The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE : metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute. - The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below). Tensors guaranteed to be the same tensor as-is - Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795 Approved by: https://github.com/Chillee, https://github.com/fmassa --- docs/source/checkpoint.rst | 3 + test/dynamo/test_activation_checkpointing.py | 27 +- test/test_autograd.py | 416 ++++++++++++++++++- torch/_higher_order_ops/wrap.py | 6 +- torch/utils/checkpoint.py | 316 +++++++++----- 5 files changed, 643 insertions(+), 125 deletions(-) diff --git a/docs/source/checkpoint.rst b/docs/source/checkpoint.rst index f7bc160fa98b..8559d8bd7366 100644 --- a/docs/source/checkpoint.rst +++ b/docs/source/checkpoint.rst @@ -35,3 +35,6 @@ torch.utils.checkpoint .. autofunction:: checkpoint .. autofunction:: checkpoint_sequential .. autofunction:: set_checkpoint_debug_enabled +.. autoclass:: CheckpointPolicy +.. autoclass:: SelectiveCheckpointContext +.. autofunction:: create_selective_checkpoint_contexts diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 14851e51895b..274e03302845 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -19,7 +19,11 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor -from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint +from torch.utils.checkpoint import ( + checkpoint, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( @@ -105,8 +109,11 @@ def op_count(gm): def _get_custom_policy(no_recompute_list=None): - def _custom_policy(mode, func, *args, **kwargs): - return func in no_recompute_list + def _custom_policy(ctx, func, *args, **kwargs): + if func in no_recompute_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE return _custom_policy @@ -530,7 +537,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -580,7 +587,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -650,7 +657,7 @@ def _custom_policy(mode, func, *args, **kwargs): def selective_checkpointing_context_fn(): meta = {} - return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) def gn(x, y): return torch.sigmoid( @@ -698,7 +705,7 @@ def fn(x, y): ) def test_compile_selective_checkpoint_partial_ctx_fn(self): def selective_checkpointing_context_fn(no_recompute_list): - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -751,7 +758,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list), ) @@ -803,7 +810,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -854,7 +861,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) diff --git a/test/test_autograd.py b/test/test_autograd.py index c133ae95b4b3..e45f5d47c692 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2,6 +2,7 @@ import collections import contextlib +import functools import gc import io import math @@ -79,8 +80,14 @@ ) from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch.utils.checkpoint import ( + checkpoint, + checkpoint_sequential, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) from torch.utils.cpp_extension import load_inline +from torch.utils.flop_counter import FlopCounterMode from torch.utils.hooks import RemovableHandle # noqa: TCH001 @@ -13215,6 +13222,413 @@ def fn2(x): self.assertEqual(counter[0], 1) +class TestSelectiveActivationCheckpoint(TestCase): + @unittest.skipIf(not TEST_CUDA, "requires CUDA") + def test_flops_and_mem(self): + # From https://github.com/pytorch/pytorch/pull/126320 + def get_act_mem(f): + out = f() + out.backward() + # Why do one forward and backward? + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + out = f() + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) + out.backward() + return act_mem + + def get_bw_flops(f): + # Normalized so that a 512 square matmul returns 1 + f().backward() + out = f() + # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so + # it will be able to observe whether an op is cached or not. + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + x = torch.randn(512, 512, requires_grad=True, device="cuda") + y = torch.randn(512, 512, requires_grad=True, device="cuda") + + def fn(x, y): + return torch.mm(x.cos(), y).sin().sum() + + def fn_ac(x, y): + return checkpoint(fn, x, y, use_reentrant=False) + + def fn_sac(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + [ + torch.ops.aten.mm.default, + ], + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn_sac2(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + act_mem_noac = get_act_mem(lambda: fn(x, y)) + bw_flops_noac = get_bw_flops(lambda: fn(x, y)) + + self.assertEqual(act_mem_noac, 2.0) + self.assertEqual(bw_flops_noac, 2.0) + + act_mem_ac = get_act_mem(lambda: fn_ac(x, y)) + bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y)) + + self.assertEqual(act_mem_ac, 0.0) + self.assertEqual(bw_flops_ac, 3.0) + + act_mem_sac = get_act_mem(lambda: fn_sac(x, y)) + bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y)) + + self.assertEqual(act_mem_sac, 1.0) + self.assertEqual(bw_flops_sac, 2.0) + + act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y)) + bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y)) + + self.assertEqual(act_mem_sac2, 1.0) + self.assertEqual(bw_flops_sac2, 2.0) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_output_already_has_autograd_meta(self): + # View of tensor of non-differentiable dtype still has AutogradMeta + def fn(x, y): + return x.view(-1), y.sin().cos() + + x = torch.tensor([1, 2, 3], dtype=torch.int64) + y = torch.randn(3, requires_grad=True) + + context_fn = functools.partial( + create_selective_checkpoint_contexts, + [ + torch.ops.aten.view.default, + ], + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + out[1].sum().backward() + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_subclass_dispatching_sizes(self): + # Test that we ignore ops that grab metadata like torch.ops.aten.sym_size.default + # Caching such metadata ops can be problematic when the following are satisfied: + # + # 1. size/strides are dispatched upon + # 2. our policy saves sizes + ta = torch.randn(6, 2) + + class CustomSizeDynamicShapesTensor(torch.Tensor): + @staticmethod + def __new__(cls, inner): + return torch.Tensor._make_wrapper_subclass( + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + inner.size(), + inner.stride(), + None, + None, + inner.dtype, + inner.layout, + inner.device, + False, + inner.requires_grad, + "sizes", + ) + + def __init__(self, inner): + self.inner = inner + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_inner = torch.utils._pytree.tree_map_only( + cls, lambda x: x.inner, args + ) + out_inner = func(*args_inner, **kwargs) + return torch.utils._pytree.tree_map_only( + torch.Tensor, lambda x: cls(x), out_inner + ) + + def policy_fn(ctx, op, *args, **kwargs): + if op is torch.ops.aten.sym_size.default: + # Silently ignored! + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + # We avoid the following case + # + # saved :[4, 3], [], [], [4, 3], [4, 3], [4, 3], [12] + # forward :sum ,sum,mul, mul , mul ,view , view + # recompute :sum ,sum,mul, view , view + # + # Views save the shape of their input, so we expect the second + # view to save 12, but because during AC packing during forward + # saves the shapes of the input for metadata checks later, + # we would save the wrong shape during the recompute. + view_out = (x * x.sum()).view(-1).view(4, 3) + self.assertEqual(view_out.grad_fn._saved_self_sym_sizes, [12]) + return view_out.exp() + + x = torch.randn(4, 3, requires_grad=True) + x_wrapper = CustomSizeDynamicShapesTensor(x) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x_wrapper, use_reentrant=False, context_fn=context_fn) + out.sum().backward() + + def test_bad_inputs(self): + bad_op_list1 = [2] + + with self.assertRaisesRegex( + ValueError, "Expected op in `op_list` to be an OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list1) + + bad_op_list2 = [torch.ops.aten.sin] + + with self.assertRaisesRegex( + ValueError, "update the OpOverloadPacket to a specific OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list2) + + with self.assertRaisesRegex(TypeError, "either a function or a list of ops."): + create_selective_checkpoint_contexts(2) + + # Dynamo fails for various reasons: + # - some tests using custom op that does not implement Fake + # - dynamo is trying to trace into saved variable hooks unpack hook for some reason + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_policy_with_state(self): + # If I have a stateful callable, state is shared between the original + # forward and the recompute. + counters = [] + + class Policy: + def __init__(self): + self.counter = [0] + self.recompute_counter = [0] + + def __call__(self, ctx, func, *args, **kwargs): + counter = self.recompute_counter if ctx.is_recompute else self.counter + counter[0] += 1 + counters.append(counter[0]) + if counter == 1 and func is torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().sin().sin() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + Policy(), + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward() + # 1. counter properly reset to 0 for the recompute + # 2. due to early-stop we do not recompute the final op + self.assertEqual(counters, [1, 2, 3, 1, 2]) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_storage_lifetime(self): + from torch.utils._python_dispatch import _get_current_dispatch_mode + from torch.utils.checkpoint import ( + _CachedTorchDispatchMode, + _CachingTorchDispatchMode, + ) + + def policy_fn(ctx, op, *args, **kwargs): + return CheckpointPolicy.MUST_SAVE + + ref = None + + def fn(x): + nonlocal ref + + self.assertIsInstance( + _get_current_dispatch_mode(), + (_CachingTorchDispatchMode, _CachedTorchDispatchMode), + ) + + out = x.cos().exp() + + if isinstance(_get_current_dispatch_mode(), _CachingTorchDispatchMode): + raw_val = ( + _get_current_dispatch_mode() + .storage[torch.ops.aten.exp.default][0] + .val + ) + # ref should've been detached + # to avoid graph -> the saved variable hooks -> recompute_context -> storage -> graph + self.assertFalse(raw_val.requires_grad) + ref = weakref.ref(raw_val) + + # Careful for early-stop + return out.sin() + + with disable_gc(): + # Case 1: If graph goes away without backward, make sure there's no reference cycle + # keeping storage alive. + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, policy_fn + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + self.assertIsNotNone(ref()) + del out + self.assertIsNone(ref()) + + # Case 2: After backward, even if retain_graph=True, the storage should go away + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, policy_fn + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + self.assertIsNotNone(ref()) + out.sum().backward(retain_graph=True) + # The dispatch mode's storage should still be alive, but the entries should've + # been cleared. + self.assertIsNone(ref()) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_version_counter(self): + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().mul_(2).cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 1) Error because the output of sin is saved and mutated by mul_ + with self.assertRaisesRegex(RuntimeError, "has been mutated"): + out.sum().backward() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 2) No longer should be an error because of allow_cache_entry_mutation + out.sum().backward() + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_more_than_one_output(self): + # maybe there is a more systematic way: + counter = [0] + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.var_mean.correction: + counter[0] += 1 + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + # var_mean has two outputs + def fn(x): + a, b = torch.var_mean(x) + return a * b + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + self.assertEqual(counter[0], 2) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_non_tensor_output(self): + # When SAC is enabled, the op is not computed a second time + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + counter = [0] + + @torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) + def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + counter[0] += 1 + return x.sin(), 2 + + def setup_context(ctx, inputs, output) -> torch.Tensor: + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad, _unused): + (x,) = ctx.saved_tensors + return grad * x.cos() + + torch.library.register_autograd( + "mylib::sin_with_extra", backward, setup_context=setup_context + ) + + x = torch.randn(3, requires_grad=True) + + def fn(x): + return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin() + + ops_list = [torch.ops.mylib.sin_with_extra.default] + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, ops_list + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + self.assertEqual(counter[0], 1) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_can_only_trigger_recompute_once(self): + # We don't support this to avoid adding extra complexity for now. + # If there's a need, we could probably do some kind of use_count tracking. + # TODO: have a nice error message here. + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward(retain_graph=True) + + with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"): + out.sum().backward(retain_graph=True) + + class TestAutogradMultipleDispatch(TestCase): def test_autograd_multiple_dispatch_registrations(self, device): t = torch.randn(3, 3, device=device, requires_grad=True) diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 6d83a44e752a..e7fe553387d1 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -1,15 +1,17 @@ # mypy: allow-untyped-defs import inspect +import itertools import logging import torch from torch._ops import HigherOrderOperator -from torch.utils.checkpoint import checkpoint, uid +from torch.utils.checkpoint import checkpoint + import torch._dynamo.config log = logging.getLogger(__name__) - +uid = itertools.count(1) # Used for testing the HigherOrderOperator mechanism class Wrap(HigherOrderOperator): diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 5cbfd1543cf4..dab7730d8439 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -5,18 +5,8 @@ import warnings import weakref from collections import defaultdict -from itertools import count -from typing import ( - Any, - Callable, - ContextManager, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Tuple, -) +from typing import * # noqa: F403 +import enum from weakref import ReferenceType import torch @@ -39,6 +29,10 @@ "set_checkpoint_early_stop", "DefaultDeviceType", "set_checkpoint_debug_enabled", + "CheckpointPolicy", + "SelectiveCheckpointContext", + "create_selective_checkpoint_contexts", + "SAC_IGNORED_OPS", ] _DEFAULT_DETERMINISM_MODE = "default" @@ -1153,149 +1147,247 @@ def _is_compiling(func, args, kwargs): return False -def _detach(x): - if isinstance(x, torch.Tensor): - return x.detach() +class _VersionWrapper: + # Check that cached tensors are not mutated. + def __init__(self, val): + self.val: Union[torch.Tensor, Any] = val + self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None + + def get_val(self, allow_cache_entry_mutation): + if self.version is not None and not allow_cache_entry_mutation: + if self.val._version != self.version: + # Can we give user a stack trace of where the mutation happened? + raise RuntimeError( + "Tensor cached during selective activation checkpoint has been mutated" + ) + return self.val + + +def _maybe_detach(x, any_ret_has_alias_info): + # We detach for two separate reasons: + # - For view ops, we need to ensure that when the tensor is returned from + # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr + # - Avoid reference cycles + # For case 1, it is not enough to check whether x has differentiable dtype + # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. + # when the tensor is a view. + if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): + with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): + # Ensure that view performed beneath autograd properly propagates + # version counter. TODO: Use reentrant_dispatch instead of + # manually manipulating dispatch keys. Using reentrant_dispatch + # would respect inference_mode, though that is not relevant for + # this case. + x = x.detach() return x -uid = count(1) +class SelectiveCheckpointContext: + """ + Context passed to policy function during selective checkpointing. + This class is used to pass relevant metadata to the policy function during + selective checkpointing. The metadata includes whether the current invocation + of the policy function is during recomputation or not. -# NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times -# (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call), -# so we ignore these ops and just always recompute them. -_ignored_ops = { - torch.ops.prim.device.default, + Example: + >>> # xdoctest: +SKIP(stub) + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> print(ctx.is_recompute) + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + def __init__(self, *, is_recompute): + self.is_recompute = is_recompute + + +class CheckpointPolicy(enum.Enum): + """ + Enum for specifying the policy for checkpointing during backpropagation. + + The following policies are supported: + + - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward + pass and will not be recomputed during the backward pass + - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the + forward pass and will be recomputed during the backward pass + + Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden + by other subsystems like `torch.compile`. + + .. note:: + A policy function that always returns ``PREFER_RECOMPUTE`` is + equivalent to vanilla checkpointing. + + A policy function that returns ``PREFER_SAVE`` every op is + NOT equivalent to not using checkpointing. Using such a policy would + save additional tensors not limited to ones that are actually needed for + gradient computation. + """ + MUST_SAVE = 0 + PREFER_SAVE = 1 + MUST_RECOMPUTE = 2 + PREFER_RECOMPUTE = 3 + + +SAC_IGNORED_OPS = { + # AC inserts different number of detach during forward and recompute. torch.ops.aten.detach.default, + # AC's determinism check invokes additional metadata ops during forward. + # With subclasses involved, these metadata ops become dispatchable, this + # can result in incorrectness if these ops are selected cached. + torch.ops.prim.device.default, } | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) class _CachingTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachedTorchDispatchMode. - """ + # Used together with _CachedTorchDispatchMode to implement SAC. def __init__(self, policy_fn, storage): self.policy_fn = policy_fn self.storage = storage - def push_into_storage(self, out, func, args, kwargs): - out_detached = tree_map(_detach, out) - self.storage[func].append(out_detached) + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) - def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs): - if should_not_recompute: + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) + + if is_compiling and policy == CheckpointPolicy.MUST_SAVE: fx_traceback.current_meta["recompute"] = 0 - # NOTE: Here we just store and reuse output of all ops, since in torch.compile mode - # we decide and handle recomputation in the partitioner. + out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - return out - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: - return func(*args, **kwargs) - should_not_recompute = self.policy_fn("forward", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs) - else: - if should_not_recompute: - out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) + return out class _CachedTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachingTorchDispatchMode. - """ - def __init__(self, policy_fn, storage): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage, allow_cache_entry_mutation): self.policy_fn = policy_fn self.storage = storage - - def pop_from_storage(self, func, args, kwargs): - assert func in self.storage - out = self.storage[func].pop(0) - return out - - def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs): - out = self.pop_from_storage(func, args, kwargs) - return out + self.allow_cache_entry_mutation = allow_cache_entry_mutation def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: + if func in SAC_IGNORED_OPS: return func(*args, **kwargs) - should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + storage = self.storage.get(func) + if storage is None: + raise RuntimeError(f"{func} encountered during backward, but not found in storage") + if len(storage) == 0: + raise RuntimeError( + "Trying to backward an extra time. You are only allowed to backward once " + "on any region computed under selective activation checkpoint." + ) + out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) else: - if should_not_recompute: - out = self.pop_from_storage(func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + out = func(*args, **kwargs) + return out -def _pt2_selective_checkpoint_context_fn_gen(policy_fn): + +def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): """ - A helper function that generates a pair of contexts to be later passed into - `torch.utils.checkpoint` API to implment selective checkpointing. + Helper to avoid recomputing certain ops during activation checkpointing. - .. warning:: - This is context_fn is intended for use with torch.compile only. + Use this with `torch.utils.checkpoint.checkpoint` to control which + operations are recomputed during the backward pass. Args: - policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function - to decide whether a particular op should be recomputed in backward pass or not. - In eager mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op is guaranteed to be recomputed. - In torch.compile mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op may or may not be recomputed - (it's up to the partitioner to decide). - + policy_fn_or_list (Callable or List): + - If a policy function is provided, it should accept a + :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and + kwargs to the op, and return a :class:`CheckpointPolicy` enum value + indicating whether the execution of the op should be recomputed or not. + - If a list of operations is provided, it is equivalent to a policy + returning `CheckpointPolicy.MUST_SAVE` for the specified + operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other + operations. + allow_cache_entry_mutation (bool, optional): By default, an error is + raised if any tensors cached by selective activation checkpoint are + mutated in order to ensure correctness. If set to `True`, this check + is disabled. Returns: - A pair of generated contexts. + A tuple of two context managers. Example: >>> # xdoctest: +REQUIRES(LINUX) + >>> import functools >>> - >>> def get_custom_policy(): - >>> no_recompute_list = [ - >>> torch.ops.aten.mm.default, - >>> ] - >>> def custom_policy(mode, func, *args, **kwargs): - >>> return func in no_recompute_list - >>> return custom_policy + >>> x = torch.rand(10, 10, requires_grad=True) + >>> y = torch.rand(10, 10, requires_grad=True) >>> - >>> def selective_checkpointing_context_fn(): - >>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy()) + >>> ops_to_save = [ + >>> torch.ops.aten.mm.default, + >>> ] >>> - >>> def gn(x, y): - >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> if op in ops_to_save: + >>> return CheckpointPolicy.MUST_SAVE + >>> else: + >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> - >>> def fn(x, y): - >>> return torch.utils.checkpoint.checkpoint( - >>> gn, x, y, - >>> use_reentrant=False, - >>> context_fn=selective_checkpointing_context_fn, - >>> ) + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> # or equivalently + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> - >>> x = torch.randn(4, 4, requires_grad=True) - >>> y = torch.randn(4, 4, requires_grad=True) + >>> def fn(x, y): + >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> - >>> compiled_fn = torch.compile(fn) + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) """ - storage: Dict[Any, List[Any]] = defaultdict(list) - return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage) + # NB: If grad_mode is disabled, checkpoint would not run forward under + # context_fn anyway, so proceed as usual. + if isinstance(policy_fn_or_list, list): + for op in policy_fn_or_list: + if not isinstance(op, torch._ops.OpOverload): + _extra_msg = ( + "Please update the OpOverloadPacket to a specific OpOverload." + "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." + ) if isinstance(op, torch._ops.OpOverloadPacket) else "" + raise ValueError( + f"Expected op in `op_list` to be an OpOverload but got: {op} " + f"of type {type(op)}. {_extra_msg}" + ) + def policy_fn(ctx, op, *args, **kwargs): + if op in policy_fn_or_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + elif callable(policy_fn_or_list): + policy_fn = policy_fn_or_list + else: + raise TypeError("policy_fn_or_list must be either a function or a list of ops.") + + storage: Dict[Any, List[Any]] = defaultdict(list) + return ( + _CachingTorchDispatchMode(policy_fn, storage), + _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), + ) # NB: this helper wraps fn before calling checkpoint_impl. kwargs and # saving/restoring of global state is handled here. From d77a1aaa8623ba5e70f4f147362d84769784cf43 Mon Sep 17 00:00:00 2001 From: loganthomas Date: Tue, 18 Jun 2024 18:26:07 +0000 Subject: [PATCH 09/64] DOC: add note about same sized tensors to dist.gather() (#128676) Fixes #103305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128676 Approved by: https://github.com/wconstab --- torch/distributed/distributed_c10d.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index bd81fd61b02f..d44c3733a214 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3041,11 +3041,12 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): """ Gathers tensors from the whole group in a list. - Complex tensors are supported. + Complex and uneven sized tensors are supported. Args: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. + Uneven sized tensors are supported. tensor (Tensor): Tensor to be broadcast from current process. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. @@ -3118,6 +3119,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal """ Gather tensors from all ranks and put them in a single output tensor. + This function requires all tensors to be the same size on each process. + Args: output_tensor (Tensor): Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the @@ -3341,11 +3344,13 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): """ Gathers a list of tensors in a single process. + This function requires all tensors to be the same size on each process. + Args: tensor (Tensor): Input tensor. - gather_list (list[Tensor], optional): List of appropriately-sized - tensors to use for gathered data (default is None, must be specified - on the destination rank) + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. From 1a527915a64b8e5f60951715b09fa294b1a8844f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 09:54:11 -0700 Subject: [PATCH 10/64] [DSD] Correctly handle shared parameters for optimizer state_dict (#128685) * Fixes https://github.com/pytorch/pytorch/issues/128011 See the discussion in https://github.com/pytorch/pytorch/pull/128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128685 Approved by: https://github.com/LucasLLC --- .../distributed/checkpoint/test_state_dict.py | 27 ++++++++++++ torch/distributed/checkpoint/state_dict.py | 42 ++++++++++++++++--- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 3da18ea5cc60..ac6263569af4 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -851,6 +851,33 @@ def test_deprecate_fsdp_api(self) -> None: ): get_model_state_dict(model) + @with_comms + @skip_if_lt_x_gpu(2) + def test_shared_weight(self): + class TiedEmbeddingModel(nn.Module): + def __init__(self, vocab_size, embedding_dim): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + self.decoder = nn.Linear(embedding_dim, vocab_size) + self.decoder.weight = self.embedding.weight # Tying weights + + def forward(self, input): + input = (input * 10).to(torch.int) + embedded = self.embedding(input) + output = self.decoder(embedded) + return output + + def init_model_optim(): + device_mesh = init_device_mesh("cuda", (self.world_size,)) + orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) + orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) + dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3) + return orig_model, orig_optim, copy_optim, dist_model, dist_optim + + self._test_save_load(init_model_optim) + class TestNoComm(MultiProcessTestCase): def setUp(self) -> None: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 16a1ddde2158..6bdeb389e8a0 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -153,6 +153,9 @@ class _StateDictInfo(StateDictOptions): fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] ] = field(default_factory=dict) + shared_params_mapping: Dict[ + Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] + ] = field(default_factory=dict) submodule_prefixes: Set[str] = field(default_factory=set) handle_model: bool = True handle_optim: bool = True @@ -286,14 +289,29 @@ def _verify_options( fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[Set[str], torch.Tensor] ] = {} + shared_params_mapping: Dict[ + Union[str, torch.Tensor], Union[Set[str], torch.Tensor] + ] = {} for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + fqns = _get_fqns(model, name) - if not isinstance(param, _EXTRA_STATE): - fqn_param_mapping[param] = fqns + fqn = fqn_param_mapping.get(param, None) + if fqn is not None: + cast(Set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() for fqn in fqns: if not isinstance(param, _EXTRA_STATE): fqn_param_mapping[fqn] = param + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + submodule_prefixes: Set[str] = set() if submodules: submodules = set(submodules) @@ -361,6 +379,7 @@ def fsdp_state_dict_type_without_warning( return _StateDictInfo( **asdict(options), fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, submodule_prefixes=submodule_prefixes, fsdp_context=fsdp_context, fsdp_modules=cast(List[nn.Module], fsdp_modules), @@ -450,7 +469,7 @@ def _get_model_state_dict( for key in list(state_dict.keys()): fqns = _get_fqns(model, key) - assert len(fqns) == 1 + assert len(fqns) == 1, (key, fqns) fqn = next(iter(fqns)) if fqn != key: # As we only support FSDP, DDP, and TP, the only cases are @@ -797,6 +816,19 @@ def _split_optim_state_dict( pg_state.append({_PARAMS: []}) for param in param_group[_PARAMS]: for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(List[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + params = pg_state[-1][_PARAMS] assert isinstance(params, list) params.append(fqn) @@ -805,9 +837,7 @@ def _split_optim_state_dict( for loaded_param_group in cast( ListDictValueType, optim_state_dict[_PG] ): - params = loaded_param_group[_PARAMS] - assert isinstance(params, list) - if fqn in params: + if fqn in cast(List[str], loaded_param_group[_PARAMS]): pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 for param_group in cast(ListDictValueType, optim_state_dict[_PG]): From bdffd9f0c6f4564ee0cdd15d030215b5df58b2a9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 23:10:58 -0700 Subject: [PATCH 11/64] [export] Graph break on nn.Parameter construction (#128935) Fixes https://github.com/pytorch/pytorch/issues/126109 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128935 Approved by: https://github.com/angelayi --- torch/_dynamo/variables/torch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 74c2193646bc..1cc4622dea52 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -877,6 +877,9 @@ def handle_ntuple(value): @classmethod def call_nn_parameter(cls, tx, data=None, requires_grad=True): """A call to torch.nn.Parameter() gets lifted to before the graph""" + if tx.export: + unimplemented("nn parameter construction not supported with export") + if isinstance(requires_grad, variables.VariableTracker): try: requires_grad = requires_grad.as_python_constant() From 44483972bdd3dcd0c047020694817210846b5d70 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Jun 2024 06:51:37 -0700 Subject: [PATCH 12/64] [EZ] Keep weight_norm var name aligned (#128955) To keep it aligned with https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/native_functions.yaml#L6484 I.e. `x`->`v`, `y`->`g` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128955 Approved by: https://github.com/albanD, https://github.com/Skylion007 --- torch/_decomp/decompositions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 7ebc69462fa1..dca552137ca6 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4770,11 +4770,11 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None): @register_decomposition(torch.ops.aten._weight_norm_interface) -def _weight_norm_interface(x, y, dim=0): +def _weight_norm_interface(v, g, dim=0): # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 - keep_dim = tuple(i for i in range(len(x.shape)) if i != dim) - norm = x.norm(2, keep_dim, keepdim=True) - return x * (y / norm), norm + keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) + norm = v.norm(2, keep_dim, keepdim=True) + return v * (g / norm), norm @register_decomposition(aten.isin) From 04a5d3228ecd5af790dabcfeb27c8c4f86742e11 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Jun 2024 19:11:04 +0000 Subject: [PATCH 13/64] [ts migration] Support prim::tolist and aten::len (#128894) Support prim::tolist and aten::len. Add unit tests for prim::min. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128894 Approved by: https://github.com/angelayi --- test/export/test_converter.py | 106 +++++++++++++++++++++++++++++++++- torch/_export/converter.py | 12 +++- 2 files changed, 116 insertions(+), 2 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 300f70223a26..8ea6a8089ae8 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -111,13 +111,102 @@ def forward(self, x): def test_aten_len(self): class Module(torch.nn.Module): - def forward(self, x): + def forward(self, x: torch.Tensor): length = len(x) return torch.ones(length) + # aten::len.Tensor inp = (torch.ones(2, 3),) self._check_equal_ts_ep_converter(Module(), inp) + class Module(torch.nn.Module): + def forward(self, x: List[int]): + length = len(x) + return torch.ones(length) + + # aten::len.t + inp = ([1, 2, 3],) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[int, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_int + inp = ({1: "a", 2: "b", 3: "c"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[bool, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_bool + inp = ({True: "a", False: "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[float, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_float + inp = ({1.2: "a", 3.4: "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[torch.Tensor, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_Tensor + inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + # aten::len.str and aten::len.Dict_str are not supported + # since torch._C._jit_flatten does not support str + # inp = ("abcdefg",) + # self._check_equal_ts_ep_converter(Module(), inp) + # inp = ({"a": 1, "b": 2},) + # self._check_equal_ts_ep_converter(Module(), inp) + + def test_prim_min(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x_len = len(x) + y_len = len(y) + + # prim::min.int + len_int = min(x_len, y_len) + + # prim::min.float + len_float = int(min(x_len * 2.0, y_len * 2.0)) + + # prim::min.self_int + len_self_int = min([x_len, y_len]) + + # prim::min.self_float + len_self_float = int(min([x_len * 2.0, y_len * 2.0])) + + # prim::min.float_int + len_float_int = int(min(x_len * 2.0, y_len)) + + # prim::min.int_float + len_int_float = int(min(x_len, y_len * 2.0)) + + return torch.ones( + len_int + + len_float + + len_self_int + + len_self_float + + len_float_int + + len_int_float + ) + + inp = (torch.randn(10, 2), torch.randn(5)) + self._check_equal_ts_ep_converter(Module(), inp) + def test_aten___getitem___list(self): class Module(torch.nn.Module): def forward(self, x): @@ -659,6 +748,21 @@ def forward(self, x): # inp = (torch.randn([2, 3, 4]),) # self._check_equal_ts_ep_converter(func6, inp) + def test_prim_tolist(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> List[int]: + return x.tolist() + + inp = (torch.tensor([1, 2, 3]),) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> List[List[int]]: + return x.tolist() + + inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 2c54db38dee8..48f983b2917e 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -91,6 +91,7 @@ def get_dtype_as_int(tensor): "aten::__not__": operator.not_, "aten::__contains__": operator.contains, "prim::dtype": get_dtype_as_int, + "aten::len": len, } @@ -187,7 +188,7 @@ def _map_blocks_to_lifted_attrs(entry): def get_op_overload(node: torch._C.Node): schema_str = node.schema() - schema = torch._C.parse_schema(schema_str) + schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str) ns, op_name = str(schema.name).split("::") override = schema.overload_name @@ -651,6 +652,15 @@ def convert_profiler__record_function_exit(self, node: torch._C.Node): args = tuple(self.get_fx_value(input) for input in node.inputs()) self.fx_graph.call_function(target, args) + def convert_prim_tolist(self, node: torch._C.Node): + # prim::tolist cannot be supported by `_convert_standard_operators` + # since it requires call_method instead of call_function. + target = "tolist" + args = (self.get_fx_value(next(node.inputs())),) + fx_node = self.fx_graph.call_method(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + def _convert_standard_operators(self, node: torch._C.Node): target = kind_to_standard_operators[node.kind()] args = tuple(self.get_fx_value(input) for input in node.inputs()) From abde6cab4c7f972672ae008223000c16fd3964cd Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 12 Jun 2024 19:33:15 -0700 Subject: [PATCH 14/64] Remove compile_threads=1 in test_inductor_collectives.py (#128580) Summary: I believe https://github.com/pytorch/pytorch/issues/125235 should be fixed after switching to subprocess-based parallel compile. Test Plan: Ran locally with python-3.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128580 Approved by: https://github.com/eellison --- test/distributed/test_inductor_collectives.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 35e44b19bedd..ee4535fd5a73 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -60,8 +60,6 @@ def world_size(self) -> int: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_broadcast_inductor(self): """ Testing if broadcast works correctly when using inductor @@ -94,8 +92,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_inductor(self): """ This is matmul/cat/allreduce is a pattern we aim to optimize. @@ -129,8 +125,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_inductor_cudagraph_trees(self): """ Tests whether cudagraph trees support all_reduce from nccl @@ -177,8 +171,6 @@ def test_c10d_functional_tagged_pt2_compliant(self): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_eager_allreduce_inductor_wait(self): def eager_func(a, b, c, d, *, tag, ranks, group_size): x = torch.matmul(a, b) @@ -218,8 +210,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_inductor_allreduce_eager_wait(self): def inductor_func(a, b, c, d, *, tag, ranks, group_size): x = torch.matmul(a, b) @@ -256,8 +246,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_input_buffer_reuse(self): def func(a, *, tag, ranks, group_size): ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) @@ -275,8 +263,6 @@ def func(a, *, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_permute_tensor(self): def func(tensor, src_dst_pairs, *, tag, ranks, group_size): return _functional_collectives.permute_tensor( @@ -304,8 +290,6 @@ def func(tensor, src_dst_pairs, *, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_output_buffer_reuse(self): class Model(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -329,8 +313,6 @@ def forward(self, x, world_size, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_contiguous_input(self): class Model(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -355,8 +337,6 @@ def forward(self, x, world_size, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_into_tensor_inductor(self): """ This is matmul/cat/allreduce is a pattern we aim to optimize. @@ -388,8 +368,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_reduce_scatter_tensor_inductor(self): def example(a, b, *, tag, ranks, group_size): c = torch.matmul(a, b) @@ -418,8 +396,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_all_to_all_single_inductor(self): def example( inp, @@ -488,8 +464,6 @@ def example( @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_all_to_all_single_inductor_split_sizes_none(self): def example(inp, *, tag, ranks, group_size): a2a = torch.ops.c10d_functional.all_to_all_single( From fe8558b7aa4ce55d06893c48d5cb00b7a7eb7dae Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 10:37:13 -0700 Subject: [PATCH 15/64] [DSD] Add unittest to verify HSDP1 + broadcast_from_rank0 (#128755) HSDP1 + broadcast_from_rank0 actually behaves differently from FSDP1 + broadcast_from_rank0. So we need an unittest to cover this use case. This test relies on the fix from https://github.com/pytorch/pytorch/pull/128446. Differential Revision: [D58621436](https://our.internmc.facebook.com/intern/diff/D58621436/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128755 Approved by: https://github.com/Skylion007, https://github.com/wz337 ghstack dependencies: #128685 --- .../distributed/checkpoint/test_state_dict.py | 157 ++++++++++-------- 1 file changed, 87 insertions(+), 70 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index ac6263569af4..773635062880 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -33,7 +33,11 @@ set_optimizer_state_dict, StateDictOptions, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, + StateDictType, +) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.optim import _apply_optimizer_in_backward from torch.nn.parallel import DistributedDataParallel as DDP @@ -70,7 +74,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): @property def world_size(self) -> int: - return 2 + return min(4, torch.cuda.device_count()) def _test_save_load( self, @@ -567,55 +571,71 @@ def test_non_persistent_buffers(self) -> None: set_model_state_dict(ddp_model, get_model_state_dict(ddp_model)) self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model)) - @with_comms - @skip_if_lt_x_gpu(2) - def test_broadcast_from_rank0(self) -> None: - def inner_test(wrapper): - model = CompositeParamModel(device=torch.device("cuda")) - optim = torch.optim.Adam(model.parameters()) - fsdp_model = wrapper(copy.deepcopy(model)) - fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) + def _test_broadcast_from_rank0(self, wrapper) -> None: + model = CompositeParamModel(device=torch.device("cuda")) + optim = torch.optim.Adam(model.parameters()) + fsdp_model = wrapper(copy.deepcopy(model)) + fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) - batch = torch.rand(8, 100, device="cuda") - model(batch).sum().backward() - optim.step() - states, optim_states = get_state_dict(model, optim) + batch = torch.rand(8, 100, device="cuda") + model(batch).sum().backward() + optim.step() + states, optim_states = get_state_dict(model, optim) - fsdp_model(batch).sum().backward() - fsdp_optim.step() + fsdp_model(batch).sum().backward() + fsdp_optim.step() - def check(equal): - fsdp_states = get_model_state_dict( - fsdp_model, - options=StateDictOptions(full_state_dict=True), - ) - fsdp_optim_states = get_optimizer_state_dict( - fsdp_model, - fsdp_optim, - options=StateDictOptions(full_state_dict=True), - ) - if equal: - self.assertEqual(states, fsdp_states) - self.assertEqual(optim_states, fsdp_optim_states) - else: - self.assertNotEqual(states, fsdp_states) - self.assertNotEqual(optim_states, fsdp_optim_states) - - check(equal=True) - fsdp_model(batch).sum().backward() - fsdp_optim.step() - check(equal=False) - - # Drop the states to simulate loading from rank0 - if dist.get_rank() > 0: - load_states = {} - load_states2 = {} - load_optim_states = {} + def check(equal): + fsdp_states = get_model_state_dict( + fsdp_model, + options=StateDictOptions(full_state_dict=True), + ) + fsdp_optim_states = get_optimizer_state_dict( + fsdp_model, + fsdp_optim, + options=StateDictOptions(full_state_dict=True), + ) + if equal: + self.assertEqual(states, fsdp_states) + self.assertEqual(optim_states, fsdp_optim_states) else: - load_states = copy.deepcopy(states) - load_states2 = copy.deepcopy(states) - load_optim_states = copy.deepcopy(optim_states) + self.assertNotEqual(states, fsdp_states) + self.assertNotEqual(optim_states, fsdp_optim_states) + + check(equal=True) + fsdp_model(batch).sum().backward() + fsdp_optim.step() + check(equal=False) + + # Drop the states to simulate loading from rank0 + if dist.get_rank() > 0: + load_states = {} + load_states2 = {} + load_optim_states = {} + else: + load_states = copy.deepcopy(states) + load_states2 = copy.deepcopy(states) + load_optim_states = copy.deepcopy(optim_states) + set_model_state_dict( + fsdp_model, + model_state_dict=load_states, + options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), + ) + set_optimizer_state_dict( + fsdp_model, + fsdp_optim, + optim_state_dict=load_optim_states, + options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), + ) + + check(equal=True) + # Verify the `strict` flag. + load_states = load_states2 + if load_states: + key = next(iter(load_states.keys())) + load_states.pop(key) + with self.assertRaisesRegex(RuntimeError, "Missing key"): set_model_state_dict( fsdp_model, model_state_dict=load_states, @@ -623,30 +643,10 @@ def check(equal): broadcast_from_rank0=True, full_state_dict=True ), ) - set_optimizer_state_dict( - fsdp_model, - fsdp_optim, - optim_state_dict=load_optim_states, - options=StateDictOptions( - broadcast_from_rank0=True, full_state_dict=True - ), - ) - - check(equal=True) - # Verify the `strict` flag. - load_states = load_states2 - if load_states: - key = next(iter(load_states.keys())) - load_states.pop(key) - with self.assertRaisesRegex(RuntimeError, "Missing key"): - set_model_state_dict( - fsdp_model, - model_state_dict=load_states, - options=StateDictOptions( - broadcast_from_rank0=True, full_state_dict=True - ), - ) + @with_comms + @skip_if_lt_x_gpu(2) + def test_broadcast_from_rank0(self) -> None: device_mesh = init_device_mesh("cuda", (self.world_size,)) self.run_subtests( { @@ -655,7 +655,24 @@ def check(equal): functools.partial(FSDP, device_mesh=device_mesh), ] }, - inner_test, + self._test_broadcast_from_rank0, + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_broadcast_from_rank0_hsdp(self) -> None: + device_mesh = init_device_mesh("cuda", (2, self.world_size // 2)) + self.run_subtests( + { + "wrapper": [ + functools.partial( + FSDP, + device_mesh=device_mesh, + sharding_strategy=ShardingStrategy.HYBRID_SHARD, + ), + ] + }, + self._test_broadcast_from_rank0, ) @with_comms From 9a7e2519d3d15f8d469b71cab914fcdaf071ebd6 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 18 Jun 2024 19:59:50 +0000 Subject: [PATCH 16/64] [MPS] Fused Adam & AdamW (#127242) Summary: This PR adds fused Adam and AdamW implementations. Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory: **Fast math enabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300 Times are in milliseconds (ms). ``` **Fast math disabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200 Times are in milliseconds (ms). ``` ```python def profile_fused_adam(): from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device = "mps" results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)] if adamWflag: fn = adamw.adamw else: fn = adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label='Fused Adam', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242 Approved by: https://github.com/kulinseth, https://github.com/janeyx99 --- aten/src/ATen/native/mps/OperationUtils.h | 19 +- aten/src/ATen/native/mps/OperationUtils.mm | 31 +- .../operations/FusedAdamAmsgradKernelImpl.h | 24 ++ .../operations/FusedAdamAmsgradKernelImpl.mm | 37 +++ .../native/mps/operations/FusedAdamKernel.mm | 69 +++++ .../mps/operations/FusedAdamKernelImpl.h | 23 ++ .../mps/operations/FusedAdamKernelImpl.mm | 35 +++ .../operations/FusedAdamWAmsgradKernelImpl.h | 24 ++ .../operations/FusedAdamWAmsgradKernelImpl.mm | 37 +++ .../native/mps/operations/FusedAdamWKernel.mm | 68 +++++ .../mps/operations/FusedAdamWKernelImpl.h | 23 ++ .../mps/operations/FusedAdamWKernelImpl.mm | 35 +++ .../native/mps/operations/FusedOptimizerOps.h | 274 ++++++++++++++++++ .../native/mps/operations/MultiTensorApply.h | 190 ++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_mps.py | 34 +-- test/test_optim.py | 31 +- torch/optim/adam.py | 6 + torch/optim/adamw.py | 6 + torch/testing/_internal/common_optimizers.py | 4 +- torch/utils/_foreach_utils.py | 2 +- 21 files changed, 911 insertions(+), 63 deletions(-) create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernel.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedOptimizerOps.h create mode 100644 aten/src/ATen/native/mps/operations/MultiTensorApply.h diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 25e86e6d262f..a9493cbce3ad 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -336,25 +336,34 @@ inline bool is_dense_in_storage(const at::Tensor& t) { class MetalShaderLibrary { public: - MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {} + MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {} MetalShaderLibrary(const MetalShaderLibrary&) = delete; inline id getPipelineStateForFunc(const std::string& fname) { - return getLibraryPipelineState(getLibrary(), fname); + return getLibraryPipelineState(getLibrary(), fname).first; } id getPipelineStateForFunc(const std::string& fname, const std::initializer_list& params) { - return getLibraryPipelineState(getLibrary(params), fname); + return getLibraryPipelineState(getLibrary(params), fname).first; + } + inline id getMTLFunction(const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).second; + } + id getMTLFunction(const std::string& fname, const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).second; } private: - id getLibraryPipelineState(id lib, const std::string& fname); + std::pair, id> getLibraryPipelineState(id lib, const std::string& fname); id getLibrary(); id getLibrary(const std::initializer_list& params); id compileLibrary(const std::string& src); std::string shaderSource; unsigned nparams; + MTLCompileOptions* compile_options; id library = nil; std::unordered_map> libMap; - std::unordered_map> cplMap; + std::unordered_map, id>> cplMap; }; static inline void mtl_setBuffer(id encoder, const Tensor& t, unsigned idx) { diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 82d1fe9d92f4..8dc90e497fe4 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -656,31 +656,38 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} id MetalShaderLibrary::compileLibrary(const std::string& src) { NSError* error = nil; - MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 - : MTLLanguageVersion2_3]; - // [options setFastMathEnabled: NO]; - auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; + MTLCompileOptions* options = compile_options; + if (!options) { + options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 + : MTLLanguageVersion2_3]; + [options setFastMathEnabled:NO]; + } + + const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; auto device = MPSDevice::getInstance()->device(); library = [device newLibraryWithSource:str options:options error:&error]; TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]); return library; } -id MetalShaderLibrary::getLibraryPipelineState(id lib, const std::string& fname) { - auto key = fmt::format("{}:{}", reinterpret_cast(lib), fname); - auto cpl = cplMap[key]; - if (cpl) { - return cpl; +std::pair, id> MetalShaderLibrary::getLibraryPipelineState( + id lib, + const std::string& fname) { + const auto key = fmt::format("{}:{}", reinterpret_cast(lib), fname); + auto found_cpl = cplMap.find(key); + if (found_cpl != cplMap.end()) { + return found_cpl->second; } NSError* error = nil; id func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]]; TORCH_CHECK(func, "Failed to create function state object for: ", fname); - cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error]; + auto cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error]; TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); - return cplMap[key] = cpl; + cplMap[key] = std::make_pair(cpl, func); + return cplMap[key]; } } // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h new file mode 100644 index 000000000000..8711cb228ee9 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adam_amsgrad_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm new file mode 100644 index 000000000000..be6069ad9694 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adam_amsgrad_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{ + params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adam_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<5, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm b/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm new file mode 100644 index 000000000000..2e4d89ff851c --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm @@ -0,0 +1,69 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at::native { + +void _fused_adam_kernel_mps_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (amsgrad) { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adam_amsgrad_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } else { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), + "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adam_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } +} + +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h new file mode 100644 index 000000000000..90d1ee150932 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adam_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm new file mode 100644 index 000000000000..e3c87ae9bc78 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm @@ -0,0 +1,35 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adam_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adam_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<4, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h new file mode 100644 index 000000000000..f03fcdb57413 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_amsgrad_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm new file mode 100644 index 000000000000..fd94e9686fbc --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_amsgrad_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{ + params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<5, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm b/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm new file mode 100644 index 000000000000..ce08972ef9ad --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm @@ -0,0 +1,68 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at::native { + +void _fused_adamw_kernel_mps_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (amsgrad) { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adamw_amsgrad_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } else { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), + "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adamw_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } +} +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h new file mode 100644 index 000000000000..284516e0b89c --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm new file mode 100644 index 000000000000..8899f6a5e9e1 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm @@ -0,0 +1,35 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adamw_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<4, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h b/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h new file mode 100644 index 000000000000..00a75067b7f4 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h @@ -0,0 +1,274 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +static const char* FUSED_ADAM_OPS = R"METAL( +#include + +#define kmaxThreadGroups 32 +#define kmaxTensors 32 +#define chunk_size 65536 + +constexpr constant uint kParamIdx = 0; +constexpr constant uint kGradIdx = kParamIdx + kmaxTensors; +constexpr constant uint kExpAvgIdx = kGradIdx + kmaxTensors; +constexpr constant uint kExpAvgSqIdx = kExpAvgIdx + kmaxTensors; +constexpr constant uint kMaxExpAvgSqIdx = kExpAvgSqIdx + kmaxTensors; +constexpr constant uint kStateStepsIdx = kExpAvgSqIdx + kmaxTensors; +constexpr constant uint kStateStepsIdxForAmsgrad = kMaxExpAvgSqIdx + kmaxTensors; + +template +struct AdamArguments { + metal::array params [[ id(kParamIdx) ]]; + metal::array grads [[ id(kGradIdx) ]]; + metal::array exp_avgs [[ id(kExpAvgIdx) ]]; + metal::array exp_avg_sqs [[ id(kExpAvgSqIdx) ]]; + metal::array state_steps [[ id(kStateStepsIdx) ]]; +}; + +template +struct AdamAmsgradArguments { + metal::array params [[ id(kParamIdx) ]]; + metal::array grads [[ id(kGradIdx) ]]; + metal::array exp_avgs [[ id(kExpAvgIdx) ]]; + metal::array exp_avg_sqs [[ id(kExpAvgSqIdx) ]]; + metal::array max_exp_avg_sqs [[ id(kMaxExpAvgSqIdx) ]]; + metal::array state_steps [[ id(kStateStepsIdxForAmsgrad) ]]; +}; + +struct MetadataArguments { + uint32_t numels[kmaxTensors]; + uint32_t threadgroup_to_tensor[kmaxThreadGroups]; + uint32_t threadgroup_to_chunk[kmaxThreadGroups]; +}; + +enum ADAM_MODE : uint8_t { + ORIGINAL = 0, + ADAMW = 1 +}; + +template +inline void adam_math_amsgrad( + device T & param, + device T & grad, + device T & exp_avg, + device T & exp_avg_sq, + device T & max_exp_avg_sq, + device state_steps_t & state_steps, + const float lr, + const float beta1, + const float beta2, + const float weight_decay, + const float eps, + const uint8_t maximize +) { + T grad_ = grad; + + if (maximize) { + grad = -grad; + } + + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + switch (adam_mode) { + case ADAM_MODE::ORIGINAL: + grad += param * weight_decay; + break; + case ADAM_MODE::ADAMW: + param -= lr * weight_decay * param; + break; + } + } + + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const float casted_state_steps = static_cast(state_steps); + const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); + const T step_size = lr / bias_correction1; + const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); + const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq); + + const T denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps; + param -= step_size * exp_avg / denom; + grad = grad_; +} + +template +inline void adam_math( + device T & param, + device T & grad, + device T & exp_avg, + device T & exp_avg_sq, + device state_steps_t & state_steps, + const float lr, + const float beta1, + const float beta2, + const float weight_decay, + const float eps, + const uint8_t maximize +) { + T grad_ = grad; + + if (maximize) { + grad = -grad; + } + + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + switch (adam_mode) { + case ADAM_MODE::ORIGINAL: + grad += param * weight_decay; + break; + case ADAM_MODE::ADAMW: + param -= lr * weight_decay * param; + break; + } + } + + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const float casted_state_steps = static_cast(state_steps); + const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); + const T step_size = lr / bias_correction1; + const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); + const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + const T denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps; + param -= step_size * exp_avg / denom; + grad = grad_; +} + +template +kernel void fused_adam_amsgrad( + device AdamAmsgradArguments & args [[buffer(0)]], + constant MetadataArguments & metadata_args [[buffer(1)]], + constant float & lr [[buffer(2)]], + constant float & beta1 [[buffer(3)]], + constant float & beta2 [[buffer(4)]], + constant float & weight_decay [[buffer(5)]], + constant float & eps [[buffer(6)]], + constant uint8_t & maximize [[buffer(7)]], + uint tid [[thread_position_in_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint tptg [[threads_per_threadgroup]]) { + + const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid]; + const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid]; + const uint32_t chunk_offset = chunk_idx * chunk_size; + const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset; + + const auto step_count = args.state_steps[tensor_loc]; + + // each chunk is a threadgroup + auto param = args.params[tensor_loc] + chunk_offset; + auto grad = args.grads[tensor_loc] + chunk_offset; + auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset; + auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset; + auto max_exp_avg_sq = args.max_exp_avg_sqs[tensor_loc] + chunk_offset; + + for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) { + adam_math_amsgrad( + *(param + i_start), + *(grad + i_start), + *(exp_avg + i_start), + *(exp_avg_sq + i_start), + *(max_exp_avg_sq + i_start), + *step_count, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize + ); + } +} + +template +kernel void fused_adam( + device AdamArguments & args [[buffer(0)]], + constant MetadataArguments & metadata_args [[buffer(1)]], + constant float & lr [[buffer(2)]], + constant float & beta1 [[buffer(3)]], + constant float & beta2 [[buffer(4)]], + constant float & weight_decay [[buffer(5)]], + constant float & eps [[buffer(6)]], + constant uint8_t & maximize [[buffer(7)]], + uint tid [[thread_position_in_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint tptg [[threads_per_threadgroup]]) { + + const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid]; + const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid]; + const uint32_t chunk_offset = chunk_idx * chunk_size; + const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset; + + const auto step_count = args.state_steps[tensor_loc]; + + // each chunk is a threadgroup + auto param = args.params[tensor_loc] + chunk_offset; + auto grad = args.grads[tensor_loc] + chunk_offset; + auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset; + auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset; + + for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) { + adam_math( + *(param + i_start), + *(grad + i_start), + *(exp_avg + i_start), + *(exp_avg_sq + i_start), + *step_count, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize + ); + } +} + +#define REGISTER_FUSED_ADAM_OP(DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE, HOST_NAME, KERNEL_NAME, ARGUMENTS_STRUCT) \ +template \ +[[host_name(#HOST_NAME "_" #DTYPE "_" #STATE_STEPS_DTYPE)]] \ +kernel void KERNEL_NAME( \ + device ARGUMENTS_STRUCT & args [[buffer(0)]],\ + constant MetadataArguments & metadata_args [[buffer(1)]],\ + constant float & lr [[buffer(2)]],\ + constant float & beta1 [[buffer(3)]],\ + constant float & beta2 [[buffer(4)]],\ + constant float & weight_decay [[buffer(5)]],\ + constant float & eps [[buffer(6)]],\ + constant uint8_t & maximize [[buffer(7)]],\ + uint tid [[thread_position_in_threadgroup]],\ + uint tgid [[threadgroup_position_in_grid]],\ + uint tptg [[threads_per_threadgroup]]) + +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); + +)METAL"; + +static std::pair, id> getCPLState(const std::string& fname) { + static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0); + return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname)); +} + +} //namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/MultiTensorApply.h b/aten/src/ATen/native/mps/operations/MultiTensorApply.h new file mode 100644 index 000000000000..fe9296cc0db7 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/MultiTensorApply.h @@ -0,0 +1,190 @@ +#pragma once +#include +#include +#include + +namespace at::native { +namespace mps { + +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kmaxThreadGroups = 32; +static constexpr int64_t kmaxTensors = 32; + +struct MetadataArguments { // the size of this struct must be less than 4 bytes + uint numels[kmaxTensors]; + uint threadgroup_to_tensor[kmaxThreadGroups]; + uint threadgroup_to_chunk[kmaxThreadGroups]; +}; + +template +static void multi_tensor_apply_for_fused_adam( + const std::string& kernel_name, + std::vector>& tensor_lists, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize + ) { + const auto num_tensors = tensor_lists[0].size(); + + if (num_tensors == 0) { + return; + } + + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth"); + for (const auto& d : c10::irange(depth)) { + TORCH_CHECK( + tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported"); + } + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + float lr_lv = lr; + float beta1_lv = beta1; + float beta2_lv = beta2; + float weight_decay_lv = weight_decay; + float eps_lv = eps; + uint8_t maximize_lv = maximize; + + // Remove comment for debugging + /* + mpsStream->addCompletedHandler(^(id cb) { + [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) { + NSLog(@"MPSStream: %@", log); + } + ]; + }); + */ + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [fusedOptimizerPSO, fusedOptimizerFunc] = getCPLState(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]}); + + [computeEncoder setComputePipelineState:fusedOptimizerPSO]; + + // BufferIndex is the index in the kernel function + auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease]; + id tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int64_t tensor_loc = 0; + int64_t threadgroup_loc = 0; + MetadataArguments metadata_arguments; + + for (const auto tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + + for (const auto& d : c10::irange(depth)) { + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index]) + offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size() + atIndex:d * kmaxTensors + tensor_loc]; + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index]) + offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size() + atIndex:depth * kmaxTensors + tensor_loc]; + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel(); + + tensor_loc++; + + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + + for (const auto& chunk : c10::irange(chunks)) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + + threadgroup_loc++; + + const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1; + // Reach the maximum threadgroups per dispatch + const auto blocks_full = threadgroup_loc == kmaxThreadGroups; + + if (tensor_full || blocks_full){ + [computeEncoder setBuffer:tensorArgumentBuffer + offset:0 + atIndex:0]; + [computeEncoder setBytes:&metadata_arguments + length:sizeof(MetadataArguments) + atIndex:1]; + [computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2]; + [computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3]; + [computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4]; + [computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5]; + [computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6]; + [computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7]; + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // Reset + threadgroup_loc = 0; + if (chunk == chunks - 1) { + // last chunk + tensor_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + } else { + // reuse the current tensor since the current one isn't done. + metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1]; + + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + for (const auto& d : c10::irange(depth)) { + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index]) + offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size() + atIndex:d * kmaxTensors + 0]; + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead]; + } + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index]) + offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size() + atIndex:depth * kmaxTensors + 0]; + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + + tensor_loc = 1; + } + } + } + } + + if (threadgroup_loc != 0) { + + [computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0]; + [computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1]; + [computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2]; + [computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3]; + [computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4]; + [computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5]; + [computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6]; + [computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7]; + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + + getMPSProfiler().endProfileKernel(fusedOptimizerPSO); + + } + }); +} + +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7474e0bc55d8..b030141882c8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15575,6 +15575,7 @@ dispatch: CPU: _fused_adam_kernel_cpu_ CUDA: _fused_adam_kernel_cuda_ + MPS: _fused_adam_kernel_mps_ autogen: _fused_adam, _fused_adam.out - func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15593,6 +15594,7 @@ dispatch: CPU: _fused_adamw_kernel_cpu_ CUDA: _fused_adamw_kernel_cuda_ + MPS: _fused_adamw_kernel_mps_ autogen: _fused_adamw, _fused_adamw.out - func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () diff --git a/test/test_mps.py b/test/test_mps.py index 311cf8245c4f..a97b8fb8d6b1 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -76,7 +76,6 @@ def mps_ops_grad_modifier(ops): XFAILLIST_GRAD = { # precision issues - 'digamma': [torch.float32], 'special.polygammaspecial_polygamma_n_0': [torch.float16], 'polygammapolygamma_n_0': [torch.float16], 'nn.functional.binary_cross_entropy': [torch.float16], @@ -95,7 +94,6 @@ def mps_ops_grad_modifier(ops): 'masked.scatter': [torch.float16, torch.float32], 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`. 'aminmax': [torch.float32, torch.float16], - 'polar': [torch.float32], # Correctness issues 'atanh': [torch.float32], @@ -569,7 +567,6 @@ def mps_ops_modifier(ops): 'special.ndtr': [torch.uint8], 'sqrt': [torch.uint8], 'sub': [torch.uint8], - 'tanh': [torch.uint8], 'trapezoid': [torch.uint8], 'trapz': [torch.uint8], 'true_divide': [torch.uint8], @@ -586,28 +583,13 @@ def mps_ops_modifier(ops): 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # cpu not giving nan for x/0.0 - 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # inconsistency errors between cpu and mps, max seen atol is 2 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_BEFORE_13_3_XFAILLIST = { - # Failure due to precision issues (still present on 13.3+) as well as non-standard behavior of - # cpu ops for the negative integers. - # Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)): - # - CPU output: tensor([102.668, 1.129e+15]) - # - MPS output: tensor([102.6681, inf]) - # In the latter case, inf is probably correct (this is what scipy does). - 'polygamma': [torch.float32, torch.uint8], - 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8], - 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8], - 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], - # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 'tan': [torch.float32], 'cdist': [torch.float32], @@ -656,20 +638,6 @@ def mps_ops_modifier(ops): # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], - - # Failure due to precision issues as well as non-standard behavior of cpu ops for the - # negative integers. Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)): - # - CPU output: tensor([102.668, 1.129e+15]) - # - MPS output: tensor([102.6681, inf]) - # In the latter case, inf is probably correct (this is what scipy does). - 'polygamma': [torch.float32, torch.uint8], - 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8], - 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8], - 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], } MACOS_BEFORE_14_4_XFAILLIST = { diff --git a/test/test_optim.py b/test/test_optim.py index d61c33e2adce..fb655ce36a53 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -32,6 +32,7 @@ ) from torch.testing._internal.common_dtype import floating_types_and from torch.testing._internal.common_optimizers import ( + _get_device_type, _get_optim_inputs_including_global_cliquey_kwargs, optim_db, OptimizerErrorEnum, @@ -1004,7 +1005,6 @@ def test_peak_memory_foreach(self, device, dtype, optim_info): self.assertLessEqual(mt_max_mem, expected_max_mem) - @onlyNativeDeviceTypes @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=floating_types_and( @@ -1013,10 +1013,15 @@ def test_peak_memory_foreach(self, device, dtype, optim_info): ), ) def test_fused_matches_forloop(self, device, dtype, optim_info): - if device not in optim_info.supports_fused_on: + if _get_device_type(device) not in optim_info.supports_fused_on: self.skipTest( f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" ) + if _get_device_type(device) == "mps" and dtype not in ( + torch.float16, + torch.float32, + ): + self.skipTest("MPS supports only torch.float16 and torch.float32") self._test_derived_optimizers(device, dtype, optim_info, "fused") @onlyNativeDeviceTypes @@ -1076,7 +1081,6 @@ def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info): ) self.assertEqual(params, params_c) - @onlyCUDA @parametrize("impl", ["fused", "capturable"]) @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], @@ -1100,8 +1104,15 @@ def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): ): # Capturable SGD/Adagrad does not exist self.skipTest("SGD does not currently support capturable") - if impl == "fused" and device not in optim_info.supports_fused_on: + if _get_device_type(device) == "cpu": + self.skipTest("Test is only for non-cpu devices") + elif ( + impl == "fused" + and _get_device_type(device) not in optim_info.supports_fused_on + ): self.skipTest(f"{device} is not supported for fused on {opt_name}") + elif impl == "capturable" and _get_device_type(device) == "mps": + self.skipTest("MPS does not support capturable") cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu") for optim_input in cpu_optim_inputs: @@ -1114,12 +1125,12 @@ def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): # load optim_input.kwargs[impl] = True - param_cuda = param.clone().detach().to(device="cuda") - optimizer_cuda = optim_cls([param_cuda], **optim_input.kwargs) - optimizer_cuda.load_state_dict(optim_state_dict_cpu) - optimizer_cuda.zero_grad() - param_cuda.grad = torch.rand_like(param_cuda) - optimizer_cuda.step() + param_device = param.clone().detach().to(device=device) + optimizer_device = optim_cls([param_device], **optim_input.kwargs) + optimizer_device.load_state_dict(optim_state_dict_cpu) + optimizer_device.zero_grad() + param_device.grad = torch.rand_like(param_device) + optimizer_device.step() @optims(optim_db, dtypes=[torch.float32]) def test_param_groups_weight_decay(self, device, dtype, optim_info): diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 86785be4ed17..fa7397e02b42 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -309,6 +309,8 @@ def step(self, closure=None): {_capturable_doc} {_differentiable_doc} {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _On the Convergence of Adam and Beyond: @@ -660,6 +662,10 @@ def _fused_adam( ), _, ) in grouped_tensors.items(): + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + assert not isinstance(lr, Tensor) + device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 00931bed0227..20ab82755249 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -310,6 +310,8 @@ def step(self, closure=None): {_capturable_doc} {_differentiable_doc} {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: @@ -662,6 +664,10 @@ def _fused_adamw( ), _, ) in grouped_tensors.items(): + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + assert not isinstance(lr, Tensor) + device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 628bedad313d..b7d06e7dc808 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1232,7 +1232,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), - supports_fused_on=("cpu", "cuda"), + supports_fused_on=("cpu", "cuda", "mps"), decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1354,7 +1354,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamw, optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), - supports_fused_on=("cpu", "cuda"), + supports_fused_on=("cpu", "cuda", "mps"), decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index bcc274579ad0..c3100d41b6c0 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -11,7 +11,7 @@ def _get_foreach_kernels_supported_devices() -> List[str]: def _get_fused_kernels_supported_devices() -> List[str]: r"""Return the device type list that supports fused kernels in optimizer.""" - return ["cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()] + return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()] TensorListList: TypeAlias = List[List[Optional[Tensor]]] Indices: TypeAlias = List[int] From 5bc9835d64eb5592cb606252ccf19212872cefc7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 17/64] Revert "[dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428)" This reverts commit c52eda896eb3ec7f8d04b6321861f4c5614a40bb. Reverted https://github.com/pytorch/pytorch/pull/128428 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- test/dynamo/test_repros.py | 2 +- torch/_dynamo/trace_rules.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 2329ab305e76..dbcb259241fc 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1674,7 +1674,7 @@ def test_issue175(self): self.assertEqual(cnt.frame_count, 1) self.assertEqual( - 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count ) def test_exec_import(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index abbef02e63c6..b5b12435a931 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2669,6 +2669,26 @@ "torch.nn._reduction.legacy_get_enum", "torch.nn._reduction.legacy_get_string", "torch.nn.factory_kwargs", + "torch.nn.functional._adaptive_max_pool1d", + "torch.nn.functional._adaptive_max_pool2d", + "torch.nn.functional._adaptive_max_pool3d", + "torch.nn.functional._canonical_mask", + "torch.nn.functional._fractional_max_pool2d", + "torch.nn.functional._fractional_max_pool3d", + "torch.nn.functional._get_softmax_dim", + "torch.nn.functional._in_projection_packed", + "torch.nn.functional._in_projection", + "torch.nn.functional._is_integer", + "torch.nn.functional._max_pool1d", + "torch.nn.functional._max_pool2d", + "torch.nn.functional._max_pool3d", + "torch.nn.functional._mha_shape_check", + "torch.nn.functional._no_grad_embedding_renorm_", + "torch.nn.functional._none_or_dtype", + "torch.nn.functional._threshold", + "torch.nn.functional._unpool_output_size", + "torch.nn.functional._verify_batch_size", + "torch.nn.functional._verify_spatial_size", "torch.nn.functional.adaptive_avg_pool2d", "torch.nn.functional.adaptive_avg_pool3d", "torch.nn.functional.adaptive_max_pool1d_with_indices", @@ -2766,7 +2786,15 @@ "torch.nn.grad.conv2d_weight", "torch.nn.grad.conv3d_input", "torch.nn.grad.conv3d_weight", + "torch.nn.modules.activation._arg_requires_grad", + "torch.nn.modules.activation._check_arg_device", "torch.nn.modules.activation._is_make_fx_tracing", + "torch.nn.modules.container._addindent", + "torch.nn.modules.transformer._detect_is_causal_mask", + "torch.nn.modules.transformer._generate_square_subsequent_mask", + "torch.nn.modules.transformer._get_activation_fn", + "torch.nn.modules.transformer._get_clones", + "torch.nn.modules.transformer._get_seq_len", "torch.nn.modules.utils._list_with_default", "torch.nn.modules.utils._ntuple", "torch.nn.modules.utils._quadruple", From 1babeddbbf3a44318d13cf3b8afaac2a6d657115 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 18/64] Revert "[inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484)" This reverts commit 1f6e84fa6852805e15ddc9583c5f36c3a7f93df8. Reverted https://github.com/pytorch/pytorch/pull/128484 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- test/inductor/test_mkldnn_pattern_matcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index a80d72398760..810c22d037c5 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -37,8 +37,7 @@ torch.nn.Tanh(): 2, torch.nn.Hardswish(): 6, torch.nn.LeakyReLU(0.1, inplace=False): 4, - # Use floats for min/max, otherwise they can get converted to symints - torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3, + torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3, torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, torch.nn.GELU(approximate="none"): 6, torch.nn.GELU(approximate="tanh"): 10, From 44722c6b1085611e0f20917a76fcf3f8f2776e13 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 19/64] Revert "[dynamo][fsdp] Dont take unspecializedNNModuleVariable path for FSDP modules (#128453)" This reverts commit 2b28b107dbafeec18d1095a2002e79511aa241df. Reverted https://github.com/pytorch/pytorch/pull/128453 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- torch/_dynamo/variables/builder.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index af91edb432c8..8a201410d6be 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1164,11 +1164,7 @@ def wrap_module(self, value: torch.nn.Module): and not config.allow_rnn ): unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") - - # Dont take this path for FSDP - if not getattr( - value, "_is_fsdp_managed_module", None - ) and mutation_guard.is_dynamic_nn_module(value, self.tx.export): + if mutation_guard.is_dynamic_nn_module(value, self.tx.export): # created dynamically, don't specialize on it self.install_guards(GuardBuilder.TYPE_MATCH) if ( From 5dc4f652bc5c068ef15130c955e3f2ffe11f4b74 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 18 Jun 2024 13:35:49 -0400 Subject: [PATCH 20/64] Backward support for unbind() with NJT (#128032) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032 Approved by: https://github.com/soulitzer --- test/test_nestedtensor.py | 19 +++++++++++++++++++ tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 17 +++++++++++++++++ torch/csrc/autograd/FunctionsManual.h | 4 ++++ torch/nested/_internal/ops.py | 11 +++++++++++ 5 files changed, 52 insertions(+), 1 deletion(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 50d6deea9291..fa33a13ed495 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5610,6 +5610,25 @@ def f(nt): for dynamic in [False, True, None]: self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + @dtypes(torch.float32, torch.double, torch.half) + def test_unbind_backward(self, device, dtype): + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 4, device=device), + torch.randn(5, 4, device=device), + torch.randn(3, 4, device=device), + ], + layout=torch.jagged, + requires_grad=True, + ) + + a, b, c = nt.unbind() + b.sum().backward() + + expected_grad = torch.zeros_like(nt) + expected_grad.unbind()[1].add_(1.0) + torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 76a7a0a1e42a..02a3e6c518ad 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2847,7 +2847,7 @@ self: unbind_backward(grads, dim) result: auto_linear AutogradNestedTensor: - self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options()) + self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" result: auto_linear - name: stack(Tensor[] tensors, int dim=0) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 9d897c667c90..f51c2f047f93 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1014,6 +1014,23 @@ Tensor unbind_backward_nested( return at::_nested_tensor_from_tensor_list(grads_tensors); } +Tensor unbind_backward_nested_jagged( + const variable_list& grads, + const Tensor& self, + int64_t dim) { + TORCH_INTERNAL_ASSERT( + dim == 0, "unbind_backward_nested_jagged() only supports dim=0") + auto grad_nt = at::zeros_like(self); + auto unbound_grads = grad_nt.unbind(); + for (int64_t i : c10::irange(static_cast(grads.size()))) { + if (grads[i].defined()) { + unbound_grads[i].copy_(static_cast(grads[i])); + } + } + + return grad_nt; +} + Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { auto result = self; diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index dedff70be1ba..ecf99bd09805 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -244,6 +244,10 @@ at::Tensor unbind_backward_nested( const Tensor& nt_sizes, int64_t dim, const at::TensorOptions& options); +at::Tensor unbind_backward_nested_jagged( + const variable_list& grads, + const Tensor& self, + int64_t dim); at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes); at::Tensor unsqueeze_to( const at::Tensor& self, diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 6f1c47dd6947..8458f0371713 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -472,6 +472,17 @@ def to_copy_default(func, *args, **kwargs): )(jagged_unary_pointwise) +@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") +def zero__default(func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + func(inp._values) + return inp + + @register_jagged_func( torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" ) From 4cc3fb5ee2296e1178cec710a945c99aa303170d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:38:22 -0700 Subject: [PATCH 21/64] Bump urllib3 from 2.2.1 to 2.2.2 in /tools/build/bazel (#128908) Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.2.1 to 2.2.2. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.2.1...2.2.2) --- updated-dependencies: - dependency-name: urllib3 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tools/build/bazel/requirements.txt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index cd95aeeec5c6..fea6221c9b7c 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -145,7 +145,7 @@ numpy==1.26.4 \ --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f - # via -r tools/build/bazel/requirements.in + # via -r requirements.in pyyaml==6.0.1 \ --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ @@ -198,26 +198,26 @@ pyyaml==6.0.1 \ --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f - # via -r tools/build/bazel/requirements.in + # via -r requirements.in requests==2.32.2 \ --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c - # via -r tools/build/bazel/requirements.in + # via -r requirements.in sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in typing-extensions==4.11.0 \ --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a - # via -r tools/build/bazel/requirements.in -urllib3==2.2.1 \ - --hash=sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d \ - --hash=sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19 + # via -r requirements.in +urllib3==2.2.2 \ + --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ + --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 # via requests # The following packages are considered to be unsafe in a requirements file: setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in From 2227da44317f4ea836aaad96337b53533aed2770 Mon Sep 17 00:00:00 2001 From: Aaron Enye Shi Date: Tue, 18 Jun 2024 21:01:01 +0000 Subject: [PATCH 22/64] [Profiler] Clean up use_mtia to follow standard use_device instead (#126284) Summary: use_mtia should instead set use_device='mtia' similar to cuda, xpu, and privateuseone. Avoid an ever-growing list of use_* arguments. Since use_mtia is specific to FBCode, we don't need a deprecation warning. Test Plan: CI. Differential Revision: D57338005 Pulled By: aaronenyeshi Pull Request resolved: https://github.com/pytorch/pytorch/pull/126284 Approved by: https://github.com/fenypatel99 --- torch/autograd/profiler.py | 13 +++++++------ torch/profiler/profiler.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 0392a8769846..f847fc13ff8a 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -118,7 +118,7 @@ class profile: use_device (str, optional): Enables timing of device events. Adds approximately 4us of overhead to each tensor operation when use cuda. - The valid devices options are 'cuda', 'xpu' and 'privateuseone'. + The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'. record_shapes (bool, optional): If shapes recording is set, information about input dimensions will be collected. This allows one to see which @@ -205,7 +205,6 @@ def __init__( with_modules=False, use_kineto=False, use_cpu=True, - use_mtia=False, experimental_config=None, ): self.enabled: bool = enabled @@ -231,7 +230,6 @@ def __init__( self.with_stack = with_stack self.with_modules = with_modules self.use_cpu = use_cpu - self.use_mtia = use_mtia if experimental_config is None: experimental_config = _ExperimentalConfig() self.experimental_config = experimental_config @@ -246,7 +244,7 @@ def __init__( ), "Device-only events supported only with Kineto (use_kineto=True)" if self.use_device is not None: - VALID_DEVICE_OPTIONS = ["cuda", "xpu"] + VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"] if _get_privateuse1_backend_name() != "privateuseone": VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name()) if self.use_device not in VALID_DEVICE_OPTIONS: @@ -265,8 +263,6 @@ def __init__( self.kineto_activities = set() if self.use_cpu: self.kineto_activities.add(ProfilerActivity.CPU) - if self.use_mtia: - self.kineto_activities.add(ProfilerActivity.MTIA) self.profiler_kind = ProfilerState.KINETO if self.use_device == "cuda": @@ -280,6 +276,11 @@ def __init__( use_kineto and ProfilerActivity.XPU in _supported_activities() ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." self.kineto_activities.add(ProfilerActivity.XPU) + elif self.use_device == "mtia": + assert ( + use_kineto and ProfilerActivity.MTIA in _supported_activities() + ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." + self.kineto_activities.add(ProfilerActivity.MTIA) elif self.use_device is not None and self.use_device != "privateuseone": if ( not use_kineto diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f43dcc06de20..2fd3ab9be6b8 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -132,6 +132,8 @@ def __init__( self.use_device = "cuda" elif ProfilerActivity.XPU in self.activities: self.use_device = "xpu" + elif ProfilerActivity.MTIA in self.activities: + self.use_device = "mtia" elif ProfilerActivity.PrivateUse1 in self.activities: self.use_device = _get_privateuse1_backend_name() @@ -149,7 +151,6 @@ def prepare_trace(self): if self.profiler is None: self.profiler = prof.profile( use_cpu=(ProfilerActivity.CPU in self.activities), - use_mtia=(ProfilerActivity.MTIA in self.activities), use_device=self.use_device, record_shapes=self.record_shapes, with_flops=self.with_flops, From e47603a5495b33d59be0b770ac9b243877c993ad Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Jun 2024 06:51:41 -0700 Subject: [PATCH 23/64] Fix weight_norm decomposition behavior (#128956) By upcasting norm to float32 to align with CUDA and CPU behaviors https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/WeightNorm.cpp#L56-L59 Discovered this when started running OpInfo tests, see https://github.com/pytorch/pytorch/actions/runs/9552858711/job/26332062502#step:20:1060 ``` File "/var/lib/jenkins/workspace/test/test_decomp.py", line 185, in op_assert_ref assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" AssertionError: 1 Operation: aten._weight_norm_interface.default ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128956 Approved by: https://github.com/albanD ghstack dependencies: #128955 --- torch/_decomp/decompositions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index dca552137ca6..42d1cb9a1527 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4773,8 +4773,10 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None): def _weight_norm_interface(v, g, dim=0): # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) - norm = v.norm(2, keep_dim, keepdim=True) - return v * (g / norm), norm + # align with cuda behavior, keep norm in 'float' when g is 'bfloat16' + norm_dtype = torch.float if g.dtype == torch.bfloat16 else None + norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype) + return v * (g / norm.to(g.dtype)), norm @register_decomposition(aten.isin) From cec31050b4609a4bbdcd332c823139666ad57224 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:43 +0800 Subject: [PATCH 24/64] [BE][Easy] enable UFMT for `torch/distributed/{tensor,_tensor}/` (#128868) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128868 Approved by: https://github.com/fegin --- .lintrunner.toml | 9 - .../distributed/_tensor/_collective_utils.py | 2 +- torch/distributed/_tensor/_dispatch.py | 2 +- torch/distributed/_tensor/_op_schema.py | 1 + torch/distributed/_tensor/_sharding_prop.py | 1 + torch/distributed/_tensor/_tp_conv.py | 1 + torch/distributed/_tensor/api.py | 1 - torch/distributed/_tensor/debug/__init__.py | 1 - .../distributed/_tensor/debug/_op_coverage.py | 1 - torch/distributed/_tensor/debug/comm_mode.py | 3 +- .../_tensor/debug/visualize_sharding.py | 1 - .../_tensor/examples/checkpoint_example.py | 2 - .../examples/comm_mode_features_example.py | 3 - .../examples/torchrec_sharding_example.py | 2 +- .../examples/visualize_sharding_example.py | 1 + .../_tensor/experimental/__init__.py | 1 + .../_tensor/experimental/attention.py | 1 + .../_tensor/experimental/local_map.py | 1 + torch/distributed/_tensor/ops/__init__.py | 8 +- .../distributed/_tensor/ops/basic_strategy.py | 2 - torch/distributed/_tensor/ops/conv_ops.py | 1 + .../distributed/_tensor/ops/embedding_ops.py | 3 +- .../_tensor/ops/experimental_ops.py | 12 +- torch/distributed/_tensor/ops/math_ops.py | 1 - torch/distributed/_tensor/ops/matrix_ops.py | 2 +- .../distributed/_tensor/ops/pointwise_ops.py | 2 - torch/distributed/_tensor/ops/random_ops.py | 1 + torch/distributed/_tensor/ops/tensor_ops.py | 1 - torch/distributed/_tensor/ops/view_ops.py | 3 +- torch/distributed/_tensor/placement_types.py | 1 - torch/distributed/_tensor/random.py | 1 - torch/distributed/tensor/parallel/__init__.py | 4 +- torch/distributed/tensor/parallel/_utils.py | 10 +- torch/distributed/tensor/parallel/api.py | 21 +- torch/distributed/tensor/parallel/ddp.py | 1 + torch/distributed/tensor/parallel/fsdp.py | 8 +- .../tensor/parallel/input_reshard.py | 13 +- torch/distributed/tensor/parallel/loss.py | 1 + torch/distributed/tensor/parallel/style.py | 226 ++++++++++++------ 39 files changed, 213 insertions(+), 143 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index a7bbdc884415..e3f1b58027c3 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1443,15 +1443,6 @@ exclude_patterns = [ 'torch/distributed/rpc/rref_proxy.py', 'torch/distributed/rpc/server_process_global_profiler.py', 'torch/distributed/run.py', - 'torch/distributed/tensor/__init__.py', - 'torch/distributed/tensor/parallel/__init__.py', - 'torch/distributed/tensor/parallel/_utils.py', - 'torch/distributed/tensor/parallel/_view_with_dim_change.py', - 'torch/distributed/tensor/parallel/api.py', - 'torch/distributed/tensor/parallel/fsdp.py', - 'torch/distributed/tensor/parallel/input_reshard.py', - 'torch/distributed/tensor/parallel/multihead_attention_tp.py', - 'torch/distributed/tensor/parallel/style.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py index 4c1d18403666..15644ac79873 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/_tensor/_collective_utils.py @@ -3,7 +3,6 @@ import math from dataclasses import dataclass from functools import lru_cache - from typing import List, Optional import torch @@ -21,6 +20,7 @@ Work, ) + logger = logging.getLogger(__name__) diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py index 1739243a5d3b..a659c54a3d93 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/_tensor/_dispatch.py @@ -6,7 +6,6 @@ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import torch - import torch.distributed as dist import torch.distributed._tensor.api as dtensor import torch.distributed._tensor.random as random @@ -27,6 +26,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta from torch.distributed._tensor.random import is_rng_supported_mesh + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py index 071c2ac4748f..6e6884f47306 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/_tensor/_op_schema.py @@ -8,6 +8,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.device_mesh import DeviceMesh + try: from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec except ImportError: diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index 449cf6c23775..8f1cabeb0c43 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -25,6 +25,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/_tp_conv.py b/torch/distributed/_tensor/_tp_conv.py index d480e9d7f79e..cc6f1968e6ef 100644 --- a/torch/distributed/_tensor/_tp_conv.py +++ b/torch/distributed/_tensor/_tp_conv.py @@ -7,6 +7,7 @@ import torch.distributed as dist import torch.distributed._tensor.api as dtensor + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 22f7e690022a..e1c01040a909 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -5,7 +5,6 @@ from typing import Any, Callable, cast, Optional, Sequence, Tuple import torch - import torch.distributed._tensor._dispatch as op_dispatch import torch.distributed._tensor.random as random import torch.nn as nn diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/_tensor/debug/__init__.py index b7bde685fd1e..b70529f203e1 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/_tensor/debug/__init__.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs from torch.distributed._tensor.api import DTensor - from torch.distributed._tensor.debug.comm_mode import CommDebugMode diff --git a/torch/distributed/_tensor/debug/_op_coverage.py b/torch/distributed/_tensor/debug/_op_coverage.py index 4f5424633235..214c4f003ff2 100644 --- a/torch/distributed/_tensor/debug/_op_coverage.py +++ b/torch/distributed/_tensor/debug/_op_coverage.py @@ -5,7 +5,6 @@ import torch import torch.fx import torch.nn as nn - from functorch.compile import make_boxed_func from torch._functorch.compilers import aot_module from torch._inductor.decomposition import select_decomp_table diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 5b69454828f3..0241c739fb70 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -5,16 +5,15 @@ import torch from torch.autograd.graph import register_multi_grad_hook from torch.distributed._tensor.api import DTensor - from torch.nn.modules.module import ( register_module_forward_hook, register_module_forward_pre_hook, ) from torch.utils._python_dispatch import TorchDispatchMode - from torch.utils._pytree import tree_flatten from torch.utils.module_tracker import ModuleTracker + funcol_native = torch.ops._c10d_functional funcol_py = torch.ops.c10d_functional funcol_autograd = torch.ops._c10d_functional_autograd diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/_tensor/debug/visualize_sharding.py index 76cd8f3e9208..8eae86e5c0ab 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/_tensor/debug/visualize_sharding.py @@ -5,7 +5,6 @@ from torch._prims_common import ShapeType from torch.distributed._tensor import DeviceMesh - from torch.distributed._tensor.placement_types import Placement, Shard diff --git a/torch/distributed/_tensor/examples/checkpoint_example.py b/torch/distributed/_tensor/examples/checkpoint_example.py index 1cb292f12c41..1701e28ac2ca 100644 --- a/torch/distributed/_tensor/examples/checkpoint_example.py +++ b/torch/distributed/_tensor/examples/checkpoint_example.py @@ -5,7 +5,6 @@ checkpoint save/load the model. """ import os - from typing import cast, List import torch @@ -13,7 +12,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F - from torch.distributed._tensor import ( DeviceMesh, distribute_module, diff --git a/torch/distributed/_tensor/examples/comm_mode_features_example.py b/torch/distributed/_tensor/examples/comm_mode_features_example.py index 106a5db73510..93155687cf92 100644 --- a/torch/distributed/_tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/_tensor/examples/comm_mode_features_example.py @@ -1,16 +1,13 @@ import os import torch - from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.debug import CommDebugMode - from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, RowwiseParallel, ) - from torch.testing._internal.distributed._tensor.common_dtensor import ( MLPModule, MLPStacked, diff --git a/torch/distributed/_tensor/examples/torchrec_sharding_example.py b/torch/distributed/_tensor/examples/torchrec_sharding_example.py index 3e6c63dd18eb..33f8c7017f5b 100644 --- a/torch/distributed/_tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/_tensor/examples/torchrec_sharding_example.py @@ -9,7 +9,6 @@ from typing import List, TYPE_CHECKING import torch - from torch.distributed._tensor import ( DeviceMesh, DTensor, @@ -24,6 +23,7 @@ TensorStorageMetadata, ) + if TYPE_CHECKING: from torch.distributed._tensor.placement_types import Placement diff --git a/torch/distributed/_tensor/examples/visualize_sharding_example.py b/torch/distributed/_tensor/examples/visualize_sharding_example.py index 6e295e147b38..0f8396889159 100644 --- a/torch/distributed/_tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/_tensor/examples/visualize_sharding_example.py @@ -4,6 +4,7 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding + world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) diff --git a/torch/distributed/_tensor/experimental/__init__.py b/torch/distributed/_tensor/experimental/__init__.py index 2dd21605ffcc..bee73667e1ea 100644 --- a/torch/distributed/_tensor/experimental/__init__.py +++ b/torch/distributed/_tensor/experimental/__init__.py @@ -5,6 +5,7 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.experimental.local_map import local_map + __all__ = ["local_map", "implicit_replication"] diff --git a/torch/distributed/_tensor/experimental/attention.py b/torch/distributed/_tensor/experimental/attention.py index eb7703a96ba5..b7738cb2dee5 100644 --- a/torch/distributed/_tensor/experimental/attention.py +++ b/torch/distributed/_tensor/experimental/attention.py @@ -11,6 +11,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel.style import ParallelStyle + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/experimental/local_map.py b/torch/distributed/_tensor/experimental/local_map.py index 0fc6ce96e6e0..60d1796fdec4 100644 --- a/torch/distributed/_tensor/experimental/local_map.py +++ b/torch/distributed/_tensor/experimental/local_map.py @@ -7,6 +7,7 @@ from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor.placement_types import Placement + try: from torch.utils import _cxx_pytree as pytree except ImportError: diff --git a/torch/distributed/_tensor/ops/__init__.py b/torch/distributed/_tensor/ops/__init__.py index d19fdfa50cb7..eaccc8aa8d3f 100644 --- a/torch/distributed/_tensor/ops/__init__.py +++ b/torch/distributed/_tensor/ops/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +from .conv_ops import * # noqa: F403 from .embedding_ops import * # noqa: F403 -from .matrix_ops import * # noqa: F403 +from .experimental_ops import * # noqa: F403 from .math_ops import * # noqa: F403 -from .tensor_ops import * # noqa: F403 +from .matrix_ops import * # noqa: F403 from .pointwise_ops import * # noqa: F403 from .random_ops import * # noqa: F403 +from .tensor_ops import * # noqa: F403 from .view_ops import * # noqa: F403 -from .conv_ops import * # noqa: F403 -from .experimental_ops import * # noqa: F403 diff --git a/torch/distributed/_tensor/ops/basic_strategy.py b/torch/distributed/_tensor/ops/basic_strategy.py index cc28cc19d370..97dd43b1524d 100644 --- a/torch/distributed/_tensor/ops/basic_strategy.py +++ b/torch/distributed/_tensor/ops/basic_strategy.py @@ -1,6 +1,5 @@ import itertools from dataclasses import dataclass - from typing import List, Set, Tuple from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy @@ -11,7 +10,6 @@ Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh diff --git a/torch/distributed/_tensor/ops/conv_ops.py b/torch/distributed/_tensor/ops/conv_ops.py index f466a13aa463..24e75593064e 100644 --- a/torch/distributed/_tensor/ops/conv_ops.py +++ b/torch/distributed/_tensor/ops/conv_ops.py @@ -7,6 +7,7 @@ from torch.distributed._tensor.ops.utils import register_prop_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py index 6f8cc8c67851..5af79562adcb 100644 --- a/torch/distributed/_tensor/ops/embedding_ops.py +++ b/torch/distributed/_tensor/ops/embedding_ops.py @@ -11,16 +11,15 @@ expand_to_full_mesh_op_strategy, register_op_strategy, ) - from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/experimental_ops.py b/torch/distributed/_tensor/ops/experimental_ops.py index 546945acd622..6d6967d4ea8d 100644 --- a/torch/distributed/_tensor/ops/experimental_ops.py +++ b/torch/distributed/_tensor/ops/experimental_ops.py @@ -2,19 +2,21 @@ # implement matrix related ops for distributed tensor from typing import List -try: - import numpy as np -except ModuleNotFoundError: - np = None # type: ignore[assignment] - import torch from torch.distributed._tensor._op_schema import OpSchema, OutputSharding from torch.distributed._tensor.ops.utils import register_prop_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + aten = torch.ops.aten +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + + @register_prop_rule(aten.slice_backward.default) def slice_backward_rules(op_schema: OpSchema) -> OutputSharding: grad_output_spec, input_sizes, dim, start, end, step = op_schema.args_schema diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py index 377c50dffa13..412c566253ab 100644 --- a/torch/distributed/_tensor/ops/math_ops.py +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -6,7 +6,6 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch - from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, diff --git a/torch/distributed/_tensor/ops/matrix_ops.py b/torch/distributed/_tensor/ops/matrix_ops.py index 15f00af670d2..128a73a59ffe 100644 --- a/torch/distributed/_tensor/ops/matrix_ops.py +++ b/torch/distributed/_tensor/ops/matrix_ops.py @@ -19,9 +19,9 @@ Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index ab80f783cf5b..96bfb808c100 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -2,7 +2,6 @@ from typing import List, Sequence, Tuple import torch - from torch.distributed._tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, @@ -13,7 +12,6 @@ StrategyType, TupleStrategy, ) - from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, diff --git a/torch/distributed/_tensor/ops/random_ops.py b/torch/distributed/_tensor/ops/random_ops.py index 390dc419ecd7..d4b533aae09a 100644 --- a/torch/distributed/_tensor/ops/random_ops.py +++ b/torch/distributed/_tensor/ops/random_ops.py @@ -9,6 +9,7 @@ from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index d2feb19ba2f9..a91d6261c51d 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -3,7 +3,6 @@ from typing import cast, List, Optional, Sequence, Tuple import torch - from torch.distributed._tensor._op_schema import ( _is_inplace_op, OpSchema, diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 7161988adf25..ea088b7377a9 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -15,7 +15,6 @@ ) import torch - from torch import Tensor from torch.distributed._tensor._op_schema import ( OpSchema, @@ -32,10 +31,10 @@ prod, register_op_strategy, ) - from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten Shape = Tuple[int, ...] diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 31e280c2f5b8..352e12640bd7 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -6,7 +6,6 @@ import torch import torch.distributed._functional_collectives as funcol - from torch.distributed._tensor._collective_utils import ( fill_empty_tensor_to_shards, mesh_broadcast, diff --git a/torch/distributed/_tensor/random.py b/torch/distributed/_tensor/random.py index ed331736c5ce..3e43a9119ac2 100644 --- a/torch/distributed/_tensor/random.py +++ b/torch/distributed/_tensor/random.py @@ -6,7 +6,6 @@ import torch import torch.distributed as dist - from torch import Tensor from torch.distributed._tensor.placement_types import DTensorSpec, Shard from torch.distributed.device_mesh import _get_device_handle, DeviceMesh diff --git a/torch/distributed/tensor/parallel/__init__.py b/torch/distributed/tensor/parallel/__init__.py index 990550414ca4..9fe378c51b0d 100644 --- a/torch/distributed/tensor/parallel/__init__.py +++ b/torch/distributed/tensor/parallel/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from torch.distributed.tensor.parallel.api import parallelize_module - from torch.distributed.tensor.parallel.loss import loss_parallel from torch.distributed.tensor.parallel.style import ( ColwiseParallel, @@ -11,6 +10,7 @@ SequenceParallel, ) + __all__ = [ "ColwiseParallel", "ParallelStyle", @@ -19,5 +19,5 @@ "RowwiseParallel", "SequenceParallel", "parallelize_module", - "loss_parallel" + "loss_parallel", ] diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 394fde457bb2..3f47ec6f1ef3 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -5,12 +5,16 @@ from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources + + try: from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling except Exception: + def is_torchdynamo_compiling(): # type: ignore[misc] return False + LayoutsType = Union[Placement, Tuple[Placement, ...]] @@ -46,8 +50,10 @@ def _validate_tp_mesh_dim( is valid, `False` otherwise. """ if device_mesh.ndim > 1: - raise ValueError(f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" - 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]') + raise ValueError( + f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" + 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' + ) parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) if parent_mesh: diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index f78e9712d304..e0fc4d2ef2b7 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -1,21 +1,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, Union from fnmatch import fnmatch +from typing import Dict, Union import torch import torch.distributed._tensor.random as random import torch.nn as nn -from torch.distributed._tensor import ( - DeviceMesh, -) +from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.random import ( is_rng_supported_mesh, TensorParallelRNGTracker, ) from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim -from torch.distributed.tensor.parallel.style import ( - ParallelStyle, -) +from torch.distributed.tensor.parallel.style import ParallelStyle __all__ = [ @@ -98,14 +94,19 @@ def parallelize_module( # type: ignore[return] atom = path_splits.pop(0) matched_children = filter( # `t[0]` is child name - lambda t: fnmatch(t[0], atom), module.named_children() + lambda t: fnmatch(t[0], atom), + module.named_children(), ) # apply the plan to all matched submodules for _, submodule in matched_children: if path_splits: # we haven't reached the leaf, apply in dict style - leaf_path = ".".join(path_splits) # rest of the path after `atom` - parallelize_module(submodule, device_mesh, {leaf_path: parallelize_style}) + leaf_path = ".".join( + path_splits + ) # rest of the path after `atom` + parallelize_module( + submodule, device_mesh, {leaf_path: parallelize_style} + ) else: # otherwise, directly apply style to this submodule parallelize_module(submodule, device_mesh, parallelize_style) diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index baa9d638037d..6c4d6f801675 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -7,6 +7,7 @@ _unflatten_tensor, ) + __all__ = [] # type: ignore[var-annotated] diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index c38771ae86e2..df51efaf87f5 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist - import torch.distributed._shard.sharding_spec as shard_spec import torch.distributed.distributed_c10d as c10d from torch.distributed._shard.sharded_tensor import ( @@ -13,12 +12,10 @@ ShardedTensorMetadata, TensorProperties, ) - from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.device_mesh import _mesh_resources - from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor @@ -28,6 +25,7 @@ _unflatten_tensor, ) + __all__ = ["DTensorExtensions"] @@ -245,7 +243,6 @@ def _chunk_dtensor( # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): - # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. # TP is the inner dimension and FSDP is the outer dimension. # Therefore, shard placements for tensor is (Shard(0), Replicate()). @@ -324,6 +321,7 @@ class DTensorExtensions(FSDPExtensions): This is the implementation for FSDPExtensions defined in https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py """ + def __init__(self, device_handle) -> None: super().__init__() self.compute_stream = None @@ -352,7 +350,7 @@ def post_unflatten_transform( tensor, param_extension, device_handle=self.device_handle, - compute_stream=self.compute_stream + compute_stream=self.compute_stream, ) _set_fsdp_flattened(result) return result diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 3ea97846e313..4e7af55d32c3 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -5,6 +5,7 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard + __all__ = [ "input_reshard", ] @@ -49,7 +50,9 @@ def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> N nonlocal cx cx = saved_tensor_hooks # type: ignore[name-defined] - def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any: + def input_reshard_backward_hook( + _: torch.nn.Module, _i: Tuple[Any, ...], _o: Any + ) -> Any: nonlocal cx cx.__exit__() # type: ignore[name-defined, union-attr] @@ -60,7 +63,9 @@ def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any return module -def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401 +def _pack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor +) -> Any: # noqa: D401 """Hook function called after FWD to shard input.""" if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) @@ -78,7 +83,9 @@ def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> return x -def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401 +def _unpack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: Any +) -> torch.Tensor: # noqa: D401 """Hook function called before activation recomputing in BWD to restore input.""" if ( isinstance(x, DTensor) diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index f2776c5123b4..a51d14b0efbd 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -18,6 +18,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index a4f4d4de0b98..42437a708475 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -1,12 +1,20 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod -from typing import Optional, Union, Tuple, Dict, Any from functools import partial +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module +from torch.distributed._tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Placement, + Replicate, + Shard, +) __all__ = [ @@ -74,29 +82,35 @@ def __init__( *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, - use_local_output: bool = True + use_local_output: bool = True, ): super().__init__() - self.input_layouts = (input_layouts or Replicate(), ) - self.output_layouts = (output_layouts or Shard(-1), ) + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Shard(-1),) # colwise linear runtime sharding (desired sharding): # 1. requires replicate input # 2. shard output on last dim - self.desired_input_layouts = (Replicate(), ) + self.desired_input_layouts = (Replicate(),) self.use_local_output = use_local_output @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): # TODO: figure out dynamo support for instance method and switch this to instance method # annotate module input placements/sharding with input_layouts input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): @@ -104,17 +118,13 @@ def _partition_linear_fn(self, name, module, device_mesh): # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) module.register_parameter(name, dist_param) def _partition_embedding_fn(self, name, module, device_mesh): # colwise shard embedding.weight is straight forward as Shard(1) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(1)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(1)])) module.register_parameter(name, dist_param) @staticmethod @@ -131,14 +141,20 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn else: - raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!") + raise NotImplementedError( + "ColwiseParallel currently only support nn.Linear and nn.Embedding!" + ) return distribute_module( module, device_mesh, partition_fn, - partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), - partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), ) @@ -180,41 +196,49 @@ def __init__( *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, - use_local_output: bool = True + use_local_output: bool = True, ): super().__init__() - self.input_layouts = (input_layouts or Shard(-1), ) - self.output_layouts = (output_layouts or Replicate(), ) + self.input_layouts = (input_layouts or Shard(-1),) + self.output_layouts = (output_layouts or Replicate(),) self.use_local_output = use_local_output @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) # means Rowwise as nn.Linear is input * weight^T + bias, where # weight would become Shard(0) - module.register_parameter("weight", nn.Parameter( - distribute_tensor(module.weight, device_mesh, [Shard(1)]) - )) + module.register_parameter( + "weight", + nn.Parameter(distribute_tensor(module.weight, device_mesh, [Shard(1)])), + ) if module.bias is not None: - module.register_parameter("bias", nn.Parameter( - distribute_tensor(module.bias, device_mesh, [Replicate()]) - )) + module.register_parameter( + "bias", + nn.Parameter( + distribute_tensor(module.bias, device_mesh, [Replicate()]) + ), + ) def _partition_embedding_fn(self, name, module, device_mesh): # rowwise shard embedding.weight is Shard(0) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) module.register_parameter(name, dist_param) @staticmethod @@ -231,20 +255,26 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if isinstance(module, nn.Linear): partition_fn = self._partition_linear_fn # rowwise linear runtime sharding requires input tensor shard on last dim - self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), ) + self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),) elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn # rowwise embedding runtime sharding requires input tensor replicated - self.desired_input_layouts = (Replicate(), ) + self.desired_input_layouts = (Replicate(),) else: - raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!") + raise NotImplementedError( + "RowwiseParallel currently only support nn.Linear and nn.Embedding!" + ) return distribute_module( module, device_mesh, partition_fn, - partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), - partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), ) @@ -287,17 +317,15 @@ class SequenceParallel(ParallelStyle): inits for the weights on those modules, you need to broadcast the weights before/after parallelizing to ensure that they are replicated. """ - def __init__( - self, - *, - sequence_dim: int = 1, - use_local_output: bool = False - ): + + def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False): super().__init__() self.sequence_dim = sequence_dim self.use_local_output = use_local_output - def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh): + def _replicate_module_fn( + self, name: str, module: nn.Module, device_mesh: DeviceMesh + ): for p_name, param in module.named_parameters(): # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow # us to simply just use from_local @@ -312,9 +340,13 @@ def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh): if isinstance(input_tensor, DTensor): return inputs elif isinstance(input_tensor, torch.Tensor): - return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False) + return DTensor.from_local( + input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False + ) else: - raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}") + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) @staticmethod def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): @@ -380,32 +412,43 @@ def __init__( self, *, input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, - desired_input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, Tuple[Optional[Placement]]] + ] = None, input_kwarg_layouts: Optional[Dict[str, Placement]] = None, desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, - use_local_output: bool = False + use_local_output: bool = False, ): - self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts - self.desired_input_layouts = \ - (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts + self.input_layouts = ( + (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts + ) + self.desired_input_layouts = ( + (desired_input_layouts,) + if isinstance(desired_input_layouts, Placement) + else desired_input_layouts + ) self.use_local_output = use_local_output if self.input_layouts is not None: - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" - assert len(self.input_layouts) == len(self.desired_input_layouts), \ - "input_layouts and desired_input_layouts should have same length!" + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + assert len(self.input_layouts) == len( + self.desired_input_layouts + ), "input_layouts and desired_input_layouts should have same length!" self.with_kwargs = input_kwarg_layouts is not None self.input_kwarg_layouts = input_kwarg_layouts or {} self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} if self.with_kwargs: - assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \ - "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + assert len(self.input_kwarg_layouts) == len( + self.desired_input_kwarg_layouts + ), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" def _prepare_input_arg( self, input: Any, mesh: DeviceMesh, input_layout: Optional[Placement], - desired_layout: Optional[Placement] + desired_layout: Optional[Placement], ): if input_layout is not None: if isinstance(input, DTensor): @@ -413,8 +456,12 @@ def _prepare_input_arg( # assert inp.placements[0] == input_layout dt_inp = input else: - assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" - dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) + assert isinstance( + input, torch.Tensor + ), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) @@ -432,9 +479,15 @@ def _prepare_input_fn(self, inputs, device_mesh): if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" - for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): - prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)) + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + for inp, input_layout, desired_layout in zip( + inputs, self.input_layouts, self.desired_input_layouts + ): + prepared_inputs.append( + self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout) + ) return tuple(prepared_inputs) def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): @@ -445,15 +498,19 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): input_layout = self.input_kwarg_layouts.get(kwarg_key) desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) - prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout) + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( + kwarg_val, device_mesh, input_layout, desired_input_layout + ) return (prepared_arg_inputs, prepared_kwarg_inputs) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if self.with_kwargs: module.register_forward_pre_hook( - lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh), - with_kwargs=True + lambda _, inputs, kwargs: self._prepare_input_kwarg_fn( + inputs, kwargs, device_mesh + ), + with_kwargs=True, ) # type: ignore[misc] else: module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] @@ -497,38 +554,55 @@ class PrepareModuleOutput(ParallelStyle): >>> ) >>> ) """ + def __init__( self, *, output_layouts: Union[Placement, Tuple[Placement]], desired_output_layouts: Union[Placement, Tuple[Placement]], - use_local_output: bool = True + use_local_output: bool = True, ): - self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts - self.desired_output_layouts = \ - (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts + self.output_layouts = ( + (output_layouts,) + if isinstance(output_layouts, Placement) + else output_layouts + ) + self.desired_output_layouts = ( + (desired_output_layouts,) + if isinstance(desired_output_layouts, Placement) + else desired_output_layouts + ) self.use_local_output = use_local_output - assert len(self.output_layouts) == len(self.desired_output_layouts), \ - "output_layouts and desired_output_layouts should have same length!" + assert len(self.output_layouts) == len( + self.desired_output_layouts + ), "output_layouts and desired_output_layouts should have same length!" def _prepare_out_fn(self, outputs, device_mesh): prepared_outputs = [] if not isinstance(outputs, tuple): outputs = (outputs,) if len(outputs) != len(self.output_layouts): - raise ValueError("module outputs and output_layouts should have same length!") - for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts): + raise ValueError( + "module outputs and output_layouts should have same length!" + ) + for out, out_layout, desired_out_layout in zip( + outputs, self.output_layouts, self.desired_output_layouts + ): if out_layout is not None: if isinstance(out, DTensor): # TODO: re-enable the check once we fix the compile path # assert out.placements[0] == out_layout dt_out = out else: - dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False) + dt_out = DTensor.from_local( + out, device_mesh, (out_layout,), run_check=False + ) if out_layout != desired_out_layout: dt_out = dt_out.redistribute(placements=(desired_out_layout,)) - prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out) + prepared_outputs.append( + dt_out.to_local() if self.use_local_output else dt_out + ) else: prepared_outputs.append(out) if len(prepared_outputs) == 1: From 3b798df853444d66077ffa846f5682e621b07388 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:44 +0800 Subject: [PATCH 25/64] [BE][Easy] enable UFMT for `torch/distributed/{fsdp,optim,rpc}/` (#128869) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869 Approved by: https://github.com/fegin ghstack dependencies: #128868 --- .lintrunner.toml | 27 ---- torch/distributed/fsdp/__init__.py | 1 + torch/distributed/fsdp/_common_utils.py | 2 + torch/distributed/fsdp/_debug_utils.py | 1 + torch/distributed/fsdp/_flat_param.py | 1 + torch/distributed/fsdp/_init_utils.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 1 + torch/distributed/fsdp/_runtime_utils.py | 1 + torch/distributed/fsdp/_state_dict_utils.py | 3 - .../distributed/fsdp/_unshard_param_utils.py | 1 + torch/distributed/fsdp/_wrap_utils.py | 1 - torch/distributed/fsdp/api.py | 2 +- .../fsdp/fully_sharded_data_parallel.py | 2 +- torch/distributed/fsdp/sharded_grad_scaler.py | 1 + torch/distributed/fsdp/wrap.py | 1 + torch/distributed/optim/__init__.py | 10 +- .../optim/apply_optimizer_in_backward.py | 10 +- .../distributed/optim/functional_adadelta.py | 5 +- torch/distributed/optim/functional_adagrad.py | 3 +- torch/distributed/optim/functional_adam.py | 3 +- torch/distributed/optim/functional_adamax.py | 3 +- torch/distributed/optim/functional_adamw.py | 3 +- torch/distributed/optim/functional_rmsprop.py | 3 +- torch/distributed/optim/functional_rprop.py | 3 +- torch/distributed/optim/functional_sgd.py | 3 +- torch/distributed/optim/named_optimizer.py | 13 +- torch/distributed/optim/optimizer.py | 5 +- torch/distributed/optim/utils.py | 2 + .../optim/zero_redundancy_optimizer.py | 37 +++--- torch/distributed/rpc/__init__.py | 69 +++++----- torch/distributed/rpc/_testing/__init__.py | 5 +- .../_testing/faulty_agent_backend_registry.py | 11 +- torch/distributed/rpc/_utils.py | 19 ++- torch/distributed/rpc/api.py | 118 ++++++++++-------- torch/distributed/rpc/backend_registry.py | 99 ++++++++++----- torch/distributed/rpc/constants.py | 3 +- torch/distributed/rpc/functions.py | 2 + torch/distributed/rpc/internal.py | 5 +- torch/distributed/rpc/options.py | 2 + torch/distributed/rpc/rref_proxy.py | 17 ++- .../rpc/server_process_global_profiler.py | 13 +- 41 files changed, 300 insertions(+), 213 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e3f1b58027c3..99c04cac4fbb 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1413,35 +1413,8 @@ exclude_patterns = [ 'torch/distributed/nn/jit/instantiator.py', 'torch/distributed/nn/jit/templates/__init__.py', 'torch/distributed/nn/jit/templates/remote_module_template.py', - 'torch/distributed/optim/__init__.py', - 'torch/distributed/optim/apply_optimizer_in_backward.py', - 'torch/distributed/optim/functional_adadelta.py', - 'torch/distributed/optim/functional_adagrad.py', - 'torch/distributed/optim/functional_adam.py', - 'torch/distributed/optim/functional_adamax.py', - 'torch/distributed/optim/functional_adamw.py', - 'torch/distributed/optim/functional_rmsprop.py', - 'torch/distributed/optim/functional_rprop.py', - 'torch/distributed/optim/functional_sgd.py', - 'torch/distributed/optim/named_optimizer.py', - 'torch/distributed/optim/optimizer.py', - 'torch/distributed/optim/post_localSGD_optimizer.py', - 'torch/distributed/optim/utils.py', - 'torch/distributed/optim/zero_redundancy_optimizer.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', - 'torch/distributed/rpc/__init__.py', - 'torch/distributed/rpc/_testing/__init__.py', - 'torch/distributed/rpc/_testing/faulty_agent_backend_registry.py', - 'torch/distributed/rpc/_utils.py', - 'torch/distributed/rpc/api.py', - 'torch/distributed/rpc/backend_registry.py', - 'torch/distributed/rpc/constants.py', - 'torch/distributed/rpc/functions.py', - 'torch/distributed/rpc/internal.py', - 'torch/distributed/rpc/options.py', - 'torch/distributed/rpc/rref_proxy.py', - 'torch/distributed/rpc/server_process_global_profiler.py', 'torch/distributed/run.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index d887730f442f..6180dbb3df29 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -18,6 +18,7 @@ StateDictType, ) + __all__ = [ "BackwardPrefetch", "CPUOffload", diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index aae2405d0bb5..10d0f8212651 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -44,9 +44,11 @@ StateDictType, ) + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions + from ._flat_param import FlatParamHandle FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 523330e5580d..163d9a045b68 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -15,6 +15,7 @@ clean_tensor_name, ) + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 816b91433063..8bc975dc72fd 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -50,6 +50,7 @@ FSDPExtensions, ) + __all__ = [ "FlatParameter", "FlatParamHandle", diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index c8b58091bf89..aaeedf22397a 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -58,9 +58,9 @@ from torch.distributed.fsdp.wrap import _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions from torch.distributed.utils import _sync_params_and_buffers - from torch.utils._python_dispatch import is_traceable_wrapper_subclass + if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 54f800a16865..4cfe761769a3 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -55,6 +55,7 @@ ) from torch.utils._pytree import tree_map_only + if TYPE_CHECKING: from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 833c1d45697a..f84e7dd3e505 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -39,6 +39,7 @@ ) from torch.utils import _pytree as pytree + logger = logging.getLogger(__name__) # Do not include "process_group" to enable hybrid shard and MoE cases diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 797a0116587b..815cfb2dd4a1 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -17,9 +17,7 @@ import torch import torch.distributed as dist - import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper - import torch.nn as nn import torch.nn.functional as F from torch.distributed._shard.sharded_tensor import ( @@ -29,7 +27,6 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import _mesh_resources - from torch.distributed.fsdp._common_utils import ( _FSDPState, _get_module_fsdp_state_if_fully_sharded_module, diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index 435193a88703..4143d2928c8b 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -26,6 +26,7 @@ from ._flat_param import FlatParamHandle + FLAT_PARAM = "_flat_param" diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 84cdf250d8ae..895bcbd8e967 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -11,7 +11,6 @@ _get_module_fsdp_state, _override_module_mixed_precision, ) - from torch.distributed.fsdp.wrap import ( _construct_wrap_fn, _or_policy, diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 0272ee0c57c9..f2e4bdb7ea02 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -5,12 +5,12 @@ from dataclasses import dataclass from enum import auto, Enum - from typing import Optional, Sequence, Type import torch from torch.nn.modules.batchnorm import _BatchNorm + __all__ = [ "ShardingStrategy", "BackwardPrefetch", diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 9edd057a8f37..1567bb973b22 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -85,8 +85,8 @@ StateDictType, ) from torch.distributed.utils import _p_assert -from ._flat_param import FlatParameter, FlatParamHandle +from ._flat_param import FlatParameter, FlatParamHandle from ._optim_utils import ( _flatten_optim_state_dict, _get_param_id_to_param_from_optim_input, diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 3487e01263c7..7c1b2f835286 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -8,6 +8,7 @@ from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState from torch.distributed.distributed_c10d import ProcessGroup + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index acb5a6f1f642..f8604bbb1bb0 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -24,6 +24,7 @@ import torch.nn as nn + __all__ = [ "always_wrap_policy", "lambda_auto_wrap_policy", diff --git a/torch/distributed/optim/__init__.py b/torch/distributed/optim/__init__.py index fe33265fd532..924b993ec841 100644 --- a/torch/distributed/optim/__init__.py +++ b/torch/distributed/optim/__init__.py @@ -15,7 +15,6 @@ _get_in_backward_optimizers, ) from .functional_adadelta import _FunctionalAdadelta - from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam from .functional_adamax import _FunctionalAdamax @@ -26,6 +25,7 @@ from .named_optimizer import _NamedOptimizer from .utils import as_functional_optim + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -44,4 +44,10 @@ from .post_localSGD_optimizer import PostLocalSGDOptimizer from .zero_redundancy_optimizer import ZeroRedundancyOptimizer -__all__ = ["as_functional_optim", "DistributedOptimizer", "PostLocalSGDOptimizer", "ZeroRedundancyOptimizer"] + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/torch/distributed/optim/apply_optimizer_in_backward.py b/torch/distributed/optim/apply_optimizer_in_backward.py index 6bd182cca573..36f679f4eba4 100644 --- a/torch/distributed/optim/apply_optimizer_in_backward.py +++ b/torch/distributed/optim/apply_optimizer_in_backward.py @@ -2,6 +2,7 @@ import torch + __all__: List[str] = [] # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter @@ -11,6 +12,7 @@ param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() + @no_type_check def _apply_optimizer_in_backward( optimizer_class: Type[torch.optim.Optimizer], @@ -48,9 +50,7 @@ def _apply_optimizer_in_backward( # have their registered optimizer(s) applied. """ - torch._C._log_api_usage_once( - "torch.distributed.optim.apply_optimizer_in_backward" - ) + torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") @no_type_check def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: @@ -62,7 +62,9 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: # Don't create a new acc_grad if we already have one # i.e. for shared parameters or attaching multiple optimizers to a param. if param not in param_to_acc_grad_map: - param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0] + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] optimizer = optimizer_class([param], **optimizer_kwargs) diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index bc5f7c63dd17..3ad51348b6af 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adadelta Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, @@ -102,5 +103,5 @@ def step(self, gradients: List[Optional[Tensor]]): weight_decay=weight_decay, foreach=self.foreach, maximize=self.maximize, - has_complex=has_complex + has_complex=has_complex, ) diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index 93a1fe2b2240..67f7328489ed 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adagrad Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index 34868d23d8a5..3ed271765170 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adam Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index 32bce65dfe1f..8f1fdc0ccc02 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adamax Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 43addd050822..d3f1f80e9209 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional AdamW Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index 851119c8600c..7a03e8e9f462 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional RMSprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 60742bc68896..615015a95a31 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Rprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 3a8176e87705..32381855db6b 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional SGD Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 9e1e5377873d..8e0b539b1482 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -1,9 +1,18 @@ # mypy: allow-untyped-defs import logging import warnings - from copy import deepcopy -from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Mapping, + Optional, + overload, + Union, +) import torch import torch.nn as nn diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index f2eca606c026..65df14770c21 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import logging - from collections import defaultdict from threading import Lock from typing import List, Optional @@ -12,8 +11,10 @@ import torch.nn as nn from torch import Tensor from torch.distributed.rpc import RRef + from .utils import functional_optim_map + __all__ = ["DistributedOptimizer"] logger = logging.getLogger(__name__) @@ -205,7 +206,7 @@ def __init__(self, optimizer_class, params_rref, *args, **kwargs): "(i.e. Distributed Model Parallel training on CPU) due to the Python's " "Global Interpreter Lock (GIL). Please file an issue if you need this " "optimizer in TorchScript. ", - optimizer_class + optimizer_class, ) optimizer_new_func = _new_local_optimizer diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index af2220ca5574..d2c75eee7e39 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -2,6 +2,7 @@ from typing import Type from torch import optim + from .functional_adadelta import _FunctionalAdadelta from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam @@ -11,6 +12,7 @@ from .functional_rprop import _FunctionalRprop from .functional_sgd import _FunctionalSGD + # dict to map a user passed in optimizer_class to a functional # optimizer class if we have already defined inside the # distributed.optim package, this is so that we hide the diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 8a3be3b01815..f664d11afb79 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -20,11 +20,12 @@ from torch.optim import Optimizer -logger = logging.getLogger(__name__) - __all__ = ["ZeroRedundancyOptimizer"] +logger = logging.getLogger(__name__) + + # Credits: classy_vision/generic/distributed_util.py def _recursive_copy_to_device( value: Any, @@ -925,9 +926,9 @@ def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: mapping bucket indices to :class:`_DDPBucketAssignment` s for each rank. """ - assert self._overlap_with_ddp, ( - "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" - ) + assert ( + self._overlap_with_ddp + ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" if len(self._bucket_assignments_per_rank_cache) > 0: return self._bucket_assignments_per_rank_cache @@ -1074,9 +1075,9 @@ def _local_step( "Specifying `gradients` should not " "be used when `overlap_with_ddp=False`" ) - assert closure is None, ( - "`closure` is not supported when using a local functional optimizer" - ) + assert ( + closure is None + ), "`closure` is not supported when using a local functional optimizer" loss = self.optim.step(gradients=gradients) # Sync any updated attributes in the local optimizer to the exposed @@ -1504,7 +1505,7 @@ def _init_local_optimizer(self) -> None: "%s does not support the argument " "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " "error due to an empty parameter list", - self._optim_constructor + self._optim_constructor, ) self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] @@ -1515,17 +1516,16 @@ def _init_local_optimizer(self) -> None: self._bucket_assignments_per_rank[self.global_rank] ) logger.info( - "rank %s with %s parameters " - "across %s buckets", - self.global_rank, local_numel, num_assigned_buckets + "rank %s with %s parameters " "across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, ) if self.global_rank == 0: logger.info( - "%s DDP " - "buckets and " - "%s bucket " - "assignments", - len(self._overlap_info.params_per_bucket), self._overlap_info.num_bucket_assignments + "%s DDP " "buckets and " "%s bucket " "assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, ) else: # NOTE: Passing `param_groups` into the local optimizer constructor @@ -1640,7 +1640,8 @@ def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: "Using the functional optimizer %s " "instead of %s since " "`overlap_with_ddp=True`", - optim_constructor, optimizer_class + optim_constructor, + optimizer_class, ) return optim_constructor else: diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index 581433d220c6..6c6608a2a773 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,22 +1,25 @@ # mypy: allow-untyped-defs -from datetime import timedelta import logging import os import threading import warnings +from datetime import timedelta from typing import Generator, Tuple from urllib.parse import urlparse import torch import torch.distributed as dist + +__all__ = ["is_available"] + + logger = logging.getLogger(__name__) _init_counter = 0 _init_counter_lock = threading.Lock() -__all__ = ["is_available"] def is_available() -> bool: return hasattr(torch._C, "_rpc_init") @@ -27,54 +30,51 @@ def is_available() -> bool: if is_available(): + import numbers + + import torch.distributed.autograd as dist_autograd from torch._C._distributed_c10d import Store - from torch._C._distributed_rpc import ( + from torch._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, _disable_jit_rref_pickle, - _enable_jit_rref_pickle, _disable_server_process_global_profiler, + _enable_jit_rref_pickle, _enable_server_process_global_profiler, - _set_and_start_rpc_agent, - _reset_current_rpc_agent, - _delete_all_user_and_unforked_owner_rrefs, - _destroy_rref_context, - _set_profiler_node_id, - _is_current_rpc_agent_set, - _rref_context_get_debug_info, - _cleanup_python_rpc_handler, - _invoke_rpc_builtin, - _invoke_rpc_python_udf, - _invoke_rpc_torchscript, + _get_current_rpc_agent, _invoke_remote_builtin, _invoke_remote_python_udf, _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, _set_rpc_timeout, - _get_current_rpc_agent, - get_rpc_timeout, - enable_gil_profiling, - RpcBackendOptions, _TensorPipeRpcBackendOptionsBase, - RpcAgent, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, PyRRef, - TensorPipeAgent, RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + TensorPipeAgent, WorkerInfo, - _DEFAULT_INIT_METHOD, - _DEFAULT_NUM_WORKER_THREADS, - _UNSET_RPC_TIMEOUT, - _DEFAULT_RPC_TIMEOUT_SEC, - ) # noqa: F401 + ) from . import api, backend_registry, functions from .api import * # noqa: F401,F403 - import numbers - - import torch.distributed.autograd as dist_autograd - from .backend_registry import BackendType from .options import TensorPipeRpcBackendOptions # noqa: F401 - from .server_process_global_profiler import ( - _server_process_global_profile, - ) + from .server_process_global_profiler import _server_process_global_profile rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] @@ -153,7 +153,7 @@ def init_rpc( "corresponding to %(backend)s, hence that backend will be used " "instead of the default BackendType.TENSORPIPE. To silence this " "warning pass `backend=%(backend)s` explicitly.", - {'backend': backend} + {"backend": backend}, ) if backend is None: @@ -224,7 +224,6 @@ def _init_rpc_backend( world_size=None, rpc_backend_options=None, ): - _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) if _is_current_rpc_agent_set(): diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py index 640c4d09f062..8ac1c02f4cee 100644 --- a/torch/distributed/rpc/_testing/__init__.py +++ b/torch/distributed/rpc/_testing/__init__.py @@ -12,8 +12,9 @@ def is_available(): if is_available(): # Registers FAULTY_TENSORPIPE RPC backend. - from . import faulty_agent_backend_registry from torch._C._distributed_rpc_testing import ( - FaultyTensorPipeRpcBackendOptions, FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, ) + + from . import faulty_agent_backend_registry diff --git a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py index 9e8660989e5a..d04882e16e79 100644 --- a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py +++ b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -4,6 +4,7 @@ import torch.distributed as dist import torch.distributed.rpc as rpc + def _faulty_tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, @@ -11,7 +12,7 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( messages_to_fail, messages_to_delay, num_fail_sends, - **kwargs + **kwargs, ): from . import FaultyTensorPipeRpcBackendOptions @@ -28,16 +29,14 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( def _faulty_tensorpipe_init_backend_handler( store, name, rank, world_size, rpc_backend_options ): - from . import FaultyTensorPipeAgent - from . import FaultyTensorPipeRpcBackendOptions from torch.distributed.rpc import api + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, FaultyTensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" ) diff --git a/torch/distributed/rpc/_utils.py b/torch/distributed/rpc/_utils.py index 6499a80e0e17..8925bc662b5f 100644 --- a/torch/distributed/rpc/_utils.py +++ b/torch/distributed/rpc/_utils.py @@ -1,12 +1,14 @@ # mypy: allow-untyped-defs +import logging from contextlib import contextmanager from typing import cast -import logging -from . import api -from . import TensorPipeAgent + +from . import api, TensorPipeAgent + logger = logging.getLogger(__name__) + @contextmanager def _group_membership_management(store, name, is_join): token_key = "RpcGroupManagementToken" @@ -29,10 +31,17 @@ def _group_membership_management(store, name, is_join): try: store.wait([returned]) except RuntimeError: - logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned) + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) raise + def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) - ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) return ret diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index a33358eb0dc6..5fc9e61aa559 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -1,6 +1,4 @@ # mypy: allow-untyped-defs -__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", - "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] import collections import contextlib @@ -8,17 +6,10 @@ import inspect import logging import threading -from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING +from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar import torch -from torch.futures import Future - from torch._C._distributed_rpc import ( - PyRRef, - RemoteProfilerManager, - WorkerInfo, - TensorPipeAgent, - get_rpc_timeout, _cleanup_python_rpc_handler, _delete_all_user_and_unforked_owner_rrefs, _destroy_rref_context, @@ -32,18 +23,36 @@ _is_current_rpc_agent_set, _reset_current_rpc_agent, _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + TensorPipeAgent, + WorkerInfo, ) +from torch.futures import Future +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, PythonUDF, RPCExecMode, - _internal_rpc_pickler, - _build_rpc_profiling_key, ) -from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT -from ._utils import _group_membership_management, _update_group_membership +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + logger = logging.getLogger(__name__) @@ -59,6 +68,7 @@ _ignore_rref_leak = True _default_pickler = _internal_rpc_pickler + @contextlib.contextmanager def _use_rpc_pickler(rpc_pickler): r""" @@ -107,7 +117,9 @@ def __init__(self): _ALL_WORKER_NAMES: Set[Any] = set() _all_gather_dict_lock = threading.RLock() _all_gather_sequence_id: Dict[str, int] = {} -_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) def _init_rpc_states(agent): @@ -146,6 +158,7 @@ def _broadcast_to_followers(sequence_id, objects_map): states.gathered_objects = objects_map states.proceed_signal.set() + _thread_local_var = threading.local() @@ -245,7 +258,7 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): follower_name, _broadcast_to_followers, args=(sequence_id, states.gathered_objects), - timeout=rpc_timeout + timeout=rpc_timeout, ) worker_name_to_response_future_dict[follower_name] = fut @@ -283,9 +296,7 @@ def _barrier(worker_names): try: _all_gather(None, set(worker_names)) except RuntimeError as ex: - logger.error( - "Failed to complete barrier, got error %s", ex - ) + logger.error("Failed to complete barrier, got error %s", ex) @_require_initialized @@ -371,7 +382,11 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): all_worker_infos = agent.get_worker_infos() for worker in all_worker_infos: if worker.name != my_name: - rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False)) + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) agent.join(shutdown=True, timeout=timeout) finally: # In case of errors, continue to complete the local shutdown. @@ -445,13 +460,10 @@ def _rref_typeof_on_owner(rref, blocking: bool = True): return future -def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True): - fut = rpc_async( - rref.owner(), - _rref_typeof_on_owner, - args=(rref,), - timeout=timeout - ) +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) if blocking: return fut.wait() else: @@ -463,13 +475,16 @@ def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: boo if TYPE_CHECKING: + class RRef(PyRRef[T], Generic[T]): pass + else: try: # Combine the implementation class and the type class. class RRef(PyRRef, Generic[T]): pass + except TypeError: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases @@ -517,7 +532,9 @@ def method(self, *args, **kwargs): assert docstring is not None, "RRef user-facing methods should all have docstrings." # Do surgery on pybind11 generated docstrings. - docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef") + docstring = docstring.replace( + "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" + ) # Attach user-facing RRef method with modified docstring. new_method = method_factory(method_name, docstring) @@ -633,7 +650,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): dst_worker_info = _to_worker_info(to) should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -647,7 +666,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): func = wrapped if qualified_name is not None: - rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs) + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) elif isinstance(func, torch.jit.ScriptFunction): rref = _invoke_remote_torchscript( dst_worker_info.name, @@ -662,11 +683,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): PythonUDF(func, args, kwargs) ) rref = _invoke_remote_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec ) # attach profiling information if should_profile: @@ -678,7 +695,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): return rref -def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT): +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): if not callable(func): raise TypeError("function should be callable.") @@ -687,7 +706,9 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -702,11 +723,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = if qualified_name is not None: fut = _invoke_rpc_builtin( - dst_worker_info, - qualified_name, - rpc_timeout, - *args, - **kwargs + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs ) elif isinstance(func, torch.jit.ScriptFunction): fut = _invoke_rpc_torchscript( @@ -715,18 +732,14 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = args, kwargs, rpc_timeout, - is_async_exec + is_async_exec, ) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs) ) fut = _invoke_rpc_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - rpc_timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec ) if should_profile: assert torch.autograd._profiler_enabled() @@ -915,12 +928,15 @@ def _get_should_profile(): # Kineto profiler. ActiveProfilerType = torch._C._profiler.ActiveProfilerType return ( - torch.autograd._profiler_enabled() and - torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + torch.autograd._profiler_enabled() + and torch._C._autograd._profiler_type() + == ActiveProfilerType.LEGACY # type: ignore[attr-defined] ) -def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info): +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): ctx_manager = contextlib.nullcontext() if should_profile: diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 6290f9e8e205..a06f0276ede9 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -__all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"] + import collections import enum @@ -7,13 +7,19 @@ import torch import torch.distributed as dist + +from . import api, constants as rpc_constants from ._utils import _group_membership_management, _update_group_membership -from . import api -from . import constants as rpc_constants -__all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend", - "BackendValue", "BackendType"] +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] BackendValue = collections.namedtuple( "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] @@ -41,6 +47,7 @@ def _backend_type_repr(self): if BackendType.__doc__: BackendType.__doc__ = _backend_type_doc + def backend_registered(backend_name): """ Checks if backend_name is registered as an RPC backend. @@ -80,7 +87,7 @@ def register_backend( init_backend_handler=init_backend_handler, ) }, - **existing_enum_dict + **existing_enum_dict, ) # Can't handle Function Enum API (mypy bug #9079) BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] @@ -90,20 +97,22 @@ def register_backend( BackendType.__doc__ = _backend_type_doc return BackendType[backend_name] + def construct_rpc_backend_options( backend, rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, init_method=rpc_constants.DEFAULT_INIT_METHOD, - **kwargs + **kwargs, ): - return backend.value.construct_rpc_backend_options_handler( rpc_timeout, init_method, **kwargs ) + def init_backend(backend, *args, **kwargs): return backend.value.init_backend_handler(*args, **kwargs) + def _init_process_group(store, rank, world_size): # Initialize ProcessGroup. process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT @@ -115,22 +124,21 @@ def _init_process_group(store, rank, world_size): assert group is not None, "Failed to initialize default ProcessGroup." if (rank != -1) and (rank != group.rank()): - raise RuntimeError( - f"rank argument {rank} doesn't match pg rank {group.rank()}" - ) + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") if (world_size != -1) and (world_size != group.size()): raise RuntimeError( f"world_size argument {world_size} doesn't match pg size {group.size()}" ) return group + def _tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, _transports=None, _channels=None, - **kwargs + **kwargs, ): from . import TensorPipeRpcBackendOptions @@ -155,9 +163,9 @@ def _tensorpipe_validate_devices(devices, device_count): def _tensorpipe_exchange_and_check_all_device_maps( my_name, my_device_count, my_device_maps, my_devices, group ): - gathered: List[Tuple[ - str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device] - ]] = [("", 0, {}, []) for _ in range(group.size())] + gathered: List[ + Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group ) @@ -173,13 +181,15 @@ def _tensorpipe_exchange_and_check_all_device_maps( my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) return reverse_device_maps, my_devices -def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True): + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): for node in all_names: devices = all_devices[node] if len(set(devices)) != len(devices): raise ValueError( - f"Node {node} has duplicated devices\n" - f"devices = {devices}" + f"Node {node} has duplicated devices\n" f"devices = {devices}" ) if not _tensorpipe_validate_devices(devices, all_device_counts[node]): raise ValueError( @@ -190,7 +200,9 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev for source_node in all_names: # For dynamic group (non-static) do not check the target node name since it may not have joined yet - if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names): + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): raise ValueError( f"Node {source_node} has invalid target node names in its device maps\n" f"device maps = {all_device_maps[source_node].keys()}\n" @@ -238,6 +250,7 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev f"device count = {all_device_counts[target_node]}" ) + def _create_device_list(my_devices, my_device_maps, reverse_device_maps): if not my_devices: devices_set: Set[torch.device] = set() @@ -250,6 +263,7 @@ def _create_device_list(my_devices, my_device_maps, reverse_device_maps): my_devices = sorted(my_devices, key=lambda d: d.index) return my_devices + def _create_reverse_mapping(my_name, all_names, all_device_maps): reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} for node in all_names: @@ -259,8 +273,10 @@ def _create_reverse_mapping(my_name, all_names, all_device_maps): } return reverse_device_maps + def _get_device_infos(): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) opts = agent._get_backend_options() device_count = torch.cuda.device_count() @@ -268,8 +284,10 @@ def _get_device_infos(): torch.cuda.init() return device_count, opts.device_maps, opts.devices + def _set_devices_and_reverse_device_map(agent): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, agent) # Group state is retrieved from local agent # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid @@ -282,34 +300,52 @@ def _set_devices_and_reverse_device_map(agent): worker_name = worker_info.name if worker_name != my_name: # TODO: make async? - device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos) + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) else: opts = agent._get_backend_options() - device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices + device_count, device_map, devices = ( + torch.cuda.device_count(), + opts.device_maps, + opts.devices, + ) all_device_counts[worker_name] = device_count all_device_maps[worker_name] = device_map all_devices[worker_name] = devices all_names.append(worker_name) - _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False) + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps for worker_name in all_names: # Set device list for each worker - all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps) - api.rpc_sync(worker_name, _update_group_membership, - args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True)) + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions -def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options): - from . import TensorPipeAgent - from . import TensorPipeRpcBackendOptions if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, TensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" ) @@ -389,6 +425,7 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_ raise return agent + register_backend( "TENSORPIPE", _tensorpipe_construct_rpc_backend_options_handler, diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py index 3bc525b70d9b..56f6db4db259 100644 --- a/torch/distributed/rpc/constants.py +++ b/torch/distributed/rpc/constants.py @@ -1,5 +1,6 @@ from datetime import timedelta from typing import List + from torch._C._distributed_rpc import ( _DEFAULT_INIT_METHOD, _DEFAULT_NUM_WORKER_THREADS, @@ -17,7 +18,7 @@ DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS # Ensure that we don't time out when there are long periods of time without # any operations against the underlying ProcessGroup. -DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1) +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) # Value indicating that timeout is not set for RPC call, and the default should be used. UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index c9e92980cf56..e48ea8cc534a 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -159,9 +159,11 @@ def async_execution(fn): >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() >>> print(ret) # prints tensor([4., 4.]) """ + @functools.wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) + # Can't declare and use attributes of function objects (mypy#2087) wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] return wrapper diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 2fc647c414d9..5faf7d14d0da 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -12,6 +12,7 @@ import torch.distributed as dist from torch._C._distributed_rpc import _get_current_rpc_agent + __all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] # Thread local tensor tables to store tensors while pickling torch.Tensor @@ -251,7 +252,9 @@ def _build_rpc_profiling_key( Returns: String representing profiling key """ - profile_key = f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) return profile_key diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 70328f345969..53bf473ba562 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -3,6 +3,7 @@ import torch from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase + from . import constants as rpc_contants @@ -10,6 +11,7 @@ __all__ = ["TensorPipeRpcBackendOptions"] + def _to_device(device: DeviceType) -> torch.device: device = torch.device(device) if device.type != "cuda": diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index cdb0a5d22b74..85927b68bacb 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -1,20 +1,22 @@ # mypy: allow-untyped-defs from functools import partial -from . import functions -from . import rpc_async - import torch -from .constants import UNSET_RPC_TIMEOUT from torch.futures import Future +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + def _local_invoke(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + @functions.async_execution def _local_invoke_async_execution(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): def _rref_type_cont(rref_fut): rref_type = rref_fut.value() @@ -33,7 +35,7 @@ def _rref_type_cont(rref_fut): rref.owner(), _invoke_func, args=(rref, func_name, args, kwargs), - timeout=timeout + timeout=timeout, ) rref_fut = rref._get_type(timeout=timeout, blocking=False) @@ -63,6 +65,7 @@ def _complete_op(fut): rref_fut.then(_wrap_rref_type_cont) return result + # This class manages proxied RPC API calls for RRefs. It is entirely used from # C++ (see python_rpc_handler.cpp). class RRefProxy: @@ -72,4 +75,6 @@ def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): self.rpc_timeout = timeout def __getattr__(self, func_name): - return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout) + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 0543ab56a877..b5d089d30525 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -2,18 +2,20 @@ # mypy: allow-untyped-defs import itertools +from typing import List import torch from torch.autograd.profiler_legacy import profile -from typing import List from . import ( _disable_server_process_global_profiler, _enable_server_process_global_profiler, ) + __all__: List[str] = [] + class _server_process_global_profile(profile): """ It has the same API as ``torch.autograd.profiler.profile`` class, @@ -123,7 +125,8 @@ def __enter__(self): False, False, False, - torch.profiler._ExperimentalConfig()) + torch.profiler._ExperimentalConfig(), + ) _enable_server_process_global_profiler(profiler_config) return self @@ -152,8 +155,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): process_global_function_events = [] for thread_local_events in process_global_events: # Parse from ``Event``s to ``FunctionEvent``s. - thread_local_function_events = torch.autograd.profiler_legacy._parse_legacy_records( - thread_local_events + thread_local_function_events = ( + torch.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) ) thread_local_function_events.sort( key=lambda function_event: [ From a0e1e20c4157bb3e537fc784a51d7aef1e754157 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:45 +0800 Subject: [PATCH 26/64] [BE][Easy] enable UFMT for `torch/distributed/` (#128870) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870 Approved by: https://github.com/fegin ghstack dependencies: #128868, #128869 --- .lintrunner.toml | 27 - torch/distributed/__init__.py | 64 +-- .../_composable/fsdp/_fsdp_collectives.py | 1 + .../_composable/fsdp/_fsdp_common.py | 1 - .../_composable/fsdp/_fsdp_init.py | 2 +- .../_composable/fsdp/_fsdp_param.py | 3 +- .../_composable/fsdp/_fsdp_param_group.py | 3 +- .../_composable/fsdp/_fsdp_state.py | 3 +- .../_composable/fsdp/fully_shard.py | 1 - torch/distributed/_composable/fully_shard.py | 1 - torch/distributed/_composable/replicate.py | 1 + torch/distributed/_cuda_p2p/__init__.py | 3 +- torch/distributed/_functional_collectives.py | 2 + .../_functional_collectives_impl.py | 1 + torch/distributed/_sharded_tensor/__init__.py | 7 +- torch/distributed/_sharding_spec/__init__.py | 7 +- torch/distributed/_state_dict_utils.py | 1 + torch/distributed/_tools/memory_tracker.py | 19 +- torch/distributed/c10d_logger.py | 12 +- torch/distributed/collective_utils.py | 14 +- torch/distributed/constants.py | 7 +- torch/distributed/device_mesh.py | 3 +- torch/distributed/distributed_c10d.py | 520 +++++++++++++----- .../examples/memory_tracker_example.py | 2 +- torch/distributed/launcher/__init__.py | 2 +- torch/distributed/launcher/api.py | 13 +- torch/distributed/logging_handlers.py | 1 + torch/distributed/nn/__init__.py | 5 +- torch/distributed/nn/api/remote_module.py | 27 +- torch/distributed/nn/functional.py | 21 +- torch/distributed/pipelining/_IR.py | 6 +- torch/distributed/pipelining/__init__.py | 1 + torch/distributed/remote_device.py | 17 +- torch/distributed/rendezvous.py | 33 +- torch/distributed/run.py | 49 +- torch/distributed/utils.py | 1 + 36 files changed, 583 insertions(+), 298 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 99c04cac4fbb..2c3da39f80cc 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1389,33 +1389,6 @@ exclude_patterns = [ 'torch/contrib/_tensorboard_vis.py', "torch/cuda/_gpu_trace.py", 'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable - 'torch/distributed/__init__.py', - 'torch/distributed/_composable_state.py', - 'torch/distributed/_sharded_tensor/__init__.py', - 'torch/distributed/_sharding_spec/__init__.py', - 'torch/distributed/_tools/__init__.py', - 'torch/distributed/_tools/memory_tracker.py', - 'torch/distributed/argparse_util.py', - 'torch/distributed/c10d_logger.py', - 'torch/distributed/collective_utils.py', - 'torch/distributed/constants.py', - 'torch/distributed/distributed_c10d.py', - 'torch/distributed/examples/memory_tracker_example.py', - 'torch/distributed/launch.py', - 'torch/distributed/launcher/__init__.py', - 'torch/distributed/launcher/api.py', - 'torch/distributed/logging_handlers.py', - 'torch/distributed/nn/__init__.py', - 'torch/distributed/nn/api/__init__.py', - 'torch/distributed/nn/api/remote_module.py', - 'torch/distributed/nn/functional.py', - 'torch/distributed/nn/jit/__init__.py', - 'torch/distributed/nn/jit/instantiator.py', - 'torch/distributed/nn/jit/templates/__init__.py', - 'torch/distributed/nn/jit/templates/remote_module_template.py', - 'torch/distributed/remote_device.py', - 'torch/distributed/rendezvous.py', - 'torch/distributed/run.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index eb339000e89e..93b701732206 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-defs -import sys import pdb +import sys import torch + def is_available() -> bool: """ Return ``True`` if the distributed package is available. @@ -29,31 +30,31 @@ def is_available() -> bool: if is_available(): from torch._C._distributed_c10d import ( - Store, - FileStore, - TCPStore, - ProcessGroup as ProcessGroup, - Backend as _Backend, - PrefixStore, - Reducer, - Logger, - BuiltinCommHookType, - GradBucket, - Work as _Work, - _DEFAULT_FIRST_BUCKET_BYTES, - _register_comm_hook, - _register_builtin_comm_hook, _broadcast_coalesced, _compute_bucket_assignment_by_size, - _verify_params_across_processes, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, DebugLevel, + FileStore, get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, set_debug_level, set_debug_level_from_env, - _make_nccl_premul_sum, - _ControlCollectives, - _StoreCollectives, + Store, + TCPStore, + Work as _Work, ) class _DistributedPdb(pdb.Pdb): @@ -63,10 +64,11 @@ class _DistributedPdb(pdb.Pdb): Usage: _DistributedPdb().set_trace() """ + def interaction(self, *args, **kwargs): _stdin = sys.stdin try: - sys.stdin = open('/dev/stdin') + sys.stdin = open("/dev/stdin") pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin @@ -98,37 +100,31 @@ def breakpoint(rank: int = 0): del guard if sys.platform != "win32": - from torch._C._distributed_c10d import ( - HashStore, - _round_robin_process_groups, - ) + from torch._C._distributed_c10d import _round_robin_process_groups, HashStore - from .distributed_c10d import * # noqa: F403 + from .device_mesh import DeviceMesh, init_device_mesh # Variables prefixed with underscore are not auto imported # See the comment in `distributed_c10d.py` above `_backend` on why we expose # this. - + from .distributed_c10d import * # noqa: F403 from .distributed_c10d import ( _all_gather_base, - _reduce_scatter_base, - _create_process_group_wrapper, - _rank_not_in_group, _coalescing_manager, _CoalescingManager, + _create_process_group_wrapper, _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, get_node_local_rank, ) - + from .remote_device import _remote_device from .rendezvous import ( - rendezvous, _create_store_from_options, register_rendezvous_handler, + rendezvous, ) - from .remote_device import _remote_device - from .device_mesh import init_device_mesh, DeviceMesh - set_debug_level_from_env() else: diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index 1423cfd600fc..14f7f8a313fa 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -5,6 +5,7 @@ import torch.distributed as dist from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp + from ._fsdp_common import ( _get_dim0_padded_size, _raise_assert_with_print, diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 594ec483bd3b..36b181250f28 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import math import traceback - from dataclasses import dataclass from enum import auto, Enum from typing import Any, cast, List, Optional diff --git a/torch/distributed/_composable/fsdp/_fsdp_init.py b/torch/distributed/_composable/fsdp/_fsdp_init.py index 07fd45e9e3d7..141addc6b719 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_init.py +++ b/torch/distributed/_composable/fsdp/_fsdp_init.py @@ -4,10 +4,10 @@ import torch import torch.distributed as dist import torch.nn as nn - from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh from torch.distributed.device_mesh import _get_device_handle from torch.utils._python_dispatch import is_traceable_wrapper_subclass + from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo from ._fsdp_state import _get_module_fsdp_state diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index c56dc79e266b..6e0e815f7a53 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -7,12 +7,12 @@ import torch import torch._dynamo.compiled_autograd as ca import torch.nn as nn - from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed._tensor.device_mesh import _mesh_resources from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta + from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ( _chunk_with_empty, @@ -24,6 +24,7 @@ HSDPMeshInfo, ) + """ [Note: FSDP tensors] FSDP considers the following tensors: diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 06fa90e060e7..6592a815bacf 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import contextlib - from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple import torch @@ -11,6 +10,7 @@ from torch.profiler import record_function from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle + from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy from ._fsdp_collectives import ( AllGatherResult, @@ -21,6 +21,7 @@ from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState + _ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 79a09342704f..c6cdb2b29880 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools - from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import torch @@ -13,10 +12,12 @@ ) from torch.distributed.utils import _to_kwargs from torch.utils._pytree import tree_flatten, tree_map + from ._fsdp_api import MixedPrecisionPolicy from ._fsdp_common import _cast_fp_tensor, TrainingState from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup + if TYPE_CHECKING: from ._fsdp_param import FSDPParam diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 61b7878d467f..e8ab3466118b 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools - from typing import Any, cast, Iterable, List, NoReturn, Optional, Union import torch diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 950a034071a4..06b121aef80a 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -8,7 +8,6 @@ from torch.distributed._composable_state import _get_module_state, _insert_module_state from torch.distributed.fsdp._common_utils import _FSDPState from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo - from torch.distributed.fsdp._init_utils import ( _init_buffer_state, _init_core_state, diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index 0cb4ea79bc7d..6ba70cf7bfc9 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -9,6 +9,7 @@ from .contract import _get_registry, contract + _ROOT_MODULE_PREFIX = "" diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index 1d3f24c80f08..a3998c8e1d3b 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -1,15 +1,14 @@ # mypy: allow-untyped-defs from collections import defaultdict from contextlib import contextmanager - from functools import partial from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch import torch.distributed._functional_collectives as funcol - import torch.distributed.distributed_c10d as c10d + if TYPE_CHECKING: from torch._C._distributed_c10d import _DistributedBackendOptions, Backend diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 9ac89166b25f..82ca3cb8b073 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -11,6 +11,7 @@ from . import _functional_collectives_impl as fun_col_impl + try: from torch.utils._cxx_pytree import tree_map_only except ImportError: @@ -1134,6 +1135,7 @@ def all_gather_inplace( reduce_scatter_tensor as legacy_reducescatter, ) + # This dict should contain sets of functions that dynamo is allowed to remap. # Functions in this set should accept the same args/kwargs 1:1 as their mapping. traceable_collective_remaps = { diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index c39cb4a9d50d..4bd193d662bd 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -4,6 +4,7 @@ import torch import torch.distributed.distributed_c10d as c10d + """ This file contains the op impls for the legacy (c10d_functional) functional collectives. These impls simply call into the native (_c10d_functional) functional collectives. diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index 6c6694cfb081..5e6f4d2a1a6e 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -1,11 +1,12 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed._shard` package. import sys -import torch import warnings +import torch from torch.distributed._shard.sharded_tensor import * # noqa: F403 + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -15,4 +16,6 @@ stacklevel=2, ) -sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor +sys.modules[ + "torch.distributed._sharded_tensor" +] = torch.distributed._shard.sharded_tensor diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py index 21c56d5dc849..c74dd3633e0f 100644 --- a/torch/distributed/_sharding_spec/__init__.py +++ b/torch/distributed/_sharding_spec/__init__.py @@ -1,11 +1,12 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed._shard` package. import sys -import torch import warnings +import torch from torch.distributed._shard.sharding_spec import * # noqa: F403 + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -16,4 +17,6 @@ ) import torch.distributed._shard.sharding_spec as _sharding_spec -sys.modules['torch.distributed._sharding_spec'] = _sharding_spec + + +sys.modules["torch.distributed._sharding_spec"] = _sharding_spec diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 2f9f0555be64..cb9def721686 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch.distributed._functional_collectives import AsyncCollectiveTensor + if dist.is_available() or TYPE_CHECKING: from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index 10f70c9ce18e..e4d8aa6e762b 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -1,24 +1,14 @@ # mypy: allow-untyped-defs +import operator +import pickle from collections import defaultdict - from itertools import chain - -import pickle - -from typing import ( - Any, - Callable, - Dict, - List, - no_type_check, - Sequence, - TYPE_CHECKING, -) +from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING import torch import torch.nn as nn from torch.utils._python_dispatch import TorchDispatchMode -import operator + if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle @@ -234,6 +224,7 @@ def load(self, path: str) -> None: def _create_pre_forward_hook(self, name: str) -> Callable: """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" + def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: self._cur_module_name = f"{name}.forward" if ( diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index c1cc67b40681..2c92176c53eb 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -15,9 +15,9 @@ import torch import torch.distributed as dist - from torch.distributed.logging_handlers import _log_handlers + __all__: List[str] = [] _DEFAULT_DESTINATION = "default" @@ -36,7 +36,9 @@ def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Lo return logger -def _get_logging_handler(destination: str = _DEFAULT_DESTINATION) -> Tuple[logging.Handler, str]: +def _get_logging_handler( + destination: str = _DEFAULT_DESTINATION, +) -> Tuple[logging.Handler, str]: log_handler = _log_handlers[destination] log_handler_name = type(log_handler).__name__ return (log_handler, log_handler_name) @@ -69,8 +71,10 @@ def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: } return msg_dict -_T = TypeVar('_T') -_P = ParamSpec('_P') + +_T = TypeVar("_T") +_P = ParamSpec("_P") + def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: @functools.wraps(func) diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index ed6c93078299..78199e7a26f2 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -14,8 +14,10 @@ import torch.distributed as dist + T = TypeVar("T") + @dataclass class SyncPayload(Generic[T]): stage_name: Optional[str] @@ -23,6 +25,7 @@ class SyncPayload(Generic[T]): payload: T exception: Optional[Exception] = None + def broadcast( data_or_fn: Union[T, Callable[[], T]], *, @@ -55,10 +58,12 @@ def broadcast( """ if not success and data_or_fn is not None: - raise AssertionError("Data or Function is expected to be None if not successful") + raise AssertionError( + "Data or Function is expected to be None if not successful" + ) payload: Optional[T] = None - exception : Optional[Exception] = None + exception: Optional[Exception] = None # if no pg is passed then execute if rank is 0 if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): # determine if it is an executable function or data payload only @@ -119,7 +124,7 @@ def all_gather( >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) """ payload: Optional[T] = None - exception : Optional[Exception] = None + exception: Optional[Exception] = None success = True # determine if it is an executable function or data payload only if callable(data_or_fn): @@ -161,7 +166,8 @@ def all_gather( if len(exception_list) > 0: raise RuntimeError( # type: ignore[misc] - error_msg, exception_list) from exception_list[0] + error_msg, exception_list + ) from exception_list[0] return ret_list else: if not sync_obj.success: diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index 47b1f90e406c..b3754043644b 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -1,8 +1,10 @@ -from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT from datetime import timedelta from typing import Optional -__all__ = ['default_pg_timeout', 'default_pg_nccl_timeout'] +from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT + + +__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] # Default process group wide timeout, if applicable. # This only applies to the non-nccl backends @@ -16,6 +18,7 @@ try: from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT + default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT except ImportError: # if C++ NCCL support is not compiled, we don't have access to the default nccl value. diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index e46356a36894..a1fee846d254 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -6,10 +6,9 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch - from torch.distributed import is_available +from torch.utils._typing_utils import not_none -from ..utils._typing_utils import not_none __all__ = ["init_device_mesh", "DeviceMesh"] diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index d44c3733a214..91e4cf9f540c 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,11 +1,11 @@ # mypy: allow-untyped-defs """Distributed Collective Communication (c10d).""" -import itertools import collections.abc import contextlib import hashlib import io +import itertools import logging import os import pickle @@ -14,19 +14,26 @@ import warnings from collections import namedtuple from datetime import timedelta -from typing import Any, Callable, Dict, Optional, Tuple, Union, List, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import deprecated import torch +from torch._C import _DistStoreError as DistStoreError from torch._C._distributed_c10d import ( + _DistributedBackendOptions, + _register_process_group, + _resolve_process_group, + _unregister_all_process_groups, + _unregister_process_group, AllgatherOptions, AllreduceCoalescedOptions, AllreduceOptions, AllToAllOptions, - _DistributedBackendOptions, BarrierOptions, BroadcastOptions, + DebugLevel, GatherOptions, + get_debug_level, PrefixStore, ProcessGroup, ReduceOp, @@ -34,41 +41,88 @@ ReduceScatterOptions, ScatterOptions, Store, - DebugLevel, - get_debug_level, Work, - _register_process_group, - _resolve_process_group, - _unregister_all_process_groups, - _unregister_process_group, ) from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs -from .constants import default_pg_timeout, default_pg_nccl_timeout +from torch.utils._typing_utils import not_none + from .c10d_logger import _exception_logger, _time_logger +from .constants import default_pg_nccl_timeout, default_pg_timeout from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 -from ..utils._typing_utils import not_none -DistStoreError = torch._C._DistStoreError + __all__ = [ - 'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced', - 'all_gather_object', 'all_reduce', - 'all_reduce_coalesced', 'all_to_all', - 'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast', 'send_object_list', - 'recv_object_list', 'broadcast_object_list', 'destroy_process_group', - 'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank', - 'get_world_size', 'get_pg_count', 'group', 'init_process_group', 'irecv', - 'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available', - 'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available', - 'isend', 'monitored_barrier', 'new_group', 'new_subgroups', - 'new_subgroups_by_enumeration', 'recv', 'reduce', - 'reduce_scatter', 'scatter', - 'scatter_object_list', 'send', 'supports_complex', - 'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions', - 'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore', - 'ProcessGroup', 'ReduceOp', 'ReduceOptions', 'ReduceScatterOptions', - 'ScatterOptions', 'Store', 'DebugLevel', 'get_debug_level', 'Work', - 'default_pg_timeout', 'get_group_rank', 'get_global_rank', 'get_process_group_ranks', - 'reduce_op', 'all_gather_into_tensor', 'reduce_scatter_tensor', 'get_node_local_rank', + "Backend", + "BackendConfig", + "GroupMember", + "P2POp", + "all_gather", + "all_gather_coalesced", + "all_gather_object", + "all_reduce", + "all_reduce_coalesced", + "all_to_all", + "all_to_all_single", + "barrier", + "batch_isend_irecv", + "broadcast", + "send_object_list", + "recv_object_list", + "broadcast_object_list", + "destroy_process_group", + "gather", + "gather_object", + "get_backend_config", + "get_backend", + "get_rank", + "get_world_size", + "get_pg_count", + "group", + "init_process_group", + "irecv", + "is_gloo_available", + "is_initialized", + "is_mpi_available", + "is_backend_available", + "is_nccl_available", + "is_torchelastic_launched", + "is_ucc_available", + "isend", + "monitored_barrier", + "new_group", + "new_subgroups", + "new_subgroups_by_enumeration", + "recv", + "reduce", + "reduce_scatter", + "scatter", + "scatter_object_list", + "send", + "supports_complex", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "PrefixStore", + "ProcessGroup", + "ReduceOp", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + "Store", + "DebugLevel", + "get_debug_level", + "Work", + "default_pg_timeout", + "get_group_rank", + "get_global_rank", + "get_process_group_ranks", + "reduce_op", + "all_gather_into_tensor", + "reduce_scatter_tensor", + "get_node_local_rank", ] _MPI_AVAILABLE = True @@ -79,6 +133,7 @@ _pickler = pickle.Pickler _unpickler = pickle.Unpickler + # Change __module__ of all imported types from torch._C._distributed_c10d that are public def _export_c_types() -> None: _public_types_to_change_module = [ @@ -97,22 +152,25 @@ def _export_c_types() -> None: Store, DebugLevel, get_debug_level, - Work + Work, ] for type in _public_types_to_change_module: type.__module__ = "torch.distributed.distributed_c10d" + + _export_c_types() try: from torch._C._distributed_c10d import ProcessGroupMPI + ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupMPI"] except ImportError: _MPI_AVAILABLE = False try: - from torch._C._distributed_c10d import ProcessGroupNCCL - from torch._C._distributed_c10d import ProcessGroupCudaP2P + from torch._C._distributed_c10d import ProcessGroupCudaP2P, ProcessGroupNCCL + ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" ProcessGroupCudaP2P.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupNCCL", "ProcessGroupCudaP2P"] @@ -120,8 +178,8 @@ def _export_c_types() -> None: _NCCL_AVAILABLE = False try: - from torch._C._distributed_c10d import ProcessGroupGloo - from torch._C._distributed_c10d import _ProcessGroupWrapper + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupGloo"] except ImportError: @@ -129,6 +187,7 @@ def _export_c_types() -> None: try: from torch._C._distributed_c10d import ProcessGroupUCC + ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupUCC"] except ImportError: @@ -191,20 +250,20 @@ class Backend(str): backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { - 'cpu' : GLOO, - 'cuda' : NCCL, + "cpu": GLOO, + "cuda": NCCL, } backend_capability: Dict[str, List[str]] = { - GLOO : ["cpu", "cuda"], - NCCL : ["cuda"], - UCC : ["cpu", "cuda"], - MPI : ["cpu", "cuda"], + GLOO: ["cpu", "cuda"], + NCCL: ["cuda"], + UCC: ["cpu", "cuda"], + MPI: ["cpu", "cuda"], } backend_type_map: Dict[str, ProcessGroup.BackendType] = { UNDEFINED: ProcessGroup.BackendType.UNDEFINED, - GLOO : ProcessGroup.BackendType.GLOO, + GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, UCC: ProcessGroup.BackendType.UCC, } @@ -220,7 +279,13 @@ def __new__(cls, name: str): return value @classmethod - def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None) -> None: + def register_backend( + cls, + name, + func, + extended_api=False, + devices: Optional[Union[str, List[str]]] = None, + ) -> None: """ Register a new backend with the given name and instantiating function. @@ -247,19 +312,19 @@ def register_backend(cls, name, func, extended_api=False, devices: Optional[Unio """ # Allow UCC plugin if Pytorch is not built with native support. # TODO: remove this exception once UCC plugin is fully deprecated. - if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())): - assert not hasattr(Backend, name.upper()), ( - f"{name.upper()} c10d backend already exist" - ) - assert name.upper() not in Backend._plugins, ( - f"{name.upper()} c10d backend creator function already exist" - ) + if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): + assert not hasattr( + Backend, name.upper() + ), f"{name.upper()} c10d backend already exist" + assert ( + name.upper() not in Backend._plugins + ), f"{name.upper()} c10d backend creator function already exist" setattr(Backend, name.upper(), name.lower()) Backend.backend_list.append(name.lower()) if devices is not None: for device in devices: - if device != 'cpu' and device != 'cuda': + if device != "cpu" and device != "cuda": Backend.default_device_backend_map[device] = name.lower() Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM @@ -281,6 +346,7 @@ def register_backend(cls, name, func, extended_api=False, devices: Optional[Unio Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) + class BackendConfig: """Backend configuration class.""" @@ -294,7 +360,10 @@ def __init__(self, backend: Backend): # supported since PyTorch 2.0 for device, default_backend in Backend.default_device_backend_map.items(): if is_backend_available(default_backend): - if default_backend == Backend.NCCL and not torch.cuda.is_available(): + if ( + default_backend == Backend.NCCL + and not torch.cuda.is_available() + ): continue self.device_backend_map[device] = Backend(default_backend) elif backend.lower() in Backend.backend_list: @@ -316,12 +385,16 @@ def __init__(self, backend: Backend): for device_backend_pair_str in backend.lower().split(","): device_backend_pair = device_backend_pair_str.split(":") if len(device_backend_pair) != 2: - raise ValueError(f"Invalid device:backend pairing: \ - {device_backend_pair_str}. {backend_str_error_message}") + raise ValueError( + f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}" + ) device, backend = device_backend_pair if device in self.device_backend_map: - raise ValueError(f"Duplicate device type {device} \ - in backend string: {backend}. {backend_str_error_message}") + raise ValueError( + f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}" + ) self.device_backend_map[device] = Backend(backend) else: # User specified a single backend name whose device capability is @@ -334,23 +407,24 @@ def __init__(self, backend: Backend): ) backend_val = Backend(backend) self.device_backend_map = { - "cpu" : backend_val, - "cuda" : backend_val, - "xpu" : backend_val, + "cpu": backend_val, + "cuda": backend_val, + "xpu": backend_val, } - logger.info( - "Using backend config: %s", self.device_backend_map - ) + logger.info("Using backend config: %s", self.device_backend_map) def __repr__(self): """Return all the device:backend pairs separated by commas.""" - return ",".join(f"{device}:{backend}" for device, backend in self.device_backend_map.items()) + return ",".join( + f"{device}:{backend}" for device, backend in self.device_backend_map.items() + ) def get_device_backend_map(self) -> Dict[str, Backend]: """Return backend map of the device.""" return self.device_backend_map + class _reduce_op: r""" Deprecated enum-like class. @@ -397,8 +471,14 @@ class P2POp: tag (int, optional): Tag to match send with recv. """ - def __init__(self, op: Callable, tensor: torch.Tensor, peer: int, - group: Optional[ProcessGroup] = None, tag: int = 0): + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): """Init.""" self.op = op self.tensor = tensor @@ -406,8 +486,14 @@ def __init__(self, op: Callable, tensor: torch.Tensor, peer: int, self.group = group self.tag = tag - def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, - group: Optional[ProcessGroup] = None, tag: int = 0): + def __new__( + cls, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): """Create and return a new instance of the class.""" _check_op(op) _check_single_tensor(tensor, "tensor") @@ -415,7 +501,9 @@ def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, def __repr__(self): my_group_rank = get_rank(self.group) - peer_group_rank = get_group_rank(self.group, self.peer) if self.group else self.peer + peer_group_rank = ( + get_group_rank(self.group, self.peer) if self.group else self.peer + ) op_name = self.op.__name__ group_name = self.group.group_name if self.group else "default_pg" if "send" in op_name: @@ -429,6 +517,7 @@ def __repr__(self): return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})" + class _CollOp: """ A class to capture collective operations. @@ -441,8 +530,14 @@ class _CollOp: root (int, optional): root of broadcast or reduce. """ - def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torch.Tensor] = None, - redop: Optional[ReduceOp] = None, root: Optional[int] = None): + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + dst_tensor: Optional[torch.Tensor] = None, + redop: Optional[ReduceOp] = None, + root: Optional[int] = None, + ): self.op = op self.tensor = tensor self.dst_tensor = dst_tensor @@ -462,6 +557,7 @@ def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torc _pg_to_tag: Dict[ProcessGroup, str] = {} _backend: Optional[str] = None + class _World: """ Container class for c10d process group state. @@ -597,6 +693,7 @@ def pg_config_info(self) -> List[Dict[str, Any]]: _world = _World() """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" + class _WorldMeta(type): """ Meta class of ``group`` and ``GroupMember``. @@ -613,11 +710,13 @@ def WORLD(cls) -> Optional[ProcessGroup]: def WORLD(cls, pg: Optional[ProcessGroup]): _world.default_pg = pg + class group(metaclass=_WorldMeta): """Group class. Placeholder.""" pass + class GroupMember(metaclass=_WorldMeta): """Group member class.""" @@ -630,23 +729,28 @@ def _get_default_timeout(backend: Backend) -> timedelta: if not isinstance(default_pg_nccl_timeout, timedelta): # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was # changed to be a warning. We should fix the moco model. - warnings.warn("Attempted to get default timeout for nccl backend, but NCCL support is not compiled") + warnings.warn( + "Attempted to get default timeout for nccl backend, but NCCL support is not compiled" + ) return default_pg_timeout return default_pg_nccl_timeout else: return default_pg_timeout + def _check_valid_timeout(timeout: Any) -> None: if not isinstance(timeout, timedelta): raise TypeError( f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" ) + # Default process group state _default_pg_init_method: Optional[str] = None STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device: """ Return the device to use with ``group`` for control flow usage (object collectives, barrier). @@ -711,14 +815,20 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device _world.pg_default_device[group] = devices[0] logger.info( - "Using device %s for object " - "collectives.", _world.pg_default_device[group] + "Using device %s for object " "collectives.", _world.pg_default_device[group] ) return _world.pg_default_device[group] @_time_logger -def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, logging_interval=timedelta(seconds=10)) -> None: +def _store_based_barrier( + rank, + store, + group_name, + rendezvous_count, + timeout, + logging_interval=timedelta(seconds=10), +) -> None: """ Store based barrier for synchronizing processes. @@ -755,7 +865,12 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log logger.debug( "Waiting in store based barrier to initialize process group for " "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", - rank, store_key, world_size, worker_count, timeout, e + rank, + store_key, + world_size, + worker_count, + timeout, + e, ) if timedelta(seconds=(time.time() - start)) > timeout: @@ -766,7 +881,10 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log ) logger.info( - "Rank %s: Completed store-based barrier for key:%s with %s nodes.", rank, store_key, world_size + "Rank %s: Completed store-based barrier for key:%s with %s nodes.", + rank, + store_key, + world_size, ) @@ -803,13 +921,16 @@ def get_group_rank(group: ProcessGroup, global_rank: int) -> int: if group is GroupMember.WORLD: return global_rank if group not in _world.pg_group_ranks: - raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API") + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) group_ranks = _world.pg_group_ranks[group] if global_rank not in group_ranks: raise ValueError(f"Global rank {global_rank} is not part of group {group}") return group_ranks[global_rank] + def get_global_rank(group: ProcessGroup, group_rank: int) -> int: """ Translate a group rank into a global rank. @@ -828,7 +949,9 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int: if group is GroupMember.WORLD: return group_rank if group not in _world.pg_group_ranks: - raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API") + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) for rank, grp_rank in _world.pg_group_ranks[group].items(): if grp_rank == group_rank: return rank @@ -858,6 +981,7 @@ def get_process_group_ranks(group: ProcessGroup) -> List[int]: """ return list(_world.pg_group_ranks[group].keys()) + def _get_group_size(group) -> int: """Get a given group's world size.""" if group is GroupMember.WORLD or group is None: @@ -906,13 +1030,16 @@ def _check_tensor_list(param, param_name) -> None: def _as_iterable(obj) -> collections.abc.Iterable: return obj if isinstance(obj, list) else (obj,) + def _ensure_all_tensors_same_dtype(*tensors) -> None: last_dtype = None for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): tensor_dtype = tensor.dtype # Mixing complex and its element type is allowed if tensor_dtype.is_complex: - tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + tensor_dtype = ( + torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + ) if last_dtype is None: last_dtype = tensor_dtype @@ -1049,6 +1176,7 @@ def _update_default_pg(pg) -> None: rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 torch._C._distributed_c10d._set_global_rank(rank) + def get_backend_config(group: Optional[ProcessGroup] = None) -> str: """ Return the backend configuration of the given process group. @@ -1071,6 +1199,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str: backend_config = _world.pg_backend_config.get(pg) return str(not_none(backend_config)) + def get_backend(group: Optional[ProcessGroup] = None) -> Backend: """ Return the backend of the given process group. @@ -1093,6 +1222,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend: pg_store = _world.pg_map[pg] if pg in _world.pg_map else None return Backend(not_none(pg_store)[0]) + def _get_process_group_uid(pg: ProcessGroup) -> int: backend = None try: @@ -1103,6 +1233,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int: return backend.uid return -1 + def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: """ Return the pg configuration of the given process group. @@ -1120,6 +1251,7 @@ def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: "ranks": get_process_group_ranks(pg), } + def _get_all_pg_configs() -> List[Dict[str, Any]]: """ Return the pg configuration of all the process groups. @@ -1130,6 +1262,7 @@ def _get_all_pg_configs() -> List[Dict[str, Any]]: config_info.append(_get_pg_config(pg)) return config_info + def get_pg_count() -> int: """ Return the number of process groups. @@ -1137,6 +1270,7 @@ def get_pg_count() -> int: """ return _world.group_count + def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: """ Return the local rank of the current process relative to the node. @@ -1162,6 +1296,7 @@ def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: "assuming you are not running in a multi-device context and want the code to run locally instead." ) + def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: """ Set the timeout for the given process group when users want to use a different timeout instead of @@ -1349,7 +1484,14 @@ def init_process_group( ) default_pg, _ = _new_process_group_helper( - -1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg" + -1, + -1, + [], + backend, + None, + group_name, + timeout=timeout, + group_desc="default_pg", ) _update_default_pg(default_pg) else: @@ -1375,7 +1517,7 @@ def init_process_group( pg_options=pg_options, timeout=timeout, device_id=device_id, - group_desc="default_pg" + group_desc="default_pg", ) _update_default_pg(default_pg) @@ -1394,7 +1536,9 @@ def _distributed_excepthook(*args): finally: sys.stderr = old_stderr msg = buf.getvalue() - msg = "\n".join(f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n")) + msg = "\n".join( + f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") + ) sys.stderr.write(msg) sys.stderr.flush() @@ -1421,6 +1565,7 @@ def _distributed_excepthook(*args): # default devices and messes up NCCL internal state. _store_based_barrier(rank, store, group_name, world_size, timeout) + def _get_split_source(pg): split_from = None if pg.bound_device_id: @@ -1442,6 +1587,7 @@ def _get_split_source(pg): return split_from + def _shutdown_backend(pg): """ Try to shut down the backend of a process group. @@ -1453,10 +1599,13 @@ def _shutdown_backend(pg): backend = pg._get_backend(torch.device("cuda")) except RuntimeError: pass - if is_nccl_available() and isinstance(backend, (ProcessGroupNCCL, ProcessGroupCudaP2P)): + if is_nccl_available() and isinstance( + backend, (ProcessGroupNCCL, ProcessGroupCudaP2P) + ): # explictly call shutdown to ensure that NCCL resources are released backend._shutdown() + def _new_process_group_helper( group_size, group_rank, @@ -1487,9 +1636,11 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != 'cuda'): - raise ValueError("init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu") + if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + raise ValueError( + "init_process_group device_id parameter must be a cuda device with an " + "id, e.g. cuda:0, not just cuda or cpu" + ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value _check_valid_timeout(timeout) @@ -1514,8 +1665,10 @@ def _new_process_group_helper( # ranks_. We can only know this if the group we are making is the # entire world or if we have bound a device id to the world (which # causes early connection initialization). - if (is_initialized() and - (len(global_ranks_in_group) == _get_default_group().size() or _get_default_group().bound_device_id)): + if is_initialized() and ( + len(global_ranks_in_group) == _get_default_group().size() + or _get_default_group().bound_device_id + ): split_from = _get_split_source(_get_default_group()) else: split_from = None @@ -1538,7 +1691,9 @@ def _new_process_group_helper( prefix_store = PrefixStore(f"{group_name}/", store) base_pg_options = ProcessGroup.Options(backend=str(backend)) base_pg_options._timeout = timeout - pg: ProcessGroup = ProcessGroup(prefix_store, group_rank, group_size, base_pg_options) + pg: ProcessGroup = ProcessGroup( + prefix_store, group_rank, group_size, base_pg_options + ) if device_id: pg.bound_device_id = device_id backend_config = BackendConfig(backend) @@ -1561,12 +1716,19 @@ def _new_process_group_helper( return GroupMember.NON_GROUP_MEMBER, None # create new process group with accurate rank and size if pg.rank() == -1 and pg.size() == -1: - pg = ProcessGroup(backend_prefix_store, backend_class.rank(), backend_class.size(), base_pg_options) + pg = ProcessGroup( + backend_prefix_store, + backend_class.rank(), + backend_class.size(), + base_pg_options, + ) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: # raise RuntimeError("GLOO options not supported") - backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout) + backend_class = ProcessGroupGloo( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.GLOO elif backend_str == Backend.NCCL: if not is_nccl_available(): @@ -1592,19 +1754,22 @@ def _new_process_group_helper( pg_options.global_ranks_in_group = global_ranks_in_group pg_options.group_name = group_name backend_class = ProcessGroupNCCL( - backend_prefix_store, group_rank, group_size, pg_options) + backend_prefix_store, group_rank, group_size, pg_options + ) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): # TODO: once UCC plugin is fully deprecated, remove # is_ucc_available() from above elif-condition and raise # RuntimeError if is_ucc_available() returns false. - backend_class = ProcessGroupUCC(backend_prefix_store, group_rank, group_size, timeout=timeout) + backend_class = ProcessGroupUCC( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.UCC else: - assert backend_str.upper() in Backend._plugins, ( - f"Unknown c10d backend type {backend_str.upper()}" - ) + assert ( + backend_str.upper() in Backend._plugins + ), f"Unknown c10d backend type {backend_str.upper()}" backend_plugin = Backend._plugins[backend_str.upper()] creator_fn = backend_plugin.creator_fn @@ -1612,7 +1777,9 @@ def _new_process_group_helper( backend_type = ProcessGroup.BackendType.CUSTOM if not extended_api: - backend_class = creator_fn(backend_prefix_store, group_rank, group_size, timeout) + backend_class = creator_fn( + backend_prefix_store, group_rank, group_size, timeout + ) else: dist_backend_opts = _DistributedBackendOptions() dist_backend_opts.store = backend_prefix_store @@ -1640,7 +1807,10 @@ def _new_process_group_helper( break # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set - if backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] or backend_str.upper() in Backend._plugins: + if ( + backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] + or backend_str.upper() in Backend._plugins + ): # In debug mode and if GLOO is available, wrap in a wrapper PG that # enables enhanced collective checking for debuggability. if get_debug_level() == DebugLevel.DETAIL: @@ -1698,6 +1868,7 @@ def _new_process_group_helper( _world.pg_to_tag[pg] = pg_tag return pg, prefix_store + def destroy_process_group(group: Optional[ProcessGroup] = None): """ Destroy a given process group, and deinitialize the distributed package. @@ -1736,7 +1907,9 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): if group is None or group == GroupMember.WORLD: # shutdown all backends in the order of pg names. shutting down in order because # ncclCommAbort() was a 'collective' call in some versions of NCCL. - for pg_to_shutdown in sorted(_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True): + for pg_to_shutdown in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): _shutdown_backend(pg_to_shutdown) _update_default_pg(None) @@ -1832,7 +2005,9 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int: return _get_group_size(group) -def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]: +def isend( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> Optional[Work]: """ Send a tensor asynchronously. @@ -1871,7 +2046,13 @@ def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, return pg.send([tensor], dst, tag) -def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]: + +def irecv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> Optional[Work]: """ Receives a tensor asynchronously. @@ -1913,8 +2094,11 @@ def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proce group_src_rank = get_group_rank(pg, src) return pg.recv([tensor], group_src_rank, tag) + @_exception_logger -def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> None: +def send( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> None: """ Send a tensor synchronously. @@ -1951,8 +2135,14 @@ def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, t group_dst_rank = get_group_rank(group, dst) group.send([tensor], group_dst_rank, tag).wait() + @_exception_logger -def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> int: +def recv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> int: """ Receives a tensor synchronously. @@ -2004,7 +2194,15 @@ def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proces class _IllegalWork(Work): def __getattribute__(self, name): - if name in ["is_success", "exception", "wait", "source_rank", "_source_rank", "result", "synchronize"]: + if name in [ + "is_success", + "exception", + "wait", + "source_rank", + "_source_rank", + "result", + "synchronize", + ]: raise ValueError(f"Illegal to call {name} on IllegalWork object") @@ -2057,7 +2255,9 @@ def _coalescing_manager( group = group or _get_default_group() op_list = _world.pg_coalesce_state.setdefault(group, []) if op_list: - raise ValueError("ProcessGroup has non-empty op list at the start of coalescing") + raise ValueError( + "ProcessGroup has non-empty op list at the start of coalescing" + ) if device: group._start_coalescing(device) cm = _CoalescingManager() @@ -2212,6 +2412,7 @@ def broadcast(tensor, src, group=None, async_op=False): else: work.wait() + @_exception_logger def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2292,6 +2493,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + @_exception_logger @deprecated( "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " @@ -2359,6 +2561,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + @_exception_logger def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2404,6 +2607,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + def _object_to_tensor(obj, device, group): f = io.BytesIO() _pickler(f).dump(obj) @@ -2416,7 +2620,9 @@ def _object_to_tensor(obj, device, group): backend = get_backend(group) if backend == Backend.NCCL: hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) - logger.warning("_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash) + logger.warning( + "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash + ) local_size = torch.LongTensor([byte_tensor.numel()]).to(device) return byte_tensor, local_size @@ -2426,7 +2632,9 @@ def _tensor_to_object(tensor, tensor_size, group): backend = get_backend(group) if backend == Backend.NCCL: hash = torch._C._distributed_c10d._hash_tensors([tensor]) - logger.warning("_tensor_to_object size: %s hash value: %s", tensor.numel(), hash) + logger.warning( + "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + ) tensor = tensor.cpu() buf = tensor.numpy().tobytes()[:tensor_size] return _unpickler(io.BytesIO(buf)).load() @@ -2709,7 +2917,9 @@ def send_object_list(object_list, dst, group=None, device=None): # sent to this device. current_device = device or _get_pg_default_device(group) # Serialize object_list elements to tensors on src rank. - tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) object_sizes_tensor = torch.cat(size_list) # Send object sizes @@ -2793,7 +3003,9 @@ def recv_object_list(object_list, src=None, group=None, device=None): # case it is not ``None`` we move the size and object tensors to be # received to this device. current_device = device or _get_pg_default_device(group) - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) # Receive object sizes rank_sizes = recv(object_sizes_tensor, src=src, group=group) @@ -2802,11 +3014,13 @@ def recv_object_list(object_list, src=None, group=None, device=None): object_tensor = torch.empty( # type: ignore[call-overload] torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, - device=current_device + device=current_device, ) rank_objects = recv(object_tensor, src=src, group=group) - assert rank_sizes == rank_objects, "Mismatch in return ranks for object sizes and objects." + assert ( + rank_sizes == rank_objects + ), "Mismatch in return ranks for object sizes and objects." # Deserialize objects using their stored sizes. offset = 0 for i, obj_size in enumerate(object_sizes_tensor): @@ -2816,6 +3030,7 @@ def recv_object_list(object_list, src=None, group=None, device=None): object_list[i] = _tensor_to_object(obj_view, obj_size, group) return rank_objects + @_exception_logger def broadcast_object_list(object_list, src=0, group=None, device=None): """ @@ -2892,10 +3107,14 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): my_rank = get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) object_sizes_tensor = torch.cat(size_list) else: - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) # Broadcast object sizes broadcast(object_sizes_tensor, src=src, group=group) @@ -2912,7 +3131,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): object_tensor = torch.empty( # type: ignore[call-overload] torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, - device=current_device + device=current_device, ) broadcast(object_tensor, src=src, group=group) @@ -3000,7 +3219,10 @@ def scatter_object_list( pg_device = _get_pg_default_device(group) if my_rank == src: tensor_list, tensor_sizes = zip( - *[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list] + *[ + _object_to_tensor(obj, pg_device, group) + for obj in scatter_object_input_list + ] ) tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) @@ -3015,7 +3237,9 @@ def scatter_object_list( broadcast(max_tensor_size, src=src, group=group) # Scatter actual serialized objects - output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device) + output_tensor = torch.empty( + max_tensor_size.item(), dtype=torch.uint8, device=pg_device + ) scatter( output_tensor, scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined] @@ -3033,7 +3257,9 @@ def scatter_object_list( ) # Deserialize back to object - scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group) + scatter_object_output_list[0] = _tensor_to_object( + output_tensor, obj_tensor_size, group + ) @_exception_logger @@ -3900,6 +4126,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False else: work.wait() + @_exception_logger def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None): """ @@ -4041,15 +4268,18 @@ def _create_process_group_wrapper( wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) return wrapped_pg + # helper function for deterministically hashing a list of ranks def _hash_ranks(ranks: List[int]): return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() + # Takes a list of ranks and computes an integer color def _process_group_color(ranks: List[int]) -> int: # Convert our hash to an int, but avoid negative numbers by shifting a bit. return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1) + def _process_group_name(ranks, use_hashed_name): global _world if use_hashed_name: @@ -4061,6 +4291,7 @@ def _process_group_name(ranks, use_hashed_name): _world.group_count += 1 return pg_name + def _get_backend_from_str(backend: Optional[str] = None) -> Backend: # Default to the same backend as the global process group # if backend is not specified. @@ -4070,7 +4301,14 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend: @_time_logger -def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None): +def new_group( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, +): """ Create a new distributed group. @@ -4137,6 +4375,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local group_desc=group_desc, ) + def _new_group_with_tag( ranks=None, timeout=None, @@ -4144,7 +4383,7 @@ def _new_group_with_tag( pg_options=None, pg_tag=None, use_local_synchronization=False, - group_desc=None + group_desc=None, ): """ Variant of ``new_group`` that exposes tag creation. @@ -4159,7 +4398,6 @@ def _new_group_with_tag( global_rank = default_pg.rank() global_world_size = default_pg.size() - # Default to the same backend as the global process group # if the backend is not specified. if not backend: @@ -4175,7 +4413,9 @@ def _new_group_with_tag( if use_local_synchronization: # MPI backend doesn't have have a way for us to perform a partial sync if backend == Backend.MPI: - raise ValueError("MPI backend doesn't support use_local_synchronization=True") + raise ValueError( + "MPI backend doesn't support use_local_synchronization=True" + ) if ranks is not None and get_rank() not in ranks: return None @@ -4217,7 +4457,7 @@ def _new_group_with_tag( pg_options=pg_options, timeout=timeout, pg_tag=pg_tag, - group_desc=group_desc + group_desc=group_desc, ) # Create the global rank to group rank mapping @@ -4246,7 +4486,9 @@ def _new_group_with_tag( world_size = len(ranks) if use_local_synchronization else get_world_size() # Use store based barrier here since barrier() used a bunch of # default devices and messes up NCCL internal state. - _store_based_barrier(global_rank, barrier_store, group_name, world_size, timeout) + _store_based_barrier( + global_rank, barrier_store, group_name, world_size, timeout + ) return pg @@ -4332,16 +4574,20 @@ def new_subgroups( """ if group_size is None: if not torch.cuda.is_available(): - raise ValueError("Default group size only takes effect when CUDA is available." - "If your subgroup using a backend that does not depend on CUDA," - "please pass in 'group_size' correctly.") + raise ValueError( + "Default group size only takes effect when CUDA is available." + "If your subgroup using a backend that does not depend on CUDA," + "please pass in 'group_size' correctly." + ) group_size = torch.cuda.device_count() if group_size <= 0: raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") world_size = get_world_size() if world_size < group_size: - raise ValueError(f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})") + raise ValueError( + f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" + ) if world_size % group_size != 0: raise ValueError("The world size must be divisible by 'group_size'") @@ -4364,10 +4610,7 @@ def new_subgroups( rank = get_rank() if rank in ranks_in_subgroup: cur_subgroup = subgroup - logger.info( - "Rank %s is assigned to subgroup %s", - rank, ranks_in_subgroup - ) + logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup) return cur_subgroup, subgroups @@ -4479,8 +4722,13 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGro return group return None -def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup: - assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + +def _find_or_create_pg_by_ranks_and_tag( + tag: str, ranks: List[int], stride: int +) -> ProcessGroup: + assert ( + len(ranks) % stride == 0 + ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" my_rank = get_rank() my_ranks = None @@ -4505,6 +4753,7 @@ def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) # TODO copy settings and timeout from default PG return _new_group_with_tag(my_ranks, pg_tag=tag) + def _get_group_tag(pg: ProcessGroup) -> str: """Return the tag associated with ``pg``.""" tag = _world.pg_to_tag[pg] @@ -4512,12 +4761,15 @@ def _get_group_tag(pg: ProcessGroup) -> str: tag = tag[5:] return tag + def _get_process_group_name(pg: ProcessGroup) -> str: return _world.pg_names.get(pg, "None") + def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] + # This ops are not friendly to TorchDynamo. So, we decide to disallow these ops # in FX graph, allowing them to run them on eager, with torch.compile. dynamo_unsupported_distributed_c10d_ops = [ diff --git a/torch/distributed/examples/memory_tracker_example.py b/torch/distributed/examples/memory_tracker_example.py index cb2ba03777d8..e40cfb8b3f59 100644 --- a/torch/distributed/examples/memory_tracker_example.py +++ b/torch/distributed/examples/memory_tracker_example.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs -import torch import torchvision +import torch from torch.distributed._tools import MemoryTracker diff --git a/torch/distributed/launcher/__init__.py b/torch/distributed/launcher/__init__.py index f0d25f8080c2..fb744a2b9361 100644 --- a/torch/distributed/launcher/__init__.py +++ b/torch/distributed/launcher/__init__.py @@ -8,7 +8,7 @@ from torch.distributed.launcher.api import ( # noqa: F401 - LaunchConfig, elastic_launch, launch_agent, + LaunchConfig, ) diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 937647f77828..a3bcd4073c9b 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -15,13 +15,18 @@ from torch.distributed.elastic import events, metrics from torch.distributed.elastic.agent.server.api import WorkerSpec from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent -from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs, + LogsSpecs, + SignalException, +) from torch.distributed.elastic.multiprocessing.errors import ChildFailedError from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger -__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] + +__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] logger = get_logger(__name__) @@ -212,8 +217,8 @@ def launch_agent( "max_restarts": config.max_restarts, "monitor_interval": config.monitor_interval, "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] - "metrics_cfg": config.metrics_cfg - } + "metrics_cfg": config.metrics_cfg, + }, ) rdzv_parameters = RendezvousParameters( diff --git a/torch/distributed/logging_handlers.py b/torch/distributed/logging_handlers.py index 3c607fe45da7..021ad100f06a 100644 --- a/torch/distributed/logging_handlers.py +++ b/torch/distributed/logging_handlers.py @@ -9,6 +9,7 @@ import logging from typing import Dict, List + __all__: List[str] = [] _log_handlers: Dict[str, logging.Handler] = { diff --git a/torch/distributed/nn/__init__.py b/torch/distributed/nn/__init__.py index 3ed1b42cbe15..e15fb517052e 100644 --- a/torch/distributed/nn/__init__.py +++ b/torch/distributed/nn/__init__.py @@ -1,4 +1,7 @@ import torch + +from .functional import * # noqa: F403 + + if torch.distributed.rpc.is_available(): from .api.remote_module import RemoteModule -from .functional import * # noqa: F403 diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index de8a15dd65da..5583da8c3e8d 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -21,14 +21,15 @@ import torch import torch.distributed.rpc as rpc -from torch import Tensor, device, dtype, nn -from torch.distributed.nn.jit import instantiator +from torch import device, dtype, nn, Tensor from torch.distributed import _remote_device +from torch.distributed.nn.jit import instantiator from torch.distributed.rpc.internal import _internal_rpc_pickler from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle + __all__ = ["RemoteModule"] _grad_t = Union[Tuple[Tensor, ...], Tensor] @@ -120,7 +121,6 @@ def _raise_not_supported(name: str) -> None: class _RemoteModule(nn.Module): - def __new__(cls, *args, **kwargs): # Use __new__ for logging purposes. torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module") @@ -370,7 +370,10 @@ def register_forward_pre_hook( # type: ignore[return] self, hook: Union[ Callable[[T, Tuple[Any, ...]], Optional[Any]], - Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], + Callable[ + [T, Tuple[Any, ...], Dict[str, Any]], + Optional[Tuple[Any, Dict[str, Any]]], + ], ], prepend: bool = False, with_kwargs: bool = False, @@ -405,10 +408,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: ) def named_parameters( # type: ignore[return] - self, - prefix: str = "", - recurse: bool = True, - remove_duplicate: bool = True + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Parameter]]: _raise_not_supported(self.named_parameters.__name__) @@ -416,10 +416,7 @@ def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[ret _raise_not_supported(self.buffers.__name__) def named_buffers( # type: ignore[return] - self, - prefix: str = "", - recurse: bool = True, - remove_duplicate: bool = True + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Tensor]]: _raise_not_supported(self.named_buffers.__name__) @@ -464,7 +461,11 @@ def _prepare_init(self, remote_device_str: str) -> bool: assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." remote_device = _remote_device(remote_device_str) - self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank() + self.on = ( + remote_device.worker_name() + if remote_device.worker_name() is not None + else remote_device.rank() + ) self.device = str(remote_device.device()) agent = rpc._get_current_rpc_agent() # If the device map of the remote worker is set, diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index e90a78a69324..110df578552a 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -2,11 +2,13 @@ import torch import torch.distributed as dist from torch.autograd import Function + # The two imports below are not always available depending on the # USE_DISTRIBUTED compile flag. Make sure they raise import error # if we're trying to use them. from torch.distributed import group, ReduceOp + def broadcast(tensor, src, group=group.WORLD): """ Broadcasts the tensor to the whole group. @@ -116,6 +118,7 @@ def all_gather(tensor, group=group.WORLD): """ return _AllGather.apply(group, tensor) + def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): """ Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. @@ -340,6 +343,7 @@ def backward(ctx, *grad_outputs): gx = torch.sum(torch.stack(gxs), dim=0) return (None, gx) + class _AllGatherBase(Function): @staticmethod def forward(ctx, output_tensor, input_tensor, group): @@ -354,16 +358,19 @@ def backward(ctx, grad_output): out_size = list(grad_output.size()) if out_size[0] % world_size != 0: raise RuntimeError( - f'Tensor with dimensions: {out_size} does ' - f'not have first dimension divisible by world_size: {world_size}' + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" ) out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) - gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype) + gx = torch.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) else: raise RuntimeError("Backend not supported!") return (None, gx, None) + class _AlltoAll(Function): @staticmethod def forward(ctx, group, out_tensor_list, *tensors): @@ -391,7 +398,9 @@ def forward(ctx, group, out_tensor_list, *tensors): @staticmethod def backward(ctx, *grad_outputs): tensor_list = [ - torch.empty(size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype) + torch.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) for size in ctx.input_tensor_size_list ] return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) @@ -415,7 +424,9 @@ def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): @staticmethod def backward(ctx, grad_output): - tensor = torch.empty(ctx.input_size, device=grad_output.device, dtype=grad_output.dtype) + tensor = torch.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) return (None, None, None, None) + ( _AlltoAllSingle.apply( ctx.group, diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 7d0aede8943e..81ddeb8bfe0a 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -5,7 +5,7 @@ import operator from collections import defaultdict from enum import Enum -from inspect import Parameter, signature, Signature +from inspect import Parameter, Signature, signature from types import MethodType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -21,6 +21,7 @@ ) from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module + from ._backward import _null_coalesce_accumulate, stage_backward from ._unflatten import _outline_submodules from ._utils import PipeInfo @@ -1176,7 +1177,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): predecessor_module = getattr(predecessor_module, atom) except AttributeError as e: raise AttributeError( - f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}' + f"Specified target {qualname} referenced " + f'nonexistent module {".".join(atoms[: i + 1])}' ) from e mod_to_wrap = getattr(predecessor_module, atoms[-1]) diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 18b3191add5b..5b1843a33f6f 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -8,6 +8,7 @@ ) from .stage import build_stage, PipelineStage + __all__ = [ "Pipe", "pipe_split", diff --git a/torch/distributed/remote_device.py b/torch/distributed/remote_device.py index da664f7408bb..bdb1974b1b37 100644 --- a/torch/distributed/remote_device.py +++ b/torch/distributed/remote_device.py @@ -47,7 +47,7 @@ def __init__(self, remote_device: Union[str, torch.device]): else: raise ValueError(PARSE_ERROR) else: - raise TypeError(f'Invalid type for remote_device: {type(remote_device)}') + raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") # Do some basic sanity check (no empty string) if self._worker_name is not None and not self._worker_name: @@ -96,18 +96,18 @@ def device(self) -> torch.device: def __repr__(self): if self._device is not None: if self._worker_name is not None: - return f'{self._worker_name}/{self._device}' + return f"{self._worker_name}/{self._device}" elif self._rank is not None: - return f'rank:{self._rank}/{self._device}' + return f"rank:{self._rank}/{self._device}" else: return str(self._device) else: if self._worker_name is not None: - return f'{self._worker_name}' + return f"{self._worker_name}" elif self._rank is not None: - return f'{self._rank}' + return f"{self._rank}" else: - raise RuntimeError('Invalid state!') + raise RuntimeError("Invalid state!") def __eq__(self, other): if not isinstance(other, _remote_device): @@ -122,8 +122,5 @@ def __eq__(self, other): return False - def __hash__(self): - return hash(self._worker_name) ^ \ - hash(self._device) ^ \ - hash(self._rank) + return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index e3266cb238ac..a944a75271b0 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -10,7 +10,7 @@ import os import sys from datetime import timedelta -from typing import Dict, Optional, Callable, Iterator, Tuple +from typing import Callable, Dict, Iterator, Optional, Tuple from torch.distributed import FileStore, PrefixStore, Store, TCPStore @@ -21,6 +21,7 @@ __all__ = ["register_rendezvous_handler", "rendezvous"] + def register_rendezvous_handler(scheme, handler): """ Register a new rendezvous handler. @@ -47,16 +48,17 @@ def register_rendezvous_handler(scheme, handler): """ global _rendezvous_handlers if scheme in _rendezvous_handlers: - raise RuntimeError( - f"Rendezvous handler for {scheme}:// already registered" - ) + raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") _rendezvous_handlers[scheme] = handler # Query will have format "rank=0&world_size=1" and is # converted into {"rank": 0, "world_size": 1} def _query_to_dict(query: str) -> Dict[str, str]: - return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))} + return { + pair[0]: pair[1] + for pair in (pair.split("=") for pair in filter(None, query.split("&"))) + } def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: @@ -152,7 +154,9 @@ def _torchelastic_use_agent_store() -> bool: return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) -def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store: +def _create_c10d_store( + hostname, port, rank, world_size, timeout, use_libuv=True +) -> Store: """ Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. @@ -183,7 +187,13 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True else: start_daemon = rank == 0 return TCPStore( - hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv + hostname, + port, + world_size, + start_daemon, + timeout, + multi_tenant=True, + use_libuv=use_libuv, ) @@ -208,7 +218,9 @@ def _error(msg): assert result.hostname is not None - store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv) + store = _create_c10d_store( + result.hostname, result.port, rank, world_size, timeout, use_libuv + ) yield (store, rank, world_size) @@ -250,12 +262,13 @@ def _get_env_or_raise(env_var: str) -> str: else: world_size = int(_get_env_or_raise("WORLD_SIZE")) - master_addr = _get_env_or_raise("MASTER_ADDR") master_port = int(_get_env_or_raise("MASTER_PORT")) use_libuv = _get_use_libuv_from_query_dict(query_dict) - store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv) + store = _create_c10d_store( + master_addr, master_port, rank, world_size, timeout, use_libuv + ) yield (store, rank, world_size) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 5654693f3dfc..aa34891d1ecd 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -397,9 +397,9 @@ def main(): import os import sys import uuid -import importlib.metadata as metadata -from argparse import REMAINDER, ArgumentParser -from typing import Callable, List, Tuple, Type, Union, Optional, Set +from argparse import ArgumentParser, REMAINDER +from importlib import metadata +from typing import Callable, List, Optional, Set, Tuple, Type, Union import torch from torch.distributed.argparse_util import check_env, env @@ -408,9 +408,9 @@ def main(): from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.utils.backend_registration import _get_custom_mod_func -import torch.multiprocessing + logger = get_logger(__name__) @@ -693,21 +693,26 @@ def determine_local_world_size(nproc_per_node: str): if torch.cuda.is_available(): num_proc = torch.cuda.device_count() device_type = "gpu" - elif hasattr(torch, torch._C._get_privateuse1_backend_name()) and \ - _get_custom_mod_func("is_available")(): + elif ( + hasattr(torch, torch._C._get_privateuse1_backend_name()) + and _get_custom_mod_func("is_available")() + ): num_proc = _get_custom_mod_func("device_count")() device_type = torch._C._get_privateuse1_backend_name() else: num_proc = os.cpu_count() device_type = "cpu" else: - raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e + raise ValueError( + f"Unsupported nproc_per_node value: {nproc_per_node}" + ) from e logger.info( - "Using nproc_per_node=%s," - " setting to %s since the instance " - "has %s %s", - nproc_per_node, num_proc, os.cpu_count(), device_type + "Using nproc_per_node=%s," " setting to %s since the instance " "has %s %s", + nproc_per_node, + num_proc, + os.cpu_count(), + device_type, ) return num_proc @@ -753,9 +758,13 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: logs_specs_cls = entrypoint_list[0].load() if logs_specs_cls is None: - raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key") + raise ValueError( + f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" + ) - logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)) + logging.info( + "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) + ) else: logs_specs_cls = DefaultLogsSpecs @@ -768,7 +777,11 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str assert 0 < min_nodes <= max_nodes assert args.max_restarts >= 0 - if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint: + if ( + hasattr(args, "master_addr") + and args.rdzv_backend != "static" + and not args.rdzv_endpoint + ): logger.warning( "master_addr is only used for static rdzv_backend and when rdzv_endpoint " "is not specified." @@ -784,7 +797,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str "please further tune the variable for optimal performance in " "your application as needed. \n" "*****************************************", - omp_num_threads + omp_num_threads, ) # This env variable will be passed down to the subprocesses os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) @@ -888,7 +901,9 @@ def run(args): "--rdzv-endpoint=%s " "--rdzv-id=%s\n" "**************************************\n", - args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id + args.rdzv_backend, + args.rdzv_endpoint, + args.rdzv_id, ) config, cmd, cmd_args = config_from_args(args) diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index f13d06641501..1a0b849f955d 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -21,6 +21,7 @@ from torch.nn.parallel.scatter_gather import _is_namedtuple from torch.nn.utils.rnn import PackedSequence + __all__ = [] # type: ignore[var-annotated] From d9c294c6726ec833406dfaf1a2cdee77c4a5785d Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 18 Jun 2024 22:06:53 +0000 Subject: [PATCH 27/64] [Inductor] Fix arguments passed to triton kernel launch hooks (#128732) `binary.launch_enter_hook` is treated as an instance method and will add a `self` argument to the hooks. `CompiledKernel.launch_enter_hook` is a static method, which matches the hook calling convention of profilers (i.e., a single `LazyDict` argument only). Pull Request resolved: https://github.com/pytorch/pytorch/pull/128732 Approved by: https://github.com/shunting314, https://github.com/bertmaher --- test/inductor/test_profiler.py | 4 ++-- torch/_inductor/runtime/triton_heuristics.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index d2ff71dd73bb..9d0270a9aae8 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -158,10 +158,10 @@ def test_inductor_profiling_triton_hooks(self): hooks_called = {"enter": False, "exit": False} - def launch_enter_hook(*args): + def launch_enter_hook(lazy_dict): hooks_called["enter"] = True - def launch_exit_hook(*args): + def launch_exit_hook(lazy_dict): hooks_called["exit"] = True CompiledKernel.launch_enter_hook = launch_enter_hook diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5396ccf3e70d..82a25392b5e9 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -50,6 +50,7 @@ if triton is not None: from triton import Config + from triton.compiler import CompiledKernel from triton.runtime.autotuner import OutOfResources from triton.runtime.jit import KernelInterface @@ -453,8 +454,8 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): scope = { "grid_meta": cfg.kwargs, "bin": binary, - "launch_enter_hook": binary.launch_enter_hook, - "launch_exit_hook": binary.launch_exit_hook, + "launch_enter_hook": CompiledKernel.launch_enter_hook, + "launch_exit_hook": CompiledKernel.launch_exit_hook, "metadata": binary.packed_metadata if hasattr(binary, "packed_metadata") else binary.metadata, From ac5f565fa7010bd77b9e779415e8709d347234b6 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 18 Jun 2024 11:41:03 -0700 Subject: [PATCH 28/64] [FSDP2] Added `set_post_optim_event` (#128975) This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on. ``` def set_post_optim_event(self, event: torch.cuda.Event) -> None: ``` By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models. To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on. We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`). --- One possible way to use the API is to register a post-step hook on the optimizer. For example: https://github.com/pytorch/pytorch/blob/12e8d1399b979b45d16f0934017f742d01ab2b8d/test/distributed/_composable/fsdp/test_fully_shard_training.py#L546-L552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128975 Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy ghstack dependencies: #128884 --- .../fsdp/test_fully_shard_overlap.py | 82 ++++++++++++++++--- .../fsdp/test_fully_shard_training.py | 41 ++++++++++ .../_composable/fsdp/_fsdp_state.py | 14 +++- .../_composable/fsdp/fully_shard.py | 19 +++++ 4 files changed, 142 insertions(+), 14 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index 99823883abfb..1fca6c3f3c5a 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import functools from typing import Callable import torch @@ -7,6 +8,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor.experimental import implicit_replication from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( FSDPTest, @@ -23,15 +25,6 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fully_shard_training_overlap(self): - class LinearWithSleep(nn.Module): - def __init__(self, dim: int, sleep_ms: int): - super().__init__() - self.weight = nn.Parameter(torch.randn((dim, dim))) - self.sleep_ms = sleep_ms - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) - torch.manual_seed(42) # Use non-trivial comm. time but still shorter than compute time @@ -44,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: fully_shard(model, reshard_after_forward=True) orig_all_gather_into_tensor = dist.all_gather_into_tensor - orig_reduce_scatter = dist.reduce_scatter_tensor + orig_reduce_scatter_tensor = dist.reduce_scatter_tensor comm_stream = torch.cuda.Stream() def delay_collective(): @@ -61,7 +54,7 @@ def delayed_all_gather(*args, **kwargs): def delayed_reduce_scatter(*args, **kwargs): delay_collective() - return orig_reduce_scatter(*args, **kwargs) + return orig_reduce_scatter_tensor(*args, **kwargs) inp = torch.randn((2, dim), device="cuda") loss = model(inp).sum() # warmup CUDA and allocator @@ -92,6 +85,63 @@ def fwd_bwd(): ) self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time) + @skip_if_lt_x_gpu(2) + def test_fully_shard_post_optim_event_overlap(self): + torch.manual_seed(42) + + # Use non-trivial comm. time but still shorter than compute time + dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10) + # Define the model to have a high-compute linear followed by a + # low-compute linear, where only the low-compute linear uses FSDP + model = nn.Sequential( + LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim) + ).cuda() + fully_shard(model[1], reshard_after_forward=False) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + orig_all_gather_into_tensor = dist.all_gather_into_tensor + + def delayed_all_gather(*args, **kwargs): + torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms())) + return orig_all_gather_into_tensor(*args, **kwargs) + + inp = torch.randn((2, dim), device="cuda") + + def run_train_steps(num_iters: int, use_post_optim_event: bool): + for _ in range(num_iters): + optim.zero_grad() + with patch_all_gather(delayed_all_gather): + loss = model(inp).sum() + loss.backward() + with implicit_replication(): + optim.step() + if use_post_optim_event: + post_optim_event = torch.cuda.current_stream().record_event() + model[1].set_post_optim_event(post_optim_event) + + run_train_steps(1, False) # warmup CUDA and allocator + num_iters = 5 + baseline_time = self._time_fn( + functools.partial(run_train_steps, num_iters, False) + ) + test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True)) + + buffer_ms = 4 # CPU delays and copies + # Baseline: FSDP all-gather is exposed since the FSDP module waits for + # the current stream and hence the high-compute linear + self.assertLessEqual( + baseline_time, + num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms), + ) + # Test: FSDP all-gather is overlapped with the high-compute linear + # since the FSDP module only waits for the post-optim event (except on + # the 1st iteration when no event has been recorded) + expected_test_time = ( + num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms + ) + self.assertLessEqual(test_time, expected_test_time) + self.assertGreater(baseline_time, expected_test_time) + def _time_fn(self, fn: Callable): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -123,5 +173,15 @@ def backward(ctx, grad_output: torch.Tensor): return grad_input, grad_weight, None +class LinearWithSleep(nn.Module): + def __init__(self, dim: int, sleep_ms: int): + super().__init__() + self.weight = nn.Parameter(torch.randn((dim, dim))) + self.sleep_ms = sleep_ms + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 3dbaa6524379..abc579b40d62 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -532,6 +532,47 @@ def test_explicit_prefetching(self): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_post_optim_event(self): + torch.manual_seed(42) + model_args = ModelArgs(dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + def step_post_hook( + fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs + ) -> None: + post_optim_event = torch.cuda.current_stream().record_event() + fsdp_module.set_post_optim_event(post_optim_event) + + optim.register_step_post_hook(functools.partial(step_post_hook, model)) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + # Track all losses and check for equality at the end to avoid a CPU + # sync point after each iteration + ref_losses: List[torch.Tensor] = [] + losses: List[torch.Tensor] = [] + for iter_idx in range(10): + ref_optim.zero_grad() + ref_losses.append(ref_model(inp).sum()) + ref_losses[-1].backward() + ref_optim.step() + for iter_idx in range(10): + optim.zero_grad() + losses.append(model(inp).sum()) + losses[-1].backward() + optim.step() + # Sleep after the optimizer step to allow CPU to run ahead into the + # next iteration's forward, exercising the post-optim stream sync + torch.cuda._sleep(int(25 * get_cycles_per_ms())) + for ref_loss, loss in zip(ref_losses, losses): + self.assertEqual(ref_loss, loss) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index c6cdb2b29880..f04e6f6d0929 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -36,6 +36,9 @@ def __init__(self): self.post_backward_final_callback_queued: bool = False # Whether to finalize backward in this backward's final callback self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[torch.cuda.Event] = None def disable_if_config_true(func): @@ -84,9 +87,14 @@ def _root_pre_forward( self._state_ctx.iter_forward_root = self with torch.profiler.record_function("FSDP::root_pre_forward"): # Wait for optimizer before implicitly prefetched all-gathers - current_stream = torch.cuda.current_stream() - self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) - self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = torch.cuda.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) if self._device.type == "cuda": with torch.profiler.record_function("FSDP::inputs_to_device"): args_tuple, kwargs_tuple = _to_kwargs( diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index e8ab3466118b..88180f40f792 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -309,6 +309,25 @@ def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: module._get_fsdp_state() for module in modules ] + def set_post_optim_event(self, event: torch.cuda.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (torch.cuda.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") From cb5e9183c6056a7f929a12f574372e87e879d29e Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 19 Jun 2024 00:05:50 +0000 Subject: [PATCH 29/64] [Caffe2] [2/N] Remove Caffe2 from tests (#128911) Follows #128675 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128911 Approved by: https://github.com/titaiwangms, https://github.com/r-barnes --- test/jit/test_tracer.py | 45 ----------- test/onnx/pytorch_test_common.py | 4 +- test/onnx/test_operators.py | 27 ------- test/quantization/core/test_quantized_op.py | 47 ------------ test/test_determination.py | 7 -- test/test_public_bindings.py | 1 - test/test_tensorboard.py | 83 +-------------------- test/test_torch.py | 17 +---- 8 files changed, 4 insertions(+), 227 deletions(-) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 5da8ab61c5b3..d5ef39ba0c8b 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -911,51 +911,6 @@ def forward(self, x): self.assertEqual(len(list(g.inputs())), 2) FileCheck().check("mul").check("add").run(str(g)) - def test_trace_c10_ops(self): - try: - _ = torch.ops._caffe2.GenerateProposals - except AttributeError: - self.skipTest("Skip the test since c2 ops are not registered.") - - class MyModel(torch.nn.Module): - def forward(self, scores, bbox_deltas, im_info, anchors): - a, b = torch.ops._caffe2.GenerateProposals( - (scores), - (bbox_deltas), - (im_info), - (anchors), - 2.0, - 6000, - 300, - 0.7, - 16, - True, - -90, - 90, - 1.0, - True, - ) - return a, b - - model = MyModel() - A = 4 - H = 10 - W = 8 - img_count = 3 - scores = torch.ones(img_count, A, H, W, dtype=torch.float32) - bbox_deltas = torch.linspace( - 0, 10, steps=img_count * 4 * A * H * W, dtype=torch.float32 - ) - bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) - im_info = torch.ones(img_count, 3, dtype=torch.float32) - anchors = torch.ones(A, 4, dtype=torch.float32) - inputs = (scores, bbox_deltas, im_info, anchors) - traced_model = torch.jit.trace(model, inputs) - self.assertEqual(traced_model(*inputs), model(*inputs)) - self.assertExportImportModule( - traced_model, (scores, bbox_deltas, im_info, anchors) - ) - def run_ge_tests(self, optimize, use_cuda): with enable_profiling_mode_for_profiling_tests(): with torch.jit.optimized_execution(optimize): diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 6fdbf4e92839..3b66750f45d8 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -340,8 +340,8 @@ def inner(self, *args, **kwargs): # skips tests for opset_versions listed in unsupported_opset_versions. -# if the caffe2 test cannot be run for a specific version, add this wrapper -# (for example, an op was modified but the change is not supported in caffe2) +# if the PyTorch test cannot be run for a specific version, add this wrapper +# (for example, an op was modified but the change is not supported in PyTorch) def skipIfUnsupportedOpsetVersion(unsupported_opset_versions): def skip_dec(func): @functools.wraps(func) diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 87ec424cf65d..b3c75486450a 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -873,33 +873,6 @@ def test_cumsum(self): x = torch.randn(2, 3, 4, requires_grad=True) self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11) - # Github Issue: https://github.com/pytorch/pytorch/issues/71095 - # def test_c2_op(self): - # class MyModel(torch.nn.Module): - # def __init__(self): - # super().__init__() - # - # def forward(self, scores, bbox_deltas, im_info, anchors): - # a, b = torch.ops._caffe2.GenerateProposals( - # (scores), (bbox_deltas), (im_info), (anchors), - # 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True, - # ) - # return a, b - # - # model = MyModel() - # A = 4 - # H = 10 - # W = 8 - # img_count = 3 - # scores = torch.ones(img_count, A, H, W, dtype=torch.float32) - # bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, - # dtype=torch.float32) - # bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) - # im_info = torch.ones(img_count, 3, dtype=torch.float32) - # anchors = torch.ones(A, 4, dtype=torch.float32) - # inputs = (scores, bbox_deltas, im_info, anchors) - # self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0}) - def test_dict(self): class MyModel(torch.nn.Module): def forward(self, x_in): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 2e606938192d..25b062a7ab13 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4457,54 +4457,7 @@ def _test_embedding_bag_unpack_impl(self, pack_fn, unpack_fn, bit_rate, optimize self.assertEqual(unpacked_weight.q_per_channel_scales(), qweight.q_per_channel_scales()) self.assertEqual(unpacked_weight.q_per_channel_zero_points(), qweight.q_per_channel_zero_points()) - # compare against C2 to ensure numerical equivalency. - from caffe2.python import core, workspace - conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == torch.float32 else "HalfFloatToFused8BitRowwiseQuantized" - reverse_conversion_op = None - if bit_rate == 4: - conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused4BitRowwiseQuantized" - reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat" - elif bit_rate == 2: - conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused2BitRowwiseQuantized" - reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat" - - def get_c2_weights(weights, engine_str): - workspace.ResetWorkspace() - - workspace.FeedBlob("weights", weights) - workspace.RunOperatorOnce( - core.CreateOperator( - conversion_op, ["weights"], ["quantized_weights"], engine=engine_str - ) - ) - emb_q = workspace.FetchBlob("quantized_weights") - if bit_rate == 4 or bit_rate == 2: - workspace.RunOperatorOnce( - core.CreateOperator( - reverse_conversion_op, ["quantized_weights"], ["dequantized_weights"] - ) - ) - dequantized_data = torch.from_numpy(workspace.FetchBlob("dequantized_weights")) - else: - dequantized_data = torch.ops._caffe2.Fused8BitRowwiseQuantizedToFloat( - torch.tensor(emb_q) - ) - return torch.from_numpy(emb_q), dequantized_data - - if optimized_qparams: - engine = "GREEDY" - else: - engine = "" - - # C2 quantization needs the memory format of Tensor to be `continuous`, otherwise it will - # throw exceptions. torch.clone() will make the memory format to be `continuous` - c2_copy = torch.clone(weights) - w_packed_c2, w_unpacked_c2 = get_c2_weights(c2_copy, engine) - # Compare packed weights against C2. - np.testing.assert_allclose(w_packed.numpy(), w_packed_c2.numpy(), atol=1e-6, rtol=1e-6) - # Compare unpacked weights against C2 - np.testing.assert_allclose(w_unpacked.numpy(), w_unpacked_c2.numpy(), atol=1e-6, rtol=1e-6) def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, diff --git a/test/test_determination.py b/test/test_determination.py index 50cc2fa9975d..09a67de45dc6 100644 --- a/test/test_determination.py +++ b/test/test_determination.py @@ -121,13 +121,6 @@ def test_torch_file(self): ], ) - def test_caffe2_file(self): - """Caffe2 files trigger dependent tests""" - self.assertEqual(self.determined_tests(["caffe2/python/brew_test.py"]), []) - self.assertEqual( - self.determined_tests(["caffe2/python/context.py"]), self.TESTS - ) - def test_new_folder(self): """New top-level Python folder triggers all tests""" self.assertEqual(self.determined_tests(["new_module/file.py"]), self.TESTS) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 8ab2ac1f511f..65a5bf90b9f9 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -342,7 +342,6 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.rpc.rpc_test", "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture", "torch.testing._internal.distributed.rpc_utils", - "torch.utils.tensorboard._caffe2_graph", "torch._inductor.codegen.cuda.cuda_template", "torch._inductor.codegen.cuda.gemm_template", "torch._inductor.runtime.triton_helpers", diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index 3ce2ab2a172c..1e79a2bf910c 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -23,15 +23,6 @@ HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -TEST_CAFFE2 = True -try: - import caffe2.python.caffe2_pybind11_state as _caffe2_pybind11_state # noqa: F401 - from caffe2.python import brew, cnn, core, workspace - from caffe2.python.model_helper import ModelHelper -except ImportError: - TEST_CAFFE2 = False -skipIfNoCaffe2 = unittest.skipIf(not TEST_CAFFE2, "no caffe2") - TEST_MATPLOTLIB = True try: import matplotlib @@ -48,7 +39,6 @@ parametrize, TestCase, run_tests, - TEST_WITH_ASAN, TEST_WITH_CROSSREF, IS_WINDOWS, IS_MACOS, @@ -94,8 +84,6 @@ def tearDown(self): from torch.utils.tensorboard._pytorch_graph import graph from google.protobuf import text_format from PIL import Image -if TEST_TENSORBOARD and TEST_CAFFE2: - from torch.utils.tensorboard import _caffe2_graph as c2_graph class TestTensorBoardPyTorchNumpy(BaseTestCase): def test_pytorch_np(self): @@ -754,80 +742,11 @@ def test_scalar(self): res = make_np(np.int64(100000000000)) self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - @skipIfNoCaffe2 - def test_caffe2_np(self): - workspace.FeedBlob("testBlob", tensor_N(shape=(1, 3, 64, 64))) - self.assertIsInstance(make_np('testBlob'), np.ndarray) - - @skipIfNoCaffe2 - def test_caffe2_np_expect_fail(self): - with self.assertRaises(RuntimeError): - res = make_np('This_blob_does_not_exist') - def test_pytorch_np_expect_fail(self): with self.assertRaises(NotImplementedError): res = make_np({'pytorch': 1.0}) - @skipIfNoCaffe2 - @unittest.skipIf(TEST_WITH_ASAN, "Caffe2 failure with ASAN") - def test_caffe2_simple_model(self): - model = ModelHelper(name="mnist") - # how come those inputs don't break the forward pass =.=a - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(int)) - - with core.NameScope("conv1"): - conv1 = brew.conv(model, "data", 'conv1', dim_in=1, dim_out=20, kernel=5) - # Image size: 24 x 24 -> 12 x 12 - pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2) - # Image size: 12 x 12 -> 8 x 8 - conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=100, kernel=5) - # Image size: 8 x 8 -> 4 x 4 - pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2) - with core.NameScope("classifier"): - # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size - fc3 = brew.fc(model, pool2, 'fc3', dim_in=100 * 4 * 4, dim_out=500) - relu = brew.relu(model, fc3, fc3) - pred = brew.fc(model, relu, 'pred', 500, 10) - softmax = brew.softmax(model, pred, 'softmax') - xent = model.LabelCrossEntropy([softmax, "label"], 'xent') - # compute the expected loss - loss = model.AveragedLoss(xent, "loss") - model.net.RunAllOnMKL() - model.param_init_net.RunAllOnMKL() - model.AddGradientOperators([loss], skip=1) - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) - - @skipIfNoCaffe2 - def test_caffe2_simple_cnnmodel(self): - model = cnn.CNNModelHelper("NCHW", name="overfeat") - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(int)) - with core.NameScope("conv1"): - conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4) - relu1 = model.Relu(conv1, conv1) - pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) - with core.NameScope("classifier"): - fc = model.FC(pool1, "fc", 4096, 1000) - pred = model.Softmax(fc, "pred") - xent = model.LabelCrossEntropy([pred, "label"], "xent") - loss = model.AveragedLoss(xent, "loss") - - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) + class TestTensorProtoSummary(BaseTestCase): @parametrize( diff --git a/test/test_torch.py b/test/test_torch.py index f252ddf4a574..86844c77faf4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -41,7 +41,7 @@ skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, - skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, + bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( @@ -8632,21 +8632,6 @@ def test_allow_tensor_metadata_change(self): a = torch.ones(2, 3) # Metadata changes are allowed on view tensors that are created from detach(). - @skipIfNotRegistered("LayerNorm", "Skipping as LayerNorm is not registered") - def test_c10_layer_norm(self): - # test that we can call c10 ops and they return a reasonable result - X = torch.rand(5, 5, dtype=torch.float) - weight = torch.rand(*X.size()[1:], dtype=torch.float) - bias = torch.rand(*X.size()[1:], dtype=torch.float) - epsilon = 1e-4 - - expected_norm = torch.nn.functional.layer_norm( - X, X.size()[1:], weight=weight, bias=bias, eps=epsilon) - actual_norm, actual_mean, actual_stdev = \ - torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor( - weight), torch.tensor(bias), 1, epsilon, True) - torch.testing.assert_close(expected_norm, actual_norm) - def test_memory_format(self): def test_helper(x, memory_format): y = x.contiguous(memory_format=memory_format) From c5e0b844847c5c34ee824b0de2adeda85ce64133 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Jun 2024 13:14:24 -0700 Subject: [PATCH 30/64] [dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428) Co-authored-by: Laith Sakka Pull Request resolved: https://github.com/pytorch/pytorch/pull/128428 Approved by: https://github.com/yanboliang, https://github.com/mlazos --- test/dynamo/test_repros.py | 2 +- torch/_dynamo/trace_rules.py | 28 ---------------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index dbcb259241fc..2329ab305e76 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1674,7 +1674,7 @@ def test_issue175(self): self.assertEqual(cnt.frame_count, 1) self.assertEqual( - 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count ) def test_exec_import(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index b5b12435a931..abbef02e63c6 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2669,26 +2669,6 @@ "torch.nn._reduction.legacy_get_enum", "torch.nn._reduction.legacy_get_string", "torch.nn.factory_kwargs", - "torch.nn.functional._adaptive_max_pool1d", - "torch.nn.functional._adaptive_max_pool2d", - "torch.nn.functional._adaptive_max_pool3d", - "torch.nn.functional._canonical_mask", - "torch.nn.functional._fractional_max_pool2d", - "torch.nn.functional._fractional_max_pool3d", - "torch.nn.functional._get_softmax_dim", - "torch.nn.functional._in_projection_packed", - "torch.nn.functional._in_projection", - "torch.nn.functional._is_integer", - "torch.nn.functional._max_pool1d", - "torch.nn.functional._max_pool2d", - "torch.nn.functional._max_pool3d", - "torch.nn.functional._mha_shape_check", - "torch.nn.functional._no_grad_embedding_renorm_", - "torch.nn.functional._none_or_dtype", - "torch.nn.functional._threshold", - "torch.nn.functional._unpool_output_size", - "torch.nn.functional._verify_batch_size", - "torch.nn.functional._verify_spatial_size", "torch.nn.functional.adaptive_avg_pool2d", "torch.nn.functional.adaptive_avg_pool3d", "torch.nn.functional.adaptive_max_pool1d_with_indices", @@ -2786,15 +2766,7 @@ "torch.nn.grad.conv2d_weight", "torch.nn.grad.conv3d_input", "torch.nn.grad.conv3d_weight", - "torch.nn.modules.activation._arg_requires_grad", - "torch.nn.modules.activation._check_arg_device", "torch.nn.modules.activation._is_make_fx_tracing", - "torch.nn.modules.container._addindent", - "torch.nn.modules.transformer._detect_is_causal_mask", - "torch.nn.modules.transformer._generate_square_subsequent_mask", - "torch.nn.modules.transformer._get_activation_fn", - "torch.nn.modules.transformer._get_clones", - "torch.nn.modules.transformer._get_seq_len", "torch.nn.modules.utils._list_with_default", "torch.nn.modules.utils._ntuple", "torch.nn.modules.utils._quadruple", From 670b94c9c826756495b9e1ca34be1d43756d5296 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Jun 2024 13:14:25 -0700 Subject: [PATCH 31/64] [inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128484 Approved by: https://github.com/mlazos ghstack dependencies: #128428 --- test/inductor/test_mkldnn_pattern_matcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 810c22d037c5..a80d72398760 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -37,7 +37,8 @@ torch.nn.Tanh(): 2, torch.nn.Hardswish(): 6, torch.nn.LeakyReLU(0.1, inplace=False): 4, - torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3, + # Use floats for min/max, otherwise they can get converted to symints + torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3, torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, torch.nn.GELU(approximate="none"): 6, torch.nn.GELU(approximate="tanh"): 10, From 99f042d336b53844b509406f1ecf78cb6f5e5714 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:21:21 +0000 Subject: [PATCH 32/64] Revert "Forward fix to skip ROCm tests for #122836 (#128891)" This reverts commit 4061b3b8225f522ae0ed6db00111441e7d3cc3d5. Reverted https://github.com/pytorch/pytorch/pull/128891 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128891#issuecomment-2177291249)) --- test/test_nestedtensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index fa33a13ed495..6b9b8f3be45d 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5470,7 +5470,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_preserves_metadata_cache(self, device, dtype): # shape (B, *, D) nt = random_nt_from_dims( @@ -5501,7 +5500,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 @@ -5538,7 +5536,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_dynamic_min_seq_len(self, device, dtype): # shape (B, *, D) # min seq len: 7 @@ -5575,7 +5572,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 From 35c78668b408046e032a1e025b01250875959cc6 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 18 Jun 2024 13:37:50 -0700 Subject: [PATCH 33/64] Improve the debugging message for when foreach mta_called (#128991) The hope that lives in this PR: I am currently trying to debug why the foreach tests are so flaky. It looks like every flaky test falls under this pattern: - a test is flaky due to the mta_called assertion, which gathers data from the profiler regarding whether the multi_tensor_apply_kernel has been called. - then, a later test fails deterministically, usually failing to compare two results. ``` ================== 1 failed, 241 deselected, 2 rerun in 1.76s ================== Got exit code 1 Stopping at first consistent failure The following tests failed and then succeeded when run in a new process ['test/test_foreach.py::TestForeachCUDA::test_binary_op_float_inf_nan__foreach_add_cuda_bfloat16'] The following tests failed consistently: ['test/test_foreach.py::TestForeachCUDA::test_binary_op_list_error_cases__foreach_add_cuda_bfloat16'] ``` So my suspicion is that the first causes the second, but what causes the first? Idk! So it would be nice to have the error message tell us what the profiler actually saw in case it's getting muddled. This change would help mostly because I have not been able to repro this flakiness locally. Also undo the useless changes in #128220 which are actually redundant as Joel and I realized that we set the seed during the setUp of every test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128991 Approved by: https://github.com/clee2000 --- test/test_foreach.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 567d09cff02d..99d4cbe5ec00 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -90,7 +90,7 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): mta_called = any("multi_tensor_apply_kernel" in k for k in keys) assert mta_called == ( expect_fastpath and (not zero_size) - ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}" + ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" else: actual = self.func(*inputs, **kwargs) if self.is_inplace: @@ -205,7 +205,6 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_parity(self, device, dtype, op, noncontiguous, inplace): - torch.manual_seed(2024) if inplace: _, _, func, ref = self._get_funcs(op) else: @@ -585,7 +584,6 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op): "failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125035", ) def test_binary_op_list_error_cases(self, device, dtype, op): - torch.manual_seed(202406) foreach_op, foreach_op_, ref, ref_ = ( op.method_variant, op.inplace_variant, @@ -680,7 +678,6 @@ def test_binary_op_list_error_cases(self, device, dtype, op): "failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125775", ) def test_binary_op_list_slow_path(self, device, dtype, op): - torch.manual_seed(20240607) foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op) # 0-strides tensor1 = make_tensor((10, 10), dtype=dtype, device=device) @@ -799,7 +796,6 @@ def test_binary_op_list_slow_path(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_binary_op_float_inf_nan(self, device, dtype, op): - torch.manual_seed(2024) inputs = ( [ torch.tensor([float("inf")], device=device, dtype=dtype), @@ -869,9 +865,6 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_binary_op_tensors_on_different_devices(self, device, dtype, op): - torch.manual_seed(202406) - # `tensors1`: ['cuda', 'cpu'] - # `tensors2`: ['cuda', 'cpu'] _cuda_tensors = next( iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True)) ).input From 5ffb032be682a34b959c82ce289b457ea6c6e504 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:26:38 +0000 Subject: [PATCH 34/64] Revert "Backward support for unbind() with NJT (#128032)" This reverts commit 5dc4f652bc5c068ef15130c955e3f2ffe11f4b74. Reverted https://github.com/pytorch/pytorch/pull/128032 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128032#issuecomment-2177296325)) --- test/test_nestedtensor.py | 19 ------------------- tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 17 ----------------- torch/csrc/autograd/FunctionsManual.h | 4 ---- torch/nested/_internal/ops.py | 11 ----------- 5 files changed, 1 insertion(+), 52 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 6b9b8f3be45d..86f58b5a0de3 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5606,25 +5606,6 @@ def f(nt): for dynamic in [False, True, None]: self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - @dtypes(torch.float32, torch.double, torch.half) - def test_unbind_backward(self, device, dtype): - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 4, device=device), - torch.randn(5, 4, device=device), - torch.randn(3, 4, device=device), - ], - layout=torch.jagged, - requires_grad=True, - ) - - a, b, c = nt.unbind() - b.sum().backward() - - expected_grad = torch.zeros_like(nt) - expected_grad.unbind()[1].add_(1.0) - torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) - instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 02a3e6c518ad..76a7a0a1e42a 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2847,7 +2847,7 @@ self: unbind_backward(grads, dim) result: auto_linear AutogradNestedTensor: - self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" + self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options()) result: auto_linear - name: stack(Tensor[] tensors, int dim=0) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index f51c2f047f93..9d897c667c90 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1014,23 +1014,6 @@ Tensor unbind_backward_nested( return at::_nested_tensor_from_tensor_list(grads_tensors); } -Tensor unbind_backward_nested_jagged( - const variable_list& grads, - const Tensor& self, - int64_t dim) { - TORCH_INTERNAL_ASSERT( - dim == 0, "unbind_backward_nested_jagged() only supports dim=0") - auto grad_nt = at::zeros_like(self); - auto unbound_grads = grad_nt.unbind(); - for (int64_t i : c10::irange(static_cast(grads.size()))) { - if (grads[i].defined()) { - unbound_grads[i].copy_(static_cast(grads[i])); - } - } - - return grad_nt; -} - Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { auto result = self; diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index ecf99bd09805..dedff70be1ba 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -244,10 +244,6 @@ at::Tensor unbind_backward_nested( const Tensor& nt_sizes, int64_t dim, const at::TensorOptions& options); -at::Tensor unbind_backward_nested_jagged( - const variable_list& grads, - const Tensor& self, - int64_t dim); at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes); at::Tensor unsqueeze_to( const at::Tensor& self, diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 8458f0371713..6f1c47dd6947 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -472,17 +472,6 @@ def to_copy_default(func, *args, **kwargs): )(jagged_unary_pointwise) -@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") -def zero__default(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - func(inp._values) - return inp - - @register_jagged_func( torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" ) From b0d2fe6299c4462d28b23ef73d872eb608d73d96 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:28:53 +0000 Subject: [PATCH 35/64] Revert "Short-term fix to preserve NJT metadata cache in torch.compile (#122836)" This reverts commit 2a41fc03903de63270d325bd1886a50faf32d7e4. Reverted https://github.com/pytorch/pytorch/pull/122836 on behalf of https://github.com/jbschlosser due to internal test failures with DEBUG=1 asserts ([comment](https://github.com/pytorch/pytorch/pull/122836#issuecomment-2177298245)) --- aten/src/ATen/FunctionalInverses.cpp | 9 +- aten/src/ATen/native/native_functions.yaml | 14 +- test/dynamo/test_subclasses.py | 6 +- ...asDecompTest.test_has_decomposition.expect | 2 - test/test_nestedtensor.py | 173 +---------------- tools/autograd/derivatives.yaml | 4 +- torch/nested/_internal/nested_tensor.py | 174 ++++-------------- torch/nested/_internal/ops.py | 37 +--- torch/nested/_internal/sdpa.py | 62 ++----- 9 files changed, 69 insertions(+), 412 deletions(-) diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index a1cf449cde7c..16b59333f918 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -303,7 +303,7 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, return Tensor(); } -Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx, const c10::optional& min_seqlen, const c10::optional& max_seqlen) { +Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx) { auto values = at::_nested_get_values(mutated_view); if (inverse_return_mode != InverseReturnMode::NeverView) { return values; @@ -317,12 +317,7 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const auto lengths = at::_nested_get_lengths(base); auto ragged_idx = at::_nested_get_ragged_idx(base); auto dummy = at::_nested_get_jagged_dummy(base); - auto min_seqlen = at::_nested_get_min_seqlen(base); - auto max_seqlen = at::_nested_get_max_seqlen(base); - auto nt = at::_nested_view_from_jagged( - mutated_view, offsets, dummy, lengths, ragged_idx, - (min_seqlen.defined() ? c10::optional(min_seqlen) : c10::nullopt), - (max_seqlen.defined() ? c10::optional(max_seqlen) : c10::nullopt)); + auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx); if (inverse_return_mode != InverseReturnMode::NeverView) { return nt; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b030141882c8..a2d9095d56a3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6185,12 +6185,12 @@ CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy autogen: _nested_view_from_buffer_copy.out -- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) +- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) variants: function device_check: NoCheck dispatch: {} -- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor +- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor variants: function device_check: NoCheck tags: view_copy @@ -6227,16 +6227,6 @@ device_check: NoCheck dispatch: {} -- func: _nested_get_min_seqlen(Tensor self) -> Tensor - variants: function - device_check: NoCheck - dispatch: {} - -- func: _nested_get_max_seqlen(Tensor self) -> Tensor - variants: function - device_check: NoCheck - dispatch: {} - - func: _nested_get_jagged_dummy(Tensor any) -> Tensor category_override: dummy dispatch: {} diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index f16ef15990fd..302b07e4ddb7 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1616,15 +1616,15 @@ def backend(gm, args): guard_str, """\ Eq(s3 - 1, s0) -Eq(zf1, zf6)""", +Eq(zf1, zf4)""", ) else: self.assertExpectedInline( guard_str, """\ Eq(s4 - 1, s1) -Eq(s12 - 1, s7) -Eq(s11, s9)""", +Eq(s10 - 1, s5) +Eq(s9, s7)""", ) return gm diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 132d25a8b12f..1179142e15d9 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -446,8 +446,6 @@ aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out aten::_nested_get_jagged_dummy aten::_nested_get_lengths -aten::_nested_get_max_seqlen -aten::_nested_get_min_seqlen aten::_nested_get_offsets aten::_nested_get_ragged_idx aten::_nested_get_values diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 86f58b5a0de3..78d082702aec 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -67,21 +67,6 @@ def _iter_constructors(): yield torch.nested.nested_tensor -# Returns True if the function recompiles between inputs1 and inputs2 with the -# specified dynamic setting. -def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): - compile_count = [0] - - def counter(gm, example_inputs): - compile_count[0] += 1 - return gm - - compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) - compiled_f(*inputs1) - compiled_f(*inputs2) - return compile_count[0] > 1 - - # Helper function to generate a pair of random nested tensors # one is contiguous, the other is not, but they appear to have same entries # an output nested tensor consists of @@ -4833,18 +4818,19 @@ def fn(values, same_size): check_results(fn, compiled_fn, generate_inp(20)) self.assertEqual(compile_counter.frame_count, frame_count_2) + # Doesn't work until we have real views + @xfailIfTorchDynamo # Note 1: Math fallback doesn't work with bfloat16 on CUDA # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT @unittest.skipIf( TEST_WITH_ROCM, "ROCm doesn't support flash attention or mem_efficient attention for NT", ) - @dtypes( - *( - [torch.float16, torch.bfloat16, torch.float32] - if SM80OrLater - else [torch.float16, torch.float32] - ) + @parametrize( + "dtype", + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32], ) def test_sdpa(self, device, dtype): batch_size = 1 @@ -5187,6 +5173,8 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): ) self.assertEqual(output._values, output_dense) + # Doesn't work until we have real views + @xfailIfTorchDynamo @onlyCUDA @unittest.skipIf( not PLATFORM_SUPPORTS_FUSED_ATTENTION, @@ -5463,149 +5451,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): padded, [offsets_wrong], total_L ) - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_preserves_metadata_cache(self, device, dtype): - # shape (B, *, D) - nt = random_nt_from_dims( - [4, None, 3, 16], - device=device, - dtype=dtype, - layout=torch.jagged, - requires_grad=True, - ) - - # expect min / max seqlen to be stored here - cache = dict(nt._metadata_cache) - - @torch.compile - def f(nt): - q = nt.transpose(-3, -2) - output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2) - return output - - output = f(nt) - output.backward(torch.ones_like(output)) - self.assertEqual(output._metadata_cache, cache) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_dynamic_max_seq_len(self, device, dtype): - # shape (B, *, D) - # max seq len: 18 - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(18, 5), - ], - layout=torch.jagged, - ) - - # max seq len: 19 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(19, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt) * nt._get_max_seqlen() - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_dynamic_min_seq_len(self, device, dtype): - # shape (B, *, D) - # min seq len: 7 - nt = torch.nested.nested_tensor( - [ - torch.randn(7, 5), - torch.randn(8, 5), - torch.randn(9, 5), - ], - layout=torch.jagged, - ) - - # min seq len: 8 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(8, 5), - torch.randn(9, 5), - torch.randn(10, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt) * nt._get_min_seqlen() - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): - # shape (B, *, D) - # max seq len: 18 - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(18, 5), - ], - layout=torch.jagged, - ) - - # max seq len: 19 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(19, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - nt2 = nt.sin() + 1 - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt2) * nt2._get_max_seqlen() - - ref = f(nt) - output = torch.compile(f, fullgraph=True, dynamic=False)(nt) - self.assertEqual(ref, output) - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 76a7a0a1e42a..1e9b9091a20e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2794,14 +2794,14 @@ nested_size: non_differentiable nested_strides: non_differentiable -- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) +- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) self: grad.values() offsets: non_differentiable lengths: non_differentiable dummy: non_differentiable - name: _nested_get_values(Tensor(a) self) -> Tensor(a) - self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional(at::_nested_get_min_seqlen(self)) : c10::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional(at::_nested_get_max_seqlen(self)) : c10::nullopt)" + self: _nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self)) # Transformers - name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 92423cf32b2f..66d25eacc7ad 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -27,15 +27,6 @@ def _get_sdpa_extreme_seqlen(func, tensor): return int(func(tensor).item()) -def _store_val_in_tensor(val) -> torch.Tensor: - # hack to get dynamic shapes support: store in a (val, 0) shaped tensor - return torch.zeros(val, 0) - - -def _load_val_from_tensor(t: torch.Tensor): - return t.shape[0] - - class NestedTensor(torch.Tensor): _values: torch.Tensor # type: ignore[assignment] _offsets: torch.Tensor @@ -131,14 +122,6 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) - # min / max sequence length should be dynamic if present - max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None) - if max_seqlen_tensor is not None: - torch._dynamo.mark_dynamic(max_seqlen_tensor, 0) - min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None) - if min_seqlen_tensor is not None: - torch._dynamo.mark_dynamic(min_seqlen_tensor, 0) - def values(self): # dispatch to get proper view relationship return torch._nested_get_values(self) # type: ignore[attr-defined] @@ -149,56 +132,25 @@ def offsets(self): def lengths(self): return self._lengths - # Private accessor functions for min / max sequence length. They're - # purposefully not @properties because those don't work with PT2 (yet). - # These compute / cache if not present. - # TODO: Revisit this when @properties are better supported by PT2. I think the ideal - # state would be to have public @properties for min / max sequence length that compile - # (including setters). - def _get_max_seqlen(self): - max_seqlen_tensor = self._max_seqlen_tensor - if max_seqlen_tensor is None: + @property + def _max_seqlen(self): + if "max_seqlen" not in self._metadata_cache: # compute & cache - max_val = _get_sdpa_extreme_seqlen( + self._metadata_cache["max_seqlen"] = _get_sdpa_extreme_seqlen( torch.max, self._offsets.diff() if self._lengths is None else self._lengths, ) - max_seqlen_tensor = _store_val_in_tensor(max_val) - self._metadata_cache["max_seqlen"] = max_seqlen_tensor - return _load_val_from_tensor(max_seqlen_tensor) + return self._metadata_cache["max_seqlen"] - def _get_min_seqlen(self): - min_seqlen_tensor = self._min_seqlen_tensor - if min_seqlen_tensor is None: + @property + def _min_seqlen(self): + if "min_seqlen" not in self._metadata_cache: # compute & cache - min_val = _get_sdpa_extreme_seqlen( + self._metadata_cache["min_seqlen"] = _get_sdpa_extreme_seqlen( torch.min, self._offsets.diff() if self._lengths is None else self._lengths, ) - min_seqlen_tensor = _store_val_in_tensor(min_val) - self._metadata_cache["min_seqlen"] = min_seqlen_tensor - return _load_val_from_tensor(min_seqlen_tensor) - - # Private accessors used for treating min / max seqlen as inner tensors for - # flatten / unflatten. These must be properties to work with the traceable wrapper - # subclass logic. These do not compute / cache if not present. - @property - def _max_seqlen_tensor(self) -> Optional[torch.Tensor]: - return self._metadata_cache.get("max_seqlen", None) - - @property - def _min_seqlen_tensor(self) -> Optional[torch.Tensor]: - return self._metadata_cache.get("min_seqlen", None) - - # These are old private @property accessors that are kept around for internal BC - # reasons. TODO: Remove these! - @property - def _max_seqlen(self): - return self._get_max_seqlen() - - @property - def _min_seqlen(self): - return self._get_min_seqlen() + return self._metadata_cache["min_seqlen"] def __repr__(self): # We should implement this in torch/_tensor_str.py instead @@ -218,7 +170,6 @@ def __reduce_ex__(self, proto): del state["_size"] del state["_strides"] - # TODO: Update this to handle the other inner tensors func = NestedTensor args = (self._values, self._offsets) return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state)) @@ -226,33 +177,22 @@ def __reduce_ex__(self, proto): def __tensor_flatten__(self): ctx = { "requires_grad": self.requires_grad, + # TODO: Don't guard on this! + "metadata_cache": self._metadata_cache, "ragged_idx": self._ragged_idx, } inner_tensors = ["_values", "_offsets"] if self._lengths is not None: inner_tensors.append("_lengths") - if self._min_seqlen_tensor is not None: - inner_tensors.append("_min_seqlen_tensor") - if self._max_seqlen_tensor is not None: - inner_tensors.append("_max_seqlen_tensor") return inner_tensors, ctx @staticmethod def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): - # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] - assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5 + # inner tensors: _values, _offsets, [_lengths] + assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3 values = inner_tensors["_values"] offsets = inner_tensors["_offsets"] lengths = inner_tensors.get("_lengths", None) - min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None) - max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None) - - metadata_cache = {} - if min_seqlen_tensor is not None: - metadata_cache["min_seqlen"] = min_seqlen_tensor - if max_seqlen_tensor is not None: - metadata_cache["max_seqlen"] = max_seqlen_tensor - ragged_idx = meta["ragged_idx"] # Note that we cannot simply check if is_fake(values) because @@ -271,7 +211,7 @@ def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): lengths=lengths, requires_grad=meta["requires_grad"], _ragged_idx=ragged_idx, - _metadata_cache=metadata_cache, + _metadata_cache=meta["metadata_cache"], ) @classmethod @@ -336,15 +276,6 @@ def forward( offsets: torch.Tensor, metadata_cache: Optional[Dict[str, Any]] = None, ): # type: ignore[override] - # maintain BC with this usages of this where the seqlens are stuffed - # directly into the metadata cache as non-Tensors / ints - if metadata_cache is not None: - min_seqlen = metadata_cache.get("min_seqlen", None) - max_seqlen = metadata_cache.get("max_seqlen", None) - if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor): - metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen) - if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor): - metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen) return NestedTensor( values.detach(), offsets=offsets, @@ -412,12 +343,12 @@ def jagged_from_list( ] ) - # compute this now since it's easy - min_seqlen = min([t.shape[0] for t in tensors]) - max_seqlen = max([t.shape[0] for t in tensors]) - ret_nt = nested_view_from_values_offsets( - values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen - ) + ret_nt = nested_view_from_values_offsets(values, offsets) + ret_nt._metadata_cache = { + # compute this now since it's easy + "max_seqlen": max(t.shape[0] for t in tensors), + "min_seqlen": min(t.shape[0] for t in tensors), + } return (ret_nt, offsets) # type: ignore[return-value] @@ -474,19 +405,16 @@ def jagged_from_tensor_and_lengths( if is_contiguous: ret_nt = nested_view_from_values_offsets( - values[offsets[0] : offsets[-1]], - offsets - offsets[0], - min_seqlen=min_seqlen, - max_seqlen=actual_max_seqlen, + values[offsets[0] : offsets[-1]], offsets - offsets[0] ) else: - ret_nt = nested_view_from_values_offsets_lengths( - values, - offsets, - length_list, - min_seqlen=min_seqlen, - max_seqlen=actual_max_seqlen, - ) + ret_nt = nested_view_from_values_offsets_lengths(values, offsets, length_list) + + # populate metadata cache with computed seqlen extremes + ret_nt._metadata_cache = { + "max_seqlen": actual_max_seqlen, + "min_seqlen": min_seqlen, + } return (ret_nt, offsets, None if is_contiguous else length_list) @@ -508,45 +436,13 @@ def _nt_view_dummy() -> torch.Tensor: return _dummy_instance -def nested_view_from_values_offsets( - values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None -): - min_seqlen_tensor = None - if min_seqlen is not None: - min_seqlen_tensor = _store_val_in_tensor(min_seqlen) - - max_seqlen_tensor = None - if max_seqlen is not None: - max_seqlen_tensor = _store_val_in_tensor(max_seqlen) - +def nested_view_from_values_offsets(values, offsets, ragged_idx=1): return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, - offsets, - _nt_view_dummy(), - None, - ragged_idx, - min_seqlen_tensor, - max_seqlen_tensor, - ) # type: ignore[return-value] - - -def nested_view_from_values_offsets_lengths( - values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None -): - min_seqlen_tensor = None - if min_seqlen is not None: - min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + values, offsets, _nt_view_dummy(), None, ragged_idx + ) - max_seqlen_tensor = None - if max_seqlen is not None: - max_seqlen_tensor = _store_val_in_tensor(max_seqlen) +def nested_view_from_values_offsets_lengths(values, offsets, lengths, ragged_idx=1): return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, - offsets, - _nt_view_dummy(), - lengths, - ragged_idx, - min_seqlen_tensor, - max_seqlen_tensor, - ) # type: ignore[return-value] + values, offsets, _nt_view_dummy(), lengths, ragged_idx + ) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 6f1c47dd6947..6ec3ba538f97 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1088,7 +1088,7 @@ def values_default(func, *args, **kwargs): @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, - "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", + "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?", ) def _nested_view_from_jagged_default(func, *args, **kwargs): _, new_kwargs = normalize_function( @@ -1101,21 +1101,8 @@ def _nested_view_from_jagged_default(func, *args, **kwargs): new_kwargs["lengths"], ) ragged_idx = new_kwargs["ragged_idx"] - min_seqlen = new_kwargs["min_seqlen"] - max_seqlen = new_kwargs["max_seqlen"] - metadata_cache = {} - if min_seqlen is not None: - metadata_cache["min_seqlen"] = min_seqlen - if max_seqlen is not None: - metadata_cache["max_seqlen"] = max_seqlen - return NestedTensor( - values, - offsets, - lengths=lengths, - _ragged_idx=ragged_idx, - _metadata_cache=metadata_cache, - ) + return NestedTensor(values, offsets, lengths=lengths, _ragged_idx=ragged_idx) @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") @@ -1148,26 +1135,6 @@ def _nested_get_ragged_idx(func, *args, **kwargs): return inp._ragged_idx -@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all") -def _nested_get_min_seqlen(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - return inp._metadata_cache.get("min_seqlen", None) - - -@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all") -def _nested_get_max_seqlen(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - return inp._metadata_cache.get("max_seqlen", None) - - # Make the dummy available on the C++ side. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") def _nested_get_jagged_dummy(func, *args, **kwargs): diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 8f2eba4db3e4..b7c69c905e9a 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -15,7 +15,7 @@ ) from torch.nn.attention import SDPBackend -from .nested_tensor import NestedTensor +from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer log = logging.getLogger(__name__) @@ -125,7 +125,7 @@ def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( return False # This is being called inside sdp with shape [batch, heads, {seq_len}, dim] - if param._get_min_seqlen() == 0: + if param._min_seqlen == 0: if debug: log.warning( "Fused kernels do not support seq_len == 0, %s has a seq len of 0.", @@ -315,7 +315,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in if qkv.lengths() is None: # TODO: Explore performance impact of copying cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) - max_seqlen = qkv._get_max_seqlen() + max_seqlen = qkv._max_seqlen n_elem = qkv.values().shape[0] else: # TODO: Explore performance impact of copying @@ -323,7 +323,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) ) batch_size = qkv.size(0) - max_seqlen = qkv._get_max_seqlen() + max_seqlen = qkv._max_seqlen # TODO: Explore performance impact when compiling n_elem = int(cumulative_seqlen[-1].item()) return cumulative_seqlen, max_seqlen, n_elem @@ -364,7 +364,7 @@ def _view_as_dense( tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int ) -> torch.Tensor: if tensor.is_nested: - return tensor.values() + return buffer_from_jagged(tensor) return tensor.view(Nnz, num_heads, head_dim) @@ -567,8 +567,8 @@ def _sdpa_nested_preprocessing(query, key, value): output_nt_info = { "offsets": q_t.offsets(), - "_max_seqlen": q_t._get_max_seqlen(), - "_min_seqlen": q_t._get_min_seqlen(), + "_max_seqlen": q_t._max_seqlen, + "_min_seqlen": q_t._min_seqlen, } return ( @@ -694,14 +694,9 @@ def jagged_scaled_dot_product_attention( False, scale=og_scale, ) - from torch.nested._internal.nested_tensor import nested_view_from_values_offsets - # Reshape output to convert nnz to batch_size and seq_len - attention = nested_view_from_values_offsets( - attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + attention = ViewNestedFromBuffer.apply( + attention.squeeze(0), output_nt_info["offsets"] ).transpose(1, 2) return _post_process_flash_output(attention, og_size) elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: @@ -737,14 +732,9 @@ def jagged_scaled_dot_product_attention( scale=scale, ) - from torch.nested._internal.nested_tensor import nested_view_from_values_offsets - # Reshape output to convert nnz to batch_size and seq_len - return nested_view_from_values_offsets( - attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + return ViewNestedFromBuffer.apply( + attention.squeeze(0), output_nt_info["offsets"] ).transpose(1, 2) elif backend_choice == SDPBackend.MATH: # save the offsets and shape of the inputs, so we can reshape the final output @@ -754,19 +744,12 @@ def jagged_scaled_dot_product_attention( d1 = query._size[1] d2 = value._size[-1] - min_seqlen_tensor = query._metadata_cache.get( - "min_seqlen", None - ) # type: ignore[attr-defined] - max_seqlen_tensor = query._metadata_cache.get( - "max_seqlen", None - ) # type: ignore[attr-defined] - # convert jagged layout Nested Tensor to strided layout Nested Tensor # which support the math implementation of SDPA def get_strided_layout_nested_tensor(jagged_layout_nt): lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] transpose = torch.transpose(jagged_layout_nt, 1, 2) - tensor_list = transpose.values().split(list(lengths), dim=0) + tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0) strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) strided_nt = strided_nt.transpose(1, 2).contiguous() return strided_nt @@ -779,28 +762,11 @@ def get_strided_layout_nested_tensor(jagged_layout_nt): query, key, value, attn_mask, dropout_p, is_causal, scale=scale )[0] - from torch.nested._internal.nested_tensor import ( - _load_val_from_tensor, - nested_view_from_values_offsets, - ) - # convert strided layout Nested Tensor back to jagged layout Nested Tensor attn_out = attn_out.transpose(1, 2).contiguous().values() attn_out = attn_out.view(-1, d1, d2) - attn_out = nested_view_from_values_offsets( - attn_out, - offsets, - min_seqlen=( - None - if min_seqlen_tensor is None - else _load_val_from_tensor(min_seqlen_tensor) - ), - max_seqlen=( - None - if max_seqlen_tensor is None - else _load_val_from_tensor(max_seqlen_tensor) - ), - ).transpose(1, 2) + attn_out = ViewNestedFromBuffer.apply(attn_out, offsets) + attn_out = attn_out.transpose(1, 2) return attn_out else: From 2458f79f83e865a0469f844e87a64edfcecc7065 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Mon, 17 Jun 2024 12:40:38 -0700 Subject: [PATCH 36/64] [Inductor UT][Intel GPU] Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU (#128881) Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU because it have not implemented reduction kernel split. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128881 Approved by: https://github.com/blaine-rister, https://github.com/EikanWang, https://github.com/malfet --- test/inductor/test_torchinductor_strided_blocks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index bd859802892d..bf96ad8d486d 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfXpu, ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, @@ -214,6 +215,7 @@ def get_input(view_size: Tuple[int]) -> torch.Tensor: # Expect 3 block pointers: 2 inputs one output self.run_and_compare(foo, x, y, expected_num_block_pointers=3) + @skipIfXpu @parametrize( "view_size,num_block_pointers,num_triton_kernels", [ From eda375a49078f5fecc90f28ca8ff949e8e5811e9 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 17 Jun 2024 19:54:34 -0700 Subject: [PATCH 37/64] [Inductor] Remove min/max from inductor opinfo test (#128925) **Summary** Remove `max.binary, min.binary, maximum, minimum` from `inductor_one_sample` op list as we fix the bool vectorization issue in https://github.com/pytorch/pytorch/pull/126841. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_maximum python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_minimum python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_min_binary python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_max_binary ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128925 Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_torchinductor_opinfo.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 29be591dc006..c7153b5b6d84 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -425,11 +425,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "logspace": {f16}, "logspace.tensor_overload": {f16, f32, f64, i32, i64}, "masked_logsumexp": {i64}, - "max.binary": {b8}, "max_pool2d_with_indices_backward": {f16, f32, f64}, - "maximum": {b8}, - "min.binary": {b8}, - "minimum": {b8}, "new_empty_strided": {f16}, "nn.functional.adaptive_avg_pool3d": {f16}, "nn.functional.adaptive_max_pool1d": {f16, f32}, From 4bc90185fb77438717d59b2d9bb63096ae682935 Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Wed, 19 Jun 2024 01:17:05 +0000 Subject: [PATCH 38/64] fix: Print statements causing parse error (#128969) The print statements for the get_workflow_type script is problematic because the shell script calling this script is expecting the output to only be JSON. This PR resolves this by removing all print statements to covert them to a message field in the JSON return output so that the output can continue to expect to be JSON while giving us the debug data we are looking for. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128969 Approved by: https://github.com/tylertitsworth, https://github.com/ZainRizvi --- .github/scripts/get_workflow_type.py | 47 ++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/.github/scripts/get_workflow_type.py b/.github/scripts/get_workflow_type.py index 4a5303ae9212..5384ef92c12f 100644 --- a/.github/scripts/get_workflow_type.py +++ b/.github/scripts/get_workflow_type.py @@ -1,6 +1,6 @@ import json from argparse import ArgumentParser -from typing import Any +from typing import Any, Tuple from github import Auth, Github from github.Issue import Issue @@ -9,6 +9,8 @@ WORKFLOW_LABEL_META = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation LABEL_TYPE_KEY = "label_type" +MESSAGE_KEY = "message" +MESSAGE = "" # Debug message to return to the caller def parse_args() -> Any: @@ -48,45 +50,50 @@ def is_exception_branch(branch: str) -> bool: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def get_workflow_type(issue: Issue, username: str) -> str: +def get_workflow_type(issue: Issue, username: str) -> Tuple[str, str]: try: user_list = issue.get_comments()[0].body.split() if user_list[0] == "!": - print("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META + MESSAGE = "LF Workflows are disabled for everyone. Using meta runners." + return WORKFLOW_LABEL_META, MESSAGE elif user_list[0] == "*": - print("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF + MESSAGE = "LF Workflows are enabled for everyone. Using LF runners." + return WORKFLOW_LABEL_LF, MESSAGE elif username in user_list: - print(f"LF Workflows are enabled for {username}. Using LF runners.") - return WORKFLOW_LABEL_LF + MESSAGE = f"LF Workflows are enabled for {username}. Using LF runners." + return WORKFLOW_LABEL_LF, MESSAGE else: - print(f"LF Workflows are disabled for {username}. Using meta runners.") - return WORKFLOW_LABEL_META + MESSAGE = f"LF Workflows are disabled for {username}. Using meta runners." + return WORKFLOW_LABEL_META, MESSAGE except Exception as e: - print( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META + MESSAGE = f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" + return WORKFLOW_LABEL_META, MESSAGE def main() -> None: args = parse_args() if is_exception_branch(args.github_branch): - print(f"Exception branch: '{args.github_branch}', using meta runners") - output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META} + output = { + LABEL_TYPE_KEY: WORKFLOW_LABEL_META, + MESSAGE_KEY: f"Exception branch: '{args.github_branch}', using meta runners", + } else: try: gh = get_gh_client(args.github_token) # The default issue we use - https://github.com/pytorch/test-infra/issues/5132 issue = get_issue(gh, args.github_repo, args.github_issue) - - output = {LABEL_TYPE_KEY: get_workflow_type(issue, args.github_user)} + label_type, message = get_workflow_type(issue, args.github_user) + output = { + LABEL_TYPE_KEY: label_type, + MESSAGE_KEY: message, + } except Exception as e: - print(f"Failed to get issue. Falling back to meta runners. Exception: {e}") - output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META} + output = { + LABEL_TYPE_KEY: WORKFLOW_LABEL_META, + MESSAGE_KEY: f"Failed to get issue. Falling back to meta runners. Exception: {e}", + } json_output = json.dumps(output) print(json_output) From df85f34a14dd30f784418624b05bd52b12ab8b0b Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 14 Jun 2024 01:51:17 -0700 Subject: [PATCH 39/64] Add test to xfail_list only for abi_compatible (#128506) https://github.com/pytorch/pytorch/pull/126717 will skip the tests in both ABI compatible and non-ABI compatible mode. It's not expected to skip them in non-ABI compatible mode since they can actually run successfully in such mode but only have issues in ABI compatible mode. We leverage the existing `xfail_list` for those that will only fail in ABI compatible mode. - `test_qlinear_add` is already in the `xfail_list`. - `test_linear_packed` doesn't fail either in my local run (running with `TORCHINDUCTOR_ABI_COMPATIBLE=1`) or in the CI of this PR so I didn't add it into `xfail_list`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128506 Approved by: https://github.com/jgong5, https://github.com/desertfire --- test/inductor/test_cpu_cpp_wrapper.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 8bf9b1e6a61f..0a2b75ddb554 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -95,7 +95,9 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): "test_qconv2d_relu_cpu", "test_qlinear_cpu", "test_qlinear_add_cpu", + "test_qlinear_add_relu_cpu", "test_qlinear_dequant_promotion_cpu", + "test_qlinear_gelu_cpu", "test_qlinear_relu_cpu", ] for test_name in xfail_list: @@ -125,7 +127,6 @@ def make_test_case( slow=False, func_inputs=None, code_string_count=None, - skip=None, ): test_name = f"{name}_{device}" if device else name if code_string_count is None: @@ -134,8 +135,6 @@ def make_test_case( func = getattr(tests, test_name) assert callable(func), "not a callable" func = slowTest(func) if slow else func - if skip: - func = unittest.skip(skip)(func) @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): @@ -183,7 +182,6 @@ class BaseTest(NamedTuple): slow: bool = False func_inputs: list = None code_string_count: dict = {} - skip: str = None for item in [ BaseTest("test_add_complex"), @@ -242,9 +240,7 @@ class BaseTest(NamedTuple): torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), - BaseTest( - "test_linear_packed", "", test_cpu_repro.CPUReproTests(), skip="Failing" - ), + BaseTest("test_linear_packed", "", test_cpu_repro.CPUReproTests()), BaseTest( "test_lstm_packed_change_input_sizes", "cpu", @@ -318,21 +314,18 @@ class BaseTest(NamedTuple): "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_add", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_add_relu", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_dequant_promotion", @@ -388,7 +381,6 @@ class BaseTest(NamedTuple): item.slow, item.func_inputs, item.code_string_count, - skip=item.skip, ) test_torchinductor.copy_tests( From ed5b8432cdf8451520c064b16d9b0e971c5a5211 Mon Sep 17 00:00:00 2001 From: Alnis Murtovi Date: Wed, 19 Jun 2024 03:12:15 +0000 Subject: [PATCH 40/64] Enable mixed_mm only if casting from lower-bitwidth type to a higher one (#128899) This PR changes the behavior of `cuda_and_enabled_mixed_mm` such that mixed_mm is only enabled if we are casting from a lower-bitwidth type to a higher one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128899 Approved by: https://github.com/eellison --- test/inductor/test_pattern_matcher.py | 26 ++++++++++++++++++-------- torch/_inductor/fx_passes/post_grad.py | 9 +++++++-- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 58f1ff88499f..d4570a8a2dbc 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -442,7 +442,15 @@ def fn(a, b): .sub(8), ) - args_list = [ + def check_uint4x2_mixed_mm(args, expect_mixed_mm): + torch._dynamo.reset() + counters.clear() + ref = fn(*args) + test, (code,) = run_and_get_code(torch.compile(fn), *args) + torch.testing.assert_close(ref, test) + self.assertEqual("uint4x2_mixed_mm" in code, expect_mixed_mm) + + args_expect_mixed_mm = [ ( torch.randn(8, 8, device="cuda"), torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), @@ -454,6 +462,13 @@ def fn(a, b): .contiguous() .t(), ), + ] + + for args in args_expect_mixed_mm: + check_uint4x2_mixed_mm(args, True) + + # mixed mm is only enabled when casting from a lower-bitwidth dtype to a higher one + args_expect_no_mixed_mm = [ ( torch.randn(8, 8, device="cuda"), torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"), @@ -464,13 +479,8 @@ def fn(a, b): ), ] - for args in args_list: - torch._dynamo.reset() - counters.clear() - ref = fn(*args) - test, (code,) = run_and_get_code(torch.compile(fn), *args) - torch.testing.assert_close(ref, test) - self.assertTrue("uint4x2_mixed_mm" in code) + for args in args_expect_no_mixed_mm: + check_uint4x2_mixed_mm(args, False) @unittest.skipIf(not SM80OrLater, "need sm_80") @inductor_config.patch(use_mixed_mm=True) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 4bb0244b97f3..c67471c55ab7 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -229,8 +229,13 @@ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): def cuda_and_enabled_mixed_mm(match): - return (config.use_mixed_mm or config.mixed_mm_choice != "default") and getattr( - match.kwargs["mat1"].meta.get("val"), "is_cuda", False + return ( + (config.use_mixed_mm or config.mixed_mm_choice != "default") + and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) + and ( + match.kwargs["mat2_dtype"].itemsize + > match.kwargs["mat2"].meta.get("val").dtype.itemsize + ) ) From 8771e3429c3d7327f08c48d547ad73546d5603b3 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Tue, 18 Jun 2024 14:24:22 -0700 Subject: [PATCH 41/64] Introduce a prototype for SymmetricMemory (#128582) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them. ### SymmetricMemory `SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers). ### Python API Example ```python from torch._C.distributed_c10d import _SymmetricMemory # Set a store for rendezvousing symmetric allocations on a group of devices # identified by group_name. The concept of groups is logical; users can # utilize predefined groups (e.g., a group of device identified by a # ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator # backends might employ a more efficient communication channel for the actual # rendezvous process and only use the store for bootstrapping purposes. _SymmetricMemory.set_group_info(group_name, rank, world_size, store) # Identical to empty_strided, but allows symmetric memory access to be # established for the allocated tensor via _SymmetricMemory.rendezvous(). # This function itself is not a collective operation. t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name) # Users can write Python custom ops that leverages the symmetric memory access. # Below are examples of things users can do (assuming the group's world_size is 2). # Establishes symmetric memory access on tensors allocated via # _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process, # and the mapping between a local memory region and the associated SymmetricMemory # object is unique. Subsequent calls to rendezvous() with the same tensor will receive # the cached SymmetricMemory object. # # The function has a collective semantic and must be invoked simultaneously # from all rendezvous participants. symm_mem = _SymmetricMemory.rendezvous(t) # This represents the allocation on rank 0 and is accessible from all devices. buf = symm_mem.get_buffer(0, (64, 64), torch.float32) if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) assert buf.eq(42).all() else: # The remote buffer can be used as a regular tensor buf.fill_(42) symm_mem.put_signal(dst_rank=0) symm_mem.barrier() if symm_mem.rank == 0: symm_mem.barrier() assert buf.eq(43).all() else: new_val = torch.empty_like(buf) new_val.fill_(43) # Contiguous copies to/from a remote buffer utilize copy engines # which bypasses SMs (i.e. no need to load the data into registers) buf.copy_(new_val) symm_mem.barrier() ``` ### Custom CUDA Comm Kernels Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels. ```cpp TORCH_API c10::intrusive_ptr get_symmetric_memory( const at::Tensor& tensor); class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { public: ... virtual std::vector get_buffer_ptrs() = 0; virtual std::vector get_signal_pad_ptrs() = 0; virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; ... }; ``` ### Limitations of IntraNodeComm and ProcessGroupCudaP2p Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach: - Leads to awkward UX in which the required workspace needs to be specified upfront. - Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather). - Prevents torch.compile from eliminating all copies. In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels. * __->__ #128582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582 Approved by: https://github.com/wanchaol --- .lintrunner.toml | 1 + BUILD.bazel | 1 + build_variables.bzl | 2 + c10/cuda/driver_api.h | 19 +- caffe2/CMakeLists.txt | 1 + test/distributed/test_symmetric_memory.py | 156 +++++ torch/_C/_distributed_c10d.pyi | 30 + .../distributed/c10d/CUDASymmetricMemory.cu | 539 ++++++++++++++++++ .../distributed/c10d/CUDASymmetricMemory.hpp | 107 ++++ .../distributed/c10d/ProcessGroupCudaP2P.hpp | 1 + .../csrc/distributed/c10d/SymmetricMemory.cpp | 189 ++++++ .../csrc/distributed/c10d/SymmetricMemory.hpp | 152 +++++ torch/csrc/distributed/c10d/init.cpp | 39 ++ .../csrc/distributed/c10d/intra_node_comm.cpp | 99 +--- .../csrc/distributed/c10d/intra_node_comm.cu | 18 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 9 +- 16 files changed, 1252 insertions(+), 111 deletions(-) create mode 100644 test/distributed/test_symmetric_memory.py create mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cu create mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp create mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.cpp create mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.hpp diff --git a/.lintrunner.toml b/.lintrunner.toml index 2c3da39f80cc..76dedf9ea0bd 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,6 +68,7 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', + 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', 'torch/csrc/**/*.cpp', diff --git a/BUILD.bazel b/BUILD.bazel index 10c065f5084c..c563c52d861e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,6 +744,7 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index ceb28707897e..793b611a0a6f 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -501,6 +501,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", + "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", @@ -684,6 +685,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 43bcbd1d70ba..cbbdf16823ec 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -18,14 +18,17 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ + _(cuMemGetAllocationGranularity) \ + _(cuMemExportToShareableHandle) \ + _(cuMemImportFromShareableHandle) \ _(cuGetErrorString) #define C10_NVML_DRIVER_API(_) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 89c31fab1134..8426741609fe 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,6 +560,7 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py new file mode 100644 index 000000000000..a768e059044f --- /dev/null +++ b/test/distributed/test_symmetric_memory.py @@ -0,0 +1,156 @@ +# Owner(s): ["module: c10d"] + +import torch + +import torch.distributed as dist +from torch._C._distributed_c10d import _SymmetricMemory +from torch.distributed.distributed_c10d import _get_process_group_store + +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, + skip_but_pass_in_sandcastle_if, + skipIfRocm, +) + + +def requires_cuda_p2p_access(): + cuda_p2p_access_available = ( + torch.cuda.is_available() and torch.cuda.device_count() >= 2 + ) + num_devices = torch.cuda.device_count() + for i in range(num_devices - 1): + for j in range(i + 1, num_devices): + if not torch.cuda.can_device_access_peer(i, j): + cuda_p2p_access_available = False + break + if not cuda_p2p_access_available: + break + + return skip_but_pass_in_sandcastle_if( + not cuda_p2p_access_available, + "cuda p2p access is not available", + ) + + +@instantiate_parametrized_tests +@requires_cuda_p2p_access() +class SymmetricMemoryTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 2 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + _SymmetricMemory.set_group_info( + "0", + self.rank, + self.world_size, + _get_process_group_store(dist.GroupMember.WORLD), + ) + + def _verify_symmetric_memory(self, symm_mem): + self.assertEqual(symm_mem.world_size, 2) + + buf = symm_mem.get_buffer(0, (64, 64), torch.float32) + if symm_mem.rank == 0: + symm_mem.wait_signal(src_rank=1) + self.assertTrue(buf.eq(42).all()) + else: + buf.fill_(42) + symm_mem.put_signal(dst_rank=0) + + symm_mem.barrier() + + if symm_mem.rank == 0: + symm_mem.barrier() + self.assertTrue(buf.eq(43).all()) + else: + buf.fill_(43) + symm_mem.barrier() + + symm_mem.barrier() + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_empty_strided_p2p(self) -> None: + self._init_process() + + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + group_name = "0" + alloc_args = (shape, stride, dtype, device, group_name) + + t = torch.empty(shape, dtype=dtype, device=device) + with self.assertRaises(RuntimeError): + _SymmetricMemory.rendezvous(t) + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + symm_mem = _SymmetricMemory.rendezvous(t) + + del t + self._verify_symmetric_memory(symm_mem) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_empty_strided_p2p_persistent(self) -> None: + self._init_process() + + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + alloc_id = 42 # Persistent allocation + group_name = "0" + alloc_args = (shape, stride, dtype, device, group_name, alloc_id) + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + data_ptr = t.data_ptr() + + # Verify that persistent allocation would fail if there's an active + # allocation with the same alloc_id. + with self.assertRaises(RuntimeError): + _SymmetricMemory.empty_strided_p2p(*alloc_args) + + # Verify that persistent allocation would succeed in lieu of activate + # allocations with the same alloc_id, and the returned tensor would + # have the same data pointer. + del t + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + self.assertEqual(t.data_ptr(), data_ptr) + + # Verify that get_symmetric_memory would fail if called before + # rendezvous. + with self.assertRaises(RuntimeError): + _SymmetricMemory.get_symmetric_memory(t) + + symm_mem_0 = _SymmetricMemory.rendezvous(t) + symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) + self.assertEqual(id(symm_mem_0), id(symm_mem_1)) + + self._verify_symmetric_memory(symm_mem_0) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index cffbf22219c8..0095b5af434b 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -637,3 +637,33 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... + +class _SymmetricMemory: + @staticmethod + def set_group_info( + group_name: str, rank: int, world_size: int, store: Store + ) -> None: ... + @staticmethod + def empty_strided_p2p( + size: torch.types._size, + stride: torch.types._size, + dtype: torch.dtype, + device: torch.device, + group_name: str, + ) -> torch.Tensor: ... + @property + def rank(self) -> int: ... + @property + def world_size(self) -> int: ... + @staticmethod + def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... + def get_buffer( + self, + rank: int, + sizes: torch.Size, + dtype: torch.dtype, + storage_offset: Optional[int] = 0, + ) -> torch.Tensor: ... + def barrier(self, channel: int = 0) -> None: ... + def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... + def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu new file mode 100644 index 000000000000..d923fb6044f2 --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -0,0 +1,539 @@ +#include + +#include +#include +#include +#include + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#endif + +#include +#include + +namespace { + +constexpr size_t signal_pad_size = 2048; +const std::string store_comm_prefix = "CUDASymmetricMemory"; + +static size_t store_comm_seq_id = 0; + +template +std::vector store_all_gather( + const c10::intrusive_ptr& store, + int rank, + int world_size, + T val) { + static_assert(std::is_trivially_copyable_v); + + std::vector peer_keys; + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r; + peer_keys.push_back(oss.str()); + } + ++store_comm_seq_id; + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peer_keys[rank], payload); + } + + std::vector peer_vals; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + peer_vals.push_back(val); + continue; + } + store->wait({peer_keys[r]}); + auto payload = store->get(peer_keys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peer_val{}; + std::memcpy(&peer_val, payload.data(), sizeof(T)); + peer_vals.push_back(peer_val); + } + return peer_vals; +} + +void store_barrier( + const c10::intrusive_ptr& store, + int rank, + int world_size) { + store_all_gather(store, rank, world_size, 0); +} + +int import_remote_fd(int pid, int fd) { +#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd) + int pidfd = syscall(SYS_pidfd_open, pid, 0); + return syscall(SYS_pidfd_getfd, pidfd, fd, 0); +#else + TORCH_CHECK( + false, + "CUDASymmetricMemory requires pidfd_open ", + "and pidfd_getfd support"); +#endif +} + +void map_block( + void** ptr, + c10d::symmetric_memory::HandleType handle, + size_t size, + int device_idx) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + auto dev_ptr = reinterpret_cast(ptr); + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL)); + + CUmemAccessDesc desc; + desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + desc.location.id = static_cast(device_idx); + desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1)); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +} // namespace + +namespace c10d { +namespace symmetric_memory { + +CUDASymmetricMemory::CUDASymmetricMemory( + std::vector handles, + size_t block_size, + std::vector buffers, + std::vector signal_pads, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size) + : handles_(std::move(handles)), + block_size_(block_size), + buffers_(std::move(buffers)), + signal_pads_(std::move(signal_pads)), + buffer_size_(buffer_size), + local_device_idx_(local_device_idx), + rank_(rank), + world_size_(world_size) { + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + + c10::cuda::CUDAGuard guard(local_device_idx); + AT_CUDA_CHECK(cudaMemcpy( + buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); + AT_CUDA_CHECK(cudaMemcpy( + signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); +} + +CUDASymmetricMemory::~CUDASymmetricMemory() { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + c10::cuda::CUDAGuard guard(local_device_idx_); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + auto driver_api = c10::cuda::DriverAPI::get(); + for (int r = 0; r < world_size_; ++r) { + C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( + reinterpret_cast(buffers_[r]), block_size_)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); + } + c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); + c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +std::vector CUDASymmetricMemory::get_buffer_ptrs() { + return buffers_; +} + +std::vector CUDASymmetricMemory::get_signal_pad_ptrs() { + return signal_pads_; +} + +void** CUDASymmetricMemory::get_buffer_ptrs_dev() { + return buffers_dev_; +} + +void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() { + return signal_pads_dev_; +} + +size_t CUDASymmetricMemory::get_buffer_size() { + return buffer_size_; +} + +size_t CUDASymmetricMemory::get_signal_pad_size() { + return signal_pad_size; +} + +at::Tensor CUDASymmetricMemory::get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + const auto numel = + std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "CUDASymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(buffers_[rank], sizes) + .storage_offset(storage_offset) + .options(options) + .target_device(device) + .make_tensor(); +} + +void check_channel(int channel, int world_size) { + TORCH_CHECK( + channel >= 0, + "channel for barrier(), put_signal() and wait_signal() ", + "must be greater than 0 (got ", + channel, + ")"); + const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + TORCH_CHECK( + static_cast(channel) < num_channels, + "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", + num_channels - 1, + " (got ", + channel, + ")"); +} + +__device__ __forceinline__ void release_signal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + volatile uint32_t* signal = addr; + uint32_t val; + do { + val = *signal; + } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); +#endif +} + +__device__ __forceinline__ void acquire_signal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + volatile uint32_t* signal = addr; + uint32_t val; + do { + val = *signal; + } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); +#endif +} + +static __global__ void barrier_kernel( + uint32_t** signal_pads, + int channel, + int rank, + int world_size) { + if (threadIdx.x < world_size) { + auto target_rank = threadIdx.x; + release_signal(signal_pads[target_rank] + world_size * channel + rank); + acquire_signal(signal_pads[rank] + world_size * channel + target_rank); + } +} + +void CUDASymmetricMemory::barrier(int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +static __global__ void put_signal_kernel( + uint32_t** signal_pads, + int dst_rank, + int channel, + int rank, + int world_size) { + if (threadIdx.x == 0) { + release_signal(signal_pads[dst_rank] + world_size * channel + rank); + } +} + +void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + dst_rank, + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +static __global__ void wait_signal_kernel( + uint32_t** signal_pads, + int src_rank, + int channel, + int rank, + int world_size) { + if (threadIdx.x == 0) { + acquire_signal(signal_pads[rank] + world_size * channel + src_rank); + } + __threadfence_system(); +} + +void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + src_rank, + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +int CUDASymmetricMemory::get_rank() { + return rank_; +} + +int CUDASymmetricMemory::get_world_size() { + return world_size_; +} + +void* CUDASymmetricMemoryAllocator::alloc( + size_t size, + int device_idx, + const std::string& group_name) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + prop.location.id = device_idx; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + size_t signal_pad_offset = at::round_up(size, 16UL); + size_t block_size = signal_pad_offset + signal_pad_size; + + size_t granularity; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + block_size = at::round_up(block_size, granularity); + + HandleType handle; + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); + + void* ptr = nullptr; + map_block(&ptr, handle, block_size, device_idx); + + c10::cuda::CUDAGuard guard(device_idx); + AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); + + auto block = c10::make_intrusive( + handle, device_idx, block_size, size, signal_pad_offset, group_name); + { + std::unique_lock lock(mutex_); + ptr_to_block_.emplace(ptr, std::move(block)); + } + return ptr; +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +void CUDASymmetricMemoryAllocator::free(void* ptr) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto block = find_block(ptr); + if (block == nullptr) { + return; + } + // Initializing CUDASymmetricMemory with an allocation transfers its + // ownership to the CUDASymmetricMemory object. + if (block->symm_mem == nullptr) { + auto driver_api = c10::cuda::DriverAPI::get(); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( + reinterpret_cast(ptr), block->block_size)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); + } + { + std::unique_lock lock(mutex_); + ptr_to_block_.erase(ptr); + } +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + return block->buffer_size; +} + +struct RendezvousRequest { + int device_idx; + int block_fd; + int pid; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; +}; + +void validate_rendezvous_requests( + const std::vector reqs, + int world_size) { + TORCH_CHECK(reqs.size() == (size_t)world_size); + + std::unordered_set device_indices; + device_indices.reserve(world_size); + for (auto req : reqs) { + device_indices.insert(req.device_idx); + } + if (device_indices.size() < (size_t)world_size) { + TORCH_CHECK( + false, + "CUDASymmetricMemoryAllocator::rendezvous: ", + "detected allocations from overlapping devices ", + "from different ranks."); + } + + for (int r = 1; r < world_size; ++r) { + TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); + TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); + TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); + } +} + +c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( + void* ptr) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + + if (block->symm_mem != nullptr) { + return block->symm_mem; + } + + auto group_info = get_group_info(block->group_name); + auto driver_api = c10::cuda::DriverAPI::get(); + int block_fd; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); + + auto local_req = RendezvousRequest{ + .device_idx = block->device_idx, + .block_fd = block_fd, + .pid = getpid(), + .block_size = block->block_size, + .buffer_size = block->buffer_size, + .signal_pad_offset = block->signal_pad_offset}; + auto reqs = store_all_gather( + group_info.store, group_info.rank, group_info.world_size, local_req); + validate_rendezvous_requests(reqs, group_info.world_size); + + std::vector handles(group_info.world_size); + std::vector buffers(group_info.world_size, nullptr); + std::vector signal_pads(group_info.world_size, nullptr); + for (int r = 0; r < group_info.world_size; ++r) { + if (r == group_info.rank) { + handles[r] = block->handle; + buffers[r] = ptr; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + continue; + } + int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &handles[r], + (void*)(uintptr_t)imported_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + map_block(&buffers[r], handles[r], block->block_size, block->device_idx); + signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); + close(imported_fd); + } + store_barrier(group_info.store, group_info.rank, group_info.world_size); + close(block_fd); + + // Initializing CUDASymmetricMemory with an allocation transfers its + // ownership to the CUDASymmetricMemory object. So that outstanding + // references to the CUDASymmetricMemory object can keep the allocation + // alive. + block->symm_mem = c10::make_intrusive( + std::move(handles), + block->block_size, + std::move(buffers), + std::move(signal_pads), + block->buffer_size, + block->device_idx, + group_info.rank, + group_info.world_size); + return block->symm_mem; +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + return block->symm_mem != nullptr; +} + +c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { + std::shared_lock lock(mutex_); + auto it = ptr_to_block_.find(ptr); + if (it == ptr_to_block_.end()) { + return nullptr; + } + return it->second; +} + +struct RegisterCUDASymmetricMemoryAllocator { + RegisterCUDASymmetricMemoryAllocator() { + register_allocator( + c10::DeviceType::CUDA, + c10::make_intrusive()); + } +}; + +static RegisterCUDASymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp new file mode 100644 index 000000000000..82e75d22c84f --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include + +namespace c10d { +namespace symmetric_memory { + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +using HandleType = CUmemGenericAllocationHandle; +#else +using HandleType = void*; +#endif + +class CUDASymmetricMemory : public SymmetricMemory { + public: + CUDASymmetricMemory( + std::vector handles, + size_t block_size, + std::vector buffers, + std::vector signal_pads, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size); + + ~CUDASymmetricMemory() override; + + std::vector get_buffer_ptrs() override; + std::vector get_signal_pad_ptrs() override; + void** get_buffer_ptrs_dev() override; + void** get_signal_pad_ptrs_dev() override; + size_t get_buffer_size() override; + size_t get_signal_pad_size() override; + + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) override; + + void barrier(int channel) override; + void put_signal(int dst_rank, int channel) override; + void wait_signal(int src_rank, int channel) override; + + int get_rank() override; + int get_world_size() override; + + private: + std::vector handles_; + size_t block_size_; + std::vector buffers_; + std::vector signal_pads_; + size_t buffer_size_; + int local_device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::optional> finalizer_; +}; + +struct Block : public c10::intrusive_ptr_target { + HandleType handle; + int device_idx; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + std::string group_name; + c10::intrusive_ptr symm_mem = nullptr; + + Block( + HandleType handle, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::string& group_name) + : handle(handle), + device_idx(device_idx), + block_size(block_size), + buffer_size(buffer_size), + signal_pad_offset(signal_pad_offset), + group_name(group_name), + symm_mem(nullptr) {} +}; + +class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc(size_t size, int device_idx, const std::string& group_name) + override; + + void free(void* ptr) override; + size_t get_alloc_size(void* ptr) override; + c10::intrusive_ptr rendezvous(void* ptr) override; + bool is_rendezvous_completed(void* ptr) override; + + private: + c10::intrusive_ptr find_block(void* ptr); + + std::shared_mutex mutex_; + std::unordered_map> ptr_to_block_; +}; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index cff4ad09b706..7c41414c4e4e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -10,6 +10,7 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout = namespace c10d { +// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API ProcessGroupCudaP2P : public Backend { public: struct Options : Backend::Options { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp new file mode 100644 index 000000000000..b3d9f31bb034 --- /dev/null +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -0,0 +1,189 @@ +#include + +namespace { + +using namespace c10d::symmetric_memory; + +class AllocatorMap { + public: + static AllocatorMap& get() { + static AllocatorMap instance; + return instance; + } + + void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator) { + map_[device_type] = std::move(allocator); + } + + c10::intrusive_ptr get_allocator( + c10::DeviceType device_type) { + auto it = map_.find(device_type); + TORCH_CHECK( + it != map_.end(), + "SymmetricMemory does not support device type ", + device_type); + return it->second; + } + + ~AllocatorMap() { + for (auto& it : map_) { + it.second.release(); + } + } + + private: + AllocatorMap() = default; + AllocatorMap(const AllocatorMap&) = delete; + AllocatorMap& operator=(const AllocatorMap&) = delete; + + std::unordered_map< + c10::DeviceType, + c10::intrusive_ptr> + map_; +}; + +static std::unordered_map group_info_map{}; + +// Data structures for tracking persistent allocations +static std::unordered_map alloc_id_to_dev_ptr{}; +static std::unordered_map> + alloc_id_to_storage{}; + +static at::Tensor empty_strided_p2p_persistent( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + uint64_t alloc_id) { + // Make the allocation fails if a previous allocation with the same alloc_id + // is still active. + auto storage = alloc_id_to_storage.find(alloc_id); + if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) { + TORCH_CHECK( + false, + "SymmetricMemory::empty_strided_p2p_persistent: ", + "can not allocate with alloc_id == ", + alloc_id, + " because a previous allocation with the same alloc_id " + "is still active."); + } + + const size_t numel = + std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t element_size = c10::elementSize(dtype); + const size_t alloc_size = numel * element_size; + + auto allocator = get_allocator(device.type()); + void* dev_ptr = nullptr; + if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) { + dev_ptr = alloc_id_to_dev_ptr[alloc_id]; + TORCH_CHECK( + alloc_size == allocator->get_alloc_size(dev_ptr), + "SymmetricMemory::empty_strided_p2p_persistent: ", + "requested allocation size (", + alloc_size, + ") is different from the size of a previous allocation ", + "with the same alloc_id ", + allocator->get_alloc_size(dev_ptr)); + } else { + dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); + alloc_id_to_dev_ptr[alloc_id] = dev_ptr; + } + + auto options = at::TensorOptions().dtype(dtype).device(device); + auto allocated = at::from_blob(dev_ptr, size, stride, options); + + // Track the allocation's activeness + alloc_id_to_storage.erase(alloc_id); + alloc_id_to_storage.emplace( + alloc_id, allocated.storage().getWeakStorageImpl()); + return allocated; +} + +} // namespace + +namespace c10d { +namespace symmetric_memory { + +void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator) { + return AllocatorMap::get().register_allocator( + device_type, std::move(allocator)); +} + +c10::intrusive_ptr get_allocator( + c10::DeviceType device_type) { + return AllocatorMap::get().get_allocator(device_type); +} + +void set_group_info( + const std::string& group_name, + int rank, + int world_size, + c10::intrusive_ptr store) { + TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end()); + GroupInfo group_info; + group_info.rank = rank; + group_info.world_size = world_size; + group_info.store = std::move(store); + group_info_map.emplace(group_name, std::move(group_info)); +} + +const GroupInfo& get_group_info(const std::string& group_name) { + TORCH_CHECK( + group_info_map.find(group_name) != group_info_map.end(), + "get_group_info: no group info associated with the group name ", + group_name); + return group_info_map[group_name]; +} + +at::Tensor empty_strided_p2p( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + std::optional alloc_id) { + if (alloc_id.has_value()) { + return empty_strided_p2p_persistent( + size, stride, dtype, device, group_name, *alloc_id); + } + const size_t numel = + std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t element_size = c10::elementSize(dtype); + const size_t alloc_size = numel * element_size; + + auto allocator = get_allocator(device.type()); + void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); + + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::from_blob( + dev_ptr, + size, + stride, + [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); }, + options); +} + +TORCH_API c10::intrusive_ptr rendezvous( + const at::Tensor& tensor) { + auto allocator = get_allocator(tensor.device().type()); + return allocator->rendezvous(tensor.data_ptr()); +} + +c10::intrusive_ptr get_symmetric_memory( + const at::Tensor& tensor) { + auto allocator = get_allocator(tensor.device().type()); + TORCH_CHECK( + allocator->is_rendezvous_completed(tensor.data_ptr()), + "SymmetricMemory: must invoke rendezvous on a tensor ", + "before calling get_symmetric_memory on it"); + return allocator->rendezvous(tensor.data_ptr()); +} + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp new file mode 100644 index 000000000000..344b86ea5c7e --- /dev/null +++ b/torch/csrc/distributed/c10d/SymmetricMemory.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include +#include + +namespace c10d { +namespace symmetric_memory { + +// SymmetricMemory represents symmetric allocations across a group of devices. +// The allocations represented by a SymmetricMemory object are accessible by +// all devices in the group. The class can be used for op-level custom +// communication patterns (via the get_buffer APIs and the synchronization +// primitives), as well as custom communication kernels (via the buffer and +// signal_pad device pointers). +// +// To acquire a SymmetricMemory object, each rank first allocates +// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes +// SymmetricMemoryAllocator::rendezvous() on the memory to establish the +// association across peer buffers. The rendezvous is a one-time process, and +// the mapping between a local memory memory and the associated SymmetricMemory +// object is unique. +// +// NOTE [symmetric memory signal pad] +// Signal pads are P2P-accessible memory regions designated for +// synchronization. SymmetricMemory offers built-in synchronization primitives +// such as barriers, put_signal, and wait_signal, which are all based on signal +// pads. Users may utilize signal pads for their own synchronization logic, +// provided that the signal pads remain zero-filled following successful +// synchronization. +// +// NOTE [symmetric memory synchronization channel] +// Synchronization channels allow users to use a single SymmetricMemory object +// to perform isolated synchronizations on different streams. For example, +// consider the case in which two barriers are issued on two streams for +// different purposes. Without the concept of channels, we cannot guarantee the +// correctness of the barriers since signals issued from barrier on stream A +// can be received by the barrier on stream B. By specifying different channels +// for these two barriers, they can operate correctly in parallel. +class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { + public: + virtual ~SymmetricMemory() {} + + virtual std::vector get_buffer_ptrs() = 0; + virtual std::vector get_signal_pad_ptrs() = 0; + + // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer + // to a device array of size world_size, containing buffer pointers and + // signal pad pointers, respectively. + virtual void** get_buffer_ptrs_dev() = 0; + virtual void** get_signal_pad_ptrs_dev() = 0; + virtual size_t get_buffer_size() = 0; + virtual size_t get_signal_pad_size() = 0; + + virtual at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) = 0; + + virtual void barrier(int channel) = 0; + virtual void put_signal(int dst_rank, int channel) = 0; + virtual void wait_signal(int src_rank, int channel) = 0; + + virtual int get_rank() = 0; + virtual int get_world_size() = 0; +}; + +class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { + public: + virtual ~SymmetricMemoryAllocator(){}; + + virtual void* alloc( + size_t size, + int device_idx, + const std::string& group_name) = 0; + + virtual void free(void* ptr) = 0; + virtual size_t get_alloc_size(void* ptr) = 0; + virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; + virtual bool is_rendezvous_completed(void* ptr) = 0; +}; + +C10_EXPORT void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator); + +C10_EXPORT c10::intrusive_ptr get_allocator( + c10::DeviceType device_type); + +// Set a store for rendezvousing symmetric allocations on a group of devices +// identified by `group_name`. The concept of groups is logical; users can +// utilize predefined groups (e.g., a group of device identified by a +// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator +// backends might employ a more efficient communication channel for the actual +// rendezvous process and only use the store for bootstrapping purposes. +TORCH_API void set_group_info( + const std::string& group_name, + int rank, + int world_size, + c10::intrusive_ptr store); + +struct GroupInfo { + int rank; + int world_size; + c10::intrusive_ptr store; +}; + +C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); + +// Identical to empty_strided, but allows symmetric memory access to be +// established for the allocated tensor via SymmetricMemory::rendezvous(). This +// function itself is not a collective operation. It invokes +// SymmetricMemoryAllocator::alloc() for the requested device under the hood. +// +// NOTE [symmetric memory persistent allocation] +// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent +// allocation. This makes the function cache allocated memory and ensure that +// invocations with the same `alloc_id` receive tensors backed by the same +// memory address. For safety, if a previous persistent allocation is still +// active (i.e., the storage of the returned tensor is still alive), persistent +// allocations with the same `alloc_id` will fail. This determinism coupled +// with memory planning of communication buffers (e.g., by Inductor) allows +// communication algorithms to reliably reuse previously established remote +// memory access. +TORCH_API at::Tensor empty_strided_p2p( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + std::optional alloc_id); + +// Establishes symmetric memory access on tensors allocated via +// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a +// one-time process, and the mapping between a local memory region and the +// associated SymmetricMemory object is unique. Subsequent calls to +// rendezvous() with the same tensor, or tensors allocated with +// empty_strided_p2p_persistent() using the same alloc_id, will receive the +// cached SymmetricMemory object. +// +// The function has a collective semantic and must be invoked simultaneously +// from all rendezvous participants. +TORCH_API c10::intrusive_ptr rendezvous( + const at::Tensor& tensor); + +// Returns the SymmetricMemory object associated with the tensor. It can only +// be invoked after rendezvous() but does not need to be invoked collectively. +TORCH_API c10::intrusive_ptr get_symmetric_memory( + const at::Tensor& tensor); + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f1b28886b98..db5778efcf35 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include @@ -975,6 +976,44 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); + using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; + py::class_>( + module, "_SymmetricMemory") + .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info) + .def_static( + "empty_strided_p2p", + ::c10d::symmetric_memory::empty_strided_p2p, + py::arg("size"), + py::arg("stride"), + py::arg("dtype"), + py::arg("device"), + py::arg("group_name"), + py::arg("alloc_id") = py::none()) + .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) + .def_static( + "get_symmetric_memory", + &::c10d::symmetric_memory::get_symmetric_memory) + .def_property_readonly("rank", &SymmetricMemory::get_rank) + .def_property_readonly("world_size", &SymmetricMemory::get_world_size) + .def( + "get_buffer", + &SymmetricMemory::get_buffer, + py::arg("rank"), + py::arg("sizes"), + py::arg("dtype"), + py::arg("storage_offset") = 0) + .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) + .def( + "put_signal", + &SymmetricMemory::put_signal, + py::arg("dst_rank"), + py::arg("channel") = 0) + .def( + "wait_signal", + &SymmetricMemory::wait_signal, + py::arg("src_rank"), + py::arg("channel") = 0); + auto store = py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 85136a91e025..9d7ba5abf951 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -218,23 +218,8 @@ IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { return; } - // Intentionally releasing resources without synchronizing devices. The - // teardown logic is safe for propoerly sync'd user program. We don't want - // improperly sync'd user program to hang here. - for (size_t r = 0; r < worldSize_; ++r) { - if (r == rank_) { - continue; - } - AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); - AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); - } - AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); - AT_CUDA_CHECK(cudaFree(buffers_[rank_])); - if (topoInfo_ != nullptr) { - AT_CUDA_CHECK(cudaFree(topoInfo_)); - } - AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); - AT_CUDA_CHECK(cudaFree(buffersDev_)); + auto allocator = get_allocator(c10::DeviceType::CUDA); + allocator->free(symmetricMemoryPtr_); } bool IntraNodeComm::isEnabled() { @@ -344,83 +329,19 @@ bool IntraNodeComm::rendezvous() { // Detect topology Topology topology = detectTopology(nvlMesh, worldSize_); - // Initialize p2p state - auto p2pState = initP2pState(); - - // Allocate buffer - void* buffer = nullptr; - AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_)); - - // Second handshake: exchange topology and CUDA IPC handles - struct IpcInfo { - NvlMesh nvlMesh; - Topology topology; - cudaIpcMemHandle_t p2pStateHandle, bufferHandle; - }; - - // Make p2p state and buffer available for IPC - cudaIpcMemHandle_t p2pStateHandle, bufferHandle; - AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); - AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); - - IpcInfo ipcInfo{ - .nvlMesh = nvlMesh, - .topology = topology, - .p2pStateHandle = p2pStateHandle, - .bufferHandle = bufferHandle}; - - auto peerIpcInfos = - storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo); - - for (const auto& info : peerIpcInfos) { - if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || - info.topology != peerIpcInfos.front().topology) { - LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " - "participants are observing different topologies (" - << int(info.topology) << " and " << int(topology) << ")"; - AT_CUDA_CHECK(cudaFree(p2pState)); - AT_CUDA_CHECK(cudaFree(buffer)); - return false; - } - } - - std::array p2pStates = {}, buffers = {}; - for (size_t r = 0; r < peerIpcInfos.size(); ++r) { - if (r == rank_) { - p2pStates[r] = p2pState; - buffers[r] = buffer; - } else { - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - &p2pStates[r], - peerIpcInfos[r].p2pStateHandle, - cudaIpcMemLazyEnablePeerAccess)); - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - &buffers[r], - peerIpcInfos[r].bufferHandle, - cudaIpcMemLazyEnablePeerAccess)); - } - } - void* p2pStatesDev = nullptr; - AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); - AT_CUDA_CHECK(cudaMemcpy( - p2pStatesDev, - p2pStates.data(), - sizeof(p2pStates), - cudaMemcpyHostToDevice)); - - void* buffersDev = nullptr; - AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); - AT_CUDA_CHECK(cudaMemcpy( - buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); + set_group_info("IntraNodeComm", rank_, worldSize_, store_); + auto allocator = get_allocator(c10::DeviceType::CUDA); + symmetricMemoryPtr_ = + allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm"); + symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); + TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); isInitialized_ = true; topology_ = topology; - std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin()); - std::copy(buffers.begin(), buffers.end(), buffers_.begin()); - p2pStatesDev_ = p2pStatesDev; - buffersDev_ = buffersDev; + p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); + buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); topoInfo_ = topoInfo; return true; #endif diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index 51fc6252d223..ac751ff7be1e 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -132,6 +132,8 @@ struct P2pState { uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; }; +static_assert(sizeof(P2pState) <= kP2pStateSize); + template static __global__ void oneShotAllReduceKernel( at::BFloat16* input, @@ -522,7 +524,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks; if (!fuseInputCopy) { AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -582,7 +584,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -632,7 +634,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -755,15 +757,7 @@ at::Tensor IntraNodeComm::getBuffer( const std::vector& sizes, c10::ScalarType dtype, int64_t storageOffset) { - const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0); - const auto elementSize = c10::elementSize(dtype); - TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_); - auto options = at::TensorOptions().dtype(dtype).device( - at::kCUDA, at::cuda::current_device()); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storageOffset) - .options(options) - .make_tensor(); + return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); } } // namespace intra_node_comm diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index 5d7e2d426d30..a67df5c34586 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -4,12 +4,16 @@ #include #include #include +#include #include namespace c10d::intra_node_comm { +using namespace c10d::symmetric_memory; + constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; +constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; using HybridCubeMesh = std::array, kMaxDevices>; @@ -27,6 +31,7 @@ enum class AllReduceAlgo : uint8_t { HCM = 3 }; +// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { public: IntraNodeComm( @@ -97,8 +102,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool isInitialized_ = false; Topology topology_ = Topology::UNKNOWN; - std::array p2pStates_{}; - std::array buffers_{}; + void* symmetricMemoryPtr_ = nullptr; + c10::intrusive_ptr symmetricMemory_ = nullptr; void* p2pStatesDev_{}; void* buffersDev_{}; void* topoInfo_{}; From eb9f4da11e86882b5c628cea539112de9638760a Mon Sep 17 00:00:00 2001 From: chilli Date: Tue, 18 Jun 2024 11:23:49 -0700 Subject: [PATCH 42/64] Modified template indexing to broadcast indices to out instead of mask and some other flexattention micro-opts (#128938) For headdim=64 and headdim=128 Old: image New: image Note, this does regress headdim=256. We can unregress it by special casing `headdim=256`, but ehh.... we can do it later Pull Request resolved: https://github.com/pytorch/pytorch/pull/128938 Approved by: https://github.com/drisspg --- benchmarks/transformer/score_mod.py | 4 +-- torch/_inductor/kernel/flex_attention.py | 32 +++++++++--------- torch/_inductor/select_algorithm.py | 42 +++++++++++++++++++----- 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index 29d951bc1dee..135f26b0df2d 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -264,7 +264,7 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]: batch_sizes = [2, 8, 16] num_heads = [16] q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)] - head_dims = [64, 128, 256] + head_dims = [64, 128] dtypes = [ torch.bfloat16, ] @@ -302,8 +302,6 @@ def main(dynamic: bool, calculate_bwd: bool): results.append( Experiment(config, run_single_experiment(config, dynamic=dynamic)) ) - for config in tqdm(generate_experiment_configs(calculate_bwd)): - results.append(Experiment(config, run_single_experiment(config))) print_results(results) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 987dc6d89328..edb69068f0cd 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -242,10 +242,8 @@ def build_subgraph_buffer( start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk) + qk = tl.dot(q, k) # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ m = offs_m[:, None] n = start_n + offs_n[None, :] @@ -265,24 +263,26 @@ def build_subgraph_buffer( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -- compute scaling constant --- - row_max = tl.max(post_mod_scores, 1) - m_i_new = tl.maximum(m_i, row_max) + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(post_mod_scores - m_i_new[:, None]) + alpha = tl.math.exp2(m_i - m_ij) + p = tl.math.exp2(post_mod_scores - m_ij[:, None]) if not ROWS_GUARANTEED_SAFE: - masked_out_rows = (m_i_new == float("-inf")) + masked_out_rows = (m_ij == float("-inf")) alpha = tl.where(masked_out_rows, 0, alpha) p = tl.where(masked_out_rows[:, None], 0, p) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc) - - # -- update m_i and l_i -- + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new + # # -- scale and update acc -- + acc = acc * alpha[:, None] + v = tl.load(V_block_ptr) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc) + + # -- update m_i + m_i = m_ij # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -294,8 +294,8 @@ def build_subgraph_buffer( idx_m = offs_m[:, None] idx_d = tl.arange(0, BLOCK_DMODEL)[None, :] + mask = idx_m < Q_LEN # TODO generalize and add proper mask support - mask = (idx_m != -1) & (idx_d != -1) {{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}} # TODO dont want to write this if we dont require grad diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index dd78d2869ce2..fb43e7da1d13 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -13,6 +13,7 @@ import sys import textwrap import time +from collections import namedtuple from concurrent.futures import ThreadPoolExecutor from io import StringIO @@ -102,6 +103,16 @@ def finalize_all(self) -> str: return self.code +SubgraphInfo = namedtuple( + "SubgraphInfo", + [ + "body", + "template_mask", + "template_out", + ], +) + + class TritonTemplateKernel(TritonKernel): def __init__( self, @@ -132,7 +143,6 @@ def __init__( self.named_input_nodes = {} # type: ignore[var-annotated] self.defines = defines self.kernel_name = kernel_name - self.template_mask = None self.use_jit = use_jit self.num_stages = num_stages self.num_warps = num_warps @@ -147,21 +157,34 @@ def __init__( self.triton_meta: Optional[Dict[str, object]] = None # For Templated Attention this can be a list of ir.Subgraph self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs + + # The following attributes (body, template_mask, output_val) are all + # used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + self.subgraph_bodies: Dict[str, SubgraphInfo] = {} + self.body: IndentedBuffer = FakeIndentedBuffer() - self.subgraph_bodies: Dict[str, IndentedBuffer] = {} + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None @contextlib.contextmanager def set_subgraph_body(self, body_name: str): - old_body = self.body + old_body, old_mask, old_out = self.body, self.template_mask, self.template_out assert body_name in self.subgraph_bodies, body_name - self.body = self.subgraph_bodies[body_name] + self.body, self.template_mask, self.template_out = self.subgraph_bodies[ + body_name + ] yield - self.body = old_body + self.subgraph_bodies[body_name] = SubgraphInfo( + self.body, self.template_mask, self.template_out + ) + self.body, self.template_mask, self.template_out = old_body, old_mask, old_out @contextlib.contextmanager def create_subgraph_body(self, body_name: str): assert body_name not in self.subgraph_bodies - self.subgraph_bodies[body_name] = IndentedBuffer() + self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None) with self.set_subgraph_body(body_name): yield @@ -406,7 +429,8 @@ def store_output( self.range_trees[0].lookup( sympy.Integer(1), sympy_product(lengths) ).set_name("xindex") - self.template_mask = mask # type: ignore[assignment] + self.template_mask = mask + self.template_out = val self.template_indices = indices output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) @@ -492,7 +516,9 @@ def indexing( return super().indexing( index, dense_indexing=False, - copy_shape=self.template_mask, + # We pass template_out as the shape to broadcast the indexing to as + # the mask might be broadcast to the output shape + copy_shape=self.template_out, override_mask=self.template_mask, block_ptr=block_ptr, ) From acefc5c0160d8e37858b3c28fff07e6513b78e10 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 19 Jun 2024 03:45:41 +0000 Subject: [PATCH 43/64] [torch.compile] Enable bwd compilation metrics (#128973) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128973 Approved by: https://github.com/dshi7 --- torch/_dynamo/utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 9fa70e0c98d5..e283308aa37d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -714,18 +714,15 @@ def record_compilation_metrics( name = "compilation_metrics" else: name = "bwd_compilation_metrics" - # Currently only record fwd compilation metrics, will add bwd compilation metrics - # after the internal Scuba logging changes finish. - if isinstance(compilation_metrics, CompilationMetrics): - torch._logging.trace_structured( - name, - lambda: { - k: list(v) if isinstance(v, set) else v - for k, v in dataclasses.asdict(compilation_metrics).items() - }, - ) - if config.log_compilation_metrics: - log_compilation_event(compilation_metrics) + torch._logging.trace_structured( + name, + lambda: { + k: list(v) if isinstance(v, set) else v + for k, v in dataclasses.asdict(compilation_metrics).items() + }, + ) + if config.log_compilation_metrics: + log_compilation_event(compilation_metrics) def set_compilation_metrics_limit(new_size: int) -> None: From 1f0a68b57290afff9691d823829fda6ba4f73cbb Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Wed, 19 Jun 2024 03:56:20 +0000 Subject: [PATCH 44/64] [ROCm] Fix fp32 atomicAdd for non-MI100 GPUs (#128750) Current implementation is very specific to MI100. This is causing performance degradation for other GPUs. Fixes #128631 Benchmarking on MI300X: ``` Before: 1918.5126953125 ms After: 0.8285150527954102 ms ``` Co-authored-by: Jeff Daily Pull Request resolved: https://github.com/pytorch/pytorch/pull/128750 Approved by: https://github.com/xw285cornell --- aten/src/ATen/cuda/Atomic.cuh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index 56ee8f87e253..c8f5e91d3ff7 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -334,7 +334,13 @@ static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) /* Special case fp32 atomic. */ #if defined(USE_ROCM) -static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); } +static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { +#if defined(__gfx908__) + atomicAddNoRet(address, val); +#else + (void)unsafeAtomicAdd(address, val); +#endif +} #else static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); } #endif From 2f88597aad145f6b37d0208bc58f087510e14565 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 19 Jun 2024 04:28:27 +0000 Subject: [PATCH 45/64] [inductor] For internal, allow multiple workers if the method is "subprocess" (#129002) Summary: This does not change the current default behavior in fbcode ("fork" if unspecified and no worker processes if unspecified). But it allows us to more easily test the subprocess-based parallel if we override the start method to subprocess. Test Plan: Set `TORCHINDUCTOR_WORKER_START=subprocess` and locally ran all torchbench models listed [here](https://www.internalfb.com/intern/wiki/PyTorch/Teams/PyTorch_Perf_Infra/TorchBench/#torchbench-internal-mode) Differential Revision: D58755021 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129002 Approved by: https://github.com/eellison --- torch/_inductor/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 6be72fb8ea20..5e0c64c03197 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -449,12 +449,14 @@ def decide_compile_threads(): Here are the precedence to decide compile_threads 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by setting this to 1 to make pdb happy. - 2. Set to 1 if it's win32 platform or it's a fbcode build + 2. Set to 1 if it's win32 platform 3. decide by the number of CPU cores """ if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) - elif sys.platform == "win32" or is_fbcode(): + elif sys.platform == "win32": + return 1 + elif is_fbcode() and worker_start_method != "subprocess": return 1 else: cpu_count = ( From fcf2a1378b599003ec8990dae519b726e254825f Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 19 Jun 2024 04:49:39 +0000 Subject: [PATCH 46/64] Enable fp8 rowwise scaling kernel on cuda, TAKE 2: #125204 (#128989) # Summary First PR got reverted and needed a redo This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met: - `x`'s scale should be a 1-dimensional tensor of length `M`. - `y`'s scale should be a 1-dimensional tensor of length `N`. It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row". The following two PRs were required to enable local builds: - [PR #126185](https://github.com/pytorch/pytorch/pull/126185) - [PR #125523](https://github.com/pytorch/pytorch/pull/125523) ### Todo We still do not build our Python wheels with this architecture. @ptrblck @malfet, should we replace `sm_90` with `sm_90a`? The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit: https://github.com/pytorch/pytorch/pull/125204/files#r1586986954 #### ifdef I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this Kernel Credit: @jwfromm Pull Request resolved: https://github.com/pytorch/pytorch/pull/128989 Approved by: https://github.com/yangsiyu007, https://github.com/vkuzo --- aten/src/ATen/CMakeLists.txt | 1 + aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 37 ++ aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h | 15 +- aten/src/ATen/native/cuda/Blas.cpp | 122 ++++- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 536 +++++++++++++++++++ aten/src/ATen/native/cuda/RowwiseScaledMM.h | 15 + test/test_matmul_cuda.py | 175 ++++-- third_party/cutlass.BUILD | 14 +- 8 files changed, 867 insertions(+), 48 deletions(-) create mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.cu create mode 100644 aten/src/ATen/native/cuda/RowwiseScaledMM.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 0087dd95d96e..5cd6aacf2463 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -473,6 +473,7 @@ endif() if(USE_CUDA AND NOT USE_ROCM) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) + list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES} diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index 1b85e7776e22..75c503d48d51 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -170,6 +170,43 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *); CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int); CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction); +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +CUresult CUDAAPI +cuTensorMapEncodeTiled( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) { + auto fn = reinterpret_cast( + getCUDALibrary().sym(__func__)); + if (!fn) + throw std::runtime_error("Can't get cuTensorMapEncodeTiled"); + lazyNVRTC.cuTensorMapEncodeTiled = fn; + return fn( + tensorMap, + tensorDataType, + tensorRank, + globalAddress, + globalDim, + globalStrides, + boxDim, + elementStrides, + interleave, + swizzle, + l2Promotion, + oobFill); +} + +#endif + // Irregularly shaped functions CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 574b2c41c264..cb34d10db254 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -59,16 +59,25 @@ namespace at { namespace cuda { _(cuLinkAddData) \ _(cuLinkComplete) \ _(cuFuncSetAttribute) \ - _(cuFuncGetAttribute) + _(cuFuncGetAttribute) \ + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +#define AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) \ + _(cuTensorMapEncodeTiled) +#else +#define AT_FORALL_NVRTC_EXTENDED(_) \ + AT_FORALL_NVRTC_BASE(_) +#endif #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_BASE(_) \ + AT_FORALL_NVRTC_EXTENDED(_) \ _(nvrtcGetCUBINSize) \ _(nvrtcGetCUBIN) #else #define AT_FORALL_NVRTC(_) \ - AT_FORALL_NVRTC_BASE(_) + AT_FORALL_NVRTC_EXTENDED(_) #endif #else diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index ff8eb60b290b..7d796c3d67e2 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1,3 +1,7 @@ +#include +#include +#include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -10,6 +14,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -819,24 +824,106 @@ static bool _scaled_mm_allowed_device() { #endif } +namespace{ + +enum class ScalingType { + TensorWise, + RowWise, + Error +}; +/* + * Scaling Type Determination: + * --------------------------- + * Conditions and corresponding Scaling Types: + * + * - If scale_a.numel() == 1 && scale_b.numel() == 1: + * - Returns TensorWise. + * + * - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: + * - Returns RowWise. + * + * - Otherwise: + * - Returns Error. + */ + +// Validates the scale tensors to scaled_mm +// And returns the type of scaling/which kernel to use +ScalingType get_scaling_type( + const at::Tensor& scale_a, + const at::Tensor& scale_b, + int64_t dim_m, + int64_t dim_n) { + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); + + + // Check the singluar scale case for per-tensor scaling + if (scale_a.numel() == 1 && scale_b.numel() == 1) { + return ScalingType::TensorWise; + } else if (scale_a.dim() == 1 && scale_a.size(0) == dim_m) { +// Check the per-row scaling case +#if !defined(USE_ROCM) && !defined(_MSC_VER) || \ + (defined(USE_ROCM) && ROCM_VERSION >= 60000) + TORCH_CHECK( + scale_a.dim() == 1 && scale_b.dim() == 1, + "Both scale_a and scale_b must be 1-dimensional tensors"); + TORCH_CHECK( + scale_b.size(0) == dim_n, + "For row-wise scaling, scale_b must have size ", + dim_n, + " but got ", + scale_b.size(0), + "."); + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "Both scale_a and scale_b must be contiguous."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && + // ROCM_VERSION >= 60000) + } else { + // Prettier Error Case messaging + TORCH_CHECK( + false, + "For row-wise scaling, scale_a must be size ", + dim_m, + " but got ", + scale_a.numel(), + " and scale_b must be size ", + dim_n, + " but got ", + scale_b.numel(), + "."); + // Unreachable + return ScalingType::RowWise; + } + return ScalingType::Error; +} + +} // namespace + // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. // If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. // Known limitations: // - Only works if mat1 is row-major and mat2 is column-major // - Only works if matrices sizes are divisible by 32 -// +// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0) +// and scale_b should have size = to mat2.size(1) // Arguments: // - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type -// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type -// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type -// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type +// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation // - `out`: a reference to the output tensor -// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace Tensor& _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, @@ -855,10 +942,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK( mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - TORCH_CHECK((scale_a.numel() == 1 && scale_a.scalar_type() == kFloat), - "scale_a must be float scalar"); - TORCH_CHECK((scale_b.numel() == 1 && scale_b.scalar_type() == kFloat), - "scale_b must be a float scalar"); + + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); + TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], @@ -899,11 +987,25 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, {scale_result_, "scale_result", 6}}; checkAllSameGPU(__func__, targs); } - + // Validation checks have passed lets resize the output to actual size IntArrayRef mat1_sizes = mat1.sizes(); IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + // We are doing row-wise scaling + if (scaling_choice == ScalingType::RowWise) { + TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling."); + at::cuda::detail::f8f8bf16_rowwise( + mat1, + mat2, + scale_a, + scale_b, + bias, + use_fast_accum, + out); + return out; + } + cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu new file mode 100644 index 000000000000..84655d281afc --- /dev/null +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -0,0 +1,536 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +// Determine if the architecture supports rowwise scaled mm +// Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571 +#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + +#define BUILD_ROWWISE_FP8_KERNEL +#endif + +#if defined(BUILD_ROWWISE_FP8_KERNEL) + +// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader +static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) { + return at::globalContext().getNVRTC().cuTensorMapEncodeTiled( + tensorMap, + tensorDataType, + tensorRank, + globalAddress, + globalDim, + globalStrides, + boxDim, + elementStrides, + interleave, + swizzle, + l2Promotion, + oobFill); +} + + +#include +#include +#include +#include +#include +#include +#include + +// Rename the global function symbol +#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled +#include +#undef cuTensorMapEncodeTiled +// Set everything back to normal + +#include +#include +#include + +#include +#include +#include +#include + + +namespace { +// Cutlass rowwise kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM, + bool USE_BIAS, + typename INPUT_DTYPE, + typename BIAS_DTYPE> +void f8f8bf16_rowwise_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + c10::optional bias, + at::Tensor out) { + int M = XQ.size(0); + int N = WQ.size(1); + int K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK( + WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) && + WQ.stride(0) == 1); + + // auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementInputA = INPUT_DTYPE; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementBias = BIAS_DTYPE; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementBias, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + cute::conditional_t< // Second stage output type. + USE_BIAS, + ElementBias, + ElementOutput>, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementOutput, // Final (optional) stage output type. + ElementBias, // Final stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeBias = + cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using SlowAccum = cute::conditional_t; + using FastAccum = + cute::conditional_t; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, 1)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, 1)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(M, N, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + {reinterpret_cast(XQ.data_ptr()), + stride_a, + reinterpret_cast(WQ.data_ptr()), + stride_b}, + {{}, // Epilogue thread we populate below. + (ElementOutput*)out.data_ptr(), + stride_output, + (ElementOutput*)out.data_ptr(), + stride_output}}; + + if constexpr (USE_BIAS) { + arguments.epilogue.thread = { + {reinterpret_cast(bias.value().data_ptr())}, // bias + // compute_1 + { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }, + {}, // Plus + }; + } else { + arguments.epilogue.thread = { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +// FP8 Rowwise Cutlass kernel dispatch. +enum class KernelMode { Small, Large, Default }; + +KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { + auto M = XQ.size(0); + auto K = XQ.size(1); + auto N = WQ.size(0); + // Use a large kernel if at least two shapes are large.... + bool use_large_kernel = + ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || + (K >= 2048 && N >= 2048)); + if (M <= 128 || N <= 128) { + return KernelMode::Small; + } else if (use_large_kernel) { + return KernelMode::Large; + } else { + return KernelMode::Default; + } +} + +template +void dispatch_fp8_rowwise_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + c10::optional bias, + at::Tensor out) { + KernelMode kernel = get_kernel_mode(XQ, WQ); + if (kernel == KernelMode::Small) { + return f8f8bf16_rowwise_impl< + 64, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, out); + } +} + +} // namespace + +#endif // !defined(USE_ROCM) + +namespace at::cuda::detail { +void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + c10::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out) { +#if defined(BUILD_ROWWISE_FP8_KERNEL) + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + if (bias.has_value()) { + TORCH_CHECK( + bias.value().dtype() == at::kFloat || + bias.value().dtype() == at::kBFloat16, + "Bias type must be bfloat16 or float32 if provided."); + } + // Extract problem size. + int M = XQ.size(0); + int N = WQ.size(1); + int K = XQ.size(1); + + bool use_bias = bias.has_value(); + bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; + + // Templatize based on input dtype. + bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; + TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For row-wise scaling the second input is required to be a float8_e4m3fn dtype."); + + if (use_bias) { + if (bf16_bias) { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + } +#else // BUILD_ROWWISE_FP8_KERNEL + TORCH_CHECK(false, "Rowwise scaling is not currenlty supported on your device"); +#endif +} + +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.h b/aten/src/ATen/native/cuda/RowwiseScaledMM.h new file mode 100644 index 000000000000..4d9054108c85 --- /dev/null +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + c10::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // at::cuda::detail diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 83e0f9e80a68..14f0cb60a77b 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -204,7 +204,6 @@ def _expand_to_batch(t: torch.Tensor): self.assertEqual(out1_gpu, out2_gpu[0]) - f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" if torch.version.hip: @@ -256,8 +255,12 @@ def amax_to_scale( scale.copy_(res) return scale -def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): - amax = torch.max(torch.abs(x)) +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): + if dim is None: + amax = torch.max(torch.abs(x)) + else: + amax = torch.max(torch.abs(x), dim=dim).values + return amax_to_scale(amax, float8_dtype, x.dtype) def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: @@ -316,35 +319,12 @@ def mm_float8( def to_fp8_saturated( x: torch.Tensor, - x_scale: torch.tensor, fp8_dtype: torch.dtype ): - """ - Converts a tensor to a saturated fp8 tensor. - - Args: - a: Input Tensor. - b: Input Tensor. - a_scale: scale associated with `a`. - b_scale: scale associated with `b`. - output_dtype: dtype of result. - output_scale: the output tensor's scale, precomputed. - - Returns: - (torch.Tensor, torch.Tensor): (result of the matrix multiplication, associated amax) - Note: - The default behavior in PyTorch for casting to `e4m3_type` - and `e5m2_type` is to not saturate. In this context, we should - saturate. A common case where we want to saturate is when the history - of a tensor has a maximum value of `amax1`, and the current amax value - is `amax2`, where `amax1 < amax2`. - """ - x_scaled = x * x_scale - if fp8_dtype == e4m3_type: - x = x_scaled.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) elif fp8_dtype == e5m2_type: - x = x_scaled.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) else: raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") @@ -353,8 +333,6 @@ def to_fp8_saturated( @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") class TestFP8MatmulCuda(TestCase): - - @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, @@ -418,8 +396,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype): x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() - x_fp8 = to_fp8_saturated(x, x_scale, input_dtype) - y_fp8 = to_fp8_saturated(y, y_scale, input_dtype) + x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) + y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) # Calculate actual F8 mm out_scaled_mm = mm_float8( @@ -449,7 +427,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype): if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: - atol, rtol = 2e-3, 2e-3 + atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) @@ -535,6 +513,137 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) + @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) + @skipIfRocm() + @parametrize("use_fast_accum", [True, False]) + def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_scales = torch.ones(x.shape[0], device=device, dtype=torch.float32) + y_scales = torch.ones(y.shape[0], device=device, dtype=torch.float32) + + x_fp8 = x.to(torch.float8_e4m3fn) + y_fp8 = y.to(torch.float8_e4m3fn).t() + + out_fp8 = torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=x_scales, + scale_b=y_scales, + out_dtype=torch.bfloat16, + use_fast_accum=use_fast_accum, + ) + self.assertEqual( + out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) + ) + + @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) + @skipIfRocm() + def test_float8_error_messages(self, device) -> None: + M, K, N = (1024, 512, 2048) + fill_value = 0.5 + x = torch.full((M, K), fill_value, device=device) + y = torch.full((N, K), fill_value, device=device) + + x_fp8 = x.to(torch.float8_e4m3fn) + y_fp8 = y.to(torch.float8_e4m3fn).t() + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling, scale_a must be size 1024 but got 1 and scale_b must be size 2048 but got 2", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((), device="cuda"), + scale_b=torch.ones((2), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling, scale_b must have size 2048 but got 2049.", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N + 1), device="cuda"), + out_dtype=torch.bfloat16, + ) + with self.assertRaisesRegex( + RuntimeError, + "Both scale_a and scale_b must be 1-dimensional tensors", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N, N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "Both scale_a and scale_b must be contiguous.", + ): + torch._scaled_mm( + x_fp8, + y_fp8, + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N * 2), device="cuda")[::2], + out_dtype=torch.bfloat16, + ) + + with self.assertRaisesRegex( + RuntimeError, + "For row-wise scaling the second input is required to be a float8_e4m3fn dtype.", + ): + torch._scaled_mm( + x_fp8, + y_fp8.to(torch.float8_e5m2), + scale_a=torch.ones((M), device="cuda"), + scale_b=torch.ones((N), device="cuda"), + out_dtype=torch.bfloat16, + ) + + @unittest.skipIf(not scaled_mm_supported_device() or IS_WINDOWS, f8_msg) + @skipIfRocm() + @parametrize("base_dtype", [torch.bfloat16]) + def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): + torch.manual_seed(42) + input_dtype = e4m3_type + output_dtype = base_dtype + + x = torch.randn(16, 16, device="cuda", dtype=base_dtype) + y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + + x_scales = tensor_to_scale(x, input_dtype, dim=1).float() + y_scales = tensor_to_scale(y, input_dtype, dim=0).float() + + x_fp8 = to_fp8_saturated(x * x_scales[:, None], e4m3_type) + y_fp8 = to_fp8_saturated(y * y_scales[None, :], e4m3_type) + + # Calculate actual F8 mm + out_scaled_mm = mm_float8( + x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated = mm_float8_emulated( + x_fp8, x_scales[:, None], y_fp8, y_scales[None, :], output_dtype + ) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + + torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e712d59597cc..e3e7b7b288e7 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -5,7 +5,17 @@ load("@rules_cc//cc:defs.bzl", "cc_library") cc_library( name = "cutlass", - hdrs = glob(["include/**/*.h", "include/**/*.hpp"]), - includes = ["include/"], + hdrs = glob([ + "include/**/*.h", + "include/**/*.hpp", + "include/**/*.inl", + "tools/util/include/**/*.h", + "tools/util/include/**/*.hpp", + "tools/util/include/**/*.inl", + ]), + includes = [ + "include/", + "tools/util/include/", + ], visibility = ["//visibility:public"], ) From a584b2a389b9fcaa1cd4d92dcc6914f9cf92491b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 04:59:10 +0000 Subject: [PATCH 47/64] Revert "Add test to xfail_list only for abi_compatible (#128506)" This reverts commit df85f34a14dd30f784418624b05bd52b12ab8b0b. Reverted https://github.com/pytorch/pytorch/pull/128506 on behalf of https://github.com/huydhn due to The failure shows up in trunk https://hud.pytorch.org/pytorch/pytorch/commit/df85f34a14dd30f784418624b05bd52b12ab8b0b ([comment](https://github.com/pytorch/pytorch/pull/128506#issuecomment-2177744578)) --- test/inductor/test_cpu_cpp_wrapper.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 0a2b75ddb554..8bf9b1e6a61f 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -95,9 +95,7 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): "test_qconv2d_relu_cpu", "test_qlinear_cpu", "test_qlinear_add_cpu", - "test_qlinear_add_relu_cpu", "test_qlinear_dequant_promotion_cpu", - "test_qlinear_gelu_cpu", "test_qlinear_relu_cpu", ] for test_name in xfail_list: @@ -127,6 +125,7 @@ def make_test_case( slow=False, func_inputs=None, code_string_count=None, + skip=None, ): test_name = f"{name}_{device}" if device else name if code_string_count is None: @@ -135,6 +134,8 @@ def make_test_case( func = getattr(tests, test_name) assert callable(func), "not a callable" func = slowTest(func) if slow else func + if skip: + func = unittest.skip(skip)(func) @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): @@ -182,6 +183,7 @@ class BaseTest(NamedTuple): slow: bool = False func_inputs: list = None code_string_count: dict = {} + skip: str = None for item in [ BaseTest("test_add_complex"), @@ -240,7 +242,9 @@ class BaseTest(NamedTuple): torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), - BaseTest("test_linear_packed", "", test_cpu_repro.CPUReproTests()), + BaseTest( + "test_linear_packed", "", test_cpu_repro.CPUReproTests(), skip="Failing" + ), BaseTest( "test_lstm_packed_change_input_sizes", "cpu", @@ -314,18 +318,21 @@ class BaseTest(NamedTuple): "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_add", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_add_relu", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_dequant_promotion", @@ -381,6 +388,7 @@ class BaseTest(NamedTuple): item.slow, item.func_inputs, item.code_string_count, + skip=item.skip, ) test_torchinductor.copy_tests( From 3a185778edb18abfbad155a87ff3b2d716e4c220 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Tue, 18 Jun 2024 14:41:35 -0700 Subject: [PATCH 48/64] [aotinductor] Add torch.polar fallback op for shim v2 (#128722) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Compilation error: ``` $ TORCHINDUCTOR_C_SHIM_VERSION=2 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_LOGS_FORMAT="%(pathname)s:%(lineno)s: %(message)s" TORCH_LOGS="+output_code" python test/inductor/test_cpu_cpp_wrapper.py -k test_polar /tmp/tmp2sp128xj/dy/cdypvu3hvgg3mwxydwbiuddsnmuoi37it3mrpjktcnu6vt4hr3ki.cpp:59:33: error: ‘aoti_torch_cpu_polar’ was not declared in this scope; did you mean ‘aoti_torch_cpu_topk’? ``` Steps: 1. Add aten.polar 2. run `python torchgen/gen.py --update-aoti-c-shim`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128722 Approved by: https://github.com/chenyang78, https://github.com/desertfire --- test/inductor/test_cpu_cpp_wrapper.py | 1 + test/inductor/test_torchinductor.py | 10 ++++++++++ .../test_torchinductor_codegen_dynamic_shapes.py | 1 + torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h | 1 + torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h | 1 + torchgen/aoti/fallback_ops.py | 1 + 6 files changed, 15 insertions(+) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 8bf9b1e6a61f..ee792fb7fa0b 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -235,6 +235,7 @@ class BaseTest(NamedTuple): BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()), BaseTest("test_linear1"), BaseTest("test_linear2"), + BaseTest("test_polar"), BaseTest( "test_linear_binary", "", diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 38b1abc6aa08..15b284abdd13 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4697,6 +4697,16 @@ def fn(x): self.common(fn, (x,)) + def test_polar(self): + def fn(dist, angle): + return torch.polar(dist, angle) + + inp = ( + torch.tensor([1, 2], dtype=torch.float64), + torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64), + ) + self.common(fn, (*inp,)) + def test_cauchy(self): def fn(x, y): return torch.sum(1 / (torch.unsqueeze(x, -1) - y)) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 07d48141a0e2..595c6c52fd3b 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -237,6 +237,7 @@ def run(*ex, **kwargs): "test_pointwise_hermite_polynomial_he_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_polar_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), "test_randn_generator_dynamic_shapes": TestFailure(("cpu",)), "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_single_elem_dynamic_shapes": TestFailure(("cpu",)), diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 770042f477bb..52576b28b6b9 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -94,6 +94,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_native_dropout(AtenTensorHandle AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 1eba22f85c97..cedc09778599 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -101,6 +101,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_native_dropout(AtenTensorHandle AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 9d086cc326a5..e4b110601da2 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -93,6 +93,7 @@ "aten.ormqr.default", "aten._pdist_backward.default", "aten._pdist_forward.default", + "aten.polar.default", "aten.pow.Scalar", "aten.pow.Tensor_Scalar", "aten.pow.Tensor_Tensor", From ba92f5277fe7a736f713b4937b88f01a866d8fdf Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 18 Jun 2024 08:30:26 -0700 Subject: [PATCH 49/64] [inductor][refactor] Unify the use of generate_kernel_call (#128467) Summary: Refactor TritonTemplateKernel.call_kernel and ForeachKernel.call_kernel to use wrapper.generate_kernel_call to generate kernel calls instead of explicitly composing the kernel call string. This consolidates the entry point of generate_kernel_call and similifies later changes in this PR stack. Differential Revision: [D58733631](https://our.internmc.facebook.com/intern/diff/D58733631) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128467 Approved by: https://github.com/shunting314 --- torch/_inductor/codegen/cpp_wrapper_cuda.py | 3 ++ torch/_inductor/codegen/cuda/cuda_kernel.py | 2 - torch/_inductor/codegen/multi_kernel.py | 13 +++---- torch/_inductor/codegen/triton.py | 12 +----- torch/_inductor/codegen/triton_foreach.py | 27 ++++---------- torch/_inductor/codegen/wrapper.py | 41 +++++++++++++++------ torch/_inductor/select_algorithm.py | 36 +++++------------- 7 files changed, 56 insertions(+), 78 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index ad8c8eafbbd1..ed36bdb6df18 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -183,6 +183,9 @@ def generate_kernel_call( name, call_args, grid, device_index, cuda, triton, arg_types ) + device_index, call_args = self.prepare_triton_kernel_call( + device_index, call_args + ) params = CudaKernelParamCache.get(name) assert ( params is not None diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 12b7b21de61e..b6256c0ccd0f 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -177,11 +177,9 @@ def call_kernel( else: call_args.append("None") - current_device = V.graph.scheduler.get_current_device_or_throw() wrapper.generate_kernel_call( name, call_args, - device_index=current_device.index, cuda=True, triton=False, arg_types=arg_types, diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 84279191ceac..d4e348fba74b 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -198,11 +198,12 @@ def call_kernel(self, kernel_name): for the multi-kernel. """ assert kernel_name == self.kernel_name - call_args_list, arg_types = zip( - *[kernel.get_call_args() for kernel in self.kernels] - ) - call_args_list = list(call_args_list) - arg_types_list = list(arg_types) + call_args_list = [] + arg_types_list = [] + for kernel in self.kernels: + _, call_args, _, arg_types = kernel.args.python_argdefs() + call_args_list.append(call_args) + arg_types_list.append(arg_types) all_call_args, arg_types = get_all_call_args(call_args_list, arg_types_list) grid: List[Any] = [] @@ -223,12 +224,10 @@ def call_kernel(self, kernel_name): ) grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid) - current_device = V.graph.scheduler.get_current_device_or_throw() V.graph.wrapper_code.generate_kernel_call( kernel_name, final_call_args, grid, - current_device.index, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 381bbebf3acc..9a098a776e05 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2657,19 +2657,9 @@ def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): if tree.grid_dim is not None: grid.append(expr) - def get_call_args(self): - # arg_types is needed for cpp wrapper codegen - _, call_args, _, arg_types = self.args.python_argdefs() - # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar - for i in range(len(call_args)): - if V.graph.is_unspec_arg(call_args[i]): - call_args[i] = call_args[i] + ".item()" - - return call_args, arg_types - def call_kernel(self, name: str, node: Optional[IRNode] = None): wrapper = V.graph.wrapper_code - call_args, arg_types = self.get_call_args() + _, call_args, _, arg_types = self.args.python_argdefs() grid: List[Any] = [] self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) current_device = V.graph.scheduler.get_current_device_or_throw() diff --git a/torch/_inductor/codegen/triton_foreach.py b/torch/_inductor/codegen/triton_foreach.py index c2d6817fb215..a174d5448b4b 100644 --- a/torch/_inductor/codegen/triton_foreach.py +++ b/torch/_inductor/codegen/triton_foreach.py @@ -229,23 +229,10 @@ def codegen_kernel(self, name=None): def call_kernel(self, code, name: str): _, call_args, _, arg_types = self.args.python_argdefs() - # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar - for i in range(len(call_args)): - if V.graph.is_unspec_arg(call_args[i]): - call_args[i] = call_args[i] + ".item()" - current_device = V.graph.scheduler.get_current_device_or_throw() - if V.graph.cpp_wrapper: - V.graph.wrapper_code.generate_kernel_call( - name, - call_args, - device_index=current_device.index, - grid=self.grid(), - arg_types=arg_types, - ) - else: - # TODO: refactor generate_kernel_call - call_args_str = ", ".join(call_args) - stream_name = code.write_get_raw_stream(current_device.index, V.graph) - code.writeline( - f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})" - ) + V.graph.wrapper_code.generate_kernel_call( + name, + call_args, + grid=self.grid(), + arg_types=arg_types, + grid_fn="", + ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 092dfd4e0b9c..581bfd9249ab 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -673,7 +673,7 @@ def generate_extern_kernel_out( def generate_user_defined_triton_kernel( self, kernel_name, grid, configs, args, triton_meta, arg_types=None ): - grid, code = user_defined_kernel_grid_fn_code( + grid_fn, code = user_defined_kernel_grid_fn_code( kernel_name, configs, grid, wrapper=self ) # Must happen after free symbols are already codegened @@ -681,11 +681,7 @@ def generate_user_defined_triton_kernel( for line in code.split("\n"): self.writeline(line) - current_device = V.graph.scheduler.get_current_device_or_throw() - stream_name = self.write_get_raw_stream(current_device.index, V.graph) - self.writeline( - f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})" - ) + self.generate_kernel_call(kernel_name, args, grid_fn=grid_fn) def generate_scatter_fallback( self, @@ -1363,6 +1359,24 @@ def generate_save_uncompiled_kernels(self): def generate_default_grid(self, name: str, grid_args: List[Any]): return grid_args + def prepare_triton_kernel_call(self, device_index, call_args): + def wrap_arg(arg): + if isinstance(arg, str): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + return arg + ".item()" if V.graph.is_unspec_arg(arg) else arg + elif isinstance(arg, (int, float, bool, SymbolicCallArg)): + return str(arg) + else: + return pexpr(V.graph.sizevars.simplify(arg)) + + call_args = [wrap_arg(arg) for arg in call_args] + + if device_index is None: + current_device = V.graph.scheduler.get_current_device_or_throw() + device_index = current_device.index + + return device_index, call_args + def generate_kernel_call( self, name, @@ -1385,12 +1399,17 @@ def generate_kernel_call( Only valid when cuda == True. """ if cuda: - call_args_str = ", ".join(pexpr(item) for item in call_args) - current_device = V.graph.scheduler.get_current_device_or_throw() - stream_name = self.write_get_raw_stream(current_device.index, V.graph) + device_index, call_args = self.prepare_triton_kernel_call( + device_index, call_args + ) + call_args_str = ", ".join(call_args) + stream_name = self.write_get_raw_stream(device_index, V.graph) if triton: - grid_str = ", ".join(pexpr(item) for item in grid) - grid_str = f"{grid_fn}({grid_str})" + if grid is None: + grid_str = grid_fn + else: + grid_str = ", ".join(pexpr(item) for item in grid) + grid_str = f"{grid_fn}({grid_str})" self.writeline( f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index fb43e7da1d13..612bddb68b86 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -42,7 +42,6 @@ ) from .codegen.triton_utils import config_of, signature_to_meta -from .codegen.wrapper import pexpr from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType from .runtime.hints import DeviceProperties @@ -529,47 +528,30 @@ def codegen_range_tree(self): def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): wrapper = V.graph.wrapper_code _, call_args, _, arg_types = self.args.python_argdefs() - call_args = [str(a) for a in call_args] - - for i in range(len(call_args)): - if V.graph.is_unspec_arg(call_args[i]): - call_args[i] = call_args[i] + ".item()" - if isinstance(call_args[i], sympy.Symbol): - call_args[i] = texpr(call_args[i]) - - current_device = V.graph.scheduler.get_current_device_or_throw() - if V.graph.cpp_wrapper: # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime # if any dynamic dimension is involved. We rely on the Python version # of the grid function to generate those grid configs, which may contain # symbolic values. The wrapper will use cexpr to print out C++ code # appropriately for the grid configs. - grid_args = [V.graph.sizevars.simplify(s) for s in self.call_sizes] + [ - self.meta - ] - grid = self.grid_fn(*grid_args) - + grid = self.call_sizes + [self.meta] wrapper.generate_kernel_call( name, call_args, - device_index=current_device.index, + grid=self.grid_fn(*grid), arg_types=arg_types, - grid=grid, triton_meta=self.triton_meta, ) else: - stream_name = wrapper.write_get_raw_stream(current_device.index, V.graph) - wrapper.add_import_once(f"import {self.grid_fn.__module__}") meta = wrapper.add_meta_once(self.meta) - - grid_call = [ - pexpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes - ] + [meta] - grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})" - wrapper.writeline( - f"{name}.run({', '.join(call_args)}, grid={grid_call}, stream={stream_name})" + grid = self.call_sizes + [meta] + wrapper.generate_kernel_call( + name, + call_args, + grid=grid, + grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}", + triton_meta=self.triton_meta, ) From d3e8b8bf47206c27b6c5fdc021f7c2c3a8009521 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Wed, 19 Jun 2024 08:09:31 +0000 Subject: [PATCH 50/64] Remove cuda check in the CUDAGraph destructor (#127382) Fixes #125804 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127382 Approved by: https://github.com/eqy, https://github.com/eellison --- aten/src/ATen/cuda/CUDAGeneratorImpl.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index 7e19ce98fbf9..c75275baa8c8 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -152,9 +152,6 @@ void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) { * Unregisters a CUDA graph from the RNG state. */ void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) { - // Ensures that the RNG state is not currently being captured. - at::cuda::assertNotCapturing( - "Cannot unregister the state during capturing stage."); // Verify the graph was previously registered. TORCH_CHECK( registered_graphs_.find(graph) != registered_graphs_.end(), From 50567f7081bd82c081d2d3f27f5753139b533918 Mon Sep 17 00:00:00 2001 From: Daulet Askarov Date: Wed, 19 Jun 2024 08:50:46 +0000 Subject: [PATCH 51/64] Pass device to is_pinned call inside TensorProperties.create_from_tensor (#128896) Summary: The default input device for is_pinned function is Cuda. This can unnecessarily create Cuda context for CPU tensors when just generating TensorProperties, bloating memory usage. Passing the device to the is_pinned call site inside def create_from_tensor solves this issue. This also fixes Model Store test https://www.internalfb.com/intern/test/844425019931542?ref_report_id=0 which is currently broken on memory usage assertions. Test Plan: UT Differential Revision: D58695006 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128896 Approved by: https://github.com/fegin --- test/distributed/checkpoint/test_utils.py | 9 +++++++++ torch/distributed/_shard/sharded_tensor/metadata.py | 2 +- torch/distributed/checkpoint/metadata.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py index 78d97f069958..5e3d60cc4cf4 100644 --- a/test/distributed/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import sys +from unittest.mock import MagicMock import torch @@ -123,5 +124,13 @@ def test_sharded_tensor_lookup(self): find_state_dict_object(state_dict, MetadataIndex("st", [1])) +class TestTensorProperties(TestCase): + def test_create_from_tensor_correct_device(self): + t = torch.randn([10, 2], device="cpu") + t.is_pinned = MagicMock(return_value=True) + TensorProperties.create_from_tensor(t) + t.is_pinned.assert_called_with(device=torch.device("cpu")) + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_shard/sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py index e53ac25fa55d..ee519de4500f 100644 --- a/torch/distributed/_shard/sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -76,7 +76,7 @@ def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned(), + pin_memory=tensor.is_pinned(device=tensor.device), ) diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py index d1f87e2d9cba..335ada4fa787 100644 --- a/torch/distributed/checkpoint/metadata.py +++ b/torch/distributed/checkpoint/metadata.py @@ -105,7 +105,7 @@ def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned(), + pin_memory=tensor.is_pinned(device=tensor.device), ) From 7fac03aee94adffb3799939b5746ffd5b675b5f9 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Wed, 19 Jun 2024 03:59:07 -0500 Subject: [PATCH 52/64] [ALI] Use lf runners for Lint (#128978) --- .github/workflows/lint.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e0e4d3c20cd8..2bbb6b1ab0ea 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -19,10 +19,10 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 - runner: linux.2xlarge + runner: lf.linux.2xlarge docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout - # to run git rev-parse HEAD~:.ci/docker when a new image is needed + # to run git rev-parse HEAD~:.ci/docker when a new image is needed. fetch-depth: 0 submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -35,7 +35,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 - runner: linux.2xlarge + runner: lf.linux.2xlarge docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed @@ -49,7 +49,7 @@ jobs: quick-checks: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: - runner: linux.2xlarge + runner: lf.linux.2xlarge docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} From e49525275df0ed5b19e676d6897668ee77b8cfb4 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 19 Jun 2024 09:58:46 +0800 Subject: [PATCH 53/64] Make TraceUtils.h to be device-agnostic (#126969) Some features of third-party devices depend on TraceUtils.h, so some of the CUDA code was removed and split into NCCLUtils files. In addition, some common functions still remain in TraceUtils.h since I'm not sure if other devices will use them later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126969 Approved by: https://github.com/c-p-i-o --- torch/csrc/distributed/c10d/NCCLUtils.cpp | 50 +++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 445 ++++++++++++++++++- torch/csrc/distributed/c10d/TraceUtils.h | 497 ---------------------- 3 files changed, 494 insertions(+), 498 deletions(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index d3a997625e14..931e66a2c42e 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -289,6 +289,56 @@ control_plane::RegisterHandler dumpHandler{ "application/octet-stream"); }}; +void DebugInfoWriter::write(const std::string& ncclTrace) { + // Open a file for writing. The ios::binary flag is used to write data as + // binary. + std::ofstream file(filename_, std::ios::binary); + + // Check if the file was opened successfully. + if (!file.is_open()) { + LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " + << filename_; + return; + } + + file.write(ncclTrace.data(), ncclTrace.size()); + LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; +} + +DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { + if (writer_ == nullptr) { + std::string fileNamePrefix = getCvarString( + {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); + // Using std::unique_ptr here to auto-delete the writer object + // when the pointer itself is destroyed. + std::unique_ptr writerPtr( + new DebugInfoWriter(fileNamePrefix, rank)); + DebugInfoWriter::registerWriter(std::move(writerPtr)); + } + return *writer_; +} + +void DebugInfoWriter::registerWriter(std::unique_ptr writer) { + TORCH_CHECK_WITH( + DistBackendError, + hasWriterRegistered_.load() == false, + "debugInfoWriter already registered"); + hasWriterRegistered_.store(true); + writer_ = std::move(writer); +} + +std::unique_ptr DebugInfoWriter::writer_ = nullptr; +std::atomic DebugInfoWriter::hasWriterRegistered_(false); + +float getDurationFromEvent( + at::cuda::CUDAEvent& ncclStartEvent, + at::cuda::CUDAEvent& ncclEndEvent) { + TORCH_CHECK( + ncclEndEvent.query(), + "getDuration can only be called after work is succeeded.") + return ncclStartEvent.elapsed_time(ncclEndEvent); +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 9ce25b55dc13..626e97dc86b4 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -10,9 +10,11 @@ #include #include +#include #include #include #include +#include #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 14) @@ -172,6 +174,39 @@ namespace c10d { +static c10::IValue entries_key = "entries"; +static c10::IValue nccl_comm_key = "nccl_comm_state"; +static c10::IValue version_key = "version"; +// Update whenever changing contents or formatting of the dump +// (minor when adding fields, major when changing existing fields) +static c10::IValue version_val = "2.2"; +static c10::IValue pg_config_key = "pg_config"; +static c10::IValue record_id_key = "record_id"; +static c10::IValue pg_id_key = "pg_id"; +static c10::IValue pg_name_key = "process_group"; +static c10::IValue collective_seq_id_key = "collective_seq_id"; +static c10::IValue p2p_seq_id_key = "p2p_seq_id"; +static c10::IValue is_p2p_key = "is_p2p"; +static c10::IValue op_id_key = "op_id"; +static c10::IValue profiling_name_key = "profiling_name"; +static c10::IValue input_sizes_key = "input_sizes"; +static c10::IValue input_dtypes_key = "input_dtypes"; +static c10::IValue output_sizes_key = "output_sizes"; +static c10::IValue output_dtypes_key = "output_dtypes"; +static c10::IValue time_created_key = "time_created_ns"; +static c10::IValue duration_key = "duration_ms"; +static c10::IValue timeout_key = "timeout_ms"; + +static c10::IValue frames_key = "frames"; +static c10::IValue state_key = "state"; +static c10::IValue line_key = "line"; +static c10::IValue name_key = "name"; +static c10::IValue filename_key = "filename"; +static c10::IValue retired_key = "retired"; +static c10::IValue time_discovered_started_key = "time_discovered_started_ns"; +static c10::IValue time_discovered_completed_key = + "time_discovered_completed_ns"; + TORCH_API size_t hashTensors(const std::vector& tensors); TORCH_API std::string getNcclVersion(); TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); @@ -195,7 +230,7 @@ TORCH_API std::string getNcclErrorDetailStr( // auto-registered). class TORCH_API DebugInfoWriter { public: - virtual ~DebugInfoWriter(); + virtual ~DebugInfoWriter() = default; virtual void write(const std::string& ncclTrace); static DebugInfoWriter& getWriter(int rank); static void registerWriter(std::unique_ptr writer); @@ -518,6 +553,414 @@ struct ncclRedOpRAII { bool premul_sum_ = false; }; +/* Helper used by work::getDuration() and nccl flight recorder */ +float getDurationFromEvent( + at::cuda::CUDAEvent& ncclStartEvent, + at::cuda::CUDAEvent& ncclEndEvent); + +struct NCCLTraceBuffer { + static NCCLTraceBuffer* get() { + // intentionally leak on exit + // because this will hold python state that may get destructed + static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); + return instance; + } + NCCLTraceBuffer() { + max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); + capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); + enabled_ = max_entries_ > 0; + } + using Event = at::cuda::CUDAEvent; + struct Entry { + size_t id_; // incremented id in the trace buffer + // used to figure out where in the circular entries + // buffer this entry will be located to + // update state information + size_t pg_id_; + std::tuple pg_name_; // + + // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 + // per coalesced group). + // collective_seq_id only increments for true collective operations (over + // all ranks in the group). p2p_seq_id only increments over non-collective + // operations in the group. op_id refers to logical operations (e.g. one per + // op inside coalesced group) + size_t collective_seq_id_; + size_t p2p_seq_id_; + size_t op_id_; + std::string profiling_name_; + + std::shared_ptr traceback_; + // we borrow pointers to start_ and end_ so we can query the state + // on reporting. However, once the event is completed, the call + // to `complete` will clear these. + Event *start_, *end_; + + // timestamp when the entry was created, likely close to the time the work + // was 'enqueued'- not necessarily started + c10::time_t time_created_; + + // configured timeout for this entry + c10::time_t timeout_ms_; + + // Is this a P2P event? + bool isP2P_; + + std::optional duration_; + + // timestamp when our CPU threads discovered that the kernel started. + // will always be _after_ it actually started, and can be very late + // if the watchdog thread got stuck on CUDA APIs. + std::optional time_discovered_started_; + + // timestamp when our CPU threads discovered that the kernel completed. + // will always be _after_ it actually complated, and can be the same time + // as the discovery of the start if the watchdog thread is stuck on CUDA + // APIs + std::optional time_discovered_completed_; + + // size information for input/output tensors + c10::SmallVector input_dims_; + std::vector input_dtypes_; + c10::SmallVector output_dims_; + std::vector output_dtypes_; + c10::SmallVector sizes_; // flattened from inputs, outputs + bool retired_ = false; // is this work entry no longer in the workMetaList_? + // a retired but not completed event has timed out + }; + + bool enabled_ = false; + bool capture_cpp_stack_ = false; + std::mutex mutex_; + std::vector entries_; + size_t max_entries_ = 0; + size_t next_ = 0; + size_t id_ = 0; + std::map, std::vector> + pg_name_to_ranks_ = {}; + + std::optional record( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + Event* start, + Event* end, + std::chrono::milliseconds timeout_ms, + bool isP2P) { + if (!enabled_) { + return c10::nullopt; + } + auto traceback = + torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); + std::lock_guard guard(mutex_); + + auto te = Entry{ + id_, + pg_id, + pg_name, + collective_seq_id, + p2p_seq_id, + op_id, + std::move(profiling_name), + std::move(traceback), + std::move(start), + std::move(end), + c10::getTime(), + timeout_ms.count(), + isP2P, + std::nullopt, + std::nullopt, + std::nullopt, + {}, + {}, + {}, + {}, + {}, + false}; + + for (const auto& input : inputs) { + c10::IntArrayRef sizes = input.sizes(); + te.input_dtypes_.push_back(input.dtype().toScalarType()); + te.input_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + for (const auto& output : outputs) { + c10::IntArrayRef sizes = output.sizes(); + te.output_dtypes_.push_back(output.dtype().toScalarType()); + te.output_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + if (entries_.size() < max_entries_) { + entries_.emplace_back(std::move(te)); + } else { + entries_[next_++] = std::move(te); + if (next_ == max_entries_) { + next_ = 0; + } + } + return id_++; + } + + void record_pg_ranks( + const std::tuple& pg_name, + std::vector ranks) { + if (!enabled_) { + return; + } + std::lock_guard guard(mutex_); + pg_name_to_ranks_[pg_name] = ranks; + } + + void update_state(Entry& r) { + if (r.start_ != nullptr) { + bool started = r.start_->query(); + if (started && !r.time_discovered_started_) { + r.time_discovered_started_ = c10::getTime(); + } + } + if (r.end_ != nullptr) { + bool completed = r.end_->query(); + if (completed && !r.time_discovered_completed_) { + r.time_discovered_completed_ = c10::getTime(); + } + } + } + + std::vector dump_entries() { + std::lock_guard guard(mutex_); + std::vector result; + result.reserve(entries_.size()); + result.insert(result.end(), entries_.begin() + next_, entries_.end()); + result.insert(result.end(), entries_.begin(), entries_.begin() + next_); + // query any remaining events + for (auto& r : result) { + update_state(r); + r.start_ = r.end_ = nullptr; + } + return result; + } + + /* + Mark an Event as completed and free its events. + This is called by the watchdog thread, and is asynchronous from the + perspective of the main thread. + compute_duration defaults to true since retire_id is only called in the + watchdog thread, which is currently a place we call cuda APIs which may hang, + but care should be taken to avoid computing duration in any function that must + never hang. (timing must also be enabled for compute_duration - see + TORCH_NCCL_ENABLE_TIMING). + */ + void retire_id(std::optional id, bool compute_duration = true) { + if (!enabled_ || !id) { + return; + } + + bool can_compute_duration = false; + Event* startEvent = nullptr; + Event* endEvent = nullptr; + std::optional duration = c10::nullopt; + + std::unique_lock guard(mutex_); + + Entry* entry = &entries_.at(*id % max_entries_); + if (entry->id_ == *id) { + update_state(*entry); + + if (compute_duration) { + can_compute_duration = entry->time_discovered_completed_.has_value() && + entry->start_ && entry->end_; + startEvent = entry->start_; + endEvent = entry->end_; + } + } + + if (can_compute_duration) { + // Compute duration without without holding the lock, because + // cudaEventDuration() can hang, and we need to acquire the lock before we + // can dump(), which we never want to block. + guard.unlock(); + duration = getDurationFromEvent(*startEvent, *endEvent); + guard.lock(); + + // Refresh the entry pointer, see if the entry has been overwritten + entry = &entries_.at(*id % max_entries_); + if (entry->id_ != *id) { + LOG(INFO) + << "retire_id abandoned for id " << *id + << ", event was overwritten while waiting to compute duration."; + return; + } + if (duration.has_value()) { + entry->duration_ = duration.value(); + } + } + + entry->retired_ = true; + entry->start_ = entry->end_ = nullptr; + } + + const c10::List getCollectiveTrace( + bool includeStacktraces, + bool onlyActive) { + auto entries = new_list(); + auto result = dump_entries(); + std::vector tracebacks; + torch::SymbolizedTracebacks stracebacks; + std::vector all_frames; + if (includeStacktraces) { + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + stracebacks = torch::symbolize(tracebacks); + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } + } + for (auto i : c10::irange(result.size())) { + auto dict = new_dict(); + auto& e = result.at(i); + // Skip completed events + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + + if (includeStacktraces) { + auto& tb = stracebacks.tracebacks.at(i); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); + } + + dict.insert(record_id_key, int64_t(e.id_)); + dict.insert(pg_id_key, int64_t(e.pg_id_)); + dict.insert(pg_name_key, e.pg_name_); + dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); + dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); + dict.insert(op_id_key, int64_t(e.op_id_)); + dict.insert(profiling_name_key, e.profiling_name_); + dict.insert(time_created_key, int64_t(e.time_created_)); + if (e.duration_) { + dict.insert(duration_key, *e.duration_); + } + + auto it = e.sizes_.begin(); + auto read_sizes = [&](const c10::SmallVector& dims) { + auto sizes = new_list(); + for (auto dim : dims) { + auto arg_sizes = new_list(); + for (auto i : c10::irange(dim)) { + (void)i; + arg_sizes.push_back(*it++); + } + sizes.push_back(arg_sizes); + } + return sizes; + }; + + dict.insert(input_sizes_key, read_sizes(e.input_dims_)); + std::vector input_dtypes_strs; + input_dtypes_strs.reserve(e.input_dtypes_.size()); + for (const auto& input_dtype : e.input_dtypes_) { + input_dtypes_strs.push_back(c10::toString(input_dtype)); + } + dict.insert(input_dtypes_key, input_dtypes_strs); + dict.insert(output_sizes_key, read_sizes(e.output_dims_)); + std::vector output_dtypes_strs; + output_dtypes_strs.reserve(e.output_dtypes_.size()); + for (const auto& output_dtype : e.output_dtypes_) { + output_dtypes_strs.push_back(c10::toString(output_dtype)); + } + dict.insert(output_dtypes_key, output_dtypes_strs); + if (e.time_discovered_completed_.has_value()) { + dict.insert(state_key, "completed"); + } else if (e.time_discovered_started_.has_value()) { + dict.insert(state_key, "started"); + } else { + dict.insert(state_key, "scheduled"); + } + + dict.insert( + time_discovered_started_key, + e.time_discovered_started_.has_value() + ? int64_t(*e.time_discovered_started_) + : c10::IValue()); + dict.insert( + time_discovered_completed_key, + e.time_discovered_completed_.has_value() + ? int64_t(*e.time_discovered_completed_) + : c10::IValue()); + dict.insert(retired_key, e.retired_); + dict.insert(timeout_key, e.timeout_ms_); + dict.insert(is_p2p_key, e.isP2P_); + + entries.push_back(dict); + } + return entries; + } + + // dump pg_entries + const c10::Dict getPgConfig() { + auto pg_config = new_dict(); + for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { + auto pg_info = new_dict(); + pg_info.insert("name", std::get<0>(pg_name)); + pg_info.insert("desc", std::get<1>(pg_name)); + pg_info.insert("ranks", ranks_str(ranks)); + pg_config.insert(std::get<0>(pg_name), pg_info); + } + return pg_config; + } + + // dump all collectives + ncclDumpMap + std::string dump( + const std::optional>>& ncclDumpMap, + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto result = new_dict(); + // common values + result.insert(version_key, version_val); + result.insert(pg_config_key, getPgConfig()); + + // collective trace + if (includeCollectives) { + result.insert( + entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); + } + + // convert ncclDumpMap into a dictionary + auto per_comm_dict = new_dict(); + if (ncclDumpMap.has_value()) { + for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { + auto inner_dict = new_dict(); + for (const auto& [key, value] : ncclDump) { + inner_dict.insert(key, value); + } + per_comm_dict.insert(ncclId, inner_dict); + } + } + if (per_comm_dict.size() > 0) { + result.insert(nccl_comm_key, per_comm_dict); + } + return pickle_str(result); + } +}; + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index de623d77fe9e..9c469dbd5bc6 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -10,11 +10,6 @@ #include #include -#ifdef USE_C10D_NCCL -#include -#include -#endif - #include #include #include @@ -24,41 +19,6 @@ namespace c10d { -static c10::IValue entries_key = "entries"; -static c10::IValue nccl_comm_key = "nccl_comm_state"; -static c10::IValue version_key = "version"; -// Update whenever changing contents or formatting of the dump -// (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "2.2"; -static c10::IValue pg_config_key = "pg_config"; -static c10::IValue record_id_key = "record_id"; -static c10::IValue pg_id_key = "pg_id"; -static c10::IValue pg_name_key = "process_group"; -static c10::IValue collective_seq_id_key = "collective_seq_id"; -static c10::IValue p2p_seq_id_key = "p2p_seq_id"; -static c10::IValue is_p2p_key = "is_p2p"; -static c10::IValue op_id_key = "op_id"; -static c10::IValue profiling_name_key = "profiling_name"; -static c10::IValue input_sizes_key = "input_sizes"; -static c10::IValue input_dtypes_key = "input_dtypes"; -static c10::IValue output_sizes_key = "output_sizes"; -static c10::IValue output_dtypes_key = "output_dtypes"; -static c10::IValue time_created_key = "time_created_ns"; -static c10::IValue duration_key = "duration_ms"; -static c10::IValue timeout_key = "timeout_ms"; - -static c10::IValue frames_key = "frames"; -static c10::IValue state_key = "state"; -static c10::IValue line_key = "line"; -static c10::IValue name_key = "name"; -static c10::IValue filename_key = "filename"; -static c10::IValue retired_key = "retired"; -static c10::IValue time_discovered_started_key = "time_discovered_started_ns"; -static c10::IValue time_discovered_completed_key = - "time_discovered_completed_ns"; - -/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */ - inline std::string getTraceStartKey(const std::string& pgName, int rank) { return pgName + "_" + std::to_string(rank) + "_trace_start"; } @@ -303,66 +263,6 @@ inline std::string retrieveDesyncReport( return report; } -/* Trace Utils Related to Flight Recorder */ - -/* Note: this is only used by PGNCCL (could be generalized in an ideal world but - * wasn't done that way, so isn't expected to be fully general at the moment) */ - -#ifdef USE_C10D_NCCL - -/* Helper used by work::getDuration() and nccl flight recorder */ -float getDurationFromEvent( - at::cuda::CUDAEvent& ncclStartEvent, - at::cuda::CUDAEvent& ncclEndEvent) { - TORCH_CHECK( - ncclEndEvent.query(), - "getDuration can only be called after work is succeeded.") - return ncclStartEvent.elapsed_time(ncclEndEvent); -} - -DebugInfoWriter::~DebugInfoWriter() = default; - -void DebugInfoWriter::write(const std::string& ncclTrace) { - // Open a file for writing. The ios::binary flag is used to write data as - // binary. - std::ofstream file(filename_, std::ios::binary); - - // Check if the file was opened successfully. - if (!file.is_open()) { - LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " - << filename_; - return; - } - - file.write(ncclTrace.data(), ncclTrace.size()); - LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; -} - -DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { - if (writer_ == nullptr) { - std::string fileNamePrefix = getCvarString( - {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); - // Using std::unique_ptr here to auto-delete the writer object - // when the pointer itself is destroyed. - std::unique_ptr writerPtr( - new DebugInfoWriter(fileNamePrefix, rank)); - DebugInfoWriter::registerWriter(std::move(writerPtr)); - } - return *writer_; -} - -void DebugInfoWriter::registerWriter(std::unique_ptr writer) { - TORCH_CHECK_WITH( - DistBackendError, - hasWriterRegistered_.load() == false, - "debugInfoWriter already registered"); - hasWriterRegistered_.store(true); - writer_ = std::move(writer); -} - -std::unique_ptr DebugInfoWriter::writer_ = nullptr; -std::atomic DebugInfoWriter::hasWriterRegistered_(false); - inline std::string pickle_str(const c10::IValue& v) { std::vector result; { @@ -421,401 +321,4 @@ inline std::string ranks_str(const std::vector& ranks) { return c10::str("[", str, "]"); } -struct NCCLTraceBuffer { - static NCCLTraceBuffer* get() { - // intentionally leak on exit - // because this will hold python state that may get destructed - static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); - return instance; - } - NCCLTraceBuffer() { - max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); - capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); - enabled_ = max_entries_ > 0; - } - using Event = at::cuda::CUDAEvent; - struct Entry { - size_t id_; // incremented id in the trace buffer - // used to figure out where in the circular entries - // buffer this entry will be located to - // update state information - size_t pg_id_; - std::tuple pg_name_; // - - // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 - // per coalesced group). - // collective_seq_id only increments for true collective operations (over - // all ranks in the group). p2p_seq_id only increments over non-collective - // operations in the group. op_id refers to logical operations (e.g. one per - // op inside coalesced group) - size_t collective_seq_id_; - size_t p2p_seq_id_; - size_t op_id_; - std::string profiling_name_; - - std::shared_ptr traceback_; - // we borrow pointers to start_ and end_ so we can query the state - // on reporting. However, once the event is completed, the call - // to `complete` will clear these. - Event *start_, *end_; - - // timestamp when the entry was created, likely close to the time the work - // was 'enqueued'- not necessarily started - c10::time_t time_created_; - - // configured timeout for this entry - c10::time_t timeout_ms_; - - // Is this a P2P event? - bool isP2P_; - - std::optional duration_; - - // timestamp when our CPU threads discovered that the kernel started. - // will always be _after_ it actually started, and can be very late - // if the watchdog thread got stuck on CUDA APIs. - std::optional time_discovered_started_; - - // timestamp when our CPU threads discovered that the kernel completed. - // will always be _after_ it actually complated, and can be the same time - // as the discovery of the start if the watchdog thread is stuck on CUDA - // APIs - std::optional time_discovered_completed_; - - // size information for input/output tensors - c10::SmallVector input_dims_; - std::vector input_dtypes_; - c10::SmallVector output_dims_; - std::vector output_dtypes_; - c10::SmallVector sizes_; // flattened from inputs, outputs - bool retired_ = false; // is this work entry no longer in the workMetaList_? - // a retired but not completed event has timed out - }; - - bool enabled_ = false; - bool capture_cpp_stack_ = false; - std::mutex mutex_; - std::vector entries_; - size_t max_entries_ = 0; - size_t next_ = 0; - size_t id_ = 0; - std::map, std::vector> - pg_name_to_ranks_ = {}; - - std::optional record( - size_t pg_id, - const std::tuple& pg_name, - size_t collective_seq_id, - size_t p2p_seq_id, - size_t op_id, - std::string profiling_name, - const std::vector& inputs, - const std::vector& outputs, - Event* start, - Event* end, - std::chrono::milliseconds timeout_ms, - bool isP2P) { - if (!enabled_) { - return c10::nullopt; - } - auto traceback = - torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - std::lock_guard guard(mutex_); - - auto te = Entry{ - id_, - pg_id, - pg_name, - collective_seq_id, - p2p_seq_id, - op_id, - std::move(profiling_name), - std::move(traceback), - std::move(start), - std::move(end), - c10::getTime(), - timeout_ms.count(), - isP2P}; - - for (const auto& input : inputs) { - c10::IntArrayRef sizes = input.sizes(); - te.input_dtypes_.push_back(input.dtype().toScalarType()); - te.input_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - for (const auto& output : outputs) { - c10::IntArrayRef sizes = output.sizes(); - te.output_dtypes_.push_back(output.dtype().toScalarType()); - te.output_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - if (entries_.size() < max_entries_) { - entries_.emplace_back(std::move(te)); - } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } - } - return id_++; - } - - void record_pg_ranks( - const std::tuple& pg_name, - std::vector ranks) { - if (!enabled_) { - return; - } - std::lock_guard guard(mutex_); - pg_name_to_ranks_[pg_name] = ranks; - } - - void update_state(Entry& r) { - if (r.start_ != nullptr) { - bool started = r.start_->query(); - if (started && !r.time_discovered_started_) { - r.time_discovered_started_ = c10::getTime(); - } - } - if (r.end_ != nullptr) { - bool completed = r.end_->query(); - if (completed && !r.time_discovered_completed_) { - r.time_discovered_completed_ = c10::getTime(); - } - } - } - - std::vector dump_entries() { - std::lock_guard guard(mutex_); - std::vector result; - result.reserve(entries_.size()); - result.insert(result.end(), entries_.begin() + next_, entries_.end()); - result.insert(result.end(), entries_.begin(), entries_.begin() + next_); - // query any remaining events - for (auto& r : result) { - update_state(r); - r.start_ = r.end_ = nullptr; - } - return result; - } - - /* - Mark an Event as completed and free its events. - - This is called by the watchdog thread, and is asynchronous from the - perspective of the main thread. - - compute_duration defaults to true since retire_id is only called in the - watchdog thread, which is currently a place we call cuda APIs which may hang, - but care should be taken to avoid computing duration in any function that must - never hang. (timing must also be enabled for compute_duration - see - TORCH_NCCL_ENABLE_TIMING). - */ - void retire_id(std::optional id, bool compute_duration = true) { - if (!enabled_ || !id) { - return; - } - - bool can_compute_duration = false; - Event* startEvent = nullptr; - Event* endEvent = nullptr; - std::optional duration = c10::nullopt; - - std::unique_lock guard(mutex_); - - Entry* entry = &entries_.at(*id % max_entries_); - if (entry->id_ == *id) { - update_state(*entry); - - if (compute_duration) { - can_compute_duration = entry->time_discovered_completed_.has_value() && - entry->start_ && entry->end_; - startEvent = entry->start_; - endEvent = entry->end_; - } - } - - if (can_compute_duration) { - // Compute duration without without holding the lock, because - // cudaEventDuration() can hang, and we need to acquire the lock before we - // can dump(), which we never want to block. - guard.unlock(); - duration = getDurationFromEvent(*startEvent, *endEvent); - guard.lock(); - - // Refresh the entry pointer, see if the entry has been overwritten - entry = &entries_.at(*id % max_entries_); - if (entry->id_ != *id) { - LOG(INFO) - << "retire_id abandoned for id " << *id - << ", event was overwritten while waiting to compute duration."; - return; - } - if (duration.has_value()) { - entry->duration_ = duration.value(); - } - } - - entry->retired_ = true; - entry->start_ = entry->end_ = nullptr; - } - - const c10::List getCollectiveTrace( - bool includeStacktraces, - bool onlyActive) { - auto entries = new_list(); - auto result = dump_entries(); - std::vector tracebacks; - torch::SymbolizedTracebacks stracebacks; - std::vector all_frames; - if (includeStacktraces) { - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - stracebacks = torch::symbolize(tracebacks); - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); - } - } - for (auto i : c10::irange(result.size())) { - auto dict = new_dict(); - auto& e = result.at(i); - // Skip completed events - if (onlyActive && e.time_discovered_completed_.has_value()) { - continue; - } - - if (includeStacktraces) { - auto& tb = stracebacks.tracebacks.at(i); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); - } - - dict.insert(record_id_key, int64_t(e.id_)); - dict.insert(pg_id_key, int64_t(e.pg_id_)); - dict.insert(pg_name_key, e.pg_name_); - dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); - dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); - dict.insert(op_id_key, int64_t(e.op_id_)); - dict.insert(profiling_name_key, e.profiling_name_); - dict.insert(time_created_key, int64_t(e.time_created_)); - if (e.duration_) { - dict.insert(duration_key, *e.duration_); - } - - auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { - auto sizes = new_list(); - for (auto dim : dims) { - auto arg_sizes = new_list(); - for (auto i : c10::irange(dim)) { - (void)i; - arg_sizes.push_back(*it++); - } - sizes.push_back(arg_sizes); - } - return sizes; - }; - - dict.insert(input_sizes_key, read_sizes(e.input_dims_)); - std::vector input_dtypes_strs; - input_dtypes_strs.reserve(e.input_dtypes_.size()); - for (const auto& input_dtype : e.input_dtypes_) { - input_dtypes_strs.push_back(c10::toString(input_dtype)); - } - dict.insert(input_dtypes_key, input_dtypes_strs); - dict.insert(output_sizes_key, read_sizes(e.output_dims_)); - std::vector output_dtypes_strs; - output_dtypes_strs.reserve(e.output_dtypes_.size()); - for (const auto& output_dtype : e.output_dtypes_) { - output_dtypes_strs.push_back(c10::toString(output_dtype)); - } - dict.insert(output_dtypes_key, output_dtypes_strs); - if (e.time_discovered_completed_.has_value()) { - dict.insert(state_key, "completed"); - } else if (e.time_discovered_started_.has_value()) { - dict.insert(state_key, "started"); - } else { - dict.insert(state_key, "scheduled"); - } - - dict.insert( - time_discovered_started_key, - e.time_discovered_started_.has_value() - ? int64_t(*e.time_discovered_started_) - : c10::IValue()); - dict.insert( - time_discovered_completed_key, - e.time_discovered_completed_.has_value() - ? int64_t(*e.time_discovered_completed_) - : c10::IValue()); - dict.insert(retired_key, e.retired_); - dict.insert(timeout_key, e.timeout_ms_); - dict.insert(is_p2p_key, e.isP2P_); - - entries.push_back(dict); - } - return entries; - } - - // dump pg_entries - const c10::Dict getPgConfig() { - auto pg_config = new_dict(); - for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { - auto pg_info = new_dict(); - pg_info.insert("name", std::get<0>(pg_name)); - pg_info.insert("desc", std::get<1>(pg_name)); - pg_info.insert("ranks", ranks_str(ranks)); - pg_config.insert(std::get<0>(pg_name), pg_info); - } - return pg_config; - } - - // dump all collectives + ncclDumpMap - std::string dump( - const std::optional>>& ncclDumpMap, - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { - auto result = new_dict(); - // common values - result.insert(version_key, version_val); - result.insert(pg_config_key, getPgConfig()); - - // collective trace - if (includeCollectives) { - result.insert( - entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); - } - - // convert ncclDumpMap into a dictionary - auto per_comm_dict = new_dict(); - if (ncclDumpMap.has_value()) { - for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { - auto inner_dict = new_dict(); - for (const auto& [key, value] : ncclDump) { - inner_dict.insert(key, value); - } - per_comm_dict.insert(ncclId, inner_dict); - } - } - if (per_comm_dict.size() > 0) { - result.insert(nccl_comm_key, per_comm_dict); - } - return pickle_str(result); - } -}; - -#endif } // namespace c10d From 118f9ceb7c9ec608a845b40c2142f1a1720b73c9 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 19 Jun 2024 12:10:50 +0000 Subject: [PATCH 54/64] [inductor][ci] Fix torchbench dependency issue with numpy (#128968) For some reason, pip will always upgrade the numpy version even when an older version has been installed. We have to lock numpy version to the old version to make this constraint explicit. Torchbench commit: https://github.com/pytorch/benchmark/commit/23512dbebd44a11eb84afbf53c3c071dd105297e Second attempt to fix #128845 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128968 Approved by: https://github.com/eellison --- .github/ci_commit_pins/torchbench.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt index 4a60ff3d38d4..dcf750b7fae0 100644 --- a/.github/ci_commit_pins/torchbench.txt +++ b/.github/ci_commit_pins/torchbench.txt @@ -1 +1 @@ -0dab1dd97709096e8129f8a08115ee83f64f2194 +23512dbebd44a11eb84afbf53c3c071dd105297e From 3397d5ef9089023b3a8ff781273f71d127cd8de8 Mon Sep 17 00:00:00 2001 From: Jean Schmidt <4520845+jeanschmidt@users.noreply.github.com> Date: Wed, 19 Jun 2024 14:48:16 +0000 Subject: [PATCH 55/64] Revert "[ALI] Use lf runners for Lint" (#129070) Reverts pytorch/pytorch#128978 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129070 Approved by: https://github.com/atalman --- .github/workflows/lint.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 2bbb6b1ab0ea..e0e4d3c20cd8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -19,10 +19,10 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 - runner: lf.linux.2xlarge + runner: linux.2xlarge docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout - # to run git rev-parse HEAD~:.ci/docker when a new image is needed. + # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -35,7 +35,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 - runner: lf.linux.2xlarge + runner: linux.2xlarge docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed @@ -49,7 +49,7 @@ jobs: quick-checks: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: - runner: lf.linux.2xlarge + runner: linux.2xlarge docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} From ffb50fb69124f3061dcd67e527702fa00ff00aa9 Mon Sep 17 00:00:00 2001 From: lyb <2632839426@qq.com> Date: Wed, 19 Jun 2024 15:39:02 +0000 Subject: [PATCH 56/64] [ONNX] Add onnx::Gelu support for version 20 (#128773) Fixes https://github.com/pytorch/pytorch/issues/128772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128773 Approved by: https://github.com/justinchuby --- test/onnx/test_utility_funs.py | 4 ++++ torch/onnx/symbolic_opset20.py | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index e7c8f4078103..ee63e0079c1d 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -1358,6 +1358,8 @@ def forward(self, input, other): iter = graph.nodes() self.assertEqual(next(iter).kind(), "custom_namespace::custom_op") + # gelu is exported as onnx::Gelu for opset >= 20 + @skipIfUnsupportedMaxOpsetVersion(19) def test_custom_opsets_gelu(self): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9) @@ -1382,6 +1384,8 @@ def gelu(g, self, approximate): self.assertEqual(graph.opset_import[1].domain, "com.microsoft") self.assertEqual(graph.opset_import[1].version, 1) + # gelu is exported as onnx::Gelu for opset >= 20 + @skipIfUnsupportedMaxOpsetVersion(19) def test_register_aten_custom_op_symbolic(self): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9) diff --git a/torch/onnx/symbolic_opset20.py b/torch/onnx/symbolic_opset20.py index 9557b5f2828e..d16379296958 100644 --- a/torch/onnx/symbolic_opset20.py +++ b/torch/onnx/symbolic_opset20.py @@ -32,7 +32,7 @@ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py -__all__ = ["_grid_sampler", "_affine_grid_generator"] +__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] def convert_grid_sample_mode(mode_s): @@ -84,3 +84,10 @@ def _affine_grid_generator( size, align_corners_i=int(align_corners), ) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +@_beartype.beartype +def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): + return g.op("Gelu", self, approximate_s=approximate) From 7d33ff59ba4ed920f590cb3e8f3e1bd571c78f62 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 18 Jun 2024 21:46:16 -0700 Subject: [PATCH 57/64] [Split Build]Use same package (#127934) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR removes the second separate package we were using for the libtorch wheel. In terms of testing that this works we will look use the PRs above this in the stack. As for sanity checking these are the wheels that are produced by running ``` python setup.py clean && BUILD_LIBTORCH_WHL=1 with-proxy python setup.py bdist_whee l && BUILD_PYTHON_ONLY=1 with-proxy python setup.py bdist_wheel --cmake ``` ``` sahanp@devgpu086 ~/pytorch ((5f15e171…))> ls -al dist/ (pytorch-3.10) total 677236 drwxr-xr-x 1 sahanp users 188 Jun 4 12:19 ./ drwxr-xr-x 1 sahanp users 1696 Jun 4 12:59 ../ -rw-r--r-- 1 sahanp users 81405742 Jun 4 12:19 torch-2.4.0a0+gitca0a73c-cp310-cp310-linux_x86_64.whl -rw-r--r-- 1 sahanp users 612076919 Jun 4 12:19 libtorch-2.4.0a0+gitca0a73c-py3-none-any.whl ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127934 Approved by: https://github.com/atalman --- .ci/pytorch/build.sh | 25 ++++++++-- setup.py | 95 ++++++++++++++++---------------------- tools/setup_helpers/env.py | 2 - torch/CMakeLists.txt | 45 ++++++++++++++++++ torch/__init__.py | 41 ---------------- 5 files changed, 105 insertions(+), 103 deletions(-) diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 187e6d788bdd..e370c46305e9 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -284,12 +284,26 @@ else # Which should be backward compatible with Numpy-1.X python -mpip install --pre numpy==2.0.0rc1 fi - WERROR=1 python setup.py bdist_wheel + + WERROR=1 python setup.py clean + + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel + BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 python setup.py bdist_wheel --cmake + else + WERROR=1 python setup.py bdist_wheel + fi else + python setup.py clean if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then source .ci/pytorch/install_cache_xla.sh fi - python setup.py bdist_wheel + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + echo "USE_SPLIT_BUILD cannot be used with xla or rocm" + exit 1 + else + python setup.py bdist_wheel + fi fi pip_install_whl "$(echo dist/*.whl)" @@ -328,9 +342,10 @@ else CUSTOM_OP_TEST="$PWD/test/custom_operator" python --version SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')" + mkdir -p "$CUSTOM_OP_BUILD" pushd "$CUSTOM_OP_BUILD" - cmake "$CUSTOM_OP_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \ + cmake "$CUSTOM_OP_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \ -DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM" make VERBOSE=1 popd @@ -343,7 +358,7 @@ else SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')" mkdir -p "$JIT_HOOK_BUILD" pushd "$JIT_HOOK_BUILD" - cmake "$JIT_HOOK_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \ + cmake "$JIT_HOOK_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \ -DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM" make VERBOSE=1 popd @@ -355,7 +370,7 @@ else python --version mkdir -p "$CUSTOM_BACKEND_BUILD" pushd "$CUSTOM_BACKEND_BUILD" - cmake "$CUSTOM_BACKEND_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \ + cmake "$CUSTOM_BACKEND_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \ -DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM" make VERBOSE=1 popd diff --git a/setup.py b/setup.py index 07d80a7e1392..479b103a3a17 100644 --- a/setup.py +++ b/setup.py @@ -199,7 +199,6 @@ # Builds pytorch as a wheel using libtorch.so from a seperate wheel import os -import pkgutil import sys if sys.platform == "win32" and sys.maxsize.bit_length() == 31: @@ -210,19 +209,6 @@ import platform - -def _get_package_path(package_name): - loader = pkgutil.find_loader(package_name) - if loader: - # The package might be a namespace package, so get_data may fail - try: - file_path = loader.get_filename() - return os.path.dirname(file_path) - except AttributeError: - pass - return None - - BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1" BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1" @@ -237,6 +223,7 @@ def _get_package_path(package_name): import filecmp import glob import importlib +import importlib.util import json import shutil import subprocess @@ -253,15 +240,24 @@ def _get_package_path(package_name): from tools.build_pytorch_libs import build_caffe2 from tools.generate_torch_version import get_torch_version from tools.setup_helpers.cmake import CMake -from tools.setup_helpers.env import ( - build_type, - IS_DARWIN, - IS_LINUX, - IS_WINDOWS, - LIBTORCH_PKG_NAME, -) +from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS from tools.setup_helpers.generate_linker_script import gen_linker_script + +def _get_package_path(package_name): + spec = importlib.util.find_spec(package_name) + if spec: + # The package might be a namespace package, so get_data may fail + try: + loader = spec.loader + if loader is not None: + file_path = loader.get_filename() # type: ignore[attr-defined] + return os.path.dirname(file_path) + except AttributeError: + pass + return None + + # set up appropriate env variables if BUILD_LIBTORCH_WHL: # Set up environment variables for ONLY building libtorch.so and not libtorch_python.so @@ -271,7 +267,7 @@ def _get_package_path(package_name): if BUILD_PYTHON_ONLY: os.environ["BUILD_LIBTORCHLESS"] = "ON" - os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path(LIBTORCH_PKG_NAME)}/lib" + os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib" ################################################################################ # Parameters parsed from environment @@ -347,9 +343,12 @@ def report(*args): # Version, create_version_file, and package_name ################################################################################ -DEFAULT_PACKAGE_NAME = LIBTORCH_PKG_NAME if BUILD_LIBTORCH_WHL else "torch" +package_name = os.getenv("TORCH_PACKAGE_NAME", "torch") +LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "libtorch") +if BUILD_LIBTORCH_WHL: + package_name = LIBTORCH_PKG_NAME + -package_name = os.getenv("TORCH_PACKAGE_NAME", DEFAULT_PACKAGE_NAME) package_type = os.getenv("PACKAGE_TYPE", "wheel") version = get_torch_version() report(f"Building wheel {package_name}-{version}") @@ -472,7 +471,6 @@ def build_deps(): check_submodules() check_pydep("yaml", "pyyaml") build_python = not BUILD_LIBTORCH_WHL - build_caffe2( version=version, cmake_python_library=cmake_python_library, @@ -1125,8 +1123,6 @@ def main(): raise RuntimeError( "Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. Set one to 0 and rerun." ) - - # the list of runtime dependencies required by this built package install_requires = [ "filelock", "typing-extensions>=4.8.0", @@ -1141,7 +1137,7 @@ def main(): install_requires.append("setuptools") if BUILD_PYTHON_ONLY: - install_requires.append(LIBTORCH_PKG_NAME) + install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}") use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", "")) if ( @@ -1190,7 +1186,6 @@ def main(): entry_points, extra_install_requires, ) = configure_extension_build() - install_requires += extra_install_requires extras_require = { @@ -1219,6 +1214,7 @@ def main(): "utils/data/*.pyi", "utils/data/datapipes/*.pyi", "lib/*.pdb", + "lib/*shm*", "lib/torch_shm_manager", "lib/*.h", "include/*.h", @@ -1383,15 +1379,15 @@ def main(): "utils/model_dump/*.mjs", ] - if BUILD_PYTHON_ONLY: + if not BUILD_LIBTORCH_WHL: torch_package_data.extend( [ - "lib/libtorch_python*", - "lib/*shm*", - "lib/libtorch_global_deps*", + "lib/libtorch_python.so", + "lib/libtorch_python.dylib", + "lib/libtorch_python.dll", ] ) - else: + if not BUILD_PYTHON_ONLY: torch_package_data.extend( [ "lib/*.so*", @@ -1442,28 +1438,18 @@ def main(): "packaged/autograd/*", "packaged/autograd/templates/*", ] + package_data = { + "torch": torch_package_data, + } - if BUILD_LIBTORCH_WHL: - modified_packages = [] - for package in packages: - parts = package.split(".") - if parts[0] == "torch": - modified_packages.append(DEFAULT_PACKAGE_NAME + package[len("torch") :]) - packages = modified_packages - package_dir = {LIBTORCH_PKG_NAME: "torch"} - torch_package_dir_name = LIBTORCH_PKG_NAME - package_data = {LIBTORCH_PKG_NAME: torch_package_data} - extensions = [] + if not BUILD_LIBTORCH_WHL: + package_data["torchgen"] = torchgen_package_data + package_data["caffe2"] = [ + "python/serialized_test/data/operator_test/*.zip", + ] else: - torch_package_dir_name = "torch" - package_dir = {} - package_data = { - "torch": torch_package_data, - "torchgen": torchgen_package_data, - "caffe2": [ - "python/serialized_test/data/operator_test/*.zip", - ], - } + # no extensions in BUILD_LIBTORCH_WHL mode + extensions = [] setup( name=package_name, @@ -1481,7 +1467,6 @@ def main(): install_requires=install_requires, extras_require=extras_require, package_data=package_data, - package_dir=package_dir, url="https://pytorch.org/", download_url="https://github.com/pytorch/pytorch/tags", author="PyTorch Team", diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index eed5198ca9f2..d87e97a2bb5a 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -21,8 +21,6 @@ BUILD_DIR = "build" -LIBTORCH_PKG_NAME = "libtorchsplit" - def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 10a44af747be..080e977044d6 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -309,6 +309,51 @@ if(HAVE_SOVERSION) set_target_properties(torch_python PROPERTIES VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) endif() + +# in case of the split build we need to add compile definitions +if(BUILD_LIBTORCHLESS) + if(USE_UCC AND USE_C10D_UCC) + target_compile_definitions(torch_python PRIVATE USE_C10D_UCC) + endif() + + if(USE_UCC AND USE_C10D_NCCL) + target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) + endif() + + if(USE_DISTRIBUTED) + target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) + endif() + + if(USE_MPI AND USE_C10D_MPI) + target_compile_definitions(torch_python PRIVATE USE_C10D_MPI) + endif() + + if(USE_GLOO AND USE_C10D_GLOO) + target_compile_definitions(torch_python PRIVATE USE_C10D_GLOO) + endif() + + if(NOT WIN32) + target_compile_definitions(torch_python PRIVATE USE_RPC) + endif() + + if(USE_TENSORPIPE) + target_compile_definitions(torch_python PRIVATE USE_TENSORPIPE) + endif() + + set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING + "Experimental option to use a single thread pool for inter- and intra-op parallelism") + if("${EXPERIMENTAL_SINGLE_THREAD_POOL}") + target_compile_definitions(torch_python PRIVATE "-DAT_EXPERIMENTAL_SINGLE_THREAD_POOL=1") + endif() + + if(MSVC AND NOT BUILD_SHARED_LIBS) + target_compile_definitions(torch_python PRIVATE "AT_CORE_STATIC_WINDOWS=1") + endif() + + + +endif() + add_dependencies(torch_python torch_python_stubs) add_dependencies(torch_python flatbuffers) diff --git a/torch/__init__.py b/torch/__init__.py index e50844bafc43..b78bd8f6707e 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -271,38 +271,6 @@ def _preload_cuda_deps(lib_folder, lib_name): # See Note [Global dependencies] def _load_global_deps() -> None: - LIBTORCH_PKG_NAME = "libtorchsplit" - - def find_package_path(package_name): - spec = importlib.util.find_spec(package_name) - if spec: - # The package might be a namespace package, so get_data may fail - try: - loader = spec.loader - if loader is not None: - file_path = loader.get_filename() # type: ignore[attr-defined] - return os.path.dirname(file_path) - except AttributeError: - pass - return None - - def load_shared_libraries(library_path): - lib_dir = os.path.join(library_path, "lib") - if not os.path.exists(lib_dir): - return - - # Find all shared library files with the appropriate extension - library_files = [f for f in os.listdir(lib_dir) if f.endswith(lib_ext)] - if not library_files: - return - - for lib_file in library_files: - lib_path = os.path.join(lib_dir, lib_file) - try: - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) - except OSError as err: - print(f"Failed to load {lib_path}: {err}") - if _running_with_deploy() or platform.system() == "Windows": return @@ -312,11 +280,6 @@ def load_shared_libraries(library_path): here = os.path.abspath(__file__) global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name) - split_build_lib_name = LIBTORCH_PKG_NAME - library_path = find_package_path(split_build_lib_name) - - if library_path: - global_deps_lib_path = os.path.join(library_path, "lib", lib_name) try: ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) except OSError as err: @@ -344,10 +307,6 @@ def load_shared_libraries(library_path): _preload_cuda_deps(lib_folder, lib_name) ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) - if library_path: - # loading libtorch_global_deps first due its special logic - load_shared_libraries(library_path) - if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( _running_with_deploy() or platform.system() != "Windows" From 236fbcbdf44ea429b11ac44c1814af56ca243605 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 18 Jun 2024 21:46:16 -0700 Subject: [PATCH 58/64] [Split Build] Test split build in pull CI workflow (#126813) This PR builds the split build in the pull workflow and runs the appropriate tests against them. A single linux cpu and single gpu build were chosen arbitrarily to not add too many tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126813 Approved by: https://github.com/atalman ghstack dependencies: #127934 --- .ci/pytorch/common_utils.sh | 22 ++++++++++++++++++- .ci/pytorch/test.sh | 3 +++ .github/actions/linux-build/action.yml | 21 +++++++++++++++++- .github/workflows/_linux-build-label.yml | 8 +++++++ .github/workflows/_linux-build.yml | 10 +++++++++ .github/workflows/pull.yml | 28 ++++++++++++++++++++++++ caffe2/CMakeLists.txt | 5 +++-- torch/__init__.py | 1 - 8 files changed, 93 insertions(+), 5 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 91c2d1b5dd3b..6f8f874fc8e3 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -56,9 +56,29 @@ function assert_git_not_dirty() { function pip_install_whl() { # This is used to install PyTorch and other build artifacts wheel locally # without using any network connection - python3 -mpip install --no-index --no-deps "$@" + + # Convert the input arguments into an array + local args=("$@") + + # Check if the first argument contains multiple paths separated by spaces + if [[ "${args[0]}" == *" "* ]]; then + # Split the string by spaces into an array + IFS=' ' read -r -a paths <<< "${args[0]}" + # Loop through each path and install individually + for path in "${paths[@]}"; do + echo "Installing $path" + python3 -mpip install --no-index --no-deps "$path" + done + else + # Loop through each argument and install individually + for path in "${args[@]}"; do + echo "Installing $path" + python3 -mpip install --no-index --no-deps "$path" + done + fi } + function pip_install() { # retry 3 times # old versions of pip don't have the "--progress-bar" flag diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 4a38ebefa6cb..80e4ae9285ac 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -289,6 +289,9 @@ test_python_shard() { # Bare --include flag is not supported and quoting for lint ends up with flag not being interpreted correctly # shellcheck disable=SC2086 + + # modify LD_LIBRARY_PATH to ensure it has the conda env. + # This set of tests has been shown to be buggy without it for the split-build time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION assert_git_not_dirty diff --git a/.github/actions/linux-build/action.yml b/.github/actions/linux-build/action.yml index c0f74160507b..6921778d3957 100644 --- a/.github/actions/linux-build/action.yml +++ b/.github/actions/linux-build/action.yml @@ -52,6 +52,13 @@ inputs: description: Hugging Face Hub token required: false default: "" + use_split_build: + description: | + [Experimental] Build a libtorch only wheel and build pytorch such that + are built from the libtorch wheel. + required: false + type: boolean + default: false outputs: docker-image: value: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -144,6 +151,7 @@ runs: DEBUG: ${{ inputs.build-with-debug == 'true' && '1' || '0' }} OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ inputs.HUGGING_FACE_HUB_TOKEN }} + USE_SPLIT_BUILD: ${{ inputs.use_split_build }} shell: bash run: | # detached container should get cleaned up by teardown_ec2_linux @@ -163,6 +171,7 @@ runs: -e PR_LABELS \ -e OUR_GITHUB_JOB_ID \ -e HUGGING_FACE_HUB_TOKEN \ + -e USE_SPLIT_BUILD \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -183,7 +192,7 @@ runs: - name: Store PyTorch Build Artifacts on S3 uses: seemethere/upload-artifact-s3@v5 - if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped' + if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped' && inputs.use_split_build != 'true' with: name: ${{ inputs.build-environment }} retention-days: 14 @@ -191,6 +200,16 @@ runs: path: artifacts.zip s3-bucket: ${{ inputs.s3-bucket }} + - name: Store PyTorch Build Artifacts on S3 for split build + uses: seemethere/upload-artifact-s3@v5 + if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped' && inputs.use_split_build == 'true' + with: + name: ${{ inputs.build-environment }}-experimental-split-build + retention-days: 14 + if-no-files-found: error + path: artifacts.zip + s3-bucket: ${{ inputs.s3-bucket }} + - name: Upload sccache stats if: steps.build.outcome != 'skipped' uses: seemethere/upload-artifact-s3@v5 diff --git a/.github/workflows/_linux-build-label.yml b/.github/workflows/_linux-build-label.yml index 427f993b4853..037473b50e82 100644 --- a/.github/workflows/_linux-build-label.yml +++ b/.github/workflows/_linux-build-label.yml @@ -56,6 +56,13 @@ on: required: false type: string default: "" + use_split_build: + description: | + [Experimental] Build a libtorch only wheel and build pytorch such that + are built from the libtorch wheel. + required: false + type: boolean + default: false secrets: HUGGING_FACE_HUB_TOKEN: required: false @@ -107,3 +114,4 @@ jobs: aws-role-to-assume: ${{ inputs.aws-role-to-assume }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + use_split_build: ${{ inputs.use_split_build }} diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index c3bcb0d888df..be2186139195 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -64,6 +64,14 @@ on: required: false type: string default: "" + use_split_build: + description: | + [Experimental] Build a libtorch only wheel and build pytorch such that + are built from the libtorch wheel. + required: false + type: boolean + default: false + secrets: HUGGING_FACE_HUB_TOKEN: required: false @@ -181,6 +189,7 @@ jobs: DEBUG: ${{ inputs.build-with-debug && '1' || '0' }} OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + USE_SPLIT_BUILD: ${{ inputs.use_split_build }} run: | # detached container should get cleaned up by teardown_ec2_linux container_name=$(docker run \ @@ -199,6 +208,7 @@ jobs: -e PR_LABELS \ -e OUR_GITHUB_JOB_ID \ -e HUGGING_FACE_HUB_TOKEN \ + -e USE_SPLIT_BUILD \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index dc74571852e9..464458638a15 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -487,3 +487,31 @@ jobs: build-environment: linux-jammy-py3-clang12-executorch docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} + + linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build: + name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build + uses: ./.github/workflows/_linux-build-label.yml + with: + use_split_build: true + build-environment: linux-focal-cuda12.1-py3.10-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build-test: + name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build + - target-determination + with: + timeout-minutes: 360 + build-environment: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build + docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build.outputs.test-matrix }} diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8426741609fe..65cd196b0063 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -750,6 +750,9 @@ if(BUILD_LIBTORCHLESS) find_library(TORCH_XPU_LIB torch_xpu PATHS $ENV{LIBTORCH_LIB_PATH} NO_DEFAULT_PATH) endif() add_subdirectory(../torch torch) + # ---[ Torch python bindings build + set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE) + set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE) else() set(TORCH_LIB torch) set(TORCH_CPU_LIB torch_cpu) @@ -1270,12 +1273,10 @@ install(FILES ${PROJECT_BINARY_DIR}/TorchConfig.cmake DESTINATION share/cmake/Torch) - # ---[ Torch python bindings build add_subdirectory(../torch torch) set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE) set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE) - # ========================================================== # END formerly-libtorch flags # ========================================================== diff --git a/torch/__init__.py b/torch/__init__.py index b78bd8f6707e..8042862fd364 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -50,7 +50,6 @@ def _running_with_deploy(): else: from torch.torch_version import __version__ as __version__ - __all__ = [ "BoolStorage", "BoolTensor", From 1b92bdd0ea326cd30bc3945602701ffe28c85fd5 Mon Sep 17 00:00:00 2001 From: Jean Schmidt Date: Wed, 19 Jun 2024 16:10:51 +0000 Subject: [PATCH 59/64] [ALI] [Reland] Use LF runners for Lint (#129071) Quick experiment with using LF runners for lint jobs. Picking a set of jobs where infra failures would be obvious to most people (lint) Pull Request resolved: https://github.com/pytorch/pytorch/pull/129071 Approved by: https://github.com/malfet --- .github/workflows/lint.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e0e4d3c20cd8..2bbb6b1ab0ea 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -19,10 +19,10 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 - runner: linux.2xlarge + runner: lf.linux.2xlarge docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout - # to run git rev-parse HEAD~:.ci/docker when a new image is needed + # to run git rev-parse HEAD~:.ci/docker when a new image is needed. fetch-depth: 0 submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -35,7 +35,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 - runner: linux.2xlarge + runner: lf.linux.2xlarge docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed @@ -49,7 +49,7 @@ jobs: quick-checks: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: - runner: linux.2xlarge + runner: lf.linux.2xlarge docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} From 0fc603ece429b67468bc4aa94d5dc47f227298d4 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Wed, 19 Jun 2024 01:14:06 -0700 Subject: [PATCH 60/64] [optim] Fused implementation stability table (#129006) I'd like to discuss the criteria that we regard an implementation as stable. If there is no existing standard, my initial proposal would be a 6 month period after the commit to regard it as stable. As a result, now Adam and AdamW on CUDA would be considered as stable, while the rest are of beta. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129006 Approved by: https://github.com/malfet --- docs/source/optim.rst | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/docs/source/optim.rst b/docs/source/optim.rst index f6d6ffb923c3..88b0e364c244 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -164,10 +164,10 @@ horizontally and fused implementations as fusing vertically on top of that. In general, the performance ordering of the 3 implementations is fused > foreach > for-loop. So when applicable, we default to foreach over for-loop. Applicable means the foreach implementation is available, the user has not specified any implementation-specific kwargs -(e.g., fused, foreach, differentiable), and all tensors are native and on CUDA. Note that -while fused should be even faster than foreach, the implementations are newer and we would -like to give them more bake-in time before flipping the switch everywhere. You are welcome -to try them out though! +(e.g., fused, foreach, differentiable), and all tensors are native. Note that while fused +should be even faster than foreach, the implementations are newer and we would like to give +them more bake-in time before flipping the switch everywhere. We summarize the stability status +for each implementation on the second table below, you are welcome to try them out though! Below is a table showing the available and default implementations of each algorithm: @@ -177,7 +177,7 @@ Below is a table showing the available and default implementations of each algor :delim: ; :class:`Adadelta`;foreach;yes;no - :class:`Adagrad`;foreach;yes;no + :class:`Adagrad`;foreach;yes;yes (cpu only) :class:`Adam`;foreach;yes;yes :class:`AdamW`;foreach;yes;yes :class:`SparseAdam`;for-loop;no;no @@ -188,7 +188,28 @@ Below is a table showing the available and default implementations of each algor :class:`RAdam`;foreach;yes;no :class:`RMSprop`;foreach;yes;no :class:`Rprop`;foreach;yes;no - :class:`SGD`;foreach;yes;no + :class:`SGD`;foreach;yes;yes (CPU and CUDA only) + +Below table is showing the stability status for fused implementations: + +.. csv-table:: + :header: "Algorithm", "CPU", "CUDA", "MPS" + :widths: 25, 25, 25, 25 + :delim: ; + + :class:`Adadelta`;unsupported;unsupported;unsupported + :class:`Adagrad`;beta;unsupported;unsupported + :class:`Adam`;beta;stable;beta + :class:`AdamW`;beta;stable;beta + :class:`SparseAdam`;unsupported;unsupported;unsupported + :class:`Adamax`;unsupported;unsupported;unsupported + :class:`ASGD`;unsupported;unsupported;unsupported + :class:`LBFGS`;unsupported;unsupported;unsupported + :class:`NAdam`;unsupported;unsupported;unsupported + :class:`RAdam`;unsupported;unsupported;unsupported + :class:`RMSprop`;unsupported;unsupported;unsupported + :class:`Rprop`;unsupported;unsupported;unsupported + :class:`SGD`;beta;beta;unsupported How to adjust learning rate --------------------------- From 0707811286d1846209676435f4f86f2b4b3d1a17 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Wed, 19 Jun 2024 16:45:27 +0000 Subject: [PATCH 61/64] [export] experimental joint graph API. (#128847) Summary: WARNING: This API is highly unstable and will be subject to change in the future. Add a protoype to "decompose" an ExportedProgram into a joint graph form, so that we can compute the gradients on this graph. Test Plan: buck test mode/opt caffe2/torch/fb/export:test_experimental Differential Revision: D55657917 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128847 Approved by: https://github.com/tugsbayasgalan --- docs/source/export.rst | 2 + test/export/test_experimental.py | 71 ++++++ torch/export/experimental/__init__.py | 52 ++++ torch/export/exported_program.py | 349 ++++++++++++++------------ 4 files changed, 319 insertions(+), 155 deletions(-) create mode 100644 torch/export/experimental/__init__.py diff --git a/docs/source/export.rst b/docs/source/export.rst index 29069d3228e4..984394fe254e 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -719,3 +719,5 @@ API Reference :members: .. automodule:: torch.export.custom_obj + +.. automodule:: torch.export.experimental diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 0fe6f17db448..3c2c7e332991 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -11,6 +11,7 @@ from torch._functorch.aot_autograd import aot_export_module from torch.export._trace import _convert_ts_to_export_experimental +from torch.export.experimental import _export_forward_backward from torch.testing import FileCheck @@ -194,6 +195,76 @@ def forward(self, x: Dict[str, torch.Tensor]): MDict, ({"0": torch.randn(4), "1": torch.randn(4)},) ) + def test_joint_basic(self) -> None: + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.loss = torch.nn.CrossEntropyLoss() + + def forward(self, x): + return self.loss( + self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0]) + ) + + m = Module() + example_inputs = (torch.randn(3),) + m(*example_inputs) + ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True) + joint_ep = _export_forward_backward(ep) + print(joint_ep) + + """ + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): + # No stacktrace found for following nodes + view: "f32[1, 3]" = torch.ops.aten.view.default(arg3_1, [1, 3]); arg3_1 = None + t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None + addmm: "f32[1, 3]" = torch.ops.aten.addmm.default(arg1_1, view, t); arg1_1 = t = None + view_1: "f32[3]" = torch.ops.aten.view.default(addmm, [3]); addmm = None + _softmax: "f32[3]" = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None + detach_1: "f32[3]" = torch.ops.aten.detach.default(_softmax) + clone: "f32[3]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None + detach_5: "f32[3]" = torch.ops.aten.detach.default(clone); clone = None + _log_softmax: "f32[3]" = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None + detach_12: "f32[3]" = torch.ops.aten.detach.default(_log_softmax) + mul: "f32[3]" = torch.ops.aten.mul.Tensor(_log_softmax, detach_5); _log_softmax = None + sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None + neg: "f32[]" = torch.ops.aten.neg.default(sum_1); sum_1 = None + div: "f32[]" = torch.ops.aten.div.Scalar(neg, 1); neg = None + ones_like: "f32[]" = torch.ops.aten.ones_like.default(div, pin_memory = False, memory_format = torch.preserve_format) + div_1: "f32[]" = torch.ops.aten.div.Scalar(ones_like, 1); ones_like = None + neg_1: "f32[]" = torch.ops.aten.neg.default(div_1); div_1 = None + expand: "f32[3]" = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None + mul_1: "f32[3]" = torch.ops.aten.mul.Tensor(expand, detach_5); expand = detach_5 = None + _log_softmax_backward_data: "f32[3]" = torch.ops.aten._log_softmax_backward_data.default(mul_1, detach_12, 0, torch.float32); mul_1 = detach_12 = None + _softmax_backward_data: "f32[3]" = torch.ops.aten._softmax_backward_data.default(_log_softmax_backward_data, detach_1, 0, torch.float32); _log_softmax_backward_data = detach_1 = None + view_2: "f32[1, 3]" = torch.ops.aten.view.default(_softmax_backward_data, [1, 3]); _softmax_backward_data = None + t_1: "f32[3, 1]" = torch.ops.aten.t.default(view_2) + mm: "f32[3, 3]" = torch.ops.aten.mm.default(t_1, view); t_1 = view = None + t_2: "f32[3, 3]" = torch.ops.aten.t.default(mm); mm = None + sum_2: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None + view_3: "f32[3]" = torch.ops.aten.view.default(sum_2, [3]); sum_2 = None + t_3: "f32[3, 3]" = torch.ops.aten.t.default(t_2); t_2 = None + return (div, t_3, view_3) + + Graph signature: ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='lifted_tensor_0', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None, persistent=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='div'), target=None), + OutputSpec(kind=, arg=TensorArgument(name='t_3'), target='linear.weight'), + OutputSpec(kind=, arg=TensorArgument(name='view_3'), target='linear.bias') + ] + ) + Range constraints: {} + """ + if __name__ == "__main__": run_tests() diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py new file mode 100644 index 000000000000..11bfdc1d0d12 --- /dev/null +++ b/torch/export/experimental/__init__.py @@ -0,0 +1,52 @@ +import copy + +import torch +from torch.export import ExportedProgram +from torch.export.exported_program import ( + _decompose_exported_program, + _get_updated_range_constraints, +) +from torch.export.graph_signature import ( + ConstantArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + TensorArgument, +) + + +def _remove_detach_pass(gm: torch.fx.GraphModule, sig: ExportGraphSignature) -> None: + with gm._set_replace_hook(sig.get_replace_hook()): + for node in list(reversed(gm.graph.nodes)): + if node.op != "call_function": + continue + if ( + node.target == torch.ops.aten.detach.default + and len(node.users) == 1 + and next(iter(node.users)).target == torch.ops.aten.detach.default + ): + next(iter(node.users)).replace_all_uses_with(node) + + gm.graph.eliminate_dead_code() + gm.recompile() + + +def _export_forward_backward( + ep: ExportedProgram, joint_loss_index: int = 0 +) -> ExportedProgram: + """ + WARNING: This API is highly unstable and will be subject to change in the future. + """ + from torch._decomp import core_aten_decompositions + + ep = _decompose_exported_program( + ep, decomp_table=core_aten_decompositions(), joint_loss_index=joint_loss_index + ) + gm = copy.deepcopy(ep.graph_module) + new_graph_signature = copy.deepcopy(ep.graph_signature) + _remove_detach_pass(gm, new_graph_signature) + + return ep._update(gm, new_graph_signature) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 280d3da5ad68..59618a3cfb10 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -189,6 +189,196 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: subgraph.recompile() +def _decompose_exported_program( + ep, + *, + decomp_table: Dict[torch._ops.OperatorBase, Callable], + joint_loss_index: Optional[int], +): + from torch._export.passes.lift_constants_pass import ( + ConstantAttrMap, + lift_constants_pass, + ) + from torch._functorch.aot_autograd import aot_export_module + + old_placeholders = [ + node for node in ep.graph_module.graph.nodes if node.op == "placeholder" + ] + fake_args = [node.meta["val"] for node in old_placeholders] + + buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()] + for name in buffers_to_remove: + delattr(ep.graph_module, name) + # TODO(zhxhchen17) Return the new graph_signature directly. + from torch.export._trace import _ignore_backend_decomps + + with _ignore_backend_decomps(): + gm, graph_signature = aot_export_module( + ep.graph_module, + fake_args, + decompositions=decomp_table, + trace_joint=True if joint_loss_index is not None else False, + output_loss_index=joint_loss_index + if joint_loss_index is not None + else None, + ) + + # Update the signatures with the new placeholder names in case they + # changed when calling aot_export + def update_arg(old_arg, new_ph): + if isinstance(old_arg, ConstantArgument): + return old_arg + elif isinstance(old_arg, TensorArgument): + return TensorArgument(name=new_ph.name) + elif isinstance(old_arg, SymIntArgument): + return SymIntArgument(name=new_ph.name) + raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") + + new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + new_outputs = list(gm.graph.nodes)[-1].args[0] + + # rename the placeholders + assert len(new_placeholders) == len(old_placeholders) + for old_ph, new_ph in zip(old_placeholders, new_placeholders): + new_ph.name = new_ph.target = old_ph.name + + # handle name collisions with newly decomposed graph nodes + name_map = {ph.name: ph.name for ph in new_placeholders} + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # To match the output target with correct input for input mutations + # need to find the old to new placeholder map + old_new_placeholder_map = { + spec.arg.name: new_placeholders[i].name + for i, spec in enumerate(ep.graph_signature.input_specs) + if not isinstance(spec.arg, ConstantArgument) + } + + input_specs = [ + InputSpec( + spec.kind, + update_arg(spec.arg, new_placeholders[i]), + spec.target, + spec.persistent, + ) + for i, spec in enumerate(ep.graph_signature.input_specs) + ] + output_specs = [ + OutputSpec( + spec.kind, + update_arg(spec.arg, new_outputs[i]), + old_new_placeholder_map.get(spec.target, spec.target), + ) + for i, spec in enumerate(ep.graph_signature.output_specs) + ] + + if joint_loss_index is not None: + assert graph_signature.backward_signature is not None + gradients = graph_signature.backward_signature.gradients_to_user_inputs + assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs) + specs = { + graph_signature.user_inputs[i]: spec + for i, spec in enumerate(ep.graph_signature.input_specs) + if isinstance(spec.arg, TensorArgument) + } + for i, node in enumerate(new_outputs[len(output_specs) :]): + source = gradients[node.name] + spec = specs[source] # type: ignore[index] + if spec.kind == InputKind.PARAMETER: + kind = OutputKind.GRADIENT_TO_PARAMETER + target = spec.target + elif spec.kind == InputKind.USER_INPUT: + kind = OutputKind.GRADIENT_TO_USER_INPUT + target = source + else: + raise AssertionError(f"Unknown input kind: {spec.kind}") + output_specs.append( + OutputSpec( + kind, + TensorArgument(name=node.name), + target, + ) + ) + + assert len(new_placeholders) == len(old_placeholders) + + new_graph_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) + # NOTE: aot_export adds symint metadata for placeholders with int + # values; since these become specialized, we replace such metadata with + # the original values. + # Also, set the param/buffer metadata back to the placeholders. + for old_node, new_node in zip(old_placeholders, new_placeholders): + if not isinstance(old_node.meta["val"], torch.Tensor): + new_node.meta["val"] = old_node.meta["val"] + + if ( + new_node.target in new_graph_signature.inputs_to_parameters + or new_node.target in new_graph_signature.inputs_to_buffers + ): + for k, v in old_node.meta.items(): + new_node.meta[k] = v + + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + gm.meta.update(ep.graph_module.meta) + + new_range_constraints = _get_updated_range_constraints( + gm, + ep.range_constraints, + _is_executorch=False, + ) + + constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap()) + for k, v in constants.items(): + assert k not in ep.constants + ep.constants[k] = v + + from torch._dynamo import config as _dynamo_config + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + + if not _dynamo_config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=new_graph_signature, + state_dict=ep.state_dict, + range_constraints=new_range_constraints, + module_call_graph=copy.deepcopy(ep.module_call_graph), + example_inputs=ep.example_inputs, + verifier=ep.verifier, + constants=ep.constants, + ) + return exported_program + + class ExportedProgram: """ Package of a program from :func:`export`. It contains @@ -537,166 +727,15 @@ def run_decompositions( For now, we do not decompose joint graphs. """ from torch._decomp import core_aten_decompositions - from torch._export.passes.lift_constants_pass import ( - ConstantAttrMap, - lift_constants_pass, - ) - from torch._functorch.aot_autograd import aot_export_module - - def _get_placeholders(gm): - placeholders = [] - for node in gm.graph.nodes: - if node.op != "placeholder": - break - placeholders.append(node) - return placeholders if decomp_table is None: decomp_table = core_aten_decompositions() - old_placeholders = _get_placeholders(self.graph_module) - fake_args = [node.meta["val"] for node in old_placeholders] - - buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()] - for name in buffers_to_remove: - delattr(self.graph_module, name) - # TODO(zhxhchen17) Return the new graph_signature directly. - from torch.export._trace import _ignore_backend_decomps - - with _ignore_backend_decomps(): - gm, graph_signature = aot_export_module( - self.graph_module, - fake_args, - decompositions=decomp_table, - trace_joint=False, - ) - - # Update the signatures with the new placeholder names in case they - # changed when calling aot_export - def update_arg(old_arg, new_ph): - if isinstance(old_arg, ConstantArgument): - return old_arg - elif isinstance(old_arg, TensorArgument): - return TensorArgument(name=new_ph.name) - elif isinstance(old_arg, SymIntArgument): - return SymIntArgument(name=new_ph.name) - raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") - - new_placeholders = _get_placeholders(gm) - new_outputs = list(gm.graph.nodes)[-1].args[0] - - # rename the placeholders - assert len(new_placeholders) == len(old_placeholders) - for old_ph, new_ph in zip(old_placeholders, new_placeholders): - new_ph.name = new_ph.target = old_ph.name - - # handle name collisions with newly decomposed graph nodes - name_map = {ph.name: ph.name for ph in new_placeholders} - for node in gm.graph.nodes: - if node.op == "placeholder": - continue - node.name = _rename_without_collisions(name_map, node.name, node.name) - - # propagate names to higher order op subgraphs - _name_hoo_subgraph_placeholders(gm) - - # To match the output target with correct input for input mutations - # need to find the old to new placeholder map - old_new_placeholder_map = { - spec.arg.name: new_placeholders[i].name - for i, spec in enumerate(self.graph_signature.input_specs) - if not isinstance(spec.arg, ConstantArgument) - } - - input_specs = [ - InputSpec( - spec.kind, - update_arg(spec.arg, new_placeholders[i]), - spec.target, - spec.persistent, - ) - for i, spec in enumerate(self.graph_signature.input_specs) - ] - output_specs = [ - OutputSpec( - spec.kind, - update_arg(spec.arg, new_outputs[i]), - old_new_placeholder_map.get(spec.target, spec.target), - ) - for i, spec in enumerate(self.graph_signature.output_specs) - ] - - assert len(new_placeholders) == len(old_placeholders) - - new_graph_signature = ExportGraphSignature( - input_specs=input_specs, output_specs=output_specs - ) - # NOTE: aot_export adds symint metadata for placeholders with int - # values; since these become specialized, we replace such metadata with - # the original values. - # Also, set the param/buffer metadata back to the placeholders. - for old_node, new_node in zip(old_placeholders, new_placeholders): - if not isinstance(old_node.meta["val"], torch.Tensor): - new_node.meta["val"] = old_node.meta["val"] - - if ( - new_node.target in new_graph_signature.inputs_to_parameters - or new_node.target in new_graph_signature.inputs_to_buffers - ): - for k, v in old_node.meta.items(): - new_node.meta[k] = v - - # TODO unfortunately preserving graph-level metadata is not - # working well with aot_export. So we manually copy it. - # (The node-level meta is addressed above.) - gm.meta.update(self.graph_module.meta) - - new_range_constraints = _get_updated_range_constraints( - gm, - self.range_constraints, - _is_executorch=False, - ) - - constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap()) - for k, v in constants.items(): - assert k not in self.constants - self.constants[k] = v - - from torch._dynamo import config as _dynamo_config - from torch._export.passes._node_metadata_hook import ( - _node_metadata_hook, - _set_node_metadata_hook, - ) - - if not _dynamo_config.do_not_emit_runtime_asserts: - stack_trace = ( - 'File "torch/fx/passes/runtime_assert.py", line 24, ' - "in insert_deferred_runtime_asserts" - ) - shape_env = _get_shape_env(gm) - if shape_env is not None: - with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - insert_deferred_runtime_asserts( - gm, - shape_env, - f"exported program: {first_call_function_nn_module_stack(gm.graph)}", - export=True, - ) - - exported_program = ExportedProgram( - root=gm, - graph=gm.graph, - graph_signature=new_graph_signature, - state_dict=self.state_dict, - range_constraints=new_range_constraints, - module_call_graph=copy.deepcopy(self.module_call_graph), - example_inputs=self.example_inputs, - verifier=self.verifier, - constants=self.constants, + return _decompose_exported_program( + self, + decomp_table=decomp_table, + joint_loss_index=None, ) - return exported_program def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": pm = PassManager(list(passes)) From bafd68b4fcc049264c69d33f373fd0877acdb052 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 19 Jun 2024 17:51:32 +0000 Subject: [PATCH 62/64] [inductor] fix windows python module ext and func export declaration (#129059) I have run the first inductor case on Windows base on the exploration code: https://github.com/pytorch/pytorch/pull/128330 Due to some fundamental PR still need pass `fb_code`: https://github.com/pytorch/pytorch/pull/128303 This PR would land some part of exploration code: 1. Fix Windows python module ext type: pyd. 2. Add function export declaration for Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129059 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/codegen/cpp.py | 9 ++++++++- torch/_inductor/cpp_builder.py | 21 ++++++--------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 2c800b41adff..8c171cf96676 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -73,6 +73,7 @@ value_to_cpp, ) +_IS_WINDOWS = sys.platform == "win32" schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} @@ -3941,6 +3942,9 @@ def get_num_args(self): args_num = len(arg_defs) return args_num + def get_export_declaration(self): + return "__declspec(dllexport)" if _IS_WINDOWS else "" + def codegen_group(self, name=None) -> str: self.stack.close() if not self.scheduled_nodes: @@ -3960,7 +3964,10 @@ def codegen_group(self, name=None) -> str: kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name arg_defs, _, _ = self.args.cpp_argdefs() arg_defs = ",\n".ljust(25).join(arg_defs) - code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})') + func_export_decl = self.get_export_declaration() + code.writeline( + f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' + ) # 3. Function body with code.indent(): diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index a574b3347342..4ac464ae0c14 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1006,11 +1006,11 @@ class CppBuilder: 3. Final target file: output_dir/name.ext """ - def get_shared_lib_ext(self) -> str: - SHARED_LIB_EXT = ".dll" if _IS_WINDOWS else ".so" + def __get_python_module_ext(self) -> str: + SHARED_LIB_EXT = ".pyd" if _IS_WINDOWS else ".so" return SHARED_LIB_EXT - def get_object_ext(self) -> str: + def __get_object_ext(self) -> str: EXT = ".obj" if _IS_WINDOWS else ".o" return EXT @@ -1048,7 +1048,9 @@ def __init__( self._compile_only = BuildOption.get_compile_only() file_ext = ( - self.get_object_ext() if self._compile_only else self.get_shared_lib_ext() + self.__get_object_ext() + if self._compile_only + else self.__get_python_module_ext() ) self._target_file = os.path.join(self._output_dir, f"{self._name}{file_ext}") @@ -1157,17 +1159,6 @@ def format_build_command( def get_target_file_path(self): return self._target_file - def convert_to_cpp_extension_args(self): - include_dirs = self._include_dirs_args - cflags = ( - self._cflags_args - + self._definations_args - + self._passthough_parameters_args - ) - ldflags = self._ldflags_args + self._libraries_args + self._libraries_dirs_args - - return include_dirs, cflags, ldflags - def build(self) -> Tuple[int, str]: """ It is must need a temperary directory to store object files in Windows. From b5d541609d1884bafc06715997bb2bbeadadae76 Mon Sep 17 00:00:00 2001 From: Aaron Enye Shi Date: Wed, 19 Jun 2024 18:05:39 +0000 Subject: [PATCH 63/64] [Memory Snapshot] Add recordAnnotations to capture record_function annotations (#129072) Summary: Add new traceEvents into Memory Snapshot for record_function annotations. These will capture both the profiler's step annotation as well as user annotations. Test Plan: CI Pulled By: aaronenyeshi Differential Revision: D55941362 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129072 Approved by: https://github.com/zdevito --- c10/core/Allocator.h | 1 + c10/cuda/CUDACachingAllocator.cpp | 10 ++++++++ c10/cuda/CUDACachingAllocator.h | 10 ++++++-- torch/csrc/cuda/Module.cpp | 27 ++++++++++++++++++++++ torch/csrc/cuda/memory_snapshot.cpp | 3 +++ torch/csrc/profiler/combined_traceback.cpp | 18 +++++++++++---- torch/csrc/profiler/combined_traceback.h | 5 ++++ 7 files changed, 68 insertions(+), 6 deletions(-) diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h index 412412557a0d..929e1243034c 100644 --- a/c10/core/Allocator.h +++ b/c10/core/Allocator.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 11bea6056e9d..d90f51671b18 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -958,6 +958,10 @@ class DeviceCachingAllocator { } } + void recordAnnotation(const std::shared_ptr& name) { + record_trace(TraceEntry::USER_DEFINED, 0, 0, nullptr, 0, name); + } + bool isHistoryEnabled() { return record_history; } @@ -3026,6 +3030,12 @@ class NativeCachingAllocator : public CUDAAllocator { } } + void recordAnnotation(const std::shared_ptr& name) override { + for (auto& allocator : device_allocator) { + allocator->recordAnnotation(name); + } + } + bool isHistoryEnabled() override { c10::DeviceIndex device = 0; C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 438ed8d77f75..1109965c08fd 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -170,8 +170,9 @@ struct TraceEntry { SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments) SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace // events - OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free - // bytes reported by cuda) + OOM, // the allocator threw an OutOfMemoryError (addr_ is the amount of free + // bytes reported by cuda) + USER_DEFINED // a call made from user defined API such as record_function }; TraceEntry( Action action, @@ -289,6 +290,7 @@ class CUDAAllocator : public Allocator { CreateContextFn context_recorder, size_t alloc_trace_max_entries, RecordContext when) = 0; + virtual void recordAnnotation(const std::shared_ptr& name){}; virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0; // Attached AllocatorTraceTracker callbacks will be called while the @@ -428,6 +430,10 @@ inline void recordHistory( enabled, context_recorder, alloc_trace_max_entries, when); } +inline void recordAnnotation(const std::shared_ptr& name) { + return get()->recordAnnotation(name); +} + inline bool isHistoryEnabled() { return get()->isHistoryEnabled(); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 4197c2aa5e81..4bdf2fd8ba82 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -738,6 +740,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { py::str snapshot_s = "snapshot"; py::str oom_s = "oom"; py::str device_free_s = "device_free"; + py::str user_defined_s = "user_defined"; using namespace c10::cuda::CUDACachingAllocator; @@ -761,6 +764,8 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { return segment_unmap_s; case TraceEntry::SEGMENT_MAP: return segment_map_s; + case TraceEntry::USER_DEFINED: + return user_defined_s; } throw std::runtime_error("unreachable"); }; @@ -962,6 +967,28 @@ static void registerCudaDeviceProperties(PyObject* module) { const std::string&, size_t)>(torch::cuda::_record_memory_history)); + // Save user annotations to CCA memory snapshot tool + at::addThreadLocalCallback(at::RecordFunctionCallback( + [](const at::RecordFunction& fn) -> std::unique_ptr { + if (fn.scope() != at::RecordScope::USER_SCOPE) { + return nullptr; // only record user-defined scopes. + } + unwind::Frame frame{fn.name(), "START", 0}; + auto r = std::make_shared(); + r->recordUserDefinedFrame(frame); + c10::cuda::CUDACachingAllocator::recordAnnotation(r); + return nullptr; + }, + [](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) { + if (fn.scope() != at::RecordScope::USER_SCOPE) { + return; // only record user-defined scopes. + } + unwind::Frame frame{fn.name(), "END", 0}; + auto r = std::make_shared(); + r->recordUserDefinedFrame(frame); + c10::cuda::CUDACachingAllocator::recordAnnotation(r); + })); + m.def("_cuda_isHistoryEnabled", []() { return c10::cuda::CUDACachingAllocator::isHistoryEnabled(); }); diff --git a/torch/csrc/cuda/memory_snapshot.cpp b/torch/csrc/cuda/memory_snapshot.cpp index 82696abaee22..ca9d7985bce9 100644 --- a/torch/csrc/cuda/memory_snapshot.cpp +++ b/torch/csrc/cuda/memory_snapshot.cpp @@ -275,6 +275,7 @@ std::string _memory_snapshot_pickled() { IValue snapshot_s = "snapshot"; IValue oom_s = "oom"; IValue device_free_s = "device_free"; + IValue user_defined_s = "user_defined"; using namespace c10::cuda::CUDACachingAllocator; @@ -298,6 +299,8 @@ std::string _memory_snapshot_pickled() { return segment_unmap_s; case TraceEntry::SEGMENT_MAP: return segment_map_s; + case TraceEntry::USER_DEFINED: + return user_defined_s; } throw std::runtime_error("unreachable"); }; diff --git a/torch/csrc/profiler/combined_traceback.cpp b/torch/csrc/profiler/combined_traceback.cpp index c727f58d5284..63d1641d1a24 100644 --- a/torch/csrc/profiler/combined_traceback.cpp +++ b/torch/csrc/profiler/combined_traceback.cpp @@ -91,8 +91,10 @@ SymbolizedTracebacks symbolize( for (const auto& e : to_symbolize) { if (e->python_) { if (cur_python != e->python_ && !cur_py_frames.empty()) { - // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) - cur_python->appendSymbolized(cur_py_frames, r); + if (cur_python) { + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + cur_python->appendSymbolized(cur_py_frames, r); + } cur_py_frames.clear(); } cur_python = e->python_; @@ -105,8 +107,10 @@ SymbolizedTracebacks symbolize( } } if (!cur_py_frames.empty()) { - // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) - cur_python->appendSymbolized(cur_py_frames, r); + if (cur_python) { + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + cur_python->appendSymbolized(cur_py_frames, r); + } cur_py_frames.clear(); } std::vector> python_frame_fragments = @@ -171,6 +175,12 @@ SymbolizedTracebacks symbolize( for (; py_it != py_end; ++py_it) { append_python(*py_it); } + + // Gather all user defined frames + for (const auto& f : sc->user_defined_frames_) { + r.tracebacks.back().push_back(r.all_frames.size()); + r.all_frames.emplace_back(f); + } } return r; } diff --git a/torch/csrc/profiler/combined_traceback.h b/torch/csrc/profiler/combined_traceback.h index e84a51e9c4a3..9048a4b9ece0 100644 --- a/torch/csrc/profiler/combined_traceback.h +++ b/torch/csrc/profiler/combined_traceback.h @@ -58,10 +58,15 @@ struct TORCH_API CapturedTraceback : public c10::GatheredContext { int traversePython(visitproc visit, void* arg); int clearPython(); + void recordUserDefinedFrame(const unwind::Frame& frame) { + user_defined_frames_.push_back(frame); + } + private: std::vector frames_; std::vector cpp_frames_; std::vector script_frames_; + std::vector user_defined_frames_; friend TORCH_API SymbolizedTracebacks symbolize(const std::vector& to_symbolize); From df94d57c0afad89339d768aa40ef58d91c7e4aae Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 19:04:36 +0000 Subject: [PATCH 64/64] Revert "[export] experimental joint graph API. (#128847)" This reverts commit 0707811286d1846209676435f4f86f2b4b3d1a17. Reverted https://github.com/pytorch/pytorch/pull/128847 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/128847#issuecomment-2179326891)) --- docs/source/export.rst | 2 - test/export/test_experimental.py | 71 ------ torch/export/experimental/__init__.py | 52 ---- torch/export/exported_program.py | 349 ++++++++++++-------------- 4 files changed, 155 insertions(+), 319 deletions(-) delete mode 100644 torch/export/experimental/__init__.py diff --git a/docs/source/export.rst b/docs/source/export.rst index 984394fe254e..29069d3228e4 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -719,5 +719,3 @@ API Reference :members: .. automodule:: torch.export.custom_obj - -.. automodule:: torch.export.experimental diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 3c2c7e332991..0fe6f17db448 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -11,7 +11,6 @@ from torch._functorch.aot_autograd import aot_export_module from torch.export._trace import _convert_ts_to_export_experimental -from torch.export.experimental import _export_forward_backward from torch.testing import FileCheck @@ -195,76 +194,6 @@ def forward(self, x: Dict[str, torch.Tensor]): MDict, ({"0": torch.randn(4), "1": torch.randn(4)},) ) - def test_joint_basic(self) -> None: - class Module(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - self.loss = torch.nn.CrossEntropyLoss() - - def forward(self, x): - return self.loss( - self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0]) - ) - - m = Module() - example_inputs = (torch.randn(3),) - m(*example_inputs) - ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True) - joint_ep = _export_forward_backward(ep) - print(joint_ep) - - """ - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): - # No stacktrace found for following nodes - view: "f32[1, 3]" = torch.ops.aten.view.default(arg3_1, [1, 3]); arg3_1 = None - t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None - addmm: "f32[1, 3]" = torch.ops.aten.addmm.default(arg1_1, view, t); arg1_1 = t = None - view_1: "f32[3]" = torch.ops.aten.view.default(addmm, [3]); addmm = None - _softmax: "f32[3]" = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None - detach_1: "f32[3]" = torch.ops.aten.detach.default(_softmax) - clone: "f32[3]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None - detach_5: "f32[3]" = torch.ops.aten.detach.default(clone); clone = None - _log_softmax: "f32[3]" = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None - detach_12: "f32[3]" = torch.ops.aten.detach.default(_log_softmax) - mul: "f32[3]" = torch.ops.aten.mul.Tensor(_log_softmax, detach_5); _log_softmax = None - sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None - neg: "f32[]" = torch.ops.aten.neg.default(sum_1); sum_1 = None - div: "f32[]" = torch.ops.aten.div.Scalar(neg, 1); neg = None - ones_like: "f32[]" = torch.ops.aten.ones_like.default(div, pin_memory = False, memory_format = torch.preserve_format) - div_1: "f32[]" = torch.ops.aten.div.Scalar(ones_like, 1); ones_like = None - neg_1: "f32[]" = torch.ops.aten.neg.default(div_1); div_1 = None - expand: "f32[3]" = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None - mul_1: "f32[3]" = torch.ops.aten.mul.Tensor(expand, detach_5); expand = detach_5 = None - _log_softmax_backward_data: "f32[3]" = torch.ops.aten._log_softmax_backward_data.default(mul_1, detach_12, 0, torch.float32); mul_1 = detach_12 = None - _softmax_backward_data: "f32[3]" = torch.ops.aten._softmax_backward_data.default(_log_softmax_backward_data, detach_1, 0, torch.float32); _log_softmax_backward_data = detach_1 = None - view_2: "f32[1, 3]" = torch.ops.aten.view.default(_softmax_backward_data, [1, 3]); _softmax_backward_data = None - t_1: "f32[3, 1]" = torch.ops.aten.t.default(view_2) - mm: "f32[3, 3]" = torch.ops.aten.mm.default(t_1, view); t_1 = view = None - t_2: "f32[3, 3]" = torch.ops.aten.t.default(mm); mm = None - sum_2: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None - view_3: "f32[3]" = torch.ops.aten.view.default(sum_2, [3]); sum_2 = None - t_3: "f32[3, 3]" = torch.ops.aten.t.default(t_2); t_2 = None - return (div, t_3, view_3) - - Graph signature: ExportGraphSignature( - input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='lifted_tensor_0', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None, persistent=None) - ], - output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='div'), target=None), - OutputSpec(kind=, arg=TensorArgument(name='t_3'), target='linear.weight'), - OutputSpec(kind=, arg=TensorArgument(name='view_3'), target='linear.bias') - ] - ) - Range constraints: {} - """ - if __name__ == "__main__": run_tests() diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py deleted file mode 100644 index 11bfdc1d0d12..000000000000 --- a/torch/export/experimental/__init__.py +++ /dev/null @@ -1,52 +0,0 @@ -import copy - -import torch -from torch.export import ExportedProgram -from torch.export.exported_program import ( - _decompose_exported_program, - _get_updated_range_constraints, -) -from torch.export.graph_signature import ( - ConstantArgument, - ExportGraphSignature, - InputKind, - InputSpec, - OutputKind, - OutputSpec, - SymIntArgument, - TensorArgument, -) - - -def _remove_detach_pass(gm: torch.fx.GraphModule, sig: ExportGraphSignature) -> None: - with gm._set_replace_hook(sig.get_replace_hook()): - for node in list(reversed(gm.graph.nodes)): - if node.op != "call_function": - continue - if ( - node.target == torch.ops.aten.detach.default - and len(node.users) == 1 - and next(iter(node.users)).target == torch.ops.aten.detach.default - ): - next(iter(node.users)).replace_all_uses_with(node) - - gm.graph.eliminate_dead_code() - gm.recompile() - - -def _export_forward_backward( - ep: ExportedProgram, joint_loss_index: int = 0 -) -> ExportedProgram: - """ - WARNING: This API is highly unstable and will be subject to change in the future. - """ - from torch._decomp import core_aten_decompositions - - ep = _decompose_exported_program( - ep, decomp_table=core_aten_decompositions(), joint_loss_index=joint_loss_index - ) - gm = copy.deepcopy(ep.graph_module) - new_graph_signature = copy.deepcopy(ep.graph_signature) - _remove_detach_pass(gm, new_graph_signature) - - return ep._update(gm, new_graph_signature) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 59618a3cfb10..280d3da5ad68 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -189,196 +189,6 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: subgraph.recompile() -def _decompose_exported_program( - ep, - *, - decomp_table: Dict[torch._ops.OperatorBase, Callable], - joint_loss_index: Optional[int], -): - from torch._export.passes.lift_constants_pass import ( - ConstantAttrMap, - lift_constants_pass, - ) - from torch._functorch.aot_autograd import aot_export_module - - old_placeholders = [ - node for node in ep.graph_module.graph.nodes if node.op == "placeholder" - ] - fake_args = [node.meta["val"] for node in old_placeholders] - - buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()] - for name in buffers_to_remove: - delattr(ep.graph_module, name) - # TODO(zhxhchen17) Return the new graph_signature directly. - from torch.export._trace import _ignore_backend_decomps - - with _ignore_backend_decomps(): - gm, graph_signature = aot_export_module( - ep.graph_module, - fake_args, - decompositions=decomp_table, - trace_joint=True if joint_loss_index is not None else False, - output_loss_index=joint_loss_index - if joint_loss_index is not None - else None, - ) - - # Update the signatures with the new placeholder names in case they - # changed when calling aot_export - def update_arg(old_arg, new_ph): - if isinstance(old_arg, ConstantArgument): - return old_arg - elif isinstance(old_arg, TensorArgument): - return TensorArgument(name=new_ph.name) - elif isinstance(old_arg, SymIntArgument): - return SymIntArgument(name=new_ph.name) - raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") - - new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] - new_outputs = list(gm.graph.nodes)[-1].args[0] - - # rename the placeholders - assert len(new_placeholders) == len(old_placeholders) - for old_ph, new_ph in zip(old_placeholders, new_placeholders): - new_ph.name = new_ph.target = old_ph.name - - # handle name collisions with newly decomposed graph nodes - name_map = {ph.name: ph.name for ph in new_placeholders} - for node in gm.graph.nodes: - if node.op == "placeholder": - continue - node.name = _rename_without_collisions(name_map, node.name, node.name) - - # propagate names to higher order op subgraphs - _name_hoo_subgraph_placeholders(gm) - - # To match the output target with correct input for input mutations - # need to find the old to new placeholder map - old_new_placeholder_map = { - spec.arg.name: new_placeholders[i].name - for i, spec in enumerate(ep.graph_signature.input_specs) - if not isinstance(spec.arg, ConstantArgument) - } - - input_specs = [ - InputSpec( - spec.kind, - update_arg(spec.arg, new_placeholders[i]), - spec.target, - spec.persistent, - ) - for i, spec in enumerate(ep.graph_signature.input_specs) - ] - output_specs = [ - OutputSpec( - spec.kind, - update_arg(spec.arg, new_outputs[i]), - old_new_placeholder_map.get(spec.target, spec.target), - ) - for i, spec in enumerate(ep.graph_signature.output_specs) - ] - - if joint_loss_index is not None: - assert graph_signature.backward_signature is not None - gradients = graph_signature.backward_signature.gradients_to_user_inputs - assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs) - specs = { - graph_signature.user_inputs[i]: spec - for i, spec in enumerate(ep.graph_signature.input_specs) - if isinstance(spec.arg, TensorArgument) - } - for i, node in enumerate(new_outputs[len(output_specs) :]): - source = gradients[node.name] - spec = specs[source] # type: ignore[index] - if spec.kind == InputKind.PARAMETER: - kind = OutputKind.GRADIENT_TO_PARAMETER - target = spec.target - elif spec.kind == InputKind.USER_INPUT: - kind = OutputKind.GRADIENT_TO_USER_INPUT - target = source - else: - raise AssertionError(f"Unknown input kind: {spec.kind}") - output_specs.append( - OutputSpec( - kind, - TensorArgument(name=node.name), - target, - ) - ) - - assert len(new_placeholders) == len(old_placeholders) - - new_graph_signature = ExportGraphSignature( - input_specs=input_specs, output_specs=output_specs - ) - # NOTE: aot_export adds symint metadata for placeholders with int - # values; since these become specialized, we replace such metadata with - # the original values. - # Also, set the param/buffer metadata back to the placeholders. - for old_node, new_node in zip(old_placeholders, new_placeholders): - if not isinstance(old_node.meta["val"], torch.Tensor): - new_node.meta["val"] = old_node.meta["val"] - - if ( - new_node.target in new_graph_signature.inputs_to_parameters - or new_node.target in new_graph_signature.inputs_to_buffers - ): - for k, v in old_node.meta.items(): - new_node.meta[k] = v - - # TODO unfortunately preserving graph-level metadata is not - # working well with aot_export. So we manually copy it. - # (The node-level meta is addressed above.) - gm.meta.update(ep.graph_module.meta) - - new_range_constraints = _get_updated_range_constraints( - gm, - ep.range_constraints, - _is_executorch=False, - ) - - constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap()) - for k, v in constants.items(): - assert k not in ep.constants - ep.constants[k] = v - - from torch._dynamo import config as _dynamo_config - from torch._export.passes._node_metadata_hook import ( - _node_metadata_hook, - _set_node_metadata_hook, - ) - - if not _dynamo_config.do_not_emit_runtime_asserts: - stack_trace = ( - 'File "torch/fx/passes/runtime_assert.py", line 24, ' - "in insert_deferred_runtime_asserts" - ) - shape_env = _get_shape_env(gm) - if shape_env is not None: - with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - insert_deferred_runtime_asserts( - gm, - shape_env, - f"exported program: {first_call_function_nn_module_stack(gm.graph)}", - export=True, - ) - - exported_program = ExportedProgram( - root=gm, - graph=gm.graph, - graph_signature=new_graph_signature, - state_dict=ep.state_dict, - range_constraints=new_range_constraints, - module_call_graph=copy.deepcopy(ep.module_call_graph), - example_inputs=ep.example_inputs, - verifier=ep.verifier, - constants=ep.constants, - ) - return exported_program - - class ExportedProgram: """ Package of a program from :func:`export`. It contains @@ -727,15 +537,166 @@ def run_decompositions( For now, we do not decompose joint graphs. """ from torch._decomp import core_aten_decompositions + from torch._export.passes.lift_constants_pass import ( + ConstantAttrMap, + lift_constants_pass, + ) + from torch._functorch.aot_autograd import aot_export_module + + def _get_placeholders(gm): + placeholders = [] + for node in gm.graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return placeholders if decomp_table is None: decomp_table = core_aten_decompositions() - return _decompose_exported_program( - self, - decomp_table=decomp_table, - joint_loss_index=None, + old_placeholders = _get_placeholders(self.graph_module) + fake_args = [node.meta["val"] for node in old_placeholders] + + buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()] + for name in buffers_to_remove: + delattr(self.graph_module, name) + # TODO(zhxhchen17) Return the new graph_signature directly. + from torch.export._trace import _ignore_backend_decomps + + with _ignore_backend_decomps(): + gm, graph_signature = aot_export_module( + self.graph_module, + fake_args, + decompositions=decomp_table, + trace_joint=False, + ) + + # Update the signatures with the new placeholder names in case they + # changed when calling aot_export + def update_arg(old_arg, new_ph): + if isinstance(old_arg, ConstantArgument): + return old_arg + elif isinstance(old_arg, TensorArgument): + return TensorArgument(name=new_ph.name) + elif isinstance(old_arg, SymIntArgument): + return SymIntArgument(name=new_ph.name) + raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") + + new_placeholders = _get_placeholders(gm) + new_outputs = list(gm.graph.nodes)[-1].args[0] + + # rename the placeholders + assert len(new_placeholders) == len(old_placeholders) + for old_ph, new_ph in zip(old_placeholders, new_placeholders): + new_ph.name = new_ph.target = old_ph.name + + # handle name collisions with newly decomposed graph nodes + name_map = {ph.name: ph.name for ph in new_placeholders} + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # To match the output target with correct input for input mutations + # need to find the old to new placeholder map + old_new_placeholder_map = { + spec.arg.name: new_placeholders[i].name + for i, spec in enumerate(self.graph_signature.input_specs) + if not isinstance(spec.arg, ConstantArgument) + } + + input_specs = [ + InputSpec( + spec.kind, + update_arg(spec.arg, new_placeholders[i]), + spec.target, + spec.persistent, + ) + for i, spec in enumerate(self.graph_signature.input_specs) + ] + output_specs = [ + OutputSpec( + spec.kind, + update_arg(spec.arg, new_outputs[i]), + old_new_placeholder_map.get(spec.target, spec.target), + ) + for i, spec in enumerate(self.graph_signature.output_specs) + ] + + assert len(new_placeholders) == len(old_placeholders) + + new_graph_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) + # NOTE: aot_export adds symint metadata for placeholders with int + # values; since these become specialized, we replace such metadata with + # the original values. + # Also, set the param/buffer metadata back to the placeholders. + for old_node, new_node in zip(old_placeholders, new_placeholders): + if not isinstance(old_node.meta["val"], torch.Tensor): + new_node.meta["val"] = old_node.meta["val"] + + if ( + new_node.target in new_graph_signature.inputs_to_parameters + or new_node.target in new_graph_signature.inputs_to_buffers + ): + for k, v in old_node.meta.items(): + new_node.meta[k] = v + + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + gm.meta.update(self.graph_module.meta) + + new_range_constraints = _get_updated_range_constraints( + gm, + self.range_constraints, + _is_executorch=False, + ) + + constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap()) + for k, v in constants.items(): + assert k not in self.constants + self.constants[k] = v + + from torch._dynamo import config as _dynamo_config + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + + if not _dynamo_config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=new_graph_signature, + state_dict=self.state_dict, + range_constraints=new_range_constraints, + module_call_graph=copy.deepcopy(self.module_call_graph), + example_inputs=self.example_inputs, + verifier=self.verifier, + constants=self.constants, ) + return exported_program def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": pm = PassManager(list(passes))