-
Notifications
You must be signed in to change notification settings - Fork 689
Description
🚀 The feature, motivation and pitch
The <<
(declaration in the ATen dialect: aten::__lshift__.Scala(Tensor self, Scalar other)
) and >>
(declaration in the ATen dialect: aten::__rshift__.Scala(Tensor self, Scala other)
) are supported as dunder methods in PyTorch since this MR. But currently, this is not supported in the Edge dialect (hopefully I've used those concepts correctly). By exporting a model as follows:
class ToyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x) -> torch.Tensor:
return x << 2
The default to_out_variant
pass gives the following error:
Traceback (most recent call last):
File "/nfs/home/zizhang/workspace/convert_litert/torch_lshift.py", line 29, in <module>
executorch_dialect = backend_dialect.to_executorch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 101, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 1400, in to_executorch
new_gm_res = p(new_gm)
^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/passes/infra/pass_base.py", line 44, in __call__
res = self.call(graph_module)
^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/passes/__init__.py", line 370, in call
out_var_target = target.to_out_variant()
^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/dialects/edge/_ops.py", line 322, in to_out_variant
out_variant = to_variant(self._op, SchemaKind.out)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/dialects/edge/op/api.py", line 56, in to_variant
native_schema is not None
AssertionError: Schema: aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor cannot be converted to torch.FunctionSchema
Alternatives
Currently, we can circumvent this annoy by return x * 2 ** 2 if torch.compiler.is_compiling() else x << 2
. But I would like to suggest this feature if it is not left unimplemented on purpose.
Additional context
For your better information, and to avoid the possibility that I might have used it wrong, the full model instantiation and exporting codes are as follows:
if __name__ == "__main__":
model = ToyModel()
example_input = (torch.arange(2, 256, 2, dtype=torch.int64),)
dynamic_shapes = None
model(*example_input)
with torch.no_grad():
aten_dialect = export(model, example_input, dynamic_shapes)
backend_dialect: EdgeProgramManager = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
executorch_dialect = backend_dialect.to_executorch(
)
executorch_dialect.save("lshift.pte")
et_runtime: Runtime = Runtime.get()
program: Program = et_runtime.load_program(
Path("lshift.pte"),
verification=Verification.Minimal,
)
print("Program methods:", program.method_names)
forward: Method = program.load_method("forward")
outputs = forward.execute(example_input)
print(f"Ran forward({example_input[0].shape})")
print(f" outputs: {outputs[0].shape}")
RFC (Optional)
No response
Metadata
Metadata
Assignees
Labels
Type
Projects
Status