Skip to content

Commit

Permalink
Separate root domain and rfactor domain in TransformPrinter (#1716)
Browse files Browse the repository at this point in the history
I feel that this could be helpful for debugging. Using the example that breaks our system in csarofeen#1692

Before
```
TransformPrinter : 
T0_g[ iS0{i1}, iS1{i2} ]
 root domain : (iS0{i1},iS1{i2})
T2_l[ iS11{( ceilDiv(i1, 3) )}, iS15{( 3 * ( ceilDiv(i2, 4) ) )}rf, rS14{4}rf ]
 root domain : (iS9{i1},iS15{( 3 * ( ceilDiv(i2, 4) ) )}rf,rS14{4}rf)
  Split: iS9{i1} by factor 3 -> iS11{( ceilDiv(i1, 3) )}, iS12{3}, start offset: 0, stop offset: 0
T1_g[ rS17{( ceilDiv(i1, 3) )}, rS19{( 3 * ( ceilDiv(i2, 4) ) )} ]
 root domain : (rS16{i1},rS19{( 3 * ( ceilDiv(i2, 4) ) )})
  Split: rS16{i1} by factor 3 -> rS17{( ceilDiv(i1, 3) )}, rS18{3}, start offset: 0, stop offset: 0
}
```

After:
```
TransformPrinter : 
T0_g[ iS0{i1}, iS1{i2} ]
 root domain : (iS0{i1},iS1{i2})
T2_l[ iS11{( ceilDiv(i1, 3) )}, iS15{( 3 * ( ceilDiv(i2, 4) ) )}rf, rS14{4}rf ]
 root domain : (iS9{i1},rS10{i2}rf)
  Split: iS9{i1} by factor 3 -> iS11{( ceilDiv(i1, 3) )}, iS12{3}, start offset: 0, stop offset: 0
  Split: rS10{i2}rf by factor 4 -> iS13{( ceilDiv(i2, 4) )}rf, rS14{4}rf, start offset: 0, stop offset: 0
  Merge: iS12{3} and iS13{( ceilDiv(i2, 4) )}rf -> iS15{( 3 * ( ceilDiv(i2, 4) ) )}rf
 rfactor domain : (iS9{i1},iS15{( 3 * ( ceilDiv(i2, 4) ) )}rf,rS14{4}rf)
  Split: iS9{i1} by factor 3 -> iS11{( ceilDiv(i1, 3) )}, iS12{3}, start offset: 0, stop offset: 0
T1_g[ rS17{( ceilDiv(i1, 3) )}, rS19{( 3 * ( ceilDiv(i2, 4) ) )} ]
 root domain : (rS16{i1},rS19{( 3 * ( ceilDiv(i2, 4) ) )})
  Split: rS16{i1} by factor 3 -> rS17{( ceilDiv(i1, 3) )}, rS18{3}, start offset: 0, stop offset: 0
}
```
  • Loading branch information
zasdfgbnm committed May 19, 2022
1 parent f68b830 commit 3675c70
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,7 @@ void IrTransformPrinter::handle(Fusion* f) {
}

void IrTransformPrinter::printTransforms(TensorView* tv) {
auto root_domain = tv->getMaybeRFactorDomain();
auto all_exp = DependencyCheck::getAllExprsBetween(
{root_domain.begin(), root_domain.end()},
{tv->domain()->domain().begin(), tv->domain()->domain().end()});

auto root_domain = tv->domain()->getRootDomain();
os() << " root domain : (";
for (const auto root_idx : c10::irange(root_domain.size())) {
IrPrinter::handle(root_domain[root_idx]);
Expand All @@ -781,6 +777,33 @@ void IrTransformPrinter::printTransforms(TensorView* tv) {
}
os() << ")\n";

if (tv->hasRFactor()) {
auto rfactor_domain = tv->domain()->getRFactorDomain();

auto all_exp = DependencyCheck::getAllExprsBetween(
{root_domain.begin(), root_domain.end()},
{rfactor_domain.begin(), rfactor_domain.end()});

for (auto exp : all_exp) {
os() << " ";
IrPrinter::handle(exp);
}

os() << " rfactor domain : (";
for (const auto root_idx : c10::irange(rfactor_domain.size())) {
IrPrinter::handle(rfactor_domain[root_idx]);
if (root_idx + 1 < rfactor_domain.size()) {
os() << ",";
}
}
os() << ")\n";
}

auto from = tv->getMaybeRFactorDomain();
auto all_exp = DependencyCheck::getAllExprsBetween(
{from.begin(), from.end()},
{tv->domain()->domain().begin(), tv->domain()->domain().end()});

for (auto exp : all_exp) {
os() << " ";
IrPrinter::handle(exp);
Expand Down

0 comments on commit 3675c70

Please sign in to comment.