Skip to content

Commit

Permalink
Update on "Evaluate symexprs on load path of cache not write"
Browse files Browse the repository at this point in the history
When caching is enabled, an internal model fails with
```
assert_size_stride(bmm_9, (17, s0, 512), (54784, 512, 1))
AssertionError: expected size 17==17, stride 57344==54784 at dim=0
```
looking at this model, the exact problem is when the cache is hit on the forward graph, the generated code for backward fails since the strides of the outputs of forward, passed to backward as inputs, are not what we expected.

This PR changes the evaluation logic so that we defer evaluation of output stride exprs to load path as opposed to eagerly doing it on save path.

I have not been able to come up with a unit test repro for this problem.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
oulgen committed Jun 18, 2024
1 parent efa5da5 commit 7584d50
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,14 +567,12 @@ def compile_fx_inner(
if context is not None and context.output_strides is not None:
assert len(context.output_strides) == 0
shape_env = _shape_env_from_inputs(example_inputs)
V.graph.sizevars = SizeVarAllocator(shape_env)
sizevars = SizeVarAllocator(shape_env)
for exprs in compiled_graph.output_strides:
if exprs is None:
context.output_strides.append(None)
else:
context.output_strides.append(
tuple(V.graph.sizevars.size_hint(s) for s in exprs)
)
context.output_strides.append([sizevars.size_hint(s) for s in exprs])

if aot_mode:
return compiled_graph
Expand Down

0 comments on commit 7584d50

Please sign in to comment.