Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Transformers T5 Model does not compile via FX Path #1740

Closed
gs-olive opened this issue Mar 16, 2023 · 1 comment
Closed

🐛 [Bug] Transformers T5 Model does not compile via FX Path #1740

gs-olive opened this issue Mar 16, 2023 · 1 comment
Labels
bug Something isn't working component: fx No Activity

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Mar 16, 2023

Bug Description

When compiling the T5-Base Model model via the FX path, the following error is encountered. Note the model can be pre-traced using the HuggingFace symbolic tracer (Pre-Traced / NOT Pre-Traced below).

NOT Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...)
[2023-03-16 00:51:26,256] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-03-16 00:51:40,315] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2023-03-16 00:51:40,640] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function dynamo_normalization_capturing_compiler
[2023-03-16 00:51:40,640] torch._dynamo.output_graph: [INFO] Step 2: done compiler function dynamo_normalization_capturing_compiler
[2023-03-16 00:51:41,904] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __getitem__
[2023-03-16 00:51:41,911] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing __getitem__ (RETURN_VALUE)
Traceback (most recent call last):
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 116, in dynamo_trace
    return torchdynamo.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 706, in export
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 118, in __call__
    return self.dynamo_ctx(self._orig_mod.__call__)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 254, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/t5/modeling_t5.py", line 1395, in forward
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 391, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 105, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 263, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 383, in _compile
    hooks.guard_export_fn(output.guards)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 670, in guard_export_print
    assert out_guards is None, "whole graph export entails exactly one guard export"
AssertionError: whole graph export entails exactly one guard export
Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
    trace_func=lambda module, inputs: aten_tracer.opt_trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 159, in opt_trace
    pr: PassResult = passes(fx_module)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 447, in compose_bmm
    new_func,
UnboundLocalError: local variable 'new_func' referenced before assignment
Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...) + PR #1708
Got 5 acc subgraphs and 6 non-acc subgraphs
Traceback (most recent call last):
  File "case_dict.py", line 217, in <module>
    T5MODEL()
  File "case_dict.py", line 135, in T5MODEL
    fx_trt_model = torch_tensorrt.fx.compile(traced, [input_ids, attention_mask, decoder_input_ids],
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 202, in lower_func
    lowered_module = self._lower_func(
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 178, in lower_pass
    interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 130, in __call__
    interp_result: TRTInterpreterResult = interpreter.run(
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 204, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 137, in run
    self.env[node] = self.run_node(node)
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 275, in run_node
    trt_node = super().run_node(n)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 179, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 328, in call_function
    return converter(self.network, target, args, kwargs, self._cur_node_name)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py", line 57, in aten_ops_adaptive_avg_poolnd
    raise RuntimeError(f"We do not support {target} has dim={args[1]}")
RuntimeError: We do not support aten.mean.dim has dim=[-1]

While executing %mean_dim : [#users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_tensor_scalar, [-1], True), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f0c8007ccf0>: None, <tensorrt.tensorrt.ITensor object at 0x7f0c803853b0>: ((1, 1, 1, 14), torch.float32, False, (14, 14, 14, 1), torch.channels_last, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f0c80354630>: None, <tensorrt.tensorrt.ITensor object at 0x7f0c80385a30>: ((1, 14, 768), torch.float32, True, (10752, 768, 1), torch.contiguous_format, False, {})}})
Original traceback:
  File "<eval_with_key>.0", line 24, in forward
    mean = pow_1.mean(-1, keepdim = True);  pow_1 = None

To Reproduce

Steps to reproduce the behavior:

  1. Initialize model: T5Model.from_pretrained("t5-base").eval().cuda()
  2. Initialize three input tensors, for example: torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") ("input_ids", "attention_mask", "decoder_input_ids")
  3. (Optional) Use the transformers tools to trace the model via: transformers.utils.fx.symbolic_trace(model, input_names=["input_ids", "attention_mask", "decoder_input_ids"])
  4. Compile the model using FX

Expected behavior

Model should compile via the FX path

Environment

  • Transformers: 4.26.1
  • Torch-TensorRT Version (e.g. 1.0.0): fce0a01
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230313+cu117
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.7
@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: fx No Activity
Projects
None yet
Development

No branches or pull requests

2 participants