fix(utils): propagate non_blocking in TorchAOBaseTensor._to_copy and _get_to_kwargs#4297
fix(utils): propagate non_blocking in TorchAOBaseTensor._to_copy and _get_to_kwargs#4297Dev-next-gen wants to merge 1 commit intopytorch:mainfrom
Conversation
…_get_to_kwargs ## Problem `_get_to_kwargs` explicitly discarded the `non_blocking` argument parsed from `torch._C._nn._parse_to`, with a comment saying it is "not very useful for most tensor subclasses". As a result, any call to `tensor.to(device, non_blocking=True)` on a `TorchAOBaseTensor` subclass silently became a blocking transfer at the inner-tensor level. This matters in practice for async CPU→GPU offloading workflows such as `diffusers` `enable_group_offload(use_stream=True)`: the diffusers hook schedules copies with `non_blocking=True` so that the transfer stream and the compute stream can overlap. Because the flag was dropped, all copies became blocking, negating the overlap benefit. On AMD ROCm (gfx1xxx) the missing non_blocking also interacts with a separate stream-ordering race (fixed in huggingface/diffusers#13502): the default stream can race ahead of "blocking" copies that the OS scheduler hasn't committed yet, producing device-mismatch errors in the first matmul. ## Fix 1. `_get_to_kwargs`: include `non_blocking` in the returned kwargs dict. 2. `TorchAOBaseTensor._to_copy.default`: pop `non_blocking` from kwargs and forward it to every inner `.to()` call for both `tensor_data_names` and `optional_tensor_data_names`. The change is backward-compatible: when `non_blocking=False` (the default), behaviour is identical to before. ## Tested on - 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7 - FLUX.1-dev int8 (`Int8WeightOnlyConfig`) with `enable_group_offload(use_stream=True)` - Companion fix in diffusers: huggingface/diffusers#13502
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4297
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @Dev-next-gen! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Problem
_get_to_kwargsexplicitly discarded thenon_blockingargument parsed fromtorch._C._nn._parse_to, with a comment saying it is "not very useful for most tensor subclasses". As a result, any call totensor.to(device, non_blocking=True)on aTorchAOBaseTensorsubclass silently became a blocking transfer at the inner-tensor level.This matters in practice for async CPU→GPU offloading workflows such as
diffusersenable_group_offload(use_stream=True): the diffusers hook schedules copies withnon_blocking=Trueso that the transfer stream and the compute stream can overlap. Because the flag was dropped, all copies became blocking, negating the overlap benefit.On AMD ROCm (gfx1xxx) the missing
non_blockingalso interacts with a separate stream-ordering race (fixed in huggingface/diffusers#13502): the default stream can race ahead of "blocking" copies that the OS scheduler hasn't committed yet, producing device-mismatch errors in the first matmul.Fix
_get_to_kwargs: includenon_blockingin the returned kwargs dict.TorchAOBaseTensor._to_copy.default: popnon_blockingfrom kwargs and forward it to every inner.to()call for bothtensor_data_namesandoptional_tensor_data_names.The change is backward-compatible: when
non_blocking=False(the default), behaviour is identical to before.Tested on
| GPU | 5× AMD RX 7800 XT (gfx1101) |
| ROCm | 7.1 |
| PyTorch | 2.7 |
| Model | FLUX.1-dev,
Int8WeightOnlyConfigvia torchao || Config | block-level group offload +
use_stream=True(diffusers) |Related