Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor to fake clone snapshot for immutable source of truth #100128

Closed

Conversation

voznesenskym
Copy link
Collaborator

@voznesenskym voznesenskym commented Apr 26, 2023

Stack from ghstack (oldest at bottom):

There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand).

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors.

So, for example, if you do t_ on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time.

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3). Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that x.stride()[0] == 3.

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.

cc @soumith @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 26, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100128

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ba90acc:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

voznesenskym added a commit that referenced this pull request Apr 26, 2023
Finish clone map

Lint

ghstack-source-id: f0b61d0952862be33b7580bb08d656d362d69f64
Pull Request resolved: #100128
…truth"

Finish clone map

Lint

cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Apr 26, 2023
Finish clone map

Lint

ghstack-source-id: 81f7442b002462edbcc6de12fd7abb775ab0887c
Pull Request resolved: #100128

Rm trash
@albanD albanD removed their request for review April 26, 2023 22:19
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Apr 26, 2023
Finish clone map

Lint

ghstack-source-id: 5f02091baaddf615c7e1502069e1c009d7cd3cb0
Pull Request resolved: #100128

Rm trash

Replace w/ update
@@ -87,6 +87,7 @@ class OutputGraphState(NamedTuple):
param_name_to_source: Optional[Dict[str, Source]]
side_effects: SideEffects
timestamp: int
tensor_id_to_fake_clone: Dict[int, torch._subclasses.FakeTensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please use a weakref instead of an id. This will help prevent awful bugs where the input Tensor is not kept live and then an id gets reused.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch/utils/weak.py has the stuff you want

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

@@ -1250,6 +1250,7 @@ def wrap_to_fake_tensor_and_record(
)
if is_tensor and not (static_shapes and source.is_nn_module()):
tx.output.tracked_fakes.append(TrackedFake(fake_e, source, constraint_dims))
tx.output.tensor_id_to_fake_clone[id(e)] = fake_e.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how you're using this, there is no reason to clone the entire fake tensor. Just extract and save the sizes and strides.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can do that.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a fine unblock. Please check comments.

…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Apr 27, 2023
Finish clone map

Lint

ghstack-source-id: 8b1c36ca7c4780921b4b7522823cb77e17b823db
Pull Request resolved: #100128

Rm trash

Replace w/ update

Breaks
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Apr 27, 2023
Finish clone map

Lint

ghstack-source-id: 168818ad91dbfe9d6e9bb6d430408693e516f0fa
Pull Request resolved: #100128

Rm trash

Replace w/ update

Breaks

Feedback
…truth"


There's a longstanding, well known mutability bug in dynamo, #93610 (and more issues, but this is the one I had at hand). 

Ops that do in place mutation of tensors will mutate their corresponding FakeTensors. 

So, for example, if you do `t_` on a tensor, you will reverse its strides. This, in turn, means that the FakeTensors strides are now also reversed, say, if you are trying to torch.compile:

```
class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.t_()
                y = y.t_()
                return (x + y,)
```

However, we recently introduced accessing the fake_tensor memo/cache to get the symbolic shape values for sizes and strides during guard installation time. 

This means that tensors captured with a given size and stride, say, for x above, size:(3,3) stride:(3, 1), will get their memo updates to size(3, 3), stride(1, 3).  Now, whenever you access this value for anything, it reflects it's current state in the tracing, as opposed to the state at which we initially started tracing on.

This causes us to produce guards that are never valid, for the example above, that `x.stride()[0] == 3`. 

The solution is to not allow mutation to affect the fake tensors we use as source of truth here. We can do this by forcing a clone of the fake tensor at builder time, and storing that as the source of truth for our dynamic sizes and strides during guard installation.



cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Apr 27, 2023
Finish clone map

Lint

ghstack-source-id: 070f7080a6ab63fe25caf5ce612f97e3838b5be5
Pull Request resolved: #100128

Rm trash

Replace w/ update

Breaks

Feedback

Lint
@voznesenskym voznesenskym added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 27, 2023
@voznesenskym
Copy link
Collaborator Author

@pytorchbot rebase

@facebook-github-bot facebook-github-bot deleted the gh/voznesenskym/122/head branch June 8, 2023 19:00
int3 added a commit that referenced this pull request Nov 10, 2023
…o_sizes_strides with a weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 10, 2023
…with a weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 11, 2023
…s_strides with a weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 11, 2023
… weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 12, 2023
…s_strides with a weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 12, 2023
… weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 13, 2023
…113412)

Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

Pull Request resolved: #113412
Approved by: https://github.com/Skylion007, https://github.com/voznesenskym
ghstack dependencies: #113413, #113518, #113519
int3 added a commit that referenced this pull request Nov 14, 2023
…s_strides with a weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 14, 2023
… weak dict"


Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
#100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…ytorch#113412)

Spotted while working on getting output_graph.py to typecheck.

The type hint indicates that it was intended to be initialized with a
WeakIdKeyDictionary, but the actual runtime value was a regular dict.
Not sure if there's some kind of test we should add for this fix.

Looks like the code was originally added in
pytorch#100128.

Pull Request resolved: pytorch#113412
Approved by: https://github.com/Skylion007, https://github.com/voznesenskym
ghstack dependencies: pytorch#113413, pytorch#113518, pytorch#113519
jbschlosser added a commit that referenced this pull request Mar 6, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that forwards a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* Graph break for things in torch/nn/utils/parametrize.py. Before, Dynamo's trace rules would try to inline everything in torch.nn, so this is an exception to that behavior.
* Avoid introducing `call_module()` node when tracing the forward func of parametrized modules, preferring inlining. This is an exception to the general rule that modules in torch.nn should have `call_module()` nodes in the Dynamo graph for the forward func call.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 6, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* Graph break for things in torch/nn/utils/parametrize.py. Before, Dynamo's trace rules would try to inline everything in torch.nn, so this is an exception to that behavior.
* Avoid introducing `call_module()` node when tracing the forward func of parametrized modules, preferring inlining. This is an exception to the general rule that modules in torch.nn should have `call_module()` nodes in the Dynamo graph for the forward func call.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 6, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* Graph break for things in torch/nn/utils/parametrize.py. Before, Dynamo's trace rules would try to inline everything in torch.nn, so this is an exception to that behavior.
* Avoid introducing `call_module()` node when tracing the forward func of parametrized modules, preferring inlining. This is an exception to the general rule that modules in torch.nn should have `call_module()` nodes in the Dynamo graph for the forward func call.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 6, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* Graph break for things in torch/nn/utils/parametrize.py. Before, Dynamo's trace rules would try to inline everything in torch.nn, so this is an exception to that behavior.
* Avoid introducing `call_module()` node when tracing the forward func of parametrized modules, preferring inlining. This is an exception to the general rule that modules in torch.nn should have `call_module()` nodes in the Dynamo graph for the forward func call.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 6, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 6, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 7, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 26, 2024
…arametrization"


Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 26, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 26, 2024
…arametrization"


Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Mar 26, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Mar 26, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

Pull Request resolved: #121041
Approved by: https://github.com/anijain2305
pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

Pull Request resolved: #121041
Approved by: https://github.com/anijain2305
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants