Skip to content

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Aug 20, 2024

Optimize memory cost at PR#129635

There are 2 main part of the optimization here:

  1. optimize the tensor distributing part, postpone the full_tensor generation, which avoids the memory overlap, saves around 50% peak memory at 2 param test case.
  2. apply assign=True for the load_state_dict, saves memory cost at the state dict loading by assigning the input param, around 50% peak memory at loading part.

Future work:
Memory optimization to the opt will be conducted in the next PR

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134025

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 80cef74 with merge base 333890b (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 20, 2024
@mori360 mori360 added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Aug 20, 2024
@mori360 mori360 changed the title Mem Memory optimization for DSD for TorchTune LoRA Aug 20, 2024
mori360 and others added 2 commits August 20, 2024 14:50
…2323)

Summary:
**Context:**

Currently we have a helper to print out AtenTensor in [shim_common.cpp](https://github.com/pytorch/pytorch/blob/v2.4.0-rc4/torch/csrc/inductor/aoti_torch/shim_common.cpp#L866)

The way we were using this function was a “manual” process. We inject this function into the generated output.cpp file, and recompile and reload the file. This diff automates the printing value process.

**Changes:**

1. Added a simple initial debug printer helper to print out tensor values

2. Added a filter option to selectively dump tensor values.

**Usage:**

Sample cmd :

```
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor, +schedule, output_code"  python test/inductor/test_aot_inductor.py -k test_addmm_abi_compatible_cuda
```

Sample outputs :
```
[  before_launch - triton_poi_fused_0 - buf0  ]:
 0.6331
 1.6358
-0.3459
 1.0196
-0.4122
 1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0

[  after_launch - triton_poi_fused_0 - buf0  ]:
 0.6331
 1.6358
-0.3459
 1.0196
-0.4122
 1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0

[ before_launch - aoti_torch_cuda_addmm_out - buf1  ]:
Min value: -2.25655
Max value: 2.32996
Device: cuda:0
Size: [16, 6]
Stride: [6, 1]
Dtype: float
Layout: Strided
Number of elements: 96
Is contiguous: 1
Requires grad: 0

[  before_launch - aoti_torch_cuda_addmm_out - buf0  ]:
 0.6331
 1.6358
-0.3459
 1.0196
-0.4122
 1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0

[  after_launch - aoti_torch_cuda_addmm_out - buf1  ]:
Min value: -12.0839
Max value: 11.6878
Device: cuda:0
Size: [16, 6]
Stride: [6, 1]
Dtype: float
Layout: Strided
Number of elements: 96
Is contiguous: 1
Requires grad: 0

[  after_launch - aoti_torch_cuda_addmm_out - buf0  ]:
 0.6331
 1.6358
-0.3459
 1.0196
-0.4122
 1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0

stats [('calls_captured', 1), ('unique_graphs', 1)]
inductor [('pattern_matcher_count', 2), ('pattern_matcher_nodes', 2), ('extern_calls', 2)]
.
----------------------------------------------------------------------
Ran 1 test in 10.867s

OK

```

The user is able to filter kernel names to print out values by specifying env var `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT` and see choices of kernel names in a log message like below:
```
torch/_inductor/graph.py:1642] Finished codegen for all nodes. The list of kernel names available: ['triton_poi_fused_0', 'aoti_torch_cuda_addmm_out']

```

In the follow-up diff, will add `torch.save()` to dump/save the intermediate tensors into individual `.pt` files that can be further  `torch.load()`.

Test Plan:
Run Unit Tests in OSS: (similar cmd as mentioned above in the usage part)

 `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor, output_code"  python test/inductor/test_aot_inductor.py -k test_addmm_abi_compatible_cuda`

Differential Revision: D60538496

Pull Request resolved: pytorch#132323
Approved by: https://github.com/ColinPeppler
if local_state is None:
continue
elif isinstance(local_state, DTensor):
local_state_dict[key] = (local_state, full_tensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

postpone the full_tensor generation to avoid memory cost overlap

_IncompatibleKeys,
_state_dict_fn(model, "load_state_dict")(
state_dict=state_dict, strict=info.strict
state_dict=state_dict, strict=info.strict, assign=assign
Copy link
Contributor Author

@mori360 mori360 Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assign = True could avoid the memory cost from model.to_empty(device=device) for meta device
TorchTune uses assign=True already.
load_state_dict takes assign=False by default, only set assign = True when find device is meta device

@mori360 mori360 requested a review from fegin August 21, 2024 18:04
@mori360 mori360 requested review from awgu and weifengpy August 21, 2024 18:04
@mori360 mori360 marked this pull request as ready for review August 21, 2024 20:27
@mori360 mori360 marked this pull request as draft August 23, 2024 19:22
@mori360 mori360 marked this pull request as ready for review August 23, 2024 19:39
@mori360
Copy link
Contributor Author

mori360 commented Aug 26, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@mori360
Copy link
Contributor Author

mori360 commented Aug 26, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Optimize memory cost at [PR#129635](pytorch#129635)

There are 2 main part of the optimization here:
1. optimize the tensor distributing part, postpone the full_tensor generation, which avoids the memory overlap, saves around 50% peak memory at 2 param test case.
2. apply `assign=True` for the `load_state_dict`, saves memory cost at the state dict loading by assigning the input param, around 50% peak memory at loading part.

Future work:
Memory optimization to the opt will be conducted in the next PR

Pull Request resolved: pytorch#134025
Approved by: https://github.com/fegin

Co-authored-by: Rachel Guo <guorachel@meta.com>
@mori360 mori360 deleted the mem branch November 25, 2024 19:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants