Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add automatic type promotion to element-wise ops #1240

Merged

Conversation

mfeliz-cruise
Copy link
Contributor

@mfeliz-cruise mfeliz-cruise commented Aug 8, 2022

Adds automatic type promotion to match the default torch-script behavior for element-wise ops. Debug messages added for the type mismatch and cast.

Messages written to log.

DEBUG: [Torch-TensorRT] - Type mismatch for inputs in element-wise operation %3 : Tensor = aten::add(%0, %1, %2): Int32, Float32
DEBUG: [Torch-TensorRT] - Element-wise op type promotion adding cast from Int32 to Float32 for layer %3 : Tensor = aten::add(%0, %1, %2)

Fixes # (issue)

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

  • New feature (non-breaking change which adds functionality)

  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

  • This change requires a documentation update

  • My code follows the style guidelines of this project (You can use the linters)

  • I have performed a self-review of my own code

  • I have commented my code, particularly in hard-to-understand areas and hacks

  • I have made corresponding changes to the documentation

  • I have added tests to verify my fix or my feature

  • New and existing unit tests pass locally with my changes

  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Signed-off-by: Michael Feliz michael.feliz@getcruise.com

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: tests Issues re: Tests labels Aug 8, 2022
@narendasan
Copy link
Collaborator

There was a partial fix in #1201 but I like how generic this is. Can you rebase?

cc: @peri044

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Once the rebasing is done, it's good to merge

Adds automatic type promotion to match the default torch-script behavior for element-wise ops. Debug messages added for the type mismatch and cast.

Messages written to log.
```
DEBUG: [Torch-TensorRT] - Type mismatch for inputs in element-wise operation %3 : Tensor = aten::add(%0, %1, %2): Int32, Float32
DEBUG: [Torch-TensorRT] - Element-wise op type promotion adding cast from Int32 to Float32 for layer %3 : Tensor = aten::add(%0, %1, %2)
```
Fixes # (issue)

Please delete options that are not relevant and/or add your own.

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to not work as expected)
- This change requires a documentation update

- [ ] My code follows the style guidelines of this project (You can use the linters)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified

Signed-off-by: Michael Feliz <michael.feliz@getcruise.com>
@mfeliz-cruise mfeliz-cruise force-pushed the michael.feliz/element_wise_casting branch from 202a7c7 to 10e036b Compare August 9, 2022 22:49
@mfeliz-cruise mfeliz-cruise marked this pull request as ready for review August 9, 2022 22:52
if (1 != scalar) {
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
if (1 != scalar.to<float>()) {
auto alphaTensor = impl::scalar_to_tensor(ctx, scalar);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use scalar_to_tensor instead of impl::scalar_to_tensor here ?

@peri044 peri044 merged commit 679ea21 into pytorch:master Aug 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants