Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Module: use swap_tensors for Tensor subclasses #122755

Closed
wants to merge 1 commit into from

Conversation

d4l3k
Copy link
Collaborator

@d4l3k d4l3k commented Mar 27, 2024

This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.

This uses the swap_tensors method to swap all of the tensors not just the .data field.

Test plan:

pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting'
python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

Copy link

pytorch-bot bot commented Mar 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122755

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit ae90732 with merge base 628dcde (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 27, 2024
@@ -802,8 +802,13 @@ def compute_should_use_set_data(tensor, tensor_applied):
with torch.no_grad():
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)

# subclasses may have multiple child tensors so we need to use swap_tensors
is_subclass = isinstance(param_applied, torch.Tensor) and type(param_applied) is not torch.Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @albanD @mikaylagawarecki we are enabling this by default for tensor subclasses, wondering if this make sense or we should only enable this when loading the dtensor module?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there exists non wrapper subclasses where we would still want this behavior, probably not since we already have the future flag

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a full list of subclasses (wrapper and non wrapper)? I can add some more tests/sanity check this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A safer way to check is to utilize this check function, where it only apply for "wrapper" subclass that is well defined already (i.e. DTensor, Float8Tensor)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making the swap_tensors path the default for wrapper subclasses sounds reasonable to me and was on my list. I don't think we should make this the default for non-wrapper subclasses yet (as if I understand correctly, the .data setting is not a problem for non-wrapper subclasses). The is_traceable_wrapper_subclass function Wanchao linked sounds good to me here!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate more on what you mean by it doesn't work with Parameter?

torch.nn.utils.swap_tensors does have certain constraints

  • tensors swapped cannot have python weakrefs to them
  • __slots__ of tensors swapped must match
  • use_count and weak_use_count of THPVariable->cdata associated with python tensor must be 1

# Ensure there are no weakrefs
if weakref.getweakrefs(t1):
raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
if weakref.getweakrefs(t2):
raise RuntimeError("Cannot swap t2 because it has weakref associated with it")
t1_slots = set(copyreg._slotnames(t1.__class__)) # type: ignore[attr-defined]
t2_slots = set(copyreg._slotnames(t2.__class__)) # type: ignore[attr-defined]
if t1_slots != t2_slots:
raise RuntimeError("Cannot swap t1 and t2 if they have different slots")

TORCH_CHECK(
a->cdata->use_count() == 1,
"Expected single reference to a's Tensor object but got ",
a->cdata->use_count());
TORCH_CHECK(
b->cdata->use_count() == 1,
"Expected single reference to b's Tensor object but got ",
b->cdata->use_count());
// weak_use_count() adds 1 if use_count is non-zero
TORCH_CHECK(
a->cdata->weak_use_count() == 1,
"Expected no weakrefs to a's Tensor object but got ",
a->cdata->weak_use_count() - 1);
TORCH_CHECK(
b->cdata->weak_use_count() == 1,
"Expected no weakrefs to b's Tensor object but got ",
b->cdata->weak_use_count() - 1);

The issue with RNN/GRU/LSTM that you might have seen in the failures is that these modules hold weakrefs to some of the params, and is a known issue. #122800 might be a fix for this if CI is happy with it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the exact issue but seems like .to is being called on nn.Parameter (thus passed my initial subclass check)

Parameter causing issues is LayerNorm.weight -- RuntimeError: Expected single reference to a's Tensor object but got 3

<class 'torch.nn.parameter.Parameter'>
E
======================================================================
ERROR: test_transformerencoder_batch_first_False_training_True_enable_nested_tensor_False_cpu (__main__.TestTransformersCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 819, in _apply
    torch.utils.swap_tensors(param, param_applied)
  File "/home/tristanr/pytorch/torch/utils/__init__.py", line 68, in swap_tensors
    torch._C._swap_tensor_impl(t1, t2)
RuntimeError: Expected single reference to a's Tensor object but got 3

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/tristanr/pytorch/torch/testing/_internal/common_utils.py", line 2741, in wrapper
    method(*args, **kwargs)
  File "/home/tristanr/pytorch/torch/testing/_internal/common_device_type.py", line 432, in instantiated_test
    raise rte
  File "/home/tristanr/pytorch/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
    result = test(self, **param_kwargs)
  File "/home/tristanr/pytorch/test/test_transformers.py", line 643, in test_transformerencoder
    _test(batch_first, training, enable_nested_tensor)
  File "/home/tristanr/pytorch/test/test_transformers.py", line 615, in _test
    enable_nested_tensor=enable_nested_tensor).to(device)
  File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 1174, in to
    return self._apply(convert)
  File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 778, in _apply
    module._apply(fn)
  File "/home/tristanr/pytorch/torch/nn/modules/module.py", line 823, in _apply
    raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
RuntimeError: _apply(): Couldn't swap LayerNorm.weight

To execute this test, run the following from the base repo dir:
     python test/test_transformers.py -k test_transformerencoder_batch_first_False_training_True_enable_nested_tensor_False_cpu

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So seems to be a refcount issue -- not sure why LayerNorm.weight has multiple refs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I have tested the behavior of swap_tensors with _apply on all nn.Modules in these tests and the only ones with issues are RNN/GRU/LSTM (due to the weakref issue I mentioned above)

So I do not think it is an issue with the LayerNorm/TransformerEncoder module itself holding multiple refs, though perhaps there is some logic in the test in question that causes this, but this is certainly worth investigating

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #122803 to investigate this

device_mesh,
)
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
output = replica_model(dt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see autocast test works!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, it was a pleasant surprise :)

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for enabling this!

@@ -802,8 +803,12 @@ def compute_should_use_set_data(tensor, tensor_applied):
with torch.no_grad():
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)

# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw sorry I just caught this but wanted to note that this only updates the condition to swap_tensors for parameters, but not their associated .grad field, if it exists.

should_use_swap_tensors is used again below for gradients on L837

might we want to also update the condition for gradients?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching that -- updated!

@albanD albanD removed their request for review March 27, 2024 21:59
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@d4l3k
Copy link
Collaborator Author

d4l3k commented Mar 27, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 27, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@d4l3k d4l3k added the topic: not user facing topic category label Mar 27, 2024
@d4l3k
Copy link
Collaborator Author

d4l3k commented Mar 27, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@d4l3k
Copy link
Collaborator Author

d4l3k commented Mar 27, 2024

Failing test is passing locally. Flaky? Seems unrelated to this change

tristanr@devvm17560 ~/pytorch (tristanr/dmodule_casting)> pytest test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_fft_rfft2_cpu_int32                                 (pytorch-3.10) 
======================================================================================== test session starts ========================================================================================
platform linux -- Python 3.10.13, pytest-7.3.2, pluggy-1.4.0
rootdir: /home/tristanr/pytorch
configfile: pytest.ini
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, rerunfailures-13.0, flakefinder-1.1.0, cpp-2.3.0
collected 6929 items / 6928 deselected / 1 selected                                                                                                                                                 
Running 1 items in this shard

test/inductor/test_torchinductor_opinfo.py .                                                                                                                                                  [100%]

================================================================================ 1 passed, 6928 deselected in 12.67s ================================================================================

Failing on a lot of PRs https://hud.pytorch.org/failure?name=pull%20%2F%20linux-jammy-py3.8-gcc11%20%2F%20test%20(default%2C%202%2C%203%2C%20linux.2xlarge)&jobName=linux-jammy-py3.8-gcc11%20%2F%20test%20(default%2C%202%2C%203%2C%20linux.2xlarge)&failureCaptures=%5B%22inductor%2Ftest_torchinductor_opinfo.py%3A%3ATestInductorOpInfoCPU%3A%3Atest_comprehensive_fft_ihfftn_cpu_int32%22%5D

@awgu
Copy link
Contributor

awgu commented Mar 28, 2024

It looks like the pytorchbot detected those as flaky and will let you merge without more intervention.
#122755 (comment)

@d4l3k d4l3k deleted the tristanr/dmodule_casting branch March 28, 2024 16:53
awgu added a commit to pytorch/torchtitan that referenced this pull request Mar 29, 2024
…eeded anymore"


pytorch/pytorch#122755 is in nightlies. We can remove the global flag now.

[ghstack-poisoned]
awgu added a commit to pytorch/torchtitan that referenced this pull request Mar 29, 2024
…re (#175)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #175

pytorch/pytorch#122755 is in nightlies. We can
remove the global flag now.
pytorchmergebot pushed a commit that referenced this pull request Mar 29, 2024
We do not need to set the flag after #122755.

Pull Request resolved: #122962
Approved by: https://github.com/mikaylagawarecki
wanchaol pushed a commit that referenced this pull request Apr 1, 2024
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.

This uses the `swap_tensors` method to swap all of the tensors not just the .data field.

Test plan:

```
pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting'
python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True
```

Pull Request resolved: #122755
Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki

(cherry picked from commit e6ee832)
huydhn pushed a commit that referenced this pull request Apr 2, 2024
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.

This uses the `swap_tensors` method to swap all of the tensors not just the .data field.

Test plan:

```
pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting'
python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True
```

Pull Request resolved: #122755
Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki

(cherry picked from commit e6ee832)

Co-authored-by: Tristan Rice <rice@fn.lc>
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.

This uses the `swap_tensors` method to swap all of the tensors not just the .data field.

Test plan:

```
pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting'
python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True
```

Pull Request resolved: pytorch#122755
Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants