-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Factor meta conversion into real tensor -> serializable metadata -> fake tensor; create fresh ShapeEnv when retracing #121085
Labels
module: dynamic shapes
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
williamwen42
added
the
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
label
Mar 6, 2024
ezyang
added a commit
that referenced
this issue
Mar 17, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 7c1391b20ac187f8337b2b10cd9c3c2e1d8e8cc3 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 17, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: a1fac70c60180d52ecffa611b6755428e85d7a28 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 18, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 695003be2d8e24d84ce6b487636e8cebd9e332d7 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 18, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 8e737b10c25c44769ce7ca2f95de3d2d9c32b6cd Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 18, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 14b2028f6ce92d9f299ea61666b471416adbc504 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 18, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 22ea030bc4342cf1c38835a82bf1d75cfc6cc833 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 18, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: ff09d52430a365d1954e35fe475b1dc26fdead75 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 18, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: e829a78185e1028bd8f06bb80539e8339387aa2a Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 18, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: bd9b8be52bd97524d2aa54221e9ed54d63b40e08 Pull Request resolved: #122044
pytorchmergebot
pushed a commit
that referenced
this issue
Mar 22, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: cd0bc57d6788b6f941ad4670a22fafb7c92f5dc0 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 24, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 82ecf0201ad28a8d74f2be3ab1fa338893dcad1b Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 24, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: c221c02d99c79617a7c5f63b36b8072e572ab253 Pull Request resolved: #122044
ezyang
added a commit
that referenced
this issue
Mar 25, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 2bfbfa4450be8f939e38aa1f35ddf9a594b3efb4 Pull Request resolved: #122044
pytorchmergebot
pushed a commit
that referenced
this issue
Mar 25, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #122044 Approved by: https://github.com/eellison
pytorch-bot bot
pushed a commit
that referenced
this issue
Apr 22, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #122044 Approved by: https://github.com/eellison ghstack dependencies: #122018
pytorch-bot bot
pushed a commit
that referenced
this issue
Apr 22, 2024
Fixes #121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted #121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #122044 Approved by: https://github.com/eellison
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: dynamic shapes
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
I want to try to solve three distinct problems at the same time:
The first technical capability we want to build is relatively simple. Today, when we convert real tensors to fake tensors, we do this process all in one go with some fairly complicated policy information from Dynamo; this means that to recreate an equivalent set of fake tensors (e.g., for the Dynamo to AOTAutograd handoff), we need to keep around the real tensors and extra policy information. This is annoyingly complicated (#114526) and doesn't work for minifier use case (you can't recreate fake tensors across process boundary). So instead, we slice the implementation of meta converter into a two step process. First, we extract out all of the information from real tensors / policy that we are going to use to actually create the fake tensors, and represent this in a well documented, complicated as necessary and serializable format. Then, the next phase actually creates fake tensors from it. This should be a mechanical refactor that is easy to test, and by forcing the second phase to use the serialized format, we ensure there is no extra backchannel for extra information we haven't serialized.
With this change, we now can reinitialize a new fake mode from scratch by replaying the second phase on the serializable metadata (solving problem 1), and we can save the metadata to disk and then restore the fake tensor in minifier (solving problem 2, modulo dynamic shapes).
Next, we also want to create a fresh ShapeEnv when we create a new fake tensor mode. Some of the ShapeEnv initialization information is naturally part of the serializable metadata for fake tensor metadata, as this information includes dynamic shape policy. The rest of the information is constituted by all of the internal state in ShapeEnv which is consulted by produce_guards; most notably value ranges, guards and deferred runtime asserts. It is important to realize that guards in the ShapeEnv are not necessarily implied by the output graph; for example, if a user performs a conditional in their Python program, that can induce a guard that is not reflected in the tensor compute at all.
Our current strategy for deferred runtime asserts is to reinsert them into the graph (done both by Dynamo as well as Export), at the point where the unbacked SymInt is allocated, because they reflect something you have to do at runtime and intuitively go in the graph. I could go either way for guards; also putting guards in the graph is nicely uniform, but they do not actually need to get run "in the graph" (they are run ahead of time) and eventually someone would want to optimize them out. Optimizing them out is not difficult for Inductor, however, and maybe this is a much more logical serialization format than a side-band FX graph (or worse, prints of Sympy expressions). In any case, by ensuring all of this metadata is serializable, we address the dynamic shapes part of (2). Furthermore, because we have a fresh ShapeEnv, this resolves (3), as it doesn't matter that we need to reallocate unbacked SymInts, of course we need to reallocate them, the old ones do not exist anymore.
One major downside to this proposal is compile time: we have to "do everything over again"; although we can ensure we only insert optimized guards (with redundant guards eliminated) into the graph, which should make later constraints easier. Note that we already "do everything over again" when we retrace the graph in AOTAutograd.
I got here thinking about #121079 and unfortunately, this plan doesn't completely solve the problem of re-propagating fake tensor metadata in inductor passes. At the Inductor level, there is no good use case for recreating the fake tensor environment, because at this point tensor metadata can no longer change (but see also @bdhirsh @jansel work on storage resize to zero, which unfortunately is making it to Inductor). But if we re-prop data dependent compute, we will end up with the reallocation problem again. @Chillee's incremental repropagator could potentially help, since we would be required not to repropagate item() calls. However, there is something fundamentally difficult about passes which want to operate on data-dependent compute and replace it with something else: if the new unbacked SymInt is truly not the same thing as the old one, then the pass is also responsible for propagating all old runtime asserts to the new one... this might not be so easy! Especially if the pass eliminated the unbacked SymInt entirely! We can probably deal with that when we get there though.
cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @avikchaudhuri @lezcano @tugsbayasgalan
Versions
main
The text was updated successfully, but these errors were encountered: