Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor meta conversion into real tensor -> serializable metadata -> fake tensor; create fresh ShapeEnv when retracing #121085

Closed
ezyang opened this issue Mar 2, 2024 · 0 comments
Assignees
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Mar 2, 2024

馃悰 Describe the bug

I want to try to solve three distinct problems at the same time:

  1. The motivation for Add Stateful/Stateless symbolic contexts, use fresh fake mode for dynamo backends (#113926)聽#114526 was that when we finish tracing in Dynamo, we need to reinitialize a new fake mode from scratch in order to "undo" any metadata mutation that may have occurred. This reinitialization process is fairly complicated, and extra complicated by the fact that we currently try to maintain the same ShapeEnv on inner passes.
  2. After dynamo minifier works substantially less well than after aot minifier, because there's a lot of complicated setup in the fake tensors and shape env which we do not have serializable in a way so that the reproducer can retrigger the problem. Similarly, dynamic shapes minifier has worked poorly because we have no way of serializing the state of the ShapeEnv.
  3. The unbacked SymInt reallocation problem is that when we retrace a graph with data-dependent compute, we reallocate unbacked SymInts for all of the data dependent compute. This ends up with duplicate copies of the unbacked SymInts when reuse the ShapeEnv.

The first technical capability we want to build is relatively simple. Today, when we convert real tensors to fake tensors, we do this process all in one go with some fairly complicated policy information from Dynamo; this means that to recreate an equivalent set of fake tensors (e.g., for the Dynamo to AOTAutograd handoff), we need to keep around the real tensors and extra policy information. This is annoyingly complicated (#114526) and doesn't work for minifier use case (you can't recreate fake tensors across process boundary). So instead, we slice the implementation of meta converter into a two step process. First, we extract out all of the information from real tensors / policy that we are going to use to actually create the fake tensors, and represent this in a well documented, complicated as necessary and serializable format. Then, the next phase actually creates fake tensors from it. This should be a mechanical refactor that is easy to test, and by forcing the second phase to use the serialized format, we ensure there is no extra backchannel for extra information we haven't serialized.

With this change, we now can reinitialize a new fake mode from scratch by replaying the second phase on the serializable metadata (solving problem 1), and we can save the metadata to disk and then restore the fake tensor in minifier (solving problem 2, modulo dynamic shapes).

Next, we also want to create a fresh ShapeEnv when we create a new fake tensor mode. Some of the ShapeEnv initialization information is naturally part of the serializable metadata for fake tensor metadata, as this information includes dynamic shape policy. The rest of the information is constituted by all of the internal state in ShapeEnv which is consulted by produce_guards; most notably value ranges, guards and deferred runtime asserts. It is important to realize that guards in the ShapeEnv are not necessarily implied by the output graph; for example, if a user performs a conditional in their Python program, that can induce a guard that is not reflected in the tensor compute at all.

Our current strategy for deferred runtime asserts is to reinsert them into the graph (done both by Dynamo as well as Export), at the point where the unbacked SymInt is allocated, because they reflect something you have to do at runtime and intuitively go in the graph. I could go either way for guards; also putting guards in the graph is nicely uniform, but they do not actually need to get run "in the graph" (they are run ahead of time) and eventually someone would want to optimize them out. Optimizing them out is not difficult for Inductor, however, and maybe this is a much more logical serialization format than a side-band FX graph (or worse, prints of Sympy expressions). In any case, by ensuring all of this metadata is serializable, we address the dynamic shapes part of (2). Furthermore, because we have a fresh ShapeEnv, this resolves (3), as it doesn't matter that we need to reallocate unbacked SymInts, of course we need to reallocate them, the old ones do not exist anymore.

One major downside to this proposal is compile time: we have to "do everything over again"; although we can ensure we only insert optimized guards (with redundant guards eliminated) into the graph, which should make later constraints easier. Note that we already "do everything over again" when we retrace the graph in AOTAutograd.

I got here thinking about #121079 and unfortunately, this plan doesn't completely solve the problem of re-propagating fake tensor metadata in inductor passes. At the Inductor level, there is no good use case for recreating the fake tensor environment, because at this point tensor metadata can no longer change (but see also @bdhirsh @jansel work on storage resize to zero, which unfortunately is making it to Inductor). But if we re-prop data dependent compute, we will end up with the reallocation problem again. @Chillee's incremental repropagator could potentially help, since we would be required not to repropagate item() calls. However, there is something fundamentally difficult about passes which want to operate on data-dependent compute and replace it with something else: if the new unbacked SymInt is truly not the same thing as the old one, then the pass is also responsible for propagating all old runtime asserts to the new one... this might not be so easy! Especially if the pass eliminated the unbacked SymInt entirely! We can probably deal with that when we get there though.

cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @avikchaudhuri @lezcano @tugsbayasgalan

Versions

main

@williamwen42 williamwen42 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 6, 2024
ezyang added a commit that referenced this issue Mar 13, 2024
Context: #121085

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: c532bc370dd0092a0120f7632be24b59c67dc570
Pull Request resolved: #121821
ezyang added a commit that referenced this issue Mar 17, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 7c1391b20ac187f8337b2b10cd9c3c2e1d8e8cc3
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 17, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: a1fac70c60180d52ecffa611b6755428e85d7a28
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 18, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 695003be2d8e24d84ce6b487636e8cebd9e332d7
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 18, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 8e737b10c25c44769ce7ca2f95de3d2d9c32b6cd
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 18, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 14b2028f6ce92d9f299ea61666b471416adbc504
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 18, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 22ea030bc4342cf1c38835a82bf1d75cfc6cc833
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 18, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: ff09d52430a365d1954e35fe475b1dc26fdead75
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 18, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: e829a78185e1028bd8f06bb80539e8339387aa2a
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 18, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: bd9b8be52bd97524d2aa54221e9ed54d63b40e08
Pull Request resolved: #122044
@ezyang ezyang self-assigned this Mar 19, 2024
pytorchmergebot pushed a commit that referenced this issue Mar 22, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: cd0bc57d6788b6f941ad4670a22fafb7c92f5dc0
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 24, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 82ecf0201ad28a8d74f2be3ab1fa338893dcad1b
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 24, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: c221c02d99c79617a7c5f63b36b8072e572ab253
Pull Request resolved: #122044
ezyang added a commit that referenced this issue Mar 25, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 2bfbfa4450be8f939e38aa1f35ddf9a594b3efb4
Pull Request resolved: #122044
pytorchmergebot pushed a commit that referenced this issue Mar 25, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #122044
Approved by: https://github.com/eellison
pytorch-bot bot pushed a commit that referenced this issue Apr 22, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #122044
Approved by: https://github.com/eellison
ghstack dependencies: #122018
pytorch-bot bot pushed a commit that referenced this issue Apr 22, 2024
Fixes #121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted #121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #122044
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants