Skip to content

Commit

Permalink
adding profile_ivalue (#47666)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47666

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D25255573

Pulled By: Krovatkin

fbshipit-source-id: 5d8753e4040a3d96105d28d26728125947c7a638
  • Loading branch information
jjsjann123 authored and facebook-github-bot committed Dec 9, 2020
1 parent f431e47 commit a6fa3b2
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 6 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -139,6 +139,7 @@ namespace c10 {
_(prim, HasAttr) \
_(prim, profile) \
_(prim, profile_optional) \
_(prim, profile_ivalue) \
_(prim, AddStatValue) \
_(prim, TimePoint) \
_(prim, CallFunction) \
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/interface.cpp
Expand Up @@ -36,9 +36,9 @@ void runFusionGroup(const Node* fusion_node, Stack& stack) {

void fuseGraph(std::shared_ptr<Graph>& graph) {
TORCH_CHECK(
getFuserInterface()->fn_fuse_graph != nullptr,
getFuserInterface()->fn_fuse_graph_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_fuse_graph(graph);
getFuserInterface()->fn_fuse_graph_(graph);
}

bool canFuseNode(const Node* node) {
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/interface.h
Expand Up @@ -2,6 +2,7 @@

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/profiling_record.h>

/*
* This file contains APIs for cuda fuser;
Expand All @@ -22,7 +23,7 @@ TORCH_API std::atomic<bool>& getCudaFusionGuardMode();
struct CudaFuserInterface {
void (*fn_compile_n_)(Node*) = nullptr;
void (*fn_run_n_s_)(const Node*, Stack&) = nullptr;
void (*fn_fuse_graph)(std::shared_ptr<Graph>&) = nullptr;
void (*fn_fuse_graph_)(std::shared_ptr<Graph>&) = nullptr;
bool (*fn_can_fuse_n_)(const Node*) = nullptr;
};

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/register_interface.cpp
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <torch/csrc/jit/codegen/cuda/manager.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/codegen/cuda/partition.h>

#include <torch/csrc/jit/runtime/profiling_record.h>
Expand All @@ -20,7 +21,7 @@ class RegisterInterface {
auto ptr = getFuserInterface();
ptr->fn_compile_n_ = &compileCudaFusionGroup;
ptr->fn_run_n_s_ = &runCudaFusionGroup;
ptr->fn_fuse_graph = &CudaFuseGraph;
ptr->fn_fuse_graph_ = &CudaFuseGraph;
ptr->fn_can_fuse_n_ = &isFusableCudaFusionGroup;

RegisterProfilingNode(canFuseNode);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/ir/alias_analysis.cpp
Expand Up @@ -524,6 +524,7 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::SetAttr:
return analyzeSetAttr(node);
case prim::profile_optional:
case prim::profile_ivalue:
case prim::profile:
makePointerTo(node->output(), node->inputs().at(0));
return;
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/ir/ir.cpp
Expand Up @@ -2053,6 +2053,16 @@ Node* ProfileOptionalOp::allocNewInstance(Graph* g) {
return new ProfileOptionalOp(g, {nullptr});
}

void ProfileIValueOp::cloneFrom(Node* other_) {
Node::cloneFrom(other_);
auto other = other_->cast<ProfileIValueOp>();
this->callback_ = other->getCallback();
}

Node* ProfileIValueOp::allocNewInstance(Graph* g) {
return new ProfileIValueOp(g, {nullptr});
}

TypePtr NamedValue::type() const {
if (value_) {
return value_->type();
Expand All @@ -2063,6 +2073,7 @@ TypePtr NamedValue::type() const {

const Symbol ProfileOp::Kind = ::c10::prim::profile;
const Symbol ProfileOptionalOp::Kind = ::c10::prim::profile_optional;
const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue;

OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
for (const char* sig : sig_literals) {
Expand Down
24 changes: 23 additions & 1 deletion torch/csrc/jit/ir/ir.h
Expand Up @@ -440,7 +440,7 @@ struct TORCH_API Node {
// instructions lowered by the interpreter and not run in the optimized graph
bool notExecutedOp() const {
return kind_ == prim::Constant || kind_ == prim::profile ||
kind_ == prim::profile_optional;
kind_ == prim::profile_optional || kind_ == prim::profile_ivalue;
}

// Graphs
Expand Down Expand Up @@ -1368,6 +1368,28 @@ struct TORCH_API ProfileOptionalOp : public Node {
std::function<void(std::vector<IValue>&)> callback_;
};

struct TORCH_API ProfileIValueOp : public Node {
static const Symbol Kind;
ProfileIValueOp(
Graph* graph,
std::function<void(std::vector<IValue>&)> callback)
: Node(graph, ::c10::prim::profile_ivalue), callback_(callback) {}

void cloneFrom(Node* other_) override;
Node* allocNewInstance(Graph* g) override;

const std::function<void(std::vector<IValue>&)>& getCallback() const {
return callback_;
}

void setCallback(std::function<void(std::vector<IValue>&)> callback) {
callback_ = callback;
}

private:
std::function<void(std::vector<IValue>&)> callback_;
};

// execute a Python function, used for Ops we can't optimize but that we want to
// optimize around
//
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/runtime/interpreter.cpp
Expand Up @@ -791,6 +791,9 @@ struct CodeImpl {
} else if (node->cast<ProfileOptionalOp>()) {
profile_function_table_.push_back(
node->cast<ProfileOptionalOp>()->getCallback());
} else if (node->cast<ProfileIValueOp>()) {
profile_function_table_.push_back(
node->cast<ProfileIValueOp>()->getCallback());
} else {
TORCH_INTERNAL_ASSERT(false);
}
Expand Down Expand Up @@ -945,6 +948,7 @@ struct CodeImpl {
case prim::BailOut:
emitBailOut(node);
break;
case prim::profile_ivalue:
case prim::profile_optional:
case prim::profile:
emitProfile(node);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/operator.cpp
Expand Up @@ -245,6 +245,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
prim::Store, // used in interpreter only
prim::profile, // used in interpreter only
prim::profile_optional, // used in interpreter only
prim::profile_ivalue, // used in interpreter only
prim::TypeCheck, // used in interpreter only
prim::FallbackGraph, // converted into prim::CallFunction

Expand Down Expand Up @@ -303,6 +304,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
prim::SetAttr,
prim::profile,
prim::profile_optional,
prim::profile_ivalue,
prim::TypeCheck,
prim::Print,
prim::CallFunction,
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Expand Up @@ -29,7 +29,15 @@ RegisterOperators reg(
{Operator(
prim::profile,
[](const Node* node) -> Operation {
auto callback = node->cast<ProfileOp>()->getCallback();
return [](Stack* stack) {
AT_ERROR(
"Must be lowered to Interpreter's PROFILE instruction"); // NOLINT
};
},
aliasAnalysisSpecialCase()),
Operator(
prim::profile_ivalue,
[](const Node* node) -> Operation {
return [](Stack* stack) {
AT_ERROR(
"Must be lowered to Interpreter's PROFILE instruction"); // NOLINT
Expand Down

0 comments on commit a6fa3b2

Please sign in to comment.