Skip to content

Commit

Permalink
[jit][tensorexpr] Added aten::batch_norm into fuser when in inference…
Browse files Browse the repository at this point in the history
… mode (#54204)

Summary: Pull Request resolved: #54204

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D27134348

Pulled By: huiguoo

fbshipit-source-id: 5ea7a6c5bc694fcdfc436dba3fa6eb269420324e
  • Loading branch information
huiguoo authored and facebook-github-bot committed Mar 23, 2021
1 parent fee470d commit 2a53897
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
38 changes: 38 additions & 0 deletions benchmarks/cpp/tensorexpr/bench_ops.py
@@ -1,5 +1,6 @@
import timeit
import torch
import torch.nn.functional as F

torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._debug_set_fusion_group_inlining(False)
Expand Down Expand Up @@ -65,3 +66,40 @@ def hardswish(x):
teager = timeit.timeit(stmt="op(x)", globals=globals(), number=bench_iters)
tjit = timeit.timeit(stmt="traced(x)", globals=globals(), number=bench_iters)
print(f"{op.__name__:20s} {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}")

def test_batch_norm():
op = F.batch_norm
print("{:20s} {:20s} {:>10s} {:>10s} {:>10s}".format("op", "shape", "eager", "nnc", "speedup"))
batch_norm_shapes = [
[1, 64, 112, 112],
[1, 256, 14, 14],
[1, 128, 28, 28],
[1, 64, 56, 56],
[1, 512, 7, 7],
[5, 64, 112, 112],
[5, 256, 14, 14],
[5, 128, 28, 28],
[5, 64, 56, 56],
[5, 512, 7, 7]]
for n, c, h, w in batch_norm_shapes:
x = torch.rand((n, c, h, w))
y = torch.rand((c))
z = torch.rand((c))
traced = torch.jit.trace(lambda x, y, z: op(x, y, z), (x, y, z))

# Warmup.
warmup_iters = 8
for _ in range(warmup_iters):
op(x, y, z)
traced(x, y, z)

# Validate result.
torch.testing.assert_allclose(op(x, y, z), traced(x, y, z))

# Benchmark.
bench_iters = 100
teager = timeit.timeit(stmt="op(x, y, z)", globals=locals(), number=bench_iters)
tjit = timeit.timeit(stmt="traced(x, y, z)", globals=locals(), number=bench_iters)
print(f"{op.__name__:20s} ({n:>3d}, {c:>3d}, {h:>3d}, {w:>3d}) {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}")

test_batch_norm()
30 changes: 30 additions & 0 deletions test/test_jit_fuser_te.py
Expand Up @@ -1661,6 +1661,36 @@ def apply(fn):
" ".join(["Failed:", str(dtype), op.__name__, device])
)

def test_ternary_norm_ops(self):
def apply(fn):
return lambda x, y, z: fn(x, y, z)

ternary_ops = [
F.batch_norm,
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
try:
x = self.data_for(dtype, device, size=[5, 3, 128, 128])
y = self.data_for(dtype, device, size=[3])
z = self.data_for(dtype, device, size=[3])
fn = apply(op)
ref = fn(x, y, z)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)


@unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure")
def test_list_ops(self):
def apply(fn):
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Expand Up @@ -154,6 +154,7 @@ static const OperatorSet& supported_eltwise_set() {
"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor",
"aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor",
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
// TODO: enable other min/max variants, operators that can be both
// elementwise or reductions:
"aten::min.other(Tensor self, Tensor other) -> Tensor",
Expand Down Expand Up @@ -983,6 +984,12 @@ class TensorExprFuser {
REQ(tensorexpr::pickDeviceType(node->inputs()));
}

// Only fuse aten::batch_norm when the parameter 'training' is false
if (node->kind() == aten::batch_norm) {
REQ(node->input(5)->node()->kind() == prim::Constant);
REQ(!toIValue(node->input(5)).value().toBool());
}

REQ(tensorexpr::isSupported(node));
REQ(typesAreSupported(node));

Expand Down
69 changes: 69 additions & 0 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Expand Up @@ -262,6 +262,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
case aten::reciprocal:
case aten::neg:
case aten::relu:
case aten::batch_norm:
case aten::isnan:
case aten::log:
case aten::log10:
Expand Down Expand Up @@ -1062,6 +1063,74 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
});
} break;

case aten::batch_norm: {
bool hasWeight = true;
bool hasBias = true;

if (v->node()->input(1)->node()->kind() == prim::Constant) {
const auto val = toIValue(v->node()->input(1)).value();
if (val.isNone()) {
hasWeight = false;
}
}

if (v->node()->input(2)->node()->kind() == prim::Constant) {
const auto val = toIValue(v->node()->input(2)).value();
if (val.isNone()) {
hasBias = false;
}
}

auto const& shape = valueShape(v->node()->inputs()[0]);
return Compute(
"aten_batch_norm",
c10::fmap<DimArg>(shape),
[this, v, hasWeight, hasBias](const std::vector<VarHandle>& axes) {
TORCH_INTERNAL_ASSERT(axes.size() >= 2);
auto const& n = v->node();
// axes: N, C, H, W
std::vector<ExprHandle> indices(axes.begin(), axes.end());
ExprHandle c = indices[1];

// Parameter list:
// input, weight, bias, mean, var, training, momentum, eps,
// cudnn_enabled
std::vector<ExprHandle> inputs = {
tensorOrConstant(n->input(0), indices), // input
tensorOrConstant(n->input(3), {c}), // mean
tensorOrConstant(n->input(4), {c}), // var
constant(n->input(7)) // eps
};
if (hasWeight) {
inputs.push_back(tensorOrConstant(n->input(1), {c}));
}
if (hasBias) {
inputs.push_back(tensorOrConstant(n->input(2), {c}));
}
promoteInputs(inputs);

ExprHandle input = inputs[0];
ExprHandle mean = inputs[1];
ExprHandle var = inputs[2];
ExprHandle eps = inputs[3];
ExprHandle weight = FloatImm::make(1);
ExprHandle bias = FloatImm::make(0);

if (hasWeight) {
weight = inputs[4];
}
if (hasBias) {
bias = inputs[5];
}

auto inv_var = rsqrt(var + eps);
auto alpha = inv_var * weight;
auto beta = bias - mean * alpha;
auto output = input * alpha + beta;
return demoteOutput(output, n->output());
});
} break;

case aten::log: {
return computeOneOperand("aten_log", v, [](const ExprHandle& a) {
return log(promoteIntegerToDefaultType(a));
Expand Down

0 comments on commit 2a53897

Please sign in to comment.