From a6fa3b26825fef139aecf7b931aefc7a4c1822ac Mon Sep 17 00:00:00 2001 From: jiej Date: Wed, 9 Dec 2020 15:27:23 -0800 Subject: [PATCH] adding profile_ivalue (#47666) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47666 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D25255573 Pulled By: Krovatkin fbshipit-source-id: 5d8753e4040a3d96105d28d26728125947c7a638 --- aten/src/ATen/core/interned_strings.h | 1 + torch/csrc/jit/codegen/cuda/interface.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/interface.h | 3 ++- .../jit/codegen/cuda/register_interface.cpp | 3 ++- torch/csrc/jit/ir/alias_analysis.cpp | 1 + torch/csrc/jit/ir/ir.cpp | 11 +++++++++ torch/csrc/jit/ir/ir.h | 24 ++++++++++++++++++- torch/csrc/jit/runtime/interpreter.cpp | 4 ++++ torch/csrc/jit/runtime/operator.cpp | 2 ++ .../jit/runtime/register_prim_ops_fulljit.cpp | 10 +++++++- 10 files changed, 57 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 5a0efffea261..7a74ec3b1736 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -139,6 +139,7 @@ namespace c10 { _(prim, HasAttr) \ _(prim, profile) \ _(prim, profile_optional) \ + _(prim, profile_ivalue) \ _(prim, AddStatValue) \ _(prim, TimePoint) \ _(prim, CallFunction) \ diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index 8bc3ba3b4c6f..e3efd924efb6 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -36,9 +36,9 @@ void runFusionGroup(const Node* fusion_node, Stack& stack) { void fuseGraph(std::shared_ptr& graph) { TORCH_CHECK( - getFuserInterface()->fn_fuse_graph != nullptr, + getFuserInterface()->fn_fuse_graph_ != nullptr, "Running the CUDA fuser requires a CUDA build."); - getFuserInterface()->fn_fuse_graph(graph); + getFuserInterface()->fn_fuse_graph_(graph); } bool canFuseNode(const Node* node) { diff --git a/torch/csrc/jit/codegen/cuda/interface.h b/torch/csrc/jit/codegen/cuda/interface.h index 7c156b1dc7c9..00d94a9f12e0 100644 --- a/torch/csrc/jit/codegen/cuda/interface.h +++ b/torch/csrc/jit/codegen/cuda/interface.h @@ -2,6 +2,7 @@ #include #include +#include /* * This file contains APIs for cuda fuser; @@ -22,7 +23,7 @@ TORCH_API std::atomic& getCudaFusionGuardMode(); struct CudaFuserInterface { void (*fn_compile_n_)(Node*) = nullptr; void (*fn_run_n_s_)(const Node*, Stack&) = nullptr; - void (*fn_fuse_graph)(std::shared_ptr&) = nullptr; + void (*fn_fuse_graph_)(std::shared_ptr&) = nullptr; bool (*fn_can_fuse_n_)(const Node*) = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/register_interface.cpp b/torch/csrc/jit/codegen/cuda/register_interface.cpp index f340a903131d..284ee05420a1 100644 --- a/torch/csrc/jit/codegen/cuda/register_interface.cpp +++ b/torch/csrc/jit/codegen/cuda/register_interface.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -20,7 +21,7 @@ class RegisterInterface { auto ptr = getFuserInterface(); ptr->fn_compile_n_ = &compileCudaFusionGroup; ptr->fn_run_n_s_ = &runCudaFusionGroup; - ptr->fn_fuse_graph = &CudaFuseGraph; + ptr->fn_fuse_graph_ = &CudaFuseGraph; ptr->fn_can_fuse_n_ = &isFusableCudaFusionGroup; RegisterProfilingNode(canFuseNode); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index b055d29164a5..67dbe193f11c 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -524,6 +524,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::SetAttr: return analyzeSetAttr(node); case prim::profile_optional: + case prim::profile_ivalue: case prim::profile: makePointerTo(node->output(), node->inputs().at(0)); return; diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index ceb0fd1dbfcf..4714a6ae12f6 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -2053,6 +2053,16 @@ Node* ProfileOptionalOp::allocNewInstance(Graph* g) { return new ProfileOptionalOp(g, {nullptr}); } +void ProfileIValueOp::cloneFrom(Node* other_) { + Node::cloneFrom(other_); + auto other = other_->cast(); + this->callback_ = other->getCallback(); +} + +Node* ProfileIValueOp::allocNewInstance(Graph* g) { + return new ProfileIValueOp(g, {nullptr}); +} + TypePtr NamedValue::type() const { if (value_) { return value_->type(); @@ -2063,6 +2073,7 @@ TypePtr NamedValue::type() const { const Symbol ProfileOp::Kind = ::c10::prim::profile; const Symbol ProfileOptionalOp::Kind = ::c10::prim::profile_optional; +const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue; OperatorSet::OperatorSet(std::initializer_list sig_literals) { for (const char* sig : sig_literals) { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 64c8031bd601..b20d5611c55c 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -440,7 +440,7 @@ struct TORCH_API Node { // instructions lowered by the interpreter and not run in the optimized graph bool notExecutedOp() const { return kind_ == prim::Constant || kind_ == prim::profile || - kind_ == prim::profile_optional; + kind_ == prim::profile_optional || kind_ == prim::profile_ivalue; } // Graphs @@ -1368,6 +1368,28 @@ struct TORCH_API ProfileOptionalOp : public Node { std::function&)> callback_; }; +struct TORCH_API ProfileIValueOp : public Node { + static const Symbol Kind; + ProfileIValueOp( + Graph* graph, + std::function&)> callback) + : Node(graph, ::c10::prim::profile_ivalue), callback_(callback) {} + + void cloneFrom(Node* other_) override; + Node* allocNewInstance(Graph* g) override; + + const std::function&)>& getCallback() const { + return callback_; + } + + void setCallback(std::function&)> callback) { + callback_ = callback; + } + + private: + std::function&)> callback_; +}; + // execute a Python function, used for Ops we can't optimize but that we want to // optimize around // diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index ef0f2dae9e0e..4802fd2efafa 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -791,6 +791,9 @@ struct CodeImpl { } else if (node->cast()) { profile_function_table_.push_back( node->cast()->getCallback()); + } else if (node->cast()) { + profile_function_table_.push_back( + node->cast()->getCallback()); } else { TORCH_INTERNAL_ASSERT(false); } @@ -945,6 +948,7 @@ struct CodeImpl { case prim::BailOut: emitBailOut(node); break; + case prim::profile_ivalue: case prim::profile_optional: case prim::profile: emitProfile(node); diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 0756d6b58e9f..b9e0f5fbd3fe 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -245,6 +245,7 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::Store, // used in interpreter only prim::profile, // used in interpreter only prim::profile_optional, // used in interpreter only + prim::profile_ivalue, // used in interpreter only prim::TypeCheck, // used in interpreter only prim::FallbackGraph, // converted into prim::CallFunction @@ -303,6 +304,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::SetAttr, prim::profile, prim::profile_optional, + prim::profile_ivalue, prim::TypeCheck, prim::Print, prim::CallFunction, diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index b63a2a228508..8361fb3b3385 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -29,7 +29,15 @@ RegisterOperators reg( {Operator( prim::profile, [](const Node* node) -> Operation { - auto callback = node->cast()->getCallback(); + return [](Stack* stack) { + AT_ERROR( + "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT + }; + }, + aliasAnalysisSpecialCase()), + Operator( + prim::profile_ivalue, + [](const Node* node) -> Operation { return [](Stack* stack) { AT_ERROR( "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT