Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.__future__._overwrite_module_params_on_conversion global flag, and check it in nn.Module._apply() #21613

Closed
wants to merge 60 commits into from

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Jun 10, 2019

#17072 breaks model.to(xla_device), because moving model to XLA device involves changing its parameters' TensorImpl type, and the current implementation of nn.Module.to() doesn't support changing module parameters' TensorImpl type:

# https://github.com/pytorch/pytorch/blob/6dc445e1a84dc5d093d640de54f038f021d13227/torch/nn/modules/module.py#L192-L208
def _apply(self, fn):
    ...
    for param in self._parameters.values():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't
            # want to create copy nodes, so we have to unpack the data.
            param.data = fn(param.data)  # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
            if param._grad is not None:
                param._grad.data = fn(param._grad.data)  # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
   ...

To fix this problem, we decided to gradually change the behavior of nn.Module._apply() from "updating existing parameters in-place" to "using new tensors to overwrite the existing parameters". For any fn passed into _apply() that changes the TensorImpl type of the parameters, the behavior of _apply() will be to overwrite the existing parameters. For any fn that doesn't change the TensorImpl type of the parameters, we check the newly added torch.__future__.get_overwrite_module_params_on_conversion() flag to decide whether _apply() should overwrite the existing parameters or update the existing parameters in-place.

This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.

[xla ci]

cc. @resistor @ailzhang

@yf225 yf225 requested a review from gchanan June 10, 2019 21:31
@pytorchbot pytorchbot added the module: nn Related to torch.nn label Jun 10, 2019
@ailzhang
Copy link
Contributor

@pytorchbot retest this please

@ailzhang
Copy link
Contributor

ailzhang commented Jun 10, 2019

[edit]: in fact that won't work since xla is still carrying the old patch. We'll have to manually test this out.
@yf225 would you mind pushing an empty commit with "[xla ci]" in commit message to trigger the test?

Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed my thinking on this a bit, let me know what you think.

Essentially we argued that we should introduce to_, cpu_ etc. to have the new semantics. And for one release to, cpu, etc. will have the old semantics. But realistically, what are we going to do with to, cpu, etc. later? Every model that uses devices call those functions, but few would probably break with semantics changes, which means we should probably leave the old APIs. But then we have two sets of APIs that do the same thing.

Instead, can we just do essentially what you implementation does here, which is take a parameter to control the behavior. Since XLA is new, we can even make the parameter obvious it only applies to cpu/cuda (do we support amd in the frontend), e.g. force_move_params_cpu_cuda or something.

What do you think?

torch/nn/modules/module.py Outdated Show resolved Hide resolved
def convert(t):
return t.to(device, dtype if t.is_floating_point() else None, non_blocking)

return self._apply(convert, update_params_inplace=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the note above, but we need equivalents for cpu, cuda, etc.

test/test_nn.py Outdated Show resolved Hide resolved
torch/nn/modules/module.py Outdated Show resolved Hide resolved
@pytorchbot pytorchbot added module: autograd Related to torch.autograd, and the autograd engine in general module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries labels Jun 11, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in 6b97279.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 19, 2019
…flag, and check it in `nn.Module._apply()` (#21613)

Summary:
pytorch/pytorch#17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type:
```python
# https://github.com/pytorch/pytorch/blob/6dc445e1a84dc5d093d640de54f038f021d13227/torch/nn/modules/module.py#L192-L208
def _apply(self, fn):
    ...
    for param in self._parameters.values():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't
            # want to create copy nodes, so we have to unpack the data.
            param.data = fn(param.data)  # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
            if param._grad is not None:
                param._grad.data = fn(param._grad.data)  # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
   ...
```

yf225 TODO: fix the description here when we finish the implementation

To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in.

We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model.

This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.

[xla ci]

cc. resistor ailzhang
Pull Request resolved: pytorch/pytorch#21613

Differential Revision: D15895387

Pulled By: yf225

fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
freud14 added a commit to GRAAL-Research/poutyne that referenced this pull request Jun 22, 2019
@apaszke
Copy link
Contributor

apaszke commented Jun 23, 2019

Why do we even have this flag? What's so special about XLA that we cannot overwrite the parameters with new objects, but we can e.g. in the case of CUDA?

NB I would be against adding to_ in addition to to, because even to works in-place on the module (just not on the tensors!), so the semantics would be very confusing.

@gchanan
Copy link
Contributor

gchanan commented Jun 24, 2019

@apaszke: it's the combination of the fact that (1) XLA tensors use a different TensorImpl than CPU/CUDA (which use the same) and (2) we don't have VariableImpl anymore to do the extra indirection (by design). So basically, CPUModule.to('xla') and not overwriting the parameters would require changing the type of the TensorImpl or hacking around with the pyobj.

@gchanan
Copy link
Contributor

gchanan commented Jun 24, 2019

and note we don't really support moving parameters in-place anyway, as demonstrated in the tests, views on such parameters are broken, we don't version count, etc.

So the flag is just a BC thing before we move everything to overwrite semantics, which are well supported.

freud14 added a commit to GRAAL-Research/poutyne that referenced this pull request Jun 25, 2019
ljk53 added a commit to ljk53/pytorch that referenced this pull request Jul 15, 2019
Summary:
We introduced RTTI in recent change: pytorch#21613

For internal mobile build we don't enable '-frtti' yet. This diff is trying to replace
RTTI with alternative approach.

According to dzhulgakov we could compare two tensors' type_id directly in most cases -
which is more strict than comparing TensorImpl subclass type as TensorImpl -> type_id
mapping is 1-to-n but it's more proper for this use case.

The only two cases where we can relax direct type comparison (for legacy reason) are:
1. CPUTensor <-> CUDATensor;
2. SparseCPUTensor <-> SparseCUDATensor;

Differential Revision: D16212472

fbshipit-source-id: 5946ca605e86820329762f84761db9142fd06a29
facebook-github-bot pushed a commit that referenced this pull request Jul 16, 2019
Summary:
We introduced RTTI in recent change: #21613

For internal mobile build we don't enable '-frtti' yet. This diff is trying to replace
RTTI with alternative approach.

According to dzhulgakov we could compare two tensors' type_id directly in most cases -
which is more strict than comparing TensorImpl subclass type as TensorImpl -> type_id
mapping is 1-to-n but it's more proper for this use case.

The only two cases where we can relax direct type comparison (for legacy reason) are:
1. CPUTensor <-> CUDATensor;
2. SparseCPUTensor <-> SparseCUDATensor;
Pull Request resolved: #22773

Differential Revision: D16277696

Pulled By: ljk53

fbshipit-source-id: 043e264fbacc37b7a11af2046983c70ddb62a599
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 16, 2019
Summary:
We introduced RTTI in recent change: pytorch/pytorch#21613

For internal mobile build we don't enable '-frtti' yet. This diff is trying to replace
RTTI with alternative approach.

According to dzhulgakov we could compare two tensors' type_id directly in most cases -
which is more strict than comparing TensorImpl subclass type as TensorImpl -> type_id
mapping is 1-to-n but it's more proper for this use case.

The only two cases where we can relax direct type comparison (for legacy reason) are:
1. CPUTensor <-> CUDATensor;
2. SparseCPUTensor <-> SparseCUDATensor;
Pull Request resolved: pytorch/pytorch#22773

Differential Revision: D16277696

Pulled By: ljk53

fbshipit-source-id: 043e264fbacc37b7a11af2046983c70ddb62a599
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: autograd Related to torch.autograd, and the autograd engine in general module: cpp-extensions Related to torch.utils.cpp_extension module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: pybind Related to our Python bindings / interactions with other Python libraries
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants