Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use general ATen dispatch mechanism #38975

Open
fengyuan14 opened this issue May 25, 2020 · 4 comments
Open

Use general ATen dispatch mechanism #38975

fengyuan14 opened this issue May 25, 2020 · 4 comments
Labels
module: internals Related to internal abstractions in c10 and ATen triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@fengyuan14
Copy link
Collaborator

fengyuan14 commented May 25, 2020

🚀 Feature

Use general ATen dispatch mechanism in some case, like at::lstm.

Motivation

When implementing a new out-of-source ATen backend extension for PyTorch, we find it is no chance to hook whole lstm cell based on existing at::lstm dispatch strategy. Our backend gets bad performance here.

In detail, backends of at::lstm are bypassed at runtime. VariableType::lstm calls TypeDefaultType::lstm directly, bypasses c10::dispatcher. Except for two backends, cudnn and mimopen, which are dispatched by hard-code “if-else”, others backends fall down to component operators, like at::matmul, at::sigmoid, at::tanh...

// TypeDefault::lstm -> at::native::lstm
at::native::lstm() {
  if (at::cudnn_is_acceptable(_input)) {
    ...
    return at::_cudnn_rnn();
  }
  if (at::use_miopen(_input, dropout_p)) {
    ...
    return at::miopen_rnn();
  }
  lstm_default_impl(); // fall down to component Ops, at::matmul, at::tanh...
}

We also find same issue in at::_embedding_bag_backward.

Pitch

at::convolution_overrideable, at::native_batch_norm and at::native_layer_norm are good examples in similar case.

@ezyang ezyang added module: internals Related to internal abstractions in c10 and ATen triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 26, 2020
@ezyang
Copy link
Contributor

ezyang commented May 26, 2020

Yeah, basically we should do the same thing as convolution_overrideable in this case. Or move to the glorious new world described in #29548

@fengyuan14
Copy link
Collaborator Author

Do you mean you won't bypass at::lstm, and TypeDefault::lstm should be a fallback for redispatch. Please correct me.

@ezyang
Copy link
Contributor

ezyang commented May 27, 2020

No, we will still bypass lstm. But if it's not one of our hardcoded backends (cpu/cuda), we'll call lstm_overrideable, which is NOT bypassed. This is how the convolution overrideable business works.

@fengyuan14
Copy link
Collaborator Author

Thanks, it is clear to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: internals Related to internal abstractions in c10 and ATen triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants