Skip to content

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Dec 9, 2022

Stack from ghstack (oldest at bottom):

Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.

TODO: Delete all the tensor ref tracking code, it's unnecessary

Signed-off-by: Edward Z. Yang ezyang@fb.com

cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @chunyuan-w @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @desertfire

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 9, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90528

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit 8fc5d5a:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Dec 9, 2022
ezyang added a commit that referenced this pull request Dec 9, 2022
Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: f7afdb2
Pull Request resolved: #90528
global COUNTER
sympy_expr = Symbol(f"s{COUNTER}", positive=True, integer=True)
COUNTER += 1
sympy_expr.shape_env = self
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to fix this up: the problem is that Sympy deduplicates symbols that have the same name, and this is very confusing when there are multiple ShapeEnvs over different generations. Need to figure out proper way to impede this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just initialize a sympy with a starting int? Sympy then increments this as it works, and we read it out as we wrap up the frame.

Then, when you have "generation"-al systems like dynamo, the lifecycle is:

  • Enter frame
  • Make shape_env(counter_pos)
  • do stuff (shape_env increments counter_pos as it goes)
  • Exit frame, record counter_pos + 1 at shape_env end time
  • Enter frame
  • Make shape_env(counter_pos)
    ...
    etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main problem is that I want the number to reset every fresh frame, because I generally want my compilation to be indifferent to what other computation went on. So I can't have a generation counter either.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I misunderstood you, what is the desired behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get to number from 0 every iteration, but I have a distinct Symbol.

@@ -428,8 +428,16 @@ def wrapper(self, *args, **kwargs):

# This stub exists so we can easily add metadata to sympy symbols
class Symbol(sympy.Symbol):
__slots__: List[str] = []
__slots__: List[str] = ['snames', 'shape_env']
snames: List[str]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end, it was pretty useful for debugging purposes to see what snames produced a symbol!

Tensor as_strided_tensorimpl_meta_symint(const Tensor& self, SymIntArrayRef sym_size, SymIntArrayRef sym_stride, optional<c10::SymInt> sym_storage_offset_) {
auto sym_storage_offset = sym_storage_offset_.value_or(self.sym_storage_offset());
auto result = at::detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
setStrided(result, sym_size, sym_stride, sym_storage_offset);
setStridedUnchecked(result, sym_size, sym_stride, sym_storage_offset);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as_strided calls are the usual reason we generate a guard on a base symbol. Removing all the tests is probably the wrong thing to do, but we're also generating these guards for stupid reasons, e.g., check out this godforsaken guard

File "/data/users/ezyang/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 185, in wrap_with_proxy
    set_meta(proxy, e)
  File "/data/users/ezyang/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 131, in set_meta
    proxy.node.meta['val'] = snapshot_fake(val)
  File "/data/users/ezyang/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 121, in snapshot_fake    return val.detach()
  File "/data/users/ezyang/a/pytorch/torch/_subclasses/fake_tensor.py", line 896, in __torch_dispatch__    r = func(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_ops.py", line 285, in __call__
    return self._op(*args, **kwargs or {})
  File "/data/users/ezyang/a/pytorch/torch/_decomp/decompositions.py", line 1572, in nop_decomposition
    return aten.alias(x)
  File "/data/users/ezyang/a/pytorch/torch/_ops.py", line 500, in __call__
    return self._op(*args, **kwargs or {})
  File "/data/users/ezyang/a/pytorch/torch/_meta_registrations.py", line 1258, in meta_alias
    return self.view(self.shape)
  File "/data/users/ezyang/a/pytorch/torch/_refs/__init__.py", line 3935, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "/data/users/ezyang/a/pytorch/torch/_refs/__init__.py", line 3140, in _reshape_view_helper
    return prims.view_of(a)
  File "/data/users/ezyang/a/pytorch/torch/_ops.py", line 285, in __call__
    return self._op(*args, **kwargs or {})
  File "/data/users/ezyang/a/pytorch/torch/_prims/__init__.py", line 1782, in _view_of_meta
    return a.as_strided(a.shape, a.stride(), a.storage_offset())

I can't even... what the fuck lol

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf indeed

if not config.dynamic_shapes:
return None

expr_to_tensor_ref: Dict[sympy.Symbol, Dict[TensorReference, None]] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you even need TensorReference anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, i'm going to delete it

Comment on lines 701 to 704
except Exception:
# TODO: this is getting suppressed smh
logging.warning(f"failing guard allocated at {tb}")
raise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

holdup - today we got it to where we have no missing symbols, why are suppressing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the warning is not showing up for some reason when I trigger this locally haha

Comment on lines 709 to 711
expr_as_str = " and ".join(exprs)
code_parts.append(expr_as_str)
verbose_code_parts.append(expr_as_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit move these down, we want these last, I think, as invoking sympy is costly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not only is it costly, but it is wrong to access size before we know we have a tensor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, that too, I was trying to remember why I initially had it here.

Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.

TODO: Delete all the tensor ref tracking code, it's unnecessary

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 10, 2022
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.

TODO: Delete all the tensor ref tracking code, it's unnecessary

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Dec 10, 2022
Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 5e07c30
Pull Request resolved: #90528
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.

TODO: Delete all the tensor ref tracking code, it's unnecessary

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@github-actions github-actions bot added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Dec 10, 2022
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.

TODO: Delete all the tensor ref tracking code, it's unnecessary

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang chunyuan-w zhuhaozhe blzheng wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.

TODO: Delete all the tensor ref tracking code, it's unnecessary

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang chunyuan-w zhuhaozhe blzheng wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Dec 10, 2022
Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: fc7f6c9
Pull Request resolved: #90528
Instead of inferring shape mappings from a bunch of data structures that were plumbed in InstructionTranslator, we instead work out mappings by just iterating over the GraphArgs and mapping symbols to arguments as they show up. If multiple argument sizes/strides/offset map to the same symbol, this means they are duck sized, so we also generate extra equality tests that they must be equal. Finally, we generate 0/1 specialization guards. The resulting code is much shorter, and I think also easier to understand.

TODO: Delete all the tensor ref tracking code, it's unnecessary

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang chunyuan-w zhuhaozhe blzheng wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Dec 10, 2022
Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: f838a96
Pull Request resolved: #90528
@ezyang
Copy link
Contributor Author

ezyang commented Dec 10, 2022

@pytorchbot merge -f "previous ci was good, lint fix only"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Skylion007 added a commit to Skylion007/pytorch that referenced this pull request Dec 10, 2022
pytorchmergebot pushed a commit that referenced this pull request Dec 11, 2022
Fixes a minor I noticed in #90528 also a follow up to #89000. @ezyang

Pull Request resolved: #90630
Approved by: https://github.com/ezyang
pytorchmergebot pushed a commit that referenced this pull request Dec 13, 2022
Signed-off-by: Eli Uriegas <eliuriegas@meta.com>

Follow up to #90528

Fixes #90696
Pull Request resolved: #90704
Approved by: https://github.com/weiwangmeta, https://github.com/atalman, https://github.com/malfet
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1641/head branch June 8, 2023 16:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants