Skip to content

Request of support of aten::__lshift__.Scala and aten::__rshift__.Scala #8711

@JoshuaGhost

Description

@JoshuaGhost

🚀 The feature, motivation and pitch

The << (declaration in the ATen dialect: aten::__lshift__.Scala(Tensor self, Scalar other)) and >> (declaration in the ATen dialect: aten::__rshift__.Scala(Tensor self, Scala other)) are supported as dunder methods in PyTorch since this MR. But currently, this is not supported in the Edge dialect (hopefully I've used those concepts correctly). By exporting a model as follows:

class ToyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
    def forward(self, x) -> torch.Tensor:
        return x << 2

The default to_out_variant pass gives the following error:

Traceback (most recent call last):
  File "/nfs/home/zizhang/workspace/convert_litert/torch_lshift.py", line 29, in <module>
    executorch_dialect = backend_dialect.to_executorch(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 101, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 1400, in to_executorch
    new_gm_res = p(new_gm)
                 ^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/passes/infra/pass_base.py", line 44, in __call__
    res = self.call(graph_module)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/passes/__init__.py", line 370, in call
    out_var_target = target.to_out_variant()
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/dialects/edge/_ops.py", line 322, in to_out_variant
    out_variant = to_variant(self._op, SchemaKind.out)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/dialects/edge/op/api.py", line 56, in to_variant
    native_schema is not None
AssertionError: Schema: aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor cannot be converted to torch.FunctionSchema

Alternatives

Currently, we can circumvent this annoy by return x * 2 ** 2 if torch.compiler.is_compiling() else x << 2. But I would like to suggest this feature if it is not left unimplemented on purpose.

Additional context

For your better information, and to avoid the possibility that I might have used it wrong, the full model instantiation and exporting codes are as follows:

if __name__ == "__main__":
    model = ToyModel()
    example_input = (torch.arange(2, 256, 2, dtype=torch.int64),)
    dynamic_shapes = None
    model(*example_input)
    with torch.no_grad():
        aten_dialect = export(model, example_input, dynamic_shapes)

        backend_dialect: EdgeProgramManager = to_edge_transform_and_lower(
            aten_dialect,
            compile_config=EdgeCompileConfig(_check_ir_validity=False),
            partitioner=[XnnpackPartitioner()],
        )
        executorch_dialect = backend_dialect.to_executorch(
        )
        executorch_dialect.save("lshift.pte")
        

    et_runtime: Runtime = Runtime.get()
    program: Program = et_runtime.load_program(
        Path("lshift.pte"),
        verification=Verification.Minimal,
    )
    print("Program methods:", program.method_names)
    forward: Method = program.load_method("forward")

    outputs = forward.execute(example_input)
    print(f"Ran forward({example_input[0].shape})")
    print(f"  outputs: {outputs[0].shape}")

RFC (Optional)

No response

cc @larryliu0820 @manuelcandales

Metadata

Metadata

Labels

module: kernelsIssues related to kernel libraries and utilities, and code under kernels/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions