Skip to content

Conversation

jiashenC
Copy link
Contributor

@jiashenC jiashenC commented May 30, 2024

Description

Handle custom ops during TorchScript to ExportedProgram covnersion

torch.library.define(
    "mylib::foo",
    "(Tensor x) -> Tensor",
    lib=lib,
)

# PyTorch custorm op implementation
@torch.library.impl(
    "mylib::foo",
    "CompositeExplicitAutograd",
    lib=lib,
)
def foo_impl(x):
    return x + x

# Meta function of the custom op.
@torch.library.impl_abstract(
    "mylib::foo",
    lib=lib,
)
def foo_meta(x):
    return x + x

class M(torch.nn.Module):
    def forward(self, x):
        return torch.ops.mylib.foo(x)

Test Plan

  • Add a test case where custom op is called and converted. pytest test/export/test_converter.py -s -k test_ts2ep_converter_custom_op

Copy link

pytorch-bot bot commented May 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127580

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 16b840c with merge base 0de6d24 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jiashenC jiashenC marked this pull request as ready for review May 31, 2024 01:25
Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom op changes look good, but I think your PR also includes some changes for when the model has attributes. Can we split those out into a separate PR?

Comment on lines 251 to 277
if target is torch.ops.aten.size.int:
target = torch.ops.aten.sym_size.int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add this to the registry in a followup PR

@jiashenC jiashenC requested a review from angelayi June 4, 2024 18:23
@jiashenC jiashenC force-pushed the ts_converter_custom_op branch from f03cf2a to 3378bb4 Compare June 4, 2024 23:41
@jiashenC jiashenC requested a review from angelayi June 4, 2024 23:48
@jiashenC jiashenC force-pushed the ts_converter_custom_op branch from 3378bb4 to 16b840c Compare June 6, 2024 16:39
@jiashenC jiashenC added the topic: not user facing topic category label Jun 6, 2024
@jiashenC
Copy link
Contributor Author

jiashenC commented Jun 6, 2024

@pytorchbot merge -f "skip stuck tests"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
…torch#127580)

#### Description
Handle custom ops during TorchScript to ExportedProgram covnersion
```python
torch.library.define(
    "mylib::foo",
    "(Tensor x) -> Tensor",
    lib=lib,
)

# PyTorch custorm op implementation
@torch.library.impl(
    "mylib::foo",
    "CompositeExplicitAutograd",
    lib=lib,
)
def foo_impl(x):
    return x + x

# Meta function of the custom op.
@torch.library.impl_abstract(
    "mylib::foo",
    lib=lib,
)
def foo_meta(x):
    return x + x

class M(torch.nn.Module):
    def forward(self, x):
        return torch.ops.mylib.foo(x)
```

#### Test Plan
* Add a test case where custom op is called and converted. `pytest test/export/test_converter.py -s -k test_ts2ep_converter_custom_op`
Pull Request resolved: pytorch#127580
Approved by: https://github.com/angelayi
@github-actions github-actions bot deleted the ts_converter_custom_op branch July 8, 2024 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants