Skip to content
Permalink
Browse files

Move Backend::transformPostLowering to return Expected<bool> (#4346)

Summary:
Pull Request resolved: #4346

`Backend::transformPostLowering()` may require some error handling/returning up the stack. Since it is called in the middle of the optimization pipeline and in the ONNXIFI interface, have it return `Expected<bool>` instead of `bool`.

Also added `RETURN_IF_EXPECTED_IS_ERR()` for when we want to call a function that returns some Expected and we don't care to check its return type.

Reviewed By: yinghai

Differential Revision: D20637708

fbshipit-source-id: 10d9a124a2984ff9de511a2af0376895f78495c2
  • Loading branch information
jfix71 authored and facebook-github-bot committed Mar 25, 2020
1 parent 3aff53b commit 0ddb67d09f09ceafd579ba2f37f8916a0c68f902
@@ -52,7 +52,7 @@ are two pure virtual functions all backends must implement:

Additionally, there are several virtual functions that backends can override:

- `virtual bool transformPostLowering(Function *F, CompilationContext &cctx) const;`
- `virtual Expected<bool> transformPostLowering(Function *F, CompilationContext &cctx) const;`

- Allow the backend to transform the `Function *F` after [node
lowering](https://github.com/pytorch/glow/blob/master/docs/IR.md#node-lowering)
@@ -109,8 +109,8 @@ class Backend : public Named {
/// giving the backend an opportunity to transform the graph before IRGen. The
/// backend may insert backend and device-specific nodes. The backend is
/// responsible for cleaning up after itself.
/// \returns True if the graph was modified.
virtual bool transformPostLowering(
/// \returns an Expected True if the graph was modified.
virtual Expected<bool> transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo = nullptr) const {
return false;
@@ -58,7 +58,7 @@ class Interpreter final : public BackendUsingGlowIR,

bool shouldLower(const Node *N) const override;

bool transformPostLowering(
Expected<bool> transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo = nullptr) const override;

@@ -96,7 +96,6 @@
} while (0)

/// Takes an Error and returns it if it's not success.
// TODO: extend this to work with Expected as well.
#define RETURN_IF_ERR(err) \
do { \
if (auto errV = std::forward<glow::detail::GlowError>(err)) { \
@@ -107,6 +106,18 @@
} \
} while (0)

/// Takes an Expected and returns it if it's not success.
#define RETURN_IF_EXPECTED_IS_ERR(expV) \
do { \
static_assert(glow::detail::IsExpected<decltype(expV)>(), \
"Expected value to be a Expected"); \
if (!expV) { \
auto err = expV.takeError(); \
err.addToStack(__FILE__, __LINE__); \
return std::forward<Error>(err); \
} \
} while (0)

/// Takes an Error and if it contains an ErrorValue then calls FAIL().
#define FAIL_TEST_IF_ERR(err) \
do { \
@@ -44,7 +44,7 @@ class CPUBackend : public LLVMBackend {
static std::string getName() { return "CPU"; }
static unsigned numDevices();

bool transformPostLowering(
Expected<bool> transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo = nullptr) const override;

@@ -119,9 +119,9 @@ static Node *optimizeCPUMaxSplat(MaxNode *MN, Function *F) {
new CPUMaxSplatNode(MN->getName(), input, splat->getValue()));
}

bool CPUBackend::transformPostLowering(
Function *F, CompilationContext &,
const glow::runtime::DeviceInfo *) const {
Expected<bool>
CPUBackend::transformPostLowering(Function *F, CompilationContext &,
const glow::runtime::DeviceInfo *) const {
LOG_SCOPE(F->getLogContext(), "CPUBackend::transformPostLowering")

bool changed = false;
@@ -1420,7 +1420,7 @@ bool surroundTileWithReshapes(Function *F, TileNode &tile) {

} // namespace

bool HabanaBackend::transformPostLowering(
Expected<bool> HabanaBackend::transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo) const {
LOG_SCOPE(F->getLogContext(), "HabanaBackend::transformPostLowering")
@@ -50,7 +50,7 @@ class HabanaBackend final : public Backend {

bool shouldLower(const Node *N) const override;

bool transformPostLowering(
Expected<bool> transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo = nullptr) const override;

@@ -803,7 +803,7 @@ static bool channelwiseQuantizeFloatBias(
return true;
}

bool Interpreter::transformPostLowering(
Expected<bool> Interpreter::transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo) const {
LOG_SCOPE(F->getLogContext(), "Interpreter::transformPostLowering")
@@ -790,7 +790,7 @@ static bool parallelizeFunction(Function *F, CompilationContext &cctx) {
return changed;
}

bool NNPIBackend::transformPostLowering(
Expected<bool> NNPIBackend::transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo) const {
LOG_SCOPE(F->getLogContext(), "NNPIBackend::transformPostLowering");
@@ -50,7 +50,7 @@ class NNPIBackend final : public Backend {
runtime::DeviceManager *
createDeviceManager(const runtime::DeviceConfig &deviceConfig) override;

bool transformPostLowering(
Expected<bool> transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo = nullptr) const override;

@@ -231,7 +231,7 @@ class OCLBackend final : public BackendUsingGlowIR {
Expected<std::unique_ptr<CompiledFunction>>
compile(Function *F, const BackendOptions &opts) const override;

bool transformPostLowering(
Expected<bool> transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo) const override;

@@ -27,7 +27,7 @@ using llvm::dyn_cast;
using namespace glow;

/// Perform OpenCL specific post-lowering graph transformation.
bool OCLBackend::transformPostLowering(
Expected<bool> OCLBackend::transformPostLowering(
Function *F, CompilationContext &cctx,
const glow::runtime::DeviceInfo *devInfo) const {
// NCHW transformation is not supported in training mode yet, because of some
@@ -93,7 +93,11 @@ onnxStatus Backend::checkGraphCompatibility(const void *onnxModel,

// Call the backend's transformPostLowering to match the normal compilation
// pipeline then DCE any nodes that are no longer needed.
if (glowBackend_->transformPostLowering(function, cctx)) {
auto changedOrErr = glowBackend_->transformPostLowering(function, cctx);
if (ERR_TO_BOOL(changedOrErr.takeError())) {
return ONNXIFI_STATUS_INTERNAL_ERROR;
}
if (*changedOrErr) {
runDCEPass(function, cctx);
}

@@ -4022,7 +4022,7 @@ Error glow::optimizeFunction(Function *F, const Backend &B,
}

// Allow the backend to transform the graph after lowering.
B.transformPostLowering(F, cctx, devInfo);
RETURN_IF_EXPECTED_IS_ERR(B.transformPostLowering(F, cctx, devInfo));

if (!B.shouldPreQuantizeConstants()) {
// Do the actual float ->fix-point conversion of constant tensors after
@@ -2044,8 +2044,9 @@ static SaveNode *getUniqueSaveNode(Function *F) {
class MockBackendPrequantizeConst : public MockBackend {
bool shouldPreQuantizeConstants() const override { return true; }
bool isOpSupported(const NodeInfo &) const override { return true; }
bool transformPostLowering(Function *F, CompilationContext &,
const glow::runtime::DeviceInfo *) const override {
Expected<bool>
transformPostLowering(Function *F, CompilationContext &,
const glow::runtime::DeviceInfo *) const override {
// Check the IR.
EXPECT_EQ(F->getNodes().size(), 1);
auto *save = getUniqueSaveNode(F);
@@ -2058,8 +2059,9 @@ class MockBackendPrequantizeConst : public MockBackend {
class MockBackendNotPrequantizeConst : public MockBackend {
bool shouldPreQuantizeConstants() const override { return false; }
bool isOpSupported(const NodeInfo &) const override { return true; }
bool transformPostLowering(Function *F, CompilationContext &,
const glow::runtime::DeviceInfo *) const override {
Expected<bool>
transformPostLowering(Function *F, CompilationContext &,
const glow::runtime::DeviceInfo *) const override {
// Check the IR.
EXPECT_EQ(F->getNodes().size(), 2);
auto *save = getUniqueSaveNode(F);
@@ -61,7 +61,8 @@ TEST_F(HabanaBackendTest, SurroundTile) {

// Invoke Habana backend specific graph optimisations.
CompilationContext cctx;
bool changed = backend.transformPostLowering(F_, cctx);
bool changed;
ASSIGN_VALUE_OR_FAIL_TEST(changed, backend.transformPostLowering(F_, cctx));
EXPECT_TRUE(changed);

// Invoke dead code elimination.
@@ -113,7 +114,8 @@ TEST_F(HabanaBackendTest, DoNotSurroundTile) {

// Invoke Habana backend specific graph optimisations.
CompilationContext cctx;
bool changed = backend.transformPostLowering(F_, cctx);
bool changed;
ASSIGN_VALUE_OR_FAIL_TEST(changed, backend.transformPostLowering(F_, cctx));

// Graph should not change since input to Tile is already 4D.
EXPECT_FALSE(changed);
@@ -142,7 +144,8 @@ TEST_F(HabanaBackendTest, FuseConvRelu) {

// Invoke Habana backend specific graph optimisations.
CompilationContext cctx;
bool changed = backend.transformPostLowering(F_, cctx);
bool changed;
ASSIGN_VALUE_OR_FAIL_TEST(changed, backend.transformPostLowering(F_, cctx));
EXPECT_TRUE(changed);

// Now, the graph should look like this:
@@ -216,7 +219,8 @@ TEST_F(HabanaBackendTest, FuseConvAdd) {

// Invoke Habana backend specific graph optimisations.
CompilationContext cctx;
bool changed = backend.transformPostLowering(F_, cctx);
bool changed;
ASSIGN_VALUE_OR_FAIL_TEST(changed, backend.transformPostLowering(F_, cctx));
EXPECT_TRUE(changed);

// Now, the graph should look like this:
@@ -307,7 +311,8 @@ TEST_F(HabanaBackendTest, FuseConvAddRelu) {

// Invoke Habana backend specific graph optimisations.
CompilationContext cctx;
bool changed = backend.transformPostLowering(F_, cctx);
bool changed;
ASSIGN_VALUE_OR_FAIL_TEST(changed, backend.transformPostLowering(F_, cctx));
EXPECT_TRUE(changed);

// Now, the graph should look like this:
@@ -398,7 +403,9 @@ TEST_F(HabanaBackendTest, ConvertFC) {
auto *FC = F_->createFullyConnected("fc", input, weight, bias);
auto *save = F_->createSave("save", FC);
CompilationContext cctx;
backend.transformPostLowering(F_, cctx);
bool changed;
ASSIGN_VALUE_OR_FAIL_TEST(changed, backend.transformPostLowering(F_, cctx));
EXPECT_TRUE(changed);
ASSERT_TRUE(save);
ASSERT_TRUE(llvm::isa<HabanaFullyConnectedNode>(save->getInput()));
}
@@ -412,7 +419,8 @@ TEST_F(HabanaBackendTest, ConvertConv) {
SaveNode *save = F_->createSave("save", conv);

CompilationContext cctx;
bool changed = backend.transformPostLowering(F_, cctx);
bool changed;
ASSIGN_VALUE_OR_FAIL_TEST(changed, backend.transformPostLowering(F_, cctx));
EXPECT_TRUE(changed);
ASSERT_TRUE(save);
ASSERT_TRUE(llvm::isa<HabanaConvolutionNode>(save->getInput()));

0 comments on commit 0ddb67d

Please sign in to comment.
You can’t perform that action at this time.