From e1b3937881d0da63899d6892db002d36585fc73b Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Tue, 24 Mar 2020 19:41:52 -0700 Subject: [PATCH] Ignore rest of outputs of LayerNorm when lowering to Glow (#35338) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35338 Pull Request resolved: https://github.com/pytorch/glow/pull/4343 - Ignore rest of outputs of layoutNorm - Support implicit broadcast when loading basic binary ops from Glow. Reviewed By: jfix71 Differential Revision: D20627768 fbshipit-source-id: beaca0b22ce65e32fbda96780c271f2cff2e4d8e --- include/glow/Importer/CommonOperatorLoader.h | 20 +++++ lib/Importer/Caffe2ModelLoader.cpp | 6 +- .../elementwise_linear_broadcast_net.pbtxt | 35 ++++++++ tests/unittests/Caffe2ImporterTest.cpp | 80 +++++++++++++++++++ 4 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 tests/models/caffe2Models/elementwise_linear_broadcast_net.pbtxt diff --git a/include/glow/Importer/CommonOperatorLoader.h b/include/glow/Importer/CommonOperatorLoader.h index bbb1522f65..2de1e68051 100644 --- a/include/glow/Importer/CommonOperatorLoader.h +++ b/include/glow/Importer/CommonOperatorLoader.h @@ -511,6 +511,26 @@ class CommonOperatorLoader : public ProtobufLoader { bool broadcast; ASSIGN_VALUE_OR_RETURN_ERR(broadcast, getBroadcast(dict)); + // Check implicit broadcast + if (!broadcast && in0.dims().size() != in1.dims().size()) { + bool validBroadcast = true; + auto dimsA = in0.dims(); + auto dimsB = in1.dims(); + for (int i = dimsA.size() - 1, j = dimsB.size() - 1; i >= 0 && j >= 0;) { + auto a = dimsA[i]; + auto b = dimsB[j]; + if (!(a == b || a == 1 || b == 1)) { + validBroadcast = false; + break; + } + --i; + --j; + } + if (!validBroadcast) { + LOG(WARNING) << "Invalid broadcast rule for inputs of " << opName; + } + broadcast = validBroadcast; + } int axis = -1; diff --git a/lib/Importer/Caffe2ModelLoader.cpp b/lib/Importer/Caffe2ModelLoader.cpp index 9e63cbd377..31e33bba07 100644 --- a/lib/Importer/Caffe2ModelLoader.cpp +++ b/lib/Importer/Caffe2ModelLoader.cpp @@ -580,10 +580,10 @@ Error Caffe2ModelLoader::loadLayerNorm(const caffe2::OperatorDef &op, LayerNormalizationNode *node = G_->createLayerNormalization(opName, in, weight, bias, eps); - RETURN_ERR_IF_NOT(op.output_size() == 1, - "Supporting only one output from LayerNorm"); + // We only support one output for LayoutNorm. Ignoring the + // rest of the outputs. + RETURN_IF_ERR(addNodeAsOutput(op, node, /* numOutputs */ 1)); - RETURN_IF_ERR(addNodeAsOutput(op, node)); return Error::success(); } diff --git a/tests/models/caffe2Models/elementwise_linear_broadcast_net.pbtxt b/tests/models/caffe2Models/elementwise_linear_broadcast_net.pbtxt new file mode 100644 index 0000000000..6ab31546c3 --- /dev/null +++ b/tests/models/caffe2Models/elementwise_linear_broadcast_net.pbtxt @@ -0,0 +1,35 @@ +name: "elementwise_linear" +op { + input: "X" + input: "w" + output: "i0" + name: "" + type: "Mul" + arg { + name: "axis" + i: 1 + } + arg { + name: "broadcast" + i: 0 + } +} +op { + input: "i0" + input: "b" + output: "el_result" + name: "" + type: "Add" + arg { + name: "axis" + i: 1 + } + arg { + name: "broadcast" + i: 0 + } +} +external_input: "X" +external_input: "w" +external_input: "b" +external_output: "el_result" diff --git a/tests/unittests/Caffe2ImporterTest.cpp b/tests/unittests/Caffe2ImporterTest.cpp index 92a6890f47..cf0a17ef71 100644 --- a/tests/unittests/Caffe2ImporterTest.cpp +++ b/tests/unittests/Caffe2ImporterTest.cpp @@ -2191,6 +2191,86 @@ TEST_F(Caffe2ImporterTest, elementwiseLinearUnspecifiedAxis) { EXPECT_EQ(mod.getPlaceholders().size(), 4); } +/// Test loading an ElementwiseLinear operator with implicit broadcast +TEST_F(Caffe2ImporterTest, elementwiseImplicitBroadcast) { + ExecutionEngine EE{}; + auto &mod = EE.getModule(); + Function *F = mod.createFunction("main"); + + std::string NetDescFilename( + GLOW_DATA_PATH + "tests/models/caffe2Models/elementwise_linear_broadcast_net.pbtxt"); + std::string NetWeightFilename( + GLOW_DATA_PATH "tests/models/caffe2Models/empty_init_net.pbtxt"); + + PlaceholderBindings bindings; + Placeholder *output; + + // Since the loader will assume that axis = 1, the 0th dim of the shapes of w + // and b must match the 1st dim of X. + Tensor X(ElemKind::FloatTy, {5, 10}); + Tensor w(ElemKind::FloatTy, {10}), b(ElemKind::FloatTy, {10}); + + // Destroy the loader after the graph is loaded since the following execution + // will not depend on anything from the loader. + { + Caffe2ModelLoader caffe2LD(NetDescFilename, NetWeightFilename, + {"X", "w", "b"}, + {&X.getType(), &w.getType(), &b.getType()}, *F); + output = EXIT_ON_ERR(caffe2LD.getSingleOutput()); + } + + // Check that the shape of the output matches that of the input. + std::vector expectedDims = {5, 10}; + EXPECT_TRUE(output->dims().vec() == expectedDims); + + // High level checks on the content of the graph. + // It should look like this: + // + // X w b + // | | | + // | v v + // | Reshape Reshape + // | | | + // | v v + // | Tile Tile + // | / / + // v v------ / + // Mul / + // | /--------------- + // v v + // Add + // | + // v + // Save + + EXPECT_EQ(F->getNodes().size(), 7); + auto *save = getSaveNodeFromDest(output); + auto *add = llvm::dyn_cast(save->getInput().getNode()); + ASSERT_TRUE(add); + auto *mul = llvm::dyn_cast(add->getLHS().getNode()); + ASSERT_TRUE(mul); + auto *bTile = llvm::dyn_cast(add->getRHS().getNode()); + ASSERT_TRUE(bTile); + EXPECT_EQ(bTile->getAxis(), 0); + auto *XPH = llvm::dyn_cast(mul->getLHS().getNode()); + EXPECT_EQ(XPH, mod.getPlaceholderByName("X")); + auto *wTile = llvm::dyn_cast(mul->getRHS().getNode()); + ASSERT_TRUE(wTile); + EXPECT_EQ(wTile->getAxis(), 0); + auto *bReshape = llvm::dyn_cast(bTile->getInput().getNode()); + ASSERT_TRUE(bReshape); + auto *wReshape = llvm::dyn_cast(wTile->getInput().getNode()); + ASSERT_TRUE(wReshape); + auto *wPH = llvm::dyn_cast(wReshape->getInput().getNode()); + EXPECT_EQ(wPH, mod.getPlaceholderByName("w")); + auto *bPH = llvm::dyn_cast(bReshape->getInput().getNode()); + EXPECT_EQ(bPH, mod.getPlaceholderByName("b")); + + // We have three inputs and one output. + EXPECT_EQ(mod.getPlaceholders().size(), 4); +} + /// Test loading SparseLengthsWeightedSum8BitsRowwise. This is created as a /// RowwiseQuantizedSparseLengthsWeightedSumNode. The following inputs/outputs /// are used/expected for this test. Note that the DATA input is