Skip to content

Commit

Permalink
Fix rFactor when there are indirect root domain(s), and refactor (#1723)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed May 24, 2022
1 parent 7093e39 commit d6d6b7d
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 233 deletions.
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/disjoint_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ class VectorOfUniqueEntries {
return vector_.empty();
}

// Returns the number of elements in this container
size_t size() const {
return vector_.size();
}

// Returns if entry is in this vector
bool has(T entry) const {
return set_.find(entry) != set_.end();
Expand Down
52 changes: 4 additions & 48 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,8 +1081,7 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
outer->container()->zeroVal(),
merged_id_size->as<Int>(),
outer->getParallelType(),
itype,
outer->isRFactorProduct() || inner->isRFactorProduct());
itype);

IrBuilder::create<Merge>(outer->container(), merged_id, outer, inner);

Expand Down Expand Up @@ -1135,17 +1134,15 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->container()->zeroVal(),
inner_split ? remainder->as<Int>() : factor,
in->getParallelType(),
in->getIterType(),
in->isRFactorProduct());
in->getIterType());

// inner loop IterDomain
IterDomain* idi = IrBuilder::create<IterDomain>(
in->container(),
in->container()->zeroVal(),
inner_split ? factor : remainder->as<Int>(),
in->getParallelType(),
in->getIterType(),
in->isRFactorProduct());
in->getIterType());

IrBuilder::create<Split>(
in->container(),
Expand Down Expand Up @@ -1766,48 +1763,7 @@ TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) {
// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
const std::vector<int>& axes_) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim domain");

std::vector<int> axes(axes_.size());

auto ndims = nDims();
std::transform(axes_.begin(), axes_.end(), axes.begin(), [ndims](int i) {
return i < 0 ? i + ndims : i;
});

TORCH_CHECK(
std::none_of(
axes.begin(),
axes.end(),
[ndims](int i) { return i < 0 || (unsigned int)i >= ndims; }),
"RFactor axes less than 0 or >= ndims.");

// We might be able to lift this constraint in some instances, but needs more
// investigation.
TORCH_CHECK(
!hasRFactor(), "Cannot call rfactor on the same tensor domain twice.");

std::unordered_set<int> axes_set(axes.begin(), axes.end());

bool rfactor_found = false;
bool reduction_found = false;
for (decltype(nDims()) i{0}; i < nDims(); i++) {
if (axis(i)->isReduction()) {
if (axes_set.find(i) != axes_set.end()) {
rfactor_found = true;
} else {
reduction_found = true;
}
}
}

TORCH_CHECK(
rfactor_found && reduction_found,
"Invalid rfactor found, rfactor must be provided at least one reduction axis, but not all reduction axes.");

return std::pair<TensorDomain*, TensorDomain*>{
TransformRFactor::runReplay(this, axes),
TransformRFactor::runReplay2(this, axes)};
return TransformRFactor::runReplay(this, axes_);
}

Split::Split(
Expand Down
34 changes: 34 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19824,6 +19824,40 @@ TEST_F(NVFuserTest, FusionRfactorPredication2_CUDA) {
&fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionRfactorIndirectRoot_CUDA) {
// https://github.com/csarofeen/pytorch/issues/1692
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(3);
fusion.addInput(tv0);

auto tv1 = sum(tv0, {1, 2});
fusion.addOutput(tv1);

tv1->split(2, 4);
tv1->split(1, 3);
tv1->merge(2, 3);
auto rf = tv1->rFactor({-1});

tv1->split(0, 256);
tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);
rf->computeAt(tv1, -1);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(0);

auto at_in = at::randn({6, 6, 6}, options);
auto at_out = at_in.sum({1, 2});

FusionExecutor fe;
fe.compileFusion(&fusion, {at_in});
auto cg_outputs = fe.runFusion({at_in});

testValidate(&fusion, cg_outputs, {at_in}, {at_out}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionNonDivisibleSplit1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
Loading

0 comments on commit d6d6b7d

Please sign in to comment.