-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Memory optimization for DSD for TorchTune LoRA #134025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134025
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 80cef74 with merge base 333890b ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…2323) Summary: **Context:** Currently we have a helper to print out AtenTensor in [shim_common.cpp](https://github.com/pytorch/pytorch/blob/v2.4.0-rc4/torch/csrc/inductor/aoti_torch/shim_common.cpp#L866) The way we were using this function was a “manual” process. We inject this function into the generated output.cpp file, and recompile and reload the file. This diff automates the printing value process. **Changes:** 1. Added a simple initial debug printer helper to print out tensor values 2. Added a filter option to selectively dump tensor values. **Usage:** Sample cmd : ``` AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor, +schedule, output_code" python test/inductor/test_aot_inductor.py -k test_addmm_abi_compatible_cuda ``` Sample outputs : ``` [ before_launch - triton_poi_fused_0 - buf0 ]: 0.6331 1.6358 -0.3459 1.0196 -0.4122 1.4279 [ CUDAFloatType{6} ] Min value: -0.412198 Max value: 1.63582 Device: cuda:0 Size: [6] Stride: [1] Dtype: float Layout: Strided Number of elements: 6 Is contiguous: 1 Requires grad: 0 [ after_launch - triton_poi_fused_0 - buf0 ]: 0.6331 1.6358 -0.3459 1.0196 -0.4122 1.4279 [ CUDAFloatType{6} ] Min value: -0.412198 Max value: 1.63582 Device: cuda:0 Size: [6] Stride: [1] Dtype: float Layout: Strided Number of elements: 6 Is contiguous: 1 Requires grad: 0 [ before_launch - aoti_torch_cuda_addmm_out - buf1 ]: Min value: -2.25655 Max value: 2.32996 Device: cuda:0 Size: [16, 6] Stride: [6, 1] Dtype: float Layout: Strided Number of elements: 96 Is contiguous: 1 Requires grad: 0 [ before_launch - aoti_torch_cuda_addmm_out - buf0 ]: 0.6331 1.6358 -0.3459 1.0196 -0.4122 1.4279 [ CUDAFloatType{6} ] Min value: -0.412198 Max value: 1.63582 Device: cuda:0 Size: [6] Stride: [1] Dtype: float Layout: Strided Number of elements: 6 Is contiguous: 1 Requires grad: 0 [ after_launch - aoti_torch_cuda_addmm_out - buf1 ]: Min value: -12.0839 Max value: 11.6878 Device: cuda:0 Size: [16, 6] Stride: [6, 1] Dtype: float Layout: Strided Number of elements: 96 Is contiguous: 1 Requires grad: 0 [ after_launch - aoti_torch_cuda_addmm_out - buf0 ]: 0.6331 1.6358 -0.3459 1.0196 -0.4122 1.4279 [ CUDAFloatType{6} ] Min value: -0.412198 Max value: 1.63582 Device: cuda:0 Size: [6] Stride: [1] Dtype: float Layout: Strided Number of elements: 6 Is contiguous: 1 Requires grad: 0 stats [('calls_captured', 1), ('unique_graphs', 1)] inductor [('pattern_matcher_count', 2), ('pattern_matcher_nodes', 2), ('extern_calls', 2)] . ---------------------------------------------------------------------- Ran 1 test in 10.867s OK ``` The user is able to filter kernel names to print out values by specifying env var `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT` and see choices of kernel names in a log message like below: ``` torch/_inductor/graph.py:1642] Finished codegen for all nodes. The list of kernel names available: ['triton_poi_fused_0', 'aoti_torch_cuda_addmm_out'] ``` In the follow-up diff, will add `torch.save()` to dump/save the intermediate tensors into individual `.pt` files that can be further `torch.load()`. Test Plan: Run Unit Tests in OSS: (similar cmd as mentioned above in the usage part) `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor, output_code" python test/inductor/test_aot_inductor.py -k test_addmm_abi_compatible_cuda` Differential Revision: D60538496 Pull Request resolved: pytorch#132323 Approved by: https://github.com/ColinPeppler
if local_state is None: | ||
continue | ||
elif isinstance(local_state, DTensor): | ||
local_state_dict[key] = (local_state, full_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
postpone the full_tensor generation to avoid memory cost overlap
_IncompatibleKeys, | ||
_state_dict_fn(model, "load_state_dict")( | ||
state_dict=state_dict, strict=info.strict | ||
state_dict=state_dict, strict=info.strict, assign=assign |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assign = True
could avoid the memory cost from model.to_empty(device=device)
for meta device
TorchTune uses assign=True already.
load_state_dict
takes assign=False
by default, only set assign = True
when find device is meta device
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Optimize memory cost at [PR#129635](pytorch#129635) There are 2 main part of the optimization here: 1. optimize the tensor distributing part, postpone the full_tensor generation, which avoids the memory overlap, saves around 50% peak memory at 2 param test case. 2. apply `assign=True` for the `load_state_dict`, saves memory cost at the state dict loading by assigning the input param, around 50% peak memory at loading part. Future work: Memory optimization to the opt will be conducted in the next PR Pull Request resolved: pytorch#134025 Approved by: https://github.com/fegin Co-authored-by: Rachel Guo <guorachel@meta.com>
Optimize memory cost at PR#129635
There are 2 main part of the optimization here:
assign=True
for theload_state_dict
, saves memory cost at the state dict loading by assigning the input param, around 50% peak memory at loading part.Future work:
Memory optimization to the opt will be conducted in the next PR
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn