-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Completely redo how ShapeEnv guards are generated #90528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90528
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit 8fc5d5a: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
global COUNTER | ||
sympy_expr = Symbol(f"s{COUNTER}", positive=True, integer=True) | ||
COUNTER += 1 | ||
sympy_expr.shape_env = self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to fix this up: the problem is that Sympy deduplicates symbols that have the same name, and this is very confusing when there are multiple ShapeEnvs over different generations. Need to figure out proper way to impede this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just initialize a sympy with a starting int? Sympy then increments this as it works, and we read it out as we wrap up the frame.
Then, when you have "generation"-al systems like dynamo, the lifecycle is:
- Enter frame
- Make shape_env(counter_pos)
- do stuff (shape_env increments counter_pos as it goes)
- Exit frame, record counter_pos + 1 at shape_env end time
- Enter frame
- Make shape_env(counter_pos)
...
etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main problem is that I want the number to reset every fresh frame, because I generally want my compilation to be indifferent to what other computation went on. So I can't have a generation counter either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I think I misunderstood you, what is the desired behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get to number from 0 every iteration, but I have a distinct Symbol.
@@ -428,8 +428,16 @@ def wrapper(self, *args, **kwargs): | |||
|
|||
# This stub exists so we can easily add metadata to sympy symbols | |||
class Symbol(sympy.Symbol): | |||
__slots__: List[str] = [] | |||
__slots__: List[str] = ['snames', 'shape_env'] | |||
snames: List[str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the end, it was pretty useful for debugging purposes to see what snames produced a symbol!
Tensor as_strided_tensorimpl_meta_symint(const Tensor& self, SymIntArrayRef sym_size, SymIntArrayRef sym_stride, optional<c10::SymInt> sym_storage_offset_) { | ||
auto sym_storage_offset = sym_storage_offset_.value_or(self.sym_storage_offset()); | ||
auto result = at::detail::make_tensor<TensorImpl>( | ||
c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); | ||
setStrided(result, sym_size, sym_stride, sym_storage_offset); | ||
setStridedUnchecked(result, sym_size, sym_stride, sym_storage_offset); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as_strided calls are the usual reason we generate a guard on a base symbol. Removing all the tests is probably the wrong thing to do, but we're also generating these guards for stupid reasons, e.g., check out this godforsaken guard
File "/data/users/ezyang/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 185, in wrap_with_proxy
set_meta(proxy, e)
File "/data/users/ezyang/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 131, in set_meta
proxy.node.meta['val'] = snapshot_fake(val)
File "/data/users/ezyang/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 121, in snapshot_fake return val.detach()
File "/data/users/ezyang/a/pytorch/torch/_subclasses/fake_tensor.py", line 896, in __torch_dispatch__ r = func(*args, **kwargs)
File "/data/users/ezyang/a/pytorch/torch/_ops.py", line 285, in __call__
return self._op(*args, **kwargs or {})
File "/data/users/ezyang/a/pytorch/torch/_decomp/decompositions.py", line 1572, in nop_decomposition
return aten.alias(x)
File "/data/users/ezyang/a/pytorch/torch/_ops.py", line 500, in __call__
return self._op(*args, **kwargs or {})
File "/data/users/ezyang/a/pytorch/torch/_meta_registrations.py", line 1258, in meta_alias
return self.view(self.shape)
File "/data/users/ezyang/a/pytorch/torch/_refs/__init__.py", line 3935, in view
return _reshape_view_helper(a, *shape, allow_copy=False)
File "/data/users/ezyang/a/pytorch/torch/_refs/__init__.py", line 3140, in _reshape_view_helper
return prims.view_of(a)
File "/data/users/ezyang/a/pytorch/torch/_ops.py", line 285, in __call__
return self._op(*args, **kwargs or {})
File "/data/users/ezyang/a/pytorch/torch/_prims/__init__.py", line 1782, in _view_of_meta
return a.as_strided(a.shape, a.stride(), a.storage_offset())
I can't even... what the fuck lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wtf indeed
if not config.dynamic_shapes: | ||
return None | ||
|
||
expr_to_tensor_ref: Dict[sympy.Symbol, Dict[TensorReference, None]] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you even need TensorReference anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, i'm going to delete it
torch/_dynamo/guards.py
Outdated
except Exception: | ||
# TODO: this is getting suppressed smh | ||
logging.warning(f"failing guard allocated at {tb}") | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
holdup - today we got it to where we have no missing symbols, why are suppressing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the warning is not showing up for some reason when I trigger this locally haha
torch/_dynamo/guards.py
Outdated
expr_as_str = " and ".join(exprs) | ||
code_parts.append(expr_as_str) | ||
verbose_code_parts.append(expr_as_str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit move these down, we want these last, I think, as invoking sympy is costly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not only is it costly, but it is wrong to access size before we know we have a tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, that too, I was trying to remember why I initially had it here.
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand. TODO: Delete all the tensor ref tracking code, it's unnecessary Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand. TODO: Delete all the tensor ref tracking code, it's unnecessary Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand. TODO: Delete all the tensor ref tracking code, it's unnecessary Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand. TODO: Delete all the tensor ref tracking code, it's unnecessary Signed-off-by: Edward Z. Yang <ezyangfb.com> cc gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang chunyuan-w zhuhaozhe blzheng wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand. TODO: Delete all the tensor ref tracking code, it's unnecessary Signed-off-by: Edward Z. Yang <ezyangfb.com> cc gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang chunyuan-w zhuhaozhe blzheng wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand. TODO: Delete all the tensor ref tracking code, it's unnecessary Signed-off-by: Edward Z. Yang <ezyangfb.com> cc gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang chunyuan-w zhuhaozhe blzheng wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
@pytorchbot merge -f "previous ci was good, lint fix only" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Signed-off-by: Eli Uriegas <eliuriegas@meta.com> Follow up to #90528 Fixes #90696 Pull Request resolved: #90704 Approved by: https://github.com/weiwangmeta, https://github.com/atalman, https://github.com/malfet
Stack from ghstack (oldest at bottom):
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.
TODO: Delete all the tensor ref tracking code, it's unnecessary
Signed-off-by: Edward Z. Yang ezyang@fb.com
cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @chunyuan-w @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @desertfire