Skip to content

Commit

Permalink
[jit][tensorexpr] Added aten::batch_norm into fuser when in inference…
Browse files Browse the repository at this point in the history
… mode

ghstack-source-id: 5507e00060377a2f519fd392359e792cd9374c16
Pull Request resolved: #54204
  • Loading branch information
huiguoo committed Mar 19, 2021
1 parent 0584fd9 commit 799e3d1
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/test_jit_fuser_te.py
Expand Up @@ -1661,6 +1661,36 @@ def apply(fn):
" ".join(["Failed:", str(dtype), op.__name__, device])
)

def test_ternary_norm_ops(self):
def apply(fn):
return lambda x, y, z: fn(x, y, z)

ternary_ops = [
F.batch_norm,
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
try:
x = self.data_for(dtype, device, size=[5, 3, 128, 128])
y = self.data_for(dtype, device, size=[3])
z = self.data_for(dtype, device, size=[3])
fn = apply(op)
ref = fn(x, y, z)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)


@unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure")
def test_list_ops(self):
def apply(fn):
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Expand Up @@ -154,6 +154,7 @@ static const OperatorSet& supported_eltwise_set() {
"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor",
"aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor",
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
// TODO: enable other min/max variants, operators that can be both
// elementwise or reductions:
"aten::min.other(Tensor self, Tensor other) -> Tensor",
Expand Down Expand Up @@ -981,6 +982,12 @@ class TensorExprFuser {
REQ(tensorexpr::pickDeviceType(node->inputs()));
}

// Only fuse aten::batch_norm when the parameter 'training' is false
if (node->kind() == aten::batch_norm) {
REQ(node->input(5)->node()->kind() == prim::Constant);
REQ(!toIValue(node->input(5)).value().toBool());
}

REQ(tensorexpr::isSupported(node));
REQ(typesAreSupported(node));

Expand Down
69 changes: 69 additions & 0 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Expand Up @@ -262,6 +262,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
case aten::reciprocal:
case aten::neg:
case aten::relu:
case aten::batch_norm:
case aten::isnan:
case aten::log:
case aten::log10:
Expand Down Expand Up @@ -1062,6 +1063,74 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
});
} break;

case aten::batch_norm: {
bool hasWeight = true;
bool hasBias = true;

if (v->node()->input(1)->node()->kind() == prim::Constant) {
const auto val = toIValue(v->node()->input(1)).value();
if (val.isNone()) {
hasWeight = false;
}
}

if (v->node()->input(2)->node()->kind() == prim::Constant) {
const auto val = toIValue(v->node()->input(2)).value();
if (val.isNone()) {
hasBias = false;
}
}

auto const& shape = valueShape(v->node()->inputs()[0]);
return Compute(
"aten_batch_norm",
c10::fmap<DimArg>(shape),
[this, v, hasWeight, hasBias](const std::vector<VarHandle>& axes) {
TORCH_INTERNAL_ASSERT(axes.size() >= 2);
auto const& n = v->node();
// axes: N, C, H, W
std::vector<ExprHandle> indices(axes.begin(), axes.end());
ExprHandle c = indices[1];

// Parameter list:
// input, weight, bias, mean, var, training, momentum, eps,
// cudnn_enabled
std::vector<ExprHandle> inputs = {
tensorOrConstant(n->input(0), indices), // input
tensorOrConstant(n->input(3), {c}), // mean
tensorOrConstant(n->input(4), {c}), // var
constant(n->input(7)) // eps
};
if (hasWeight) {
inputs.push_back(tensorOrConstant(n->input(1), {c}));
}
if (hasBias) {
inputs.push_back(tensorOrConstant(n->input(2), {c}));
}
promoteInputs(inputs);

ExprHandle input = inputs[0];
ExprHandle mean = inputs[1];
ExprHandle var = inputs[2];
ExprHandle eps = inputs[3];
ExprHandle weight = FloatImm::make(1);
ExprHandle bias = FloatImm::make(0);

if (hasWeight) {
weight = inputs[4];
}
if (hasBias) {
bias = inputs[5];
}

auto inv_var = rsqrt(var + eps);
auto alpha = inv_var * weight;
auto beta = bias - mean * alpha;
auto output = input * alpha + beta;
return demoteOutput(output, n->output());
});
} break;

case aten::log: {
return computeOneOperand("aten_log", v, [](const ExprHandle& a) {
return log(promoteIntegerToDefaultType(a));
Expand Down

0 comments on commit 799e3d1

Please sign in to comment.