New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.Module: use swap_tensors for Tensor subclasses #122755
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122755
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit ae90732 with merge base 628dcde (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/nn/modules/module.py
Outdated
@@ -802,8 +802,13 @@ def compute_should_use_set_data(tensor, tensor_applied): | |||
with torch.no_grad(): | |||
param_applied = fn(param) | |||
p_should_use_set_data = compute_should_use_set_data(param, param_applied) | |||
|
|||
# subclasses may have multiple child tensors so we need to use swap_tensors | |||
is_subclass = isinstance(param_applied, torch.Tensor) and type(param_applied) is not torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @albanD @mikaylagawarecki we are enabling this by default for tensor subclasses, wondering if this make sense or we should only enable this when loading the dtensor module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there exists non wrapper subclasses where we would still want this behavior, probably not since we already have the future flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a full list of subclasses (wrapper and non wrapper)? I can add some more tests/sanity check this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A safer way to check is to utilize this check function, where it only apply for "wrapper" subclass that is well defined already (i.e. DTensor, Float8Tensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making the swap_tensors path the default for wrapper subclasses sounds reasonable to me and was on my list. I don't think we should make this the default for non-wrapper subclasses yet (as if I understand correctly, the .data setting is not a problem for non-wrapper subclasses). The is_traceable_wrapper_subclass
function Wanchao linked sounds good to me here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate more on what you mean by it doesn't work with Parameter
?
torch.nn.utils.swap_tensors
does have certain constraints
- tensors swapped cannot have python weakrefs to them
__slots__
of tensors swapped must matchuse_count
andweak_use_count
ofTHPVariable->cdata
associated with python tensor must be 1
pytorch/torch/utils/__init__.py
Lines 34 to 42 in 5af839f
# Ensure there are no weakrefs | |
if weakref.getweakrefs(t1): | |
raise RuntimeError("Cannot swap t1 because it has weakref associated with it") | |
if weakref.getweakrefs(t2): | |
raise RuntimeError("Cannot swap t2 because it has weakref associated with it") | |
t1_slots = set(copyreg._slotnames(t1.__class__)) # type: ignore[attr-defined] | |
t2_slots = set(copyreg._slotnames(t2.__class__)) # type: ignore[attr-defined] | |
if t1_slots != t2_slots: | |
raise RuntimeError("Cannot swap t1 and t2 if they have different slots") |
Lines 370 to 386 in 5af839f
TORCH_CHECK( | |
a->cdata->use_count() == 1, | |
"Expected single reference to a's Tensor object but got ", | |
a->cdata->use_count()); | |
TORCH_CHECK( | |
b->cdata->use_count() == 1, | |
"Expected single reference to b's Tensor object but got ", | |
b->cdata->use_count()); | |
// weak_use_count() adds 1 if use_count is non-zero | |
TORCH_CHECK( | |
a->cdata->weak_use_count() == 1, | |
"Expected no weakrefs to a's Tensor object but got ", | |
a->cdata->weak_use_count() - 1); | |
TORCH_CHECK( | |
b->cdata->weak_use_count() == 1, | |
"Expected no weakrefs to b's Tensor object but got ", | |
b->cdata->weak_use_count() - 1); |
The issue with RNN/GRU/LSTM that you might have seen in the failures is that these modules hold weakrefs to some of the params, and is a known issue. #122800 might be a fix for this if CI is happy with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure the exact issue but seems like .to is being called on nn.Parameter (thus passed my initial subclass check)
Parameter causing issues is LayerNorm.weight -- RuntimeError: Expected single reference to a's Tensor object but got 3
<class 'torch.nn.parameter.Parameter'>
E
======================================================================
ERROR: test_transformerencoder_batch_first_False_training_True_enable_nested_tensor_False_cpu (__main__.TestTransformersCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 819, in _apply
torch.utils.swap_tensors(param, param_applied)
File "/home/tristanr/pytorch/torch/utils/__init__.py", line 68, in swap_tensors
torch._C._swap_tensor_impl(t1, t2)
RuntimeError: Expected single reference to a's Tensor object but got 3
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/tristanr/pytorch/torch/testing/_internal/common_utils.py", line 2741, in wrapper
method(*args, **kwargs)
File "/home/tristanr/pytorch/torch/testing/_internal/common_device_type.py", line 432, in instantiated_test
raise rte
File "/home/tristanr/pytorch/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
result = test(self, **param_kwargs)
File "/home/tristanr/pytorch/test/test_transformers.py", line 643, in test_transformerencoder
_test(batch_first, training, enable_nested_tensor)
File "/home/tristanr/pytorch/test/test_transformers.py", line 615, in _test
enable_nested_tensor=enable_nested_tensor).to(device)
File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 1174, in to
return self._apply(convert)
File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 778, in _apply
module._apply(fn)
File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 823, in _apply
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
RuntimeError: _apply(): Couldn't swap LayerNorm.weight
To execute this test, run the following from the base repo dir:
python test/test_transformers.py -k test_transformerencoder_batch_first_False_training_True_enable_nested_tensor_False_cpu
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So seems to be a refcount issue -- not sure why LayerNorm.weight has multiple refs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I have tested the behavior of swap_tensors
with _apply
on all nn.Modules in these tests and the only ones with issues are RNN/GRU/LSTM (due to the weakref issue I mentioned above)
So I do not think it is an issue with the LayerNorm/TransformerEncoder module itself holding multiple refs, though perhaps there is some logic in the test in question that causes this, but this is certainly worth investigating
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filed #122803 to investigate this
device_mesh, | ||
) | ||
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): | ||
output = replica_model(dt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to see autocast test works!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, it was a pleasant surprise :)
a5bd7c4
to
befa87c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks for enabling this!
@@ -802,8 +803,12 @@ def compute_should_use_set_data(tensor, tensor_applied): | |||
with torch.no_grad(): | |||
param_applied = fn(param) | |||
p_should_use_set_data = compute_should_use_set_data(param, param_applied) | |||
|
|||
# subclasses may have multiple child tensors so we need to use swap_tensors | |||
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw sorry I just caught this but wanted to note that this only updates the condition to swap_tensors
for parameters, but not their associated .grad
field, if it exists.
should_use_swap_tensors
is used again below for gradients on L837
might we want to also update the condition for gradients?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for catching that -- updated!
befa87c
to
ae90732
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Failing test is passing locally. Flaky? Seems unrelated to this change
|
It looks like the pytorchbot detected those as flaky and will let you merge without more intervention. |
…eeded anymore" pytorch/pytorch#122755 is in nightlies. We can remove the global flag now. [ghstack-poisoned]
…re (#175) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #175 pytorch/pytorch#122755 is in nightlies. We can remove the global flag now.
We do not need to set the flag after #122755. Pull Request resolved: #122962 Approved by: https://github.com/mikaylagawarecki
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors. This uses the `swap_tensors` method to swap all of the tensors not just the .data field. Test plan: ``` pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting' python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True ``` Pull Request resolved: #122755 Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki (cherry picked from commit e6ee832)
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors. This uses the `swap_tensors` method to swap all of the tensors not just the .data field. Test plan: ``` pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting' python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True ``` Pull Request resolved: #122755 Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki (cherry picked from commit e6ee832) Co-authored-by: Tristan Rice <rice@fn.lc>
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors. This uses the `swap_tensors` method to swap all of the tensors not just the .data field. Test plan: ``` pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting' python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True ``` Pull Request resolved: pytorch#122755 Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki
…#122962) We do not need to set the flag after pytorch#122755. Pull Request resolved: pytorch#122962 Approved by: https://github.com/mikaylagawarecki
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.
This uses the
swap_tensors
method to swap all of the tensors not just the .data field.Test plan:
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang