-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NestedTensor] Graph breaks with SDPA + NT constructor #126472
Comments
I ran into a similar error, which prompted the fix in the linked PR. FWIW I was able to get your repro working without graph breaks using a combination of:
|
Note: the This doesn't always happen - but the easy way to force this to repro is to just return the NJT. |
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Based on #130292; playing around with different designs, etc. **Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`). **Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up. **This PR:** * During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details: * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`. * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv. * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry. * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now). * Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates: * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager. * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
🐛 Describe the bug
When we use SDPA, we need max_seqlen and min_seqlen. Getting max/min_seqlen normally requires a .item call (which usually graph breaks, I think?).
So this focuses on removing graph breaks where:
General repro - the approach is to call
nested_view_from_values_offsets_lengths
withmax_seqlen
andmin_seqlen
passed in:Failure 1: With #122836 (rebased onto 7f1d5ab)
Failure 2: Based on the failure, I tried with @soulitzer's PR #124624 patched on top:
Failure 3: Based on this, I tried a quick patch: this change
I haven't gotten around to investigating this yet. Maybe #126198 is related (just based on unbacked symint <-> NT issues).
Failure 4: One other attempt - I figured I'd try #124803 to see if it would fix the issue without unbacked symint issues, but it runs into other issues where we get multiple NestedInts for the same dimension. (So we should probably just go with #124624 and figure out what the unbacked symint issue is about)
Versions
Described above - but these were all built on 7f1d5ab for H100
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer
The text was updated successfully, but these errors were encountered: