New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jit][tensorexpr] Added aten::batch_norm into fuser when in inference mode #54204
Conversation
… mode [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 80161bf (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: clang-tidy (1/1)Step: "Add annotations" (full log | diagnosis details | 🔁 rerun)
|
… mode ghstack-source-id: 2393e6582316f865f83c9291d5e5b8e25b7c9d5d Pull Request resolved: #54204
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a couple of comments inline!
torch/csrc/jit/tensorexpr/kernel.cpp
Outdated
|
||
// parameter list: input, weight, bias, mean, var, training, | ||
// momentum, eps, cudnn_enabled | ||
Tensor* input = tensors_.at(n->inputs()[0]->unique()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: use n->input(X)
instead of n->inputs()[X]
.
torch/csrc/jit/tensorexpr/kernel.cpp
Outdated
Tensor* weight = tensors_.at(n->inputs()[1]->unique()); | ||
Tensor* bias = tensors_.at(n->inputs()[2]->unique()); | ||
|
||
auto inv_var = rsqrt(var->call(c) + eps); | ||
auto weight_v = weight->call(c); | ||
auto bias_v = bias->call(c); | ||
auto alpha = inv_var * weight_v; | ||
auto beta = bias_v - mean->call(c) * alpha; | ||
auto output = input->call(axes) * alpha + beta; | ||
return demoteOutput(output, n->output()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could merge these four cases using weight_v = 1
and bias_v = 0
as default values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice! Can you check two things before landing:
- How does that
training
boolean parameter get set? Is it from theeval()
state of the model? What happens if the user switches the model betweeneval()
andtrain()
? - The promoteInputs question I have below; make sure we "do the right thing" if the argument types don't match.
torch/csrc/jit/tensorexpr/kernel.cpp
Outdated
ExprHandle eps = constant(n->inputs()[7]); | ||
|
||
// axes: N, C, H, W | ||
std::vector<VarHandle> c(axes.begin() + 1, axes.begin() + 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're just getting one axis here, right (c)? I think I'd be inclined to do VarHandle c = axes[1];
here instead of constructing a one-element vector. call
has a variadic overload so nothing below should have to change (at least, I don't think so)
torch/csrc/jit/tensorexpr/kernel.cpp
Outdated
auto weight_v = weight->call(c); | ||
auto alpha = inv_var * weight_v; | ||
auto output = input->call(axes) * alpha - mean->call(c) * alpha; | ||
return demoteOutput(output, n->output()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we need some promoteInputs as well. I'm not certain here b/c pytorch type promotion is complicated, but I wonder what you get if, e.g., input
is a FloatTensor
and mean
is a DoubleTensor
. (Not something that's likely to happen but we should handle it correctly all the same).
Great question! Has investigated on this and below is what I found. |
…n inference mode" Differential Revision: [D27134348](https://our.internmc.facebook.com/intern/diff/D27134348) [ghstack-poisoned]
… mode ghstack-source-id: 5507e00060377a2f519fd392359e792cd9374c16 Pull Request resolved: #54204
…n inference mode" Differential Revision: [D27134348](https://our.internmc.facebook.com/intern/diff/D27134348) [ghstack-poisoned]
… mode ghstack-source-id: df5e2fdea5ade06ac43c7e28418cd2daff595013 Pull Request resolved: #54204
…n inference mode" Differential Revision: [D27134348](https://our.internmc.facebook.com/intern/diff/D27134348) [ghstack-poisoned]
… mode ghstack-source-id: 459b504f3065716efc6d3a49a4c0120ed3b2640e Pull Request resolved: #54204
…n inference mode" Differential Revision: [D27134348](https://our.internmc.facebook.com/intern/diff/D27134348) [ghstack-poisoned]
…n inference mode" Differential Revision: [D27134348](https://our.internmc.facebook.com/intern/diff/D27134348) [ghstack-poisoned]
… mode ghstack-source-id: 20b079b389b951acfc16651f4b7f5aeeeac5815f Pull Request resolved: #54204
Stack from ghstack:
Differential Revision: D27134348