Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃悰 [Bug] Transformers BERT Model does not compile via FX Path #1673

Closed
gs-olive opened this issue Feb 15, 2023 · 7 comments
Closed

馃悰 [Bug] Transformers BERT Model does not compile via FX Path #1673

gs-olive opened this issue Feb 15, 2023 · 7 comments
Assignees
Labels
bug Something isn't working component: fx No Activity

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Feb 15, 2023

Bug Description

When compiling the BERT base uncased model via the FX path, the following error is encountered:

Via torchtrt.compile(model, ir="fx",...)
Traceback (most recent call last):
  File "bert.py", line 163, in <module>
    trt_mod = torchtrt.compile(traced, ir="fx", **compile_spec)
  File "~/TensorRT/py/torch_tensorrt/_compile.py", line 142, in compile
    return torch_tensorrt.fx.compile(
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 247, in <lambda>
    trace_func=lambda module, inputs: acc_tracer.trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 667, in trace
    traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 585, in rewriter_base_trace
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 309, in trace
    return super().trace(rewritten, concrete_args), rewritten
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "<eval_with_key>.1", line 9, in forward
TypeError: slice indices must be integers or None or have an __index__ method
Via dynamo_model = torch._dynamo.optimize(fx2trt_compiler)(model)

Note: dynamo_model(*inputs) must be called to cause model compilation and elicit the error.

Traceback (most recent call last):
  File "~/TensorRT/py/torch_tensorrt/fx/test/tracer/dynamo_backend.py", line 96, in fx2trt_compiler
    trt_compiled = fx2trt(gm, example_inputs, **kwargs_fx2trt)
  File "~/TensorRT/py/torch_tensorrt/fx/test/tracer/dynamo_backend.py", line 27, in fx2trt
    acc_model = acc_tracer.trace(model, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 681, in trace
    acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/shape_prop.py", line 185, in propagate
    return super().run(*args)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 63, in run_node
    result = self._run_node(n)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 43, in _run_node
    return super().run_node(n)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/shape_prop.py", line 152, in run_node
    raise RuntimeError(
RuntimeError: ShapeProp error for: node=%embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_ids, %self_embeddings_word_embeddings_weight), kwargs = {padding_idx: 0, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) with meta={}
Via torch_tensorrt.fx.compile(..., is_aten=True,...)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
    trace_func=lambda module, inputs: aten_tracer.opt_trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 153, in opt_trace
    pr: PassResult = passes(fx_module)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 430, in compose_bmm
    if len(real_other.meta["val"].size()) == 2:
KeyError: 'val'

To Reproduce

Steps to reproduce the behavior:

  1. Initialize model: BertModel.from_pretrained("bert-base-uncased").eval().cuda()
  2. Initialize two input tensors, for example: torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
  3. (Optional) Use the transformers tools to trace the model via: transformers.utils.fx.symbolic_trace(model, input_names=["input_ids", "attention_mask"])
  4. Compile the model using FX

Expected behavior

Model should compile via the FX path

Environment

  • Transformers: 4.26.1
  • Torch-TensorRT Version (e.g. 1.0.0): a219e05
  • PyTorch Version (e.g. 1.0): 2.0.0.dev20230209+cu117
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.7

Additional context

Relevant to Issue #1634 and PR #1648, which intend to develop 1:1 parity between FX and TS model compatibility tests.

@gs-olive gs-olive added bug Something isn't working component: fx labels Feb 15, 2023
@frank-wei
Copy link
Contributor

Nice try!

  1. direct FX I tried last year. It has problem here and there since the fx tracing would meet different corner cases
  2. using dynamo would work around the tracing issue that fx facing. One workaround is using leaf module to tag the embedding as leaf_module[https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/fx/lower_setting.py#L37]
  3. It is a bug that I reported to PT2 a few days. They just fixed it and you can try again with today's nightly.

@gs-olive
Copy link
Collaborator Author

Thanks for the detailed suggestions! To follow up, I have tried 2 and 3, and 2 still fails with the same error despite adding leaf_module_list={torch.nn.Embedding}, and 3, using the Feb 27 Torch 2.0 nightly now exhibits a different error message:

  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 157, in opt_trace
    pr: PassResult = passes(fx_module)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 439, in compose_bmm
    new_func,
UnboundLocalError: local variable 'new_func' referenced before assignment

For additional context, I have been using this symbolic tracer to generate an FX module before passing the model into the FX compiler, since without this, the model fails with the error:

torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(_lru_cache_wrapper) [] {}
...
RuntimeError: The user code is using a feature we don't support.

@gs-olive gs-olive self-assigned this Feb 28, 2023
@frank-wei
Copy link
Contributor

frank-wei commented Feb 28, 2023

  1. To clarify here, dynamo tracing is done. The graph is passed to acc tracer. It failed in shapeprop which is weird. shapepro is just let the input flow through the graph with eager mode computation. So I am wondering if the input is correct? Can you make sure eager mode like model(*input) is in good state?
  2. good news is aten tracing is good where you can print out here https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py#L151
    The error is due to the potential bug in https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py#L138 where you can try to comment out first.

@gs-olive
Copy link
Collaborator Author

gs-olive commented Mar 2, 2023

Thank you for your help with this issue. For 2, I verified that model(*input) is functional despite the displayed error. For 3, I have made some changes to the lowering passes in lower_basic_pass_aten.py regarding the view and bmm operators, which result in successful compilation. These additions can be found in PR #1708. I noticed that the torchdynamo.export call within the aten tracer is inserting the bmm operations which cause issues in the lowering pass (see the PR for additional details).

Note that the model is pre-traced using the transformers symbolic tracing utility, as FX tracing otherwise fails.

Script to Reproduce
import torch
import torch_tensorrt
import torchvision
from transformers import BertModel
from transformers.utils.fx import symbolic_trace as transformers_trace

kwargs = {
    "use_cache"           : False,
    "output_attentions"   : False,
    "output_hidden_states": False,
}
model = BertModel.from_pretrained("bert-base-uncased", cache_dir="./", **kwargs).eval().cuda()

input0 = torch.randint(0, 1, (1, 14), dtype=torch.int32).cuda()
input1 = torch.randint(0, 1, (1, 14), dtype=torch.int32).cuda()

traced = transformers_trace(model, input_names=["input_ids", "attention_mask"])

fx_trt_model = torch_tensorrt.fx.compile(traced, [input0, input1],
                                    min_acc_module_size=3,
                                    lower_precision=torch_tensorrt.fx.utils.LowerPrecision.FP32,
                                    explicit_batch_dimension=True, is_aten=True, dynamic_batch=False)

out = fx_trt_model(input0, input1)
Supported/Unsupported aten Operators
Supported node types in the model:
torch.ops.aten.sym_size: ((torch.float32,), {})
torch.ops.aten.sym_size: ((torch.int32,), {})
torch.ops.aten.mul.Tensor: ((torch.float32,), {})
_operator.add: ((), {})
torch.ops.aten.add.Tensor: ((torch.float32, torch.float32), {})
torch.ops.aten.linear: ((torch.float32, torch.float32, torch.float32), {})
torch.ops.aten.div.Tensor: ((torch.float32,), {})

Unsupported node types in the model:
torch_tensorrt.fx.passes.lower_basic_pass_aten.aten_compose_getitem_slice: ((torch.int64,), {})
torch.ops.aten.expand.default: ((torch.int64,), {})
torch.ops.aten.slice.Tensor: ((torch.float32,), {})
torch.ops.aten.slice.Tensor: ((torch.int32,), {})
torch.ops.aten.unsqueeze.default: ((torch.int32,), {})
torch.ops.aten._to_copy.default: ((torch.int32,), {})
torch.ops.aten.rsub.Scalar: ((torch.float32,), {})
torch.ops.aten.embedding.default: ((torch.float32, torch.int64), {})
torch.ops.aten.embedding.default: ((torch.float32, torch.int32), {})
torch.ops.aten.layer_norm.default: ((torch.float32, None, torch.float32, torch.float32), {})
torch.ops.aten.reshape: ((torch.float32,), {})
torch.ops.aten.permute.default: ((torch.float32,), {})
torch.ops.aten.transpose.int: ((torch.float32,), {})
torch.ops.aten.matmul: ((torch.float32, torch.float32), {})
torch.ops.aten._softmax.default: ((torch.float32,), {})
torch.ops.aten.gelu.default: ((torch.float32,), {})
torch.ops.aten.select.int: ((torch.float32,), {})
torch.ops.aten.tanh.default: ((torch.float32,), {})

@gs-olive
Copy link
Collaborator Author

gs-olive commented Mar 14, 2023

Hi @frank-wei - thanks for your suggestions on dealing with the BERT compilation issues. I attempted compilation using the latest Torch version (2.1.0.dev20230312+cu117), both with and without the HuggingFace symbolic tracer (Pre-Traced / NOT Pre-Traced below). I have provided the results here:

1. NOT Pre-Traced torch._dynamo.optimize(fx2trt, nopython=True)(model) + EITHER tracer
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(_lru_cache_wrapper) [] {}

from user code:
   File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py", line 735, in dtype
    return get_parameter_dtype(self)
  File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py", line 192, in get_parameter_dtype
    if is_torch_tpu_available():
2. Pre-Traced torch._dynamo.optimize(fx2trt, nopython=True)(model) + ACC tracer

Note: This was created with leaf_module_list={torch.nn.Embedding}

  File "/usr/local/lib/python3.8/dist-packages/torch/_subclasses/fake_tensor.py", line 1282, in validate
    raise Exception(
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
       device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., device='meta', size=(1, 14), dtype=torch.int32), cuda:0), 0), **{}) 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "~/trtre-models/collections_testing/dynamo_compiler.py", line 128, in fx2trt_compiler
    trt_compiled = fx2trt(gm, example_inputs, **kwargs_fx2trt)
  File "~/trtre-models/collections_testing/dynamo_compiler.py", line 53, in fx2trt
    acc_model = acc_tracer.trace(model, inputs, leaf_module_list={torch.nn.Embedding})
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 681, in trace
    acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/shape_prop.py", line 185, in propagate
    return super().run(*args)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 137, in run
    self.env[node] = self.run_node(node)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 63, in run_node
    result = self._run_node(n)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py", line 43, in _run_node
    return super().run_node(n)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/shape_prop.py", line 152, in run_node
    raise RuntimeError(
RuntimeError: ShapeProp error for: node=%self_embeddings_word_embeddings : [#users=1] = call_module[target=self_embeddings_word_embeddings](args = (%input_ids,), kwargs = {}) with meta={}

While executing %self_embeddings_word_embeddings : [#users=1] = call_module[target=self_embeddings_word_embeddings](args = (%input_ids,), kwargs = {})
Original traceback:
None
[2023-03-14 00:35:29,946] torch._dynamo.output_graph: [INFO] Step 2: done compiler function fx2trt_compiler
FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead
3. Pre-Traced torch._dynamo.optimize(fx2trt, nopython=True)(model) + ATEN tracer
[2023-03-14 00:23:15,697] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function fx2trt_compiler
Traceback (most recent call last):
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 110, in dynamo_trace
    return torchdynamo.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 706, in export
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 118, in __call__
    return self.dynamo_ctx(self._orig_mod.__call__)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 254, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 281, in __call__
    raise e
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 391, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 105, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 263, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 326, in _compile
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 300, in transform
    tracer = InstructionTranslator(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1785, in __init__
    self.symbolic_locals = collections.OrderedDict(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1788, in <genexpr>
    VariableBuilder(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/builder.py", line 174, in __call__
    return self._wrap(value).clone(**self.options())
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/builder.py", line 300, in _wrap
    return type_dispatch(self, value)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/builder.py", line 742, in wrap_tensor
    assert type(value) in (torch.Tensor, torch.nn.Parameter)
AssertionError: 


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "~/trtre-models/collections_testing/dynamo_compiler.py", line 128, in fx2trt_compiler
    trt_compiled = fx2trt(gm, example_inputs, **kwargs_fx2trt)
  File "~/trtre-models/collections_testing/dynamo_compiler.py", line 54, in fx2trt
    # acc_model = aten_tracer.opt_trace(model, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 150, in opt_trace
    fx_module, _ = trace(f, args)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 128, in trace
    graph_module, guards = dynamo_trace(f, args, True, "symbolic")
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 122, in dynamo_trace
    raise RuntimeError(
RuntimeError: torchdynamo internal error occured. Please see above stacktrace
[2023-03-14 00:23:15,737] torch._dynamo.output_graph: [INFO] Step 2: done compiler function fx2trt_compiler
FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 326, in _compile
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 313, in transform
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1841, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 597, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 560, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1920, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 569, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 615, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 167, in time_wrapper
    compilation_metrics[key].append(time_spent)
KeyError: OutputGraph.call_user_compiler

from user code:
   File "<eval_with_key>.0", line 527, in forward
    pooler_activation = self.pooler.activation(pooler_dense);  pooler_dense = None
4. Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...)

See #1673 (comment) and PR #1708 which makes this path work.

5. NOT Pre-Traced torch_tensorrt.fx.compile(model, is_aten=True,...)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(_lru_cache_wrapper) [] {}

from user code:
   File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py", line 735, in dtype
    return get_parameter_dtype(self)
  File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py", line 192, in get_parameter_dtype
    if is_torch_tpu_available():
6. Pre-Traced torch_tensorrt.fx.compile(model, is_aten=False,...)
Traceback (most recent call last):
  File "bert.py", line 163, in <module>
    trt_mod = torchtrt.compile(traced, ir="fx", **compile_spec)
  File "~/TensorRT/py/torch_tensorrt/_compile.py", line 142, in compile
    return torch_tensorrt.fx.compile(
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 247, in <lambda>
    trace_func=lambda module, inputs: acc_tracer.trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 667, in trace
    traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 585, in rewriter_base_trace
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py", line 309, in trace
    return super().trace(rewritten, concrete_args), rewritten
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "<eval_with_key>.1", line 9, in forward
TypeError: slice indices must be integers or None or have an __index__ method
7. NOT Pre-Traced torch_tensorrt.fx.compile(model, is_aten=False,...)

Symbolic trace fails due to control flow

Not pre-tracing leads to some issues with unsupported operators/functions using aten, and leads to control flow issues with acc. Thus far, the only path of the above for which the error source is Dynamo itself (prior to the fx2trt compiler invocation), are 1 and 5. In all of the other cases, the error arises in: Step 2: calling compiler function fx2trt. Please let me know what you think of the solution proposed in 4, and whether this could be a viable approach. It is surprising to me that 4 works (with the PR) but 3 does not.

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@narendasan
Copy link
Collaborator

Now works in the dynamo frontends

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: fx No Activity
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants