Skip to content

Commit

Permalink
Validate LOOP concrete IDs have complete IterDomains (#1676)
Browse files Browse the repository at this point in the history
* Validate LOOP concrete IDs have complete IterDomains

* cleanup

* Add some comments.

Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>
  • Loading branch information
naoyam and csarofeen committed May 17, 2022
1 parent c931eda commit 18bee67
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
164 changes: 164 additions & 0 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,101 @@ bool ComputeAtMap::areMapped(
return disjointSetOf(id0, mode)->has(id1);
}

namespace {

// Validate a LOOP concrete ID has the complete ID set required for
// indexing. See issue #1655 and FusionIncompleteConcreteID for an
// example fusion that fails with this validation. Fixing this issue
// would require creating a reference IterDomain with all the
// necessary root ID for for loop extent generation, for indexing, and for
// predication.
//
// root_ids_of_all_ids and root_ids_of_concrete_id consist of EXACT
// concrete IDs.
void validateCompletenessOfLoopConcreteID(
IterDomain* concrete_id,
const ComputeAtMap& ca_map,
const TrivialReductionInfo& trivial_reduction_info,
// All root id's of all IDs in the disjoint id set
const std::unordered_set<IterDomain*>& root_ids_of_all_ids,
// Map from a root id to the concrete id's it's represented in
const std::unordered_set<IterDomain*>& root_ids_of_concrete_id,
const std::unordered_map<IterDomain*, std::vector<IterDomain*>>&
root_id_to_maybe_concrete_ids,
// Disjoint set just for printing
const std::vector<IterDomain*>& id_set,
// All the candidate concrete IDs found for this disjoint id set
const std::vector<IterDomain*>& maybe_concrete_ids) {
std::vector<IterDomain*> root_ids_not_found_with_concrete_id;

for (auto root_id : root_ids_of_all_ids) {
if (root_ids_of_concrete_id.find(root_id) !=
root_ids_of_concrete_id.end()) {
continue;
}

// None of the root IDs of the conrete ID is exactly mapped with
// root_id.

// It is still a valid concrete ID if it has a non-broadcast
// root ID that is mapped with root_id.
if ((root_id->isBroadcast() || trivial_reduction_info.isDerived(root_id)) &&
std::any_of(
root_ids_of_concrete_id.begin(),
root_ids_of_concrete_id.end(),
[&](auto root_id_of_concrete_id) {
return !root_id_of_concrete_id->isBroadcast() &&
!trivial_reduction_info.isDerived(root_id_of_concrete_id) &&
ca_map.areMapped(
root_id,
root_id_of_concrete_id,
IdMappingMode::PERMISSIVE);
})) {
continue;
}

// If all of the corresponding maybe-concrete IDs are exactly
// mapped with the concrete ID, this missing root_id is not a
// problem. This can happen with reduction rfactor, e.g.,
// FusionAdvancedLowering1.
if (std::all_of(
root_id_to_maybe_concrete_ids.at(root_id).begin(),
root_id_to_maybe_concrete_ids.at(root_id).end(),
[&](auto maybe_concrete_id) {
return ca_map.areMapped(
concrete_id, maybe_concrete_id, IdMappingMode::EXACT);
})) {
continue;
}

root_ids_not_found_with_concrete_id.push_back(root_id);
}

if (root_ids_not_found_with_concrete_id.empty()) {
return;
}

// Error detected as some root IDs are not accounted for by the
// concrete ID.
std::stringstream error_msg;
error_msg << "IDs: " << ir_utils::toString(id_set);
error_msg << ", concrete ID: " << concrete_id->toString();
error_msg << ", maybe concrete IDs: "
<< ir_utils::toString(maybe_concrete_ids);
error_msg << ", all root IDs:";
for (auto root_id : root_ids_of_all_ids) {
error_msg << " " << root_id->toString();
}
error_msg << ", root IDs not found with concrete ID: ";
for (auto id : root_ids_not_found_with_concrete_id) {
error_msg << " " << id->toString();
}
TORCH_INTERNAL_ASSERT(
false, "Concrete ID failed to cover all root IDs. ", error_msg.str());
}

} // namespace

IterDomain* ComputeAtMap::computeConcreteId(
IterDomain* id,
IdMappingMode mode) {
Expand All @@ -275,9 +370,13 @@ IterDomain* ComputeAtMap::computeConcreteId(
id->toString());

if (disjoint_set_shared_ptr->vector().size() == 1) {
// If only one entry in the disjoint set, by definition the existing ID has
// to be the concrete ID.
return disjoint_set_shared_ptr->vector().front();
}

// Grab a set of candidate concrete_ids, we track towards the consumers in the
// ID group as one of those is guaranteed to be a valid concrete id.
VectorOfUniqueEntries<IterDomain*> maybe_concrete_ids;
for (auto id : disjoint_set_shared_ptr->vector()) {
bool id_output = true;
Expand All @@ -292,6 +391,8 @@ IterDomain* ComputeAtMap::computeConcreteId(
}
}

// Shouldn't ever happen, it would mean there's an error somewhere in the
// graph.
TORCH_INTERNAL_ASSERT(
maybe_concrete_ids.vector().size(),
"No potential concrete_id's found for ",
Expand All @@ -301,10 +402,27 @@ IterDomain* ComputeAtMap::computeConcreteId(
return maybe_concrete_ids.vector().front();
}

// The concrete_id should have the most roots it can trace back to that are
// iter domains, (non-broadcast/non-reduction). We don't trace back through
// view operations, so the one with the most iter root domains is the concrete
// ID.
IterDomain* concrete_id = nullptr;
int max_iter_root_count = 0;
int max_bcast_root_count = 0;

// For the LOOP map, the concrete ID must account for all root IDs
// of all of the IDs in each disjoit set. At least those ID's that are
// non-broadcast/non-reduction. As broadcast is only important here if it's
// concretized in the set. Track information so we can later make sure the
// concrete id has accounted for all iter domains meaning it has a correct
// loop size.
std::unordered_set<IterDomain*> root_ids_of_all_ids;
std::unordered_set<IterDomain*> root_ids_of_concrete_id;
std::unordered_map<IterDomain*, std::vector<IterDomain*>>
root_id_to_maybe_concrete_ids;

// Populate the above information, look for the concrete id, validate the loop
// concrete ID.
for (auto maybe_concrete_id : maybe_concrete_ids.vector()) {
std::unordered_set<IterDomain*> root_ids;
std::deque<IterDomain*> to_visit;
Expand All @@ -330,6 +448,20 @@ IterDomain* ComputeAtMap::computeConcreteId(
}
}

if (mode == IdMappingMode::LOOP) {
std::transform(
root_ids.begin(),
root_ids.end(),
std::inserter(root_ids_of_all_ids, root_ids_of_all_ids.end()),
[&](const auto root_id) {
auto exact_concrete_id =
getConcreteMappedID(root_id, IdMappingMode::EXACT);
root_id_to_maybe_concrete_ids[exact_concrete_id].push_back(
maybe_concrete_id);
return exact_concrete_id;
});
}

int bcast_root_count = std::count_if(
root_ids.begin(), root_ids.end(), [&](IterDomain* root_id) {
return root_id->isBroadcast()
Expand All @@ -344,12 +476,44 @@ IterDomain* ComputeAtMap::computeConcreteId(
max_iter_root_count = iter_root_count;
max_bcast_root_count = bcast_root_count;
concrete_id = maybe_concrete_id;

// If we update the concrete_id, then update the root_ids_of_concrete_id
// to reflect this id
if (mode == IdMappingMode::LOOP) {
root_ids_of_concrete_id.clear();
std::transform(
root_ids.begin(),
root_ids.end(),
std::inserter(
root_ids_of_concrete_id, root_ids_of_concrete_id.end()),
[&](const auto root_id) {
return getConcreteMappedID(root_id, IdMappingMode::EXACT);
});
}
}
} // end maybe_concrete_id

TORCH_INTERNAL_ASSERT(
concrete_id != nullptr,
"Something went wrong, could not find a concrete id.");

if (mode == IdMappingMode::LOOP) {
// Validate the concrete id has influence from all the roots of all the
// consumers that will map to this concete id in the loop map. This means
// all the consumers in all expressions of the loop nest generated based on
// this concrete ID will have their roots mapping to this concrete ID
// represented in the extent of this concrete id.
validateCompletenessOfLoopConcreteID(
concrete_id,
*this,
trivial_reduction_info_,
root_ids_of_all_ids,
root_ids_of_concrete_id,
root_id_to_maybe_concrete_ids,
disjoint_set_shared_ptr->vector(),
maybe_concrete_ids.vector());
}

return concrete_id;
}

Expand Down
35 changes: 35 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22700,6 +22700,41 @@ TEST_F(NVFuserMultithreadedTest, MultipleFunctions_CUDA) {
}
}

// Repro of issue #1655
TEST_F(NVFuserTest, FusionIncompleteConcreteID_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);
auto tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);
auto tv2 = makeSymbolicTensor(2);
fusion.addInput(tv2);

auto tv3 = broadcast(tv0, {true, true, false});
auto tv4 = broadcast(tv1, {false, true, false});
auto tv5 = broadcast(tv2, {true, false, false});

auto tv6 = add(tv3, tv4);
auto tv7 = add(tv3, tv5);

fusion.addOutput(tv6);
fusion.addOutput(tv7);

tv6->merge(0);
tv6->merge(0);

TransformPropagator::from(tv6);

tv0->computeAt(tv6, -1, ComputeAtMode::MostInlined);
tv1->computeAt(tv6, -1, ComputeAtMode::MostInlined);
tv2->computeAt(tv7, -1, ComputeAtMode::MostInlined);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
ASSERT_ANY_THROW(fusion.printKernel());
}

TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Expand Down

0 comments on commit 18bee67

Please sign in to comment.