Skip to content

Commit

Permalink
Revert D29647586: [jit] Renamed prim::Concat as prim::VarConcat
Browse files Browse the repository at this point in the history
Test Plan: revert-hammer

Differential Revision:
D29647586 (db11619)

Original commit changeset: cdd34ea5a3c9

fbshipit-source-id: bab5ac4ed67a00ac151fe39463aa3fb56897d7f4
  • Loading branch information
VitalyFedyunin authored and facebook-github-bot committed Jul 21, 2021
1 parent 48af9de commit 33db828
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 69 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/interned_strings.h
Expand Up @@ -26,6 +26,7 @@ namespace c10 {
_(prim, BroadcastingChunk) \
_(prim, BroadcastSizes) \
_(prim, ReductionSizes) \
_(prim, Concat) \
_(prim, Constant) \
_(prim, ChunkSizes) \
_(prim, ConstantMKLDNNTensor) \
Expand Down Expand Up @@ -83,7 +84,6 @@ namespace c10 {
_(prim, StringIndex) \
_(prim, NumToTensor) \
_(prim, Uninitialized) \
_(prim, VarConcat) \
_(prim, With) \
_(prim, Enter) \
_(prim, Exit) \
Expand Down
108 changes: 54 additions & 54 deletions test/cpp/jit/test_concat_opt.cpp
Expand Up @@ -49,8 +49,8 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
%1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
%2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
%5 : int = prim::Constant[value=0]()
%concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
%concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
%concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %5)
%concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %2, %5)
%res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
return (%res)
)IR";
Expand All @@ -71,14 +71,14 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
// %1 : ...,
// %2 : ...):
// %3 : int = prim::Constant[value=0]()
// %4 : Tensor = prim::VarConcat(%0, %1, %3)
// %7 : Tensor = prim::VarConcat(%4, %2, %3) // UPDATED
// %4 : Tensor = prim::Concat(%0, %1, %3)
// %7 : Tensor = prim::Concat(%4, %2, %3) // UPDATED
// %8 : Tensor[] = prim::ListConstruct(%4, %7)
// return (%8)

testing::FileCheck()
.check_count("= prim::VarConcat(%0, %1, %3)", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(%4, %2, %3)", 1, /*exactly*/ true)
.check_count("= prim::Concat(%0, %1, %3)", 1, /*exactly*/ true)
->check_count("= prim::Concat(%4, %2, %3)", 1, /*exactly*/ true)
->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
Expand All @@ -94,8 +94,8 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) {
%1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
%2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
%5 : int = prim::Constant[value=0]()
%concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %2, %5)
%concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
%concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%1, %2, %5)
%concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %2, %5)
%res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
return (%res)
)IR";
Expand All @@ -116,14 +116,14 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) {
// %1 : ...,
// %2 : ...):
// %3 : int = prim::Constant[value=0]()
// %4 : Tensor = prim::VarConcat(%1, %2, %3)
// %7 : Tensor = prim::VarConcat(%0, %4, %3) // UPDATED
// %4 : Tensor = prim::Concat(%1, %2, %3)
// %7 : Tensor = prim::Concat(%0, %4, %3) // UPDATED
// %8 : Tensor[] = prim::ListConstruct(%4, %7)
// return (%8)

testing::FileCheck()
.check_count("= prim::VarConcat(%1, %2, %3)", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(%0, %4, %3)", 1, /*exactly*/ true)
.check_count("= prim::Concat(%1, %2, %3)", 1, /*exactly*/ true)
->check_count("= prim::Concat(%0, %4, %3)", 1, /*exactly*/ true)
->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
Expand All @@ -140,11 +140,11 @@ TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) {
%2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
%5 : int = prim::Constant[value=0]()
#CHECK: prim::VarConcat
%concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
#CHECK: prim::Concat
%concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %5)
#CHECK: prim::VarConcat
%concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %0, %2, %5)
#CHECK: prim::Concat
%concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%1, %0, %2, %5)
#CHECK: prim::ListConstruct
%res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2)
Expand Down Expand Up @@ -179,10 +179,10 @@ TEST(ConcatOptTest, MoreCommonInputsElimination) {
%3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
%4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
%5 : int = prim::Constant[value=0]()
%concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
%concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
%concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %5)
%concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %4, %5)
%concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %5)
%concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %2, %5)
%concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %2, %3, %5)
%concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::Concat(%0, %1, %2, %3, %4, %5)
%res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2, %concat.3, %concat.4)
return (%res)
)IR";
Expand All @@ -201,10 +201,10 @@ TEST(ConcatOptTest, MoreCommonInputsElimination) {
checkOutputs(orig_outputs, opt_outputs);

testing::FileCheck()
.check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(%6, %2, %5)", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(%11, %3, %5)", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(%12, %4, %5)", 1, /*exactly*/ true)
.check_count("= prim::Concat(%0, %1, %5)", 1, /*exactly*/ true)
->check_count("= prim::Concat(%6, %2, %5)", 1, /*exactly*/ true)
->check_count("= prim::Concat(%11, %3, %5)", 1, /*exactly*/ true)
->check_count("= prim::Concat(%12, %4, %5)", 1, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->run(*graph);
}
Expand Down Expand Up @@ -363,22 +363,22 @@ TEST(ConcatOptTest, UseVariadicCat) {

checkOutputs(orig_outputs, opt_outputs);

// After replacing `aten::cat` with `prim::VarConcat` we should have the
// After replacing `aten::cat` with `prim::Concat` we should have the
// following graph:
//
// graph(%0 : ...,
// %1 : ...):
// %zero : int = prim:Constant[value=0]()
// %varcat : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %5, %zero)
// %varcat : Tensor = prim::Concat(%0, %1, %2, %3, %4, %5, %zero)
// return (%varcat)
testing::FileCheck()
.check_count("= prim::VarConcat(", 1, /*exactly*/ true)
.check_count("= prim::Concat(", 1, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->run(*graph);
}

TEST(ConcatOptTest, UseVariadicCatReplaceMultiple) {
TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) {
auto graph = std::make_shared<Graph>();

const std::string input =
Expand Down Expand Up @@ -415,11 +415,11 @@ TEST(ConcatOptTest, UseVariadicCatReplaceMultiple) {
// %2 : ...,
// %3 : ....):
// %zero : int = prim:Constant[value=0]()
// %varcat1 : Tensor = prim::VarConcat(%0, %1, %zero)
// %varcat2 : Tensor = prim::VarConcat(%2, %3, %zero)
// %varcat1 : Tensor = prim::Concat(%0, %1, %zero)
// %varcat2 : Tensor = prim::Concat(%2, %3, %zero)
// return (%varcat1, %varcat2)
testing::FileCheck()
.check_count("= prim::VarConcat(", 2, /*exactly*/ true)
.check_count("= prim::Concat(", 2, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->run(*graph);
Expand Down Expand Up @@ -448,18 +448,18 @@ TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) {

checkOutputs(orig_outputs, opt_outputs);

// After replacing `aten::cat` with `prim::VarConcat` we should have the
// After replacing `aten::cat` with `prim::Concat` we should have the
// following graph:
//
// graph(%0 : ...,
// %1 : ...):
// %zero : int = prim:Constant[value=0]()
// %input : Tensor[] = prim::ListConstruct(%0, %1)
// %varcat : Tensor = prim::VarConcat(%0, %1, %zero)
// %varcat : Tensor = prim::Concat(%0, %1, %zero)
// return (%varcat, %input)
testing::FileCheck()
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
->check_count("= prim::Concat(", 1, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->run(*graph);
}
Expand Down Expand Up @@ -491,20 +491,20 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) {
checkOutputs(orig_outputs, opt_outputs);

// The input list to `aten::cat` is mutated only after `aten::cat` op. So,
// it should have been replaced with `prim::VarConcat`. The transformed
// graph should look like the following:
// it should have been replaced with `prim::Concat`. The transformed graph
// should look like the following:
//
// graph(%0 : ...,
// %1 : ...,
// %2 : ...):
// %3 : int = prim:Constant[value=0]()
// %4 : Tensor[] = prim::ListConstruct(%0, %1)
// %7 : Tensor = prim::VarConcat(%0, %1, %3)
// %7 : Tensor = prim::Concat(%0, %1, %3)
// %6 : Tensor = aten::append(%4, %2)
// return (%7, %4)
testing::FileCheck()
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
->check_count("= prim::Concat(", 1, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->run(*graph);
}
Expand Down Expand Up @@ -541,7 +541,7 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) {
testing::FileCheck()
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
->check_count("= aten::cat(", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(", 0, /*exactly*/ true)
->check_count("= prim::Concat(", 0, /*exactly*/ true)
->run(*graph);
}

Expand All @@ -552,17 +552,17 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) {
checkOutputs(orig_outputs, opt_outputs);

// The mutation of the list must be removed and the `aten::cat` op must
// be replaced with the `prim::VarConcat` op in the graph. The
// transformed graph should look like the following:
// be replaced with the `prim::Concat` op in the graph. The transformed
// graph should look like the following:
//
// graph(%0 : ...,
// %1 : ...,
// %2 : ...):
// %3 : int = prim:Constant[value=0]()
// %7 : Tensor = prim::VarConcat(%0, %1, %2, %3)
// %7 : Tensor = prim::Concat(%0, %1, %2, %3)
// return (%7)
testing::FileCheck()
.check_count("= prim::VarConcat(", 1, /*exactly*/ true)
.check_count("= prim::Concat(", 1, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->run(*graph);
Expand Down Expand Up @@ -605,22 +605,22 @@ TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) {
checkOutputs(orig_outputs, opt_outputs);

// All the mutations of the list must be removed and the `aten::cat` ops must
// be replaced with `prim::VarConcat` ops in the graph. The transformed
// graph should look like the following:
// be replaced with `prim::Concat` ops in the graph. The transformed graph
// should look like the following:
//
// graph(%0 : ...,
// %1 : ...,
// %2 : ...,
// %3 : ...,
// %4 : ...):
// %10 : int = prim:Constant[value=0]()
// %5 : Tensor = prim::VarConcat(%0, %1, %10)
// %6 : Tensor = prim::VarConcat(%0, %1, %2, %10)
// %7 : Tensor = prim::VarConcat(%0, %1, %2, %3, %10)
// %8 : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %10)
// %5 : Tensor = prim::Concat(%0, %1, %10)
// %6 : Tensor = prim::Concat(%0, %1, %2, %10)
// %7 : Tensor = prim::Concat(%0, %1, %2, %3, %10)
// %8 : Tensor = prim::Concat(%0, %1, %2, %3, %4, %10)
// return (%5, %6, %7, %8)
testing::FileCheck()
.check_count("= prim::VarConcat(", 4, /*exactly*/ true)
.check_count("= prim::Concat(", 4, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->run(*graph);
Expand Down Expand Up @@ -671,13 +671,13 @@ TEST(
// %1 : ...,
// %2 : ...):
// %3 : int = prim::Constant[value=0]()
// %10 : Tensor = prim::VarConcat(%0, %1, %2, %3)
// %12 : Tensor = prim::VarConcat(%10, %0, %3) // UPDATED
// %10 : Tensor = prim::Concat(%0, %1, %2, %3)
// %12 : Tensor = prim::Concat(%10, %0, %3) // UPDATED
// %8 : Tensor[] = prim::ListConstruct(%10, %12)
// return (%8)
testing::FileCheck()
.check_count("= prim::VarConcat(%0, %1, %2, %3)", 1, /*exactly*/ true)
->check_count("= prim::VarConcat(%10, %0, %3)", 1, /*exactly*/ true)
.check_count("= prim::Concat(%0, %1, %2, %3)", 1, /*exactly*/ true)
->check_count("= prim::Concat(%10, %0, %3)", 1, /*exactly*/ true)
->check_count("= prim::ListConstruct(%10, %12)", 1, /*exactly*/ true)
->check_count("= aten::cat(", 0, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
Expand Down
20 changes: 8 additions & 12 deletions torch/csrc/jit/passes/concat_opt.cpp
Expand Up @@ -46,7 +46,7 @@ class ConcatCommonInputsEliminator {
private:
void handleBlock(Block* block) {
for (auto node : block->nodes()) {
if (node->kind() == prim::VarConcat) {
if (node->kind() == prim::Concat) {
handleCat(node);
}
for (Block* block : node->blocks()) {
Expand Down Expand Up @@ -84,18 +84,16 @@ class ConcatCommonInputsEliminator {
// the previous cat ops.
//
// Example:
// %11 = prim::VarConcat(%0, %1, <dim>)
// %11 = prim::Concat(%0, %1, <dim>)
// ...
// %13 = prim::VarConcat(%0, %1, %2, <dim>) // first 2 inputs same
// // as %11
// %13 = prim::Concat(%0, %1, %2, <dim>) // first 2 inputs same as %11
// ...
// = %13 ... // Use %13
//
// After CSE opt:
// %11 = prim::VarConcat(%0, %1, <dim>)
// %11 = prim::Concat(%0, %1, <dim>)
// ...
// %14 = prim::VarConcat(%11, %2, <dim>) // Replace first 2 inputs
// // with %11
// %14 = prim::Concat(%11, %2, <dim>) // Replace first 2 inputs with %11
// ...
// = %14 ... // Replace use of %13 with %14

Expand All @@ -116,8 +114,7 @@ class ConcatCommonInputsEliminator {

std::vector<Value*> new_inputs = {
prev->output(), curr_tensor_inputs.back(), curr_dim};
auto new_concat =
node->owningGraph()->create(prim::VarConcat, new_inputs);
auto new_concat = node->owningGraph()->create(prim::Concat, new_inputs);
new_concat->output()->setType(node->output()->type());
concats_to_replace_[node] = new_concat;
return;
Expand Down Expand Up @@ -161,8 +158,7 @@ class ConcatCommonInputsEliminator {

std::vector<Value*> new_inputs = {
curr_tensor_inputs.front(), prev->output(), curr_dim};
auto new_concat =
node->owningGraph()->create(prim::VarConcat, new_inputs);
auto new_concat = node->owningGraph()->create(prim::Concat, new_inputs);
new_concat->output()->setType(node->output()->type());
concats_to_replace_[node] = new_concat;
return;
Expand Down Expand Up @@ -539,7 +535,7 @@ class VariadicCatUpdater {
}
std::vector<Value*> inputs = list->inputs().vec();
inputs.push_back(cat->input(1));
auto var_cat = cat->owningGraph()->create(prim::VarConcat, inputs);
auto var_cat = cat->owningGraph()->create(prim::Concat, inputs);
GRAPH_UPDATE("Adding\n", *var_cat);
var_cat->insertBefore(cat);
GRAPH_UPDATE("Replacing\n", *cat, "with\n", *var_cat);
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -756,9 +756,9 @@ RegisterOperators reg(
// This is an alternative to aten::cat op that takes variable number of
// parameters as input.
// Format:
// prim::VarConcat(Tensors..., dim) -> Tensor
// prim::Concat(Tensors..., dim) -> Tensor
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::VarConcat(...) -> Tensor"),
TORCH_SELECTIVE_SCHEMA("prim::Concat(...) -> Tensor"),
[](Stack* stack) {
auto num_inputs = pop(stack).toInt();
auto dim = pop(stack).toInt();
Expand Down

0 comments on commit 33db828

Please sign in to comment.