Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual Installation of Jinja (for wheel env) or networkx (for conda) Package is Required to Use Torch.Compile #95671

Closed
weiwangmeta opened this issue Feb 28, 2023 · 7 comments

Comments

@weiwangmeta
Copy link
Contributor

weiwangmeta commented Feb 28, 2023

🐛 Describe the bug

While testing v2.0.0 release candidates and nightlies, using
pip install torch torchvision torchAudio --index-url https://download.pytorch.org/whl/nightly/cu117 or

and resnet18 code snippet in #95223

The code will produce the following error: see Error Logs section

pytorch/setup.py

Line 1035 in b818b3f

extras_require['dynamo'] = ['pytorch-triton==2.0.0+' + triton_pin[:10], 'jinja2']
tells us that jinja2 would only be installed if torch is installed like "pip install torch[dynamo]". However, our recommended installation command is usually without the "[dynamo]".

Should we resolve this issue? Or we do expect users to do an extra step of "pip install jinja2"?

cc @ezyang @gchanan @zou3519 @soumith @msaroufim @wconstab @ngimel @bdhirsh @malfet @ng

Error logs

[2023-02-28 00:33:39,306] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/graph.py", line 342, in call_function
out = lowerings[target](*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/lowering.py", line 226, in wrapped
out = decomp_fn(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/kernel/mm.py", line 133, in tuned_addmm
mm_template.generate(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 349, in generate
assert self.template, "requires jinja2"
AssertionError: requires jinja2
Traceback (most recent call last):
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/graph.py", line 342, in call_function
out = lowerings[target](*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/lowering.py", line 226, in wrapped
out = decomp_fn(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/kernel/mm.py", line 133, in tuned_addmm
mm_template.generate(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 349, in generate
assert self.template, "requires jinja2"
AssertionError: requires jinja2

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/dynamo/output_graph.py", line 708, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/init.py", line 1393, in call
return compile_fx(model
, inputs
, config_patches=self.config)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2133, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
return inner_compile(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 176, in compile_fx_inner
graph.run(*example_inputs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/graph.py", line 203, in run
return super().run(*args)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/graph.py", line 421, in run_node
result = super().run_node(n)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/fx/interpreter.py", line 177, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_inductor/graph.py", line 346, in call_function
raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: AssertionError: requires jinja2
target: aten.addmm.default
args[0]: TensorBox(StorageBox(
InputBuffer(name='primals_62', layout=FixedLayout('cuda', torch.float32, size=[1000], stride=[1]))
))
args[1]: TensorBox(StorageBox(
ComputedBuffer(name='buf183', layout=FixedLayout('cuda', torch.float32, size=(16, 512), stride=[512, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf182, i1 + 512 * i0)
tmp1 = index_expr(49, torch.float32)
tmp2 = tmp0 / tmp1
return tmp2
,
ranges=(16, 512),
origins={view}
))
))
args[2]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_61', layout=FixedLayout('cuda', torch.float32, size=[1000, 512], stride=[512, 1]))
),
FixedLayout('cuda', torch.float32, size=[512, 1000], stride=[1, 512]),
origins=
)
)

While executing %addmm : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_62, %view, %permute), kwargs = {})
Original traceback:
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torchvision/models/resnet.py", line 280, in _forward_impl
x = self.fc(x)
| File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torchvision/models/resnet.py", line 285, in forward
return self._forward_impl(x)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/weiwangmeta/rn18.py", line 10, in
out = compiled_model(x)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in call
return self.dynamo_ctx(self._orig_mod.call)(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 215, in _fn
return fn(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 343, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
transformations(instructions, code_options)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in run
super().run()
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 619, in run
and self.step()
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 583, in step
getattr(self, inst.opname)(inst)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1941, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 555, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 626, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 713, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: AssertionError: requires jinja2
target: aten.addmm.default
args[0]: TensorBox(StorageBox(
InputBuffer(name='primals_62', layout=FixedLayout('cuda', torch.float32, size=[1000], stride=[1]))
))
args[1]: TensorBox(StorageBox(
ComputedBuffer(name='buf183', layout=FixedLayout('cuda', torch.float32, size=(16, 512), stride=[512, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf182, i1 + 512 * i0)
tmp1 = index_expr(49, torch.float32)
tmp2 = tmp0 / tmp1
return tmp2
,
ranges=(16, 512),
origins={view}
))
))
args[2]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_61', layout=FixedLayout('cuda', torch.float32, size=[1000, 512], stride=[512, 1]))
),
FixedLayout('cuda', torch.float32, size=[512, 1000], stride=[1, 512]),
origins=
)
)

While executing %addmm : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_62, %view, %permute), kwargs = {})
Original traceback:
File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torchvision/models/resnet.py", line 280, in _forward_impl
x = self.fc(x)
| File "/home/weiwangmeta/.conda/envs/clean-nightly-20230226/lib/python3.9/site-packages/torchvision/models/resnet.py", line 285, in forward
return self._forward_impl(x)

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True

Minified repro

conda create -n test_nightly python=3.10
conda activate test_nightly

python resnet18.py (where resnet18.py is the above code snippet)

Versions

nightly up to 02/27/2023
v2.0.0-rc up to rc2 and present in release/2.0 top commits as well.

@weiwangmeta
Copy link
Contributor Author

pip install torch torchvision torchAudio --index-url https://download.pytorch.org/whl/test/cu117 does not seem to reproduce the issue, only pip install torch torchvision torchAudio --index-url https://download.pytorch.org/whl/nightly/cu117 can, not sure why.

@weiwangmeta weiwangmeta changed the title Manual Installation of Jinja Package is Required to Use Torch.Compile Manual Installation of Jinja (for wheel env) or networkx (for conda) Package is Required to Use Torch.Compile Feb 28, 2023
@weiwangmeta
Copy link
Contributor Author

weiwangmeta commented Feb 28, 2023

For conda env:

conda create -n test_rc2_test_conda python=3.10
conda activate test_rc2_test_conda
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch-test -c nvidia
python rn18.py

failed with:

torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised RuntimeError: Need networkx installed to perform smart recomputation heuristics

Verified that conda nightly binaries would fail in the same way (need networkx)

So below is what we got so far:

                nightly                         RC2 (as of 02/28/2023) 

wheel needs jinja2 Ok

conda needs networkx needs networkx

@weiwangmeta
Copy link
Contributor Author

For conda nightly: after "pip install networkx", it fails similarly to the wheels (nightly)
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: AssertionError: requires jinja2

@weiwangmeta
Copy link
Contributor Author

For conda test channel (i.e. RC2), after "pip install networkx", the error disappeared (similar to the wheels case).

So the jinja2 issue might be specific to the master branch. While for conda, it is indeed RC2 issue.

@weiwangmeta
Copy link
Contributor Author

To clarify, pytorch/builder#1327 fixes the conda part of the issue.

@ng
Copy link

ng commented Mar 1, 2023 via email

@atalman
Copy link
Contributor

atalman commented Mar 2, 2023

validated. this is fixed

@atalman atalman reopened this Mar 2, 2023
@atalman atalman closed this as completed Mar 2, 2023
atalman pushed a commit to atalman/pytorch that referenced this issue Mar 9, 2023
Should fix pytorch#95671  for nightly wheels issue. v2.0.0 RC does not need this.
Pull Request resolved: pytorch#95691
Approved by: https://github.com/malfet
atalman added a commit that referenced this issue Mar 9, 2023
Should fix #95671  for nightly wheels issue. v2.0.0 RC does not need this.
Pull Request resolved: #95691
Approved by: https://github.com/malfet

Co-authored-by: Wei Wang <weiwangmeta@meta.com>
pruthvistony pushed a commit to ROCm/pytorch that referenced this issue May 3, 2023
Should fix pytorch#95671  for nightly wheels issue. v2.0.0 RC does not need this.
Pull Request resolved: pytorch#95691
Approved by: https://github.com/malfet

Co-authored-by: Wei Wang <weiwangmeta@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants