Skip to content

Commit

Permalink
More precise concretization analysis (#1719)
Browse files Browse the repository at this point in the history
* If a broadcast is concretized to multiple concrete domains but they are
exactly mapped, do not consider it's non-uniquely concretized
  • Loading branch information
naoyam committed May 23, 2022
1 parent f4d3630 commit 151d95b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 3 deletions.
21 changes: 20 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace fuser {
namespace cuda {

void ConcretizedBroadcastDomains::build(Fusion* fusion) {
exact_map_ = std::make_unique<ExactRootDomainMap>(fusion);

// Initialize the origin map with input broadcast domains
for (const auto fusion_input_tv :
ir_utils::filterByType<TensorView>(fusion->inputs())) {
Expand Down Expand Up @@ -116,7 +118,9 @@ void ConcretizedBroadcastDomains::markAsConcretized(
auto child = child_domains.front();
child_domains.pop_front();
auto& concrete_ids = broadcast_to_concrete_map_[child];
if (!concrete_ids.emplace(concrete_root_domain).second) {
auto inserted =
insertRootDomainToConcreteDomainSet(concrete_root_domain, concrete_ids);
if (!inserted) {
continue;
}
const auto& child_uses = child->uses();
Expand All @@ -129,6 +133,21 @@ void ConcretizedBroadcastDomains::markAsConcretized(
}
}

bool ConcretizedBroadcastDomains::insertRootDomainToConcreteDomainSet(
IterDomain* new_root_id,
std::unordered_set<IterDomain*>& id_set) {
auto has_exactly_mapped_id =
std::any_of(id_set.begin(), id_set.end(), [&](IterDomain* existing_id) {
return exact_map_->areMapped(new_root_id, existing_id);
});
if (has_exactly_mapped_id) {
return false;
} else {
id_set.emplace(new_root_id);
return true;
}
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
11 changes: 9 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#pragma once

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>

#include <c10/macros/Export.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -44,6 +45,10 @@ class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor {
IterDomain* broadcast_root_domain,
IterDomain* concrete_root_domain);

bool insertRootDomainToConcreteDomainSet(
IterDomain* new_root_id,
std::unordered_set<IterDomain*>& id_set);

private:
//! Maps each root broadcast domain to its original root broadcast
//! domains. Their can be multiple original domains due to, e.g.,
Expand All @@ -53,6 +58,8 @@ class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor {
//! Map all broadcast domains to concrete root domains
std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
broadcast_to_concrete_map_;

std::unique_ptr<ExactRootDomainMap> exact_map_;
};

} // namespace cuda
Expand Down
64 changes: 64 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20975,6 +20975,70 @@ TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) {
}
#endif

TEST_F(NVFuserTest, FusionBroadcastConcretization5_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);
auto tv1 = makeSymbolicTensor(1);
fusion.addInput(tv1);
auto tv2 = makeSymbolicTensor(1);
fusion.addInput(tv2);
auto tv3 = makeSymbolicTensor(1);
fusion.addInput(tv3);

// Assert tv2 and tv3 have the same shape
auto tv4 = add(tv2, tv3);
fusion.addOutput(tv4);

// Concretize a broadcast domain to multiple non-concrete domains
// through a multi-output expression. It should be considered to be
// non-uniquely concretized.
auto tv5 = broadcast(tv0, {false, true});
// Reduce only the non-broadcast domain.
auto tvs = Welford(tv5, {0});
auto tv9 = add(tvs.avg, tv1);
auto tv10 = add(tvs.var_sum, tv2);
fusion.addOutput(tv9);
fusion.addOutput(tv10);

// Same pattern as the above, but concretize the broadcast domain
// with tv2 and tv3, which have the exactly same shape, so the
// broadcast should be considered uniquely concretized.
auto tv11 = broadcast(tv0, {false, true});
// Reduce only the non-broadcast domain.
auto tvs2 = Welford(tv11, {0});
auto tv15 = add(tvs2.avg, tv2);
auto tv16 = add(tvs2.var_sum, tv3);
fusion.addOutput(tv15);
fusion.addOutput(tv16);

// Reduce only the broadcast domain. Since it's reduced, it should
// not be considered to be concretized.
auto tv17 = broadcast(tv0, {false, true});
auto tvs3 = Welford(tv17, {1});
fusion.addOutput(tvs3.avg);

ConcretizedBroadcastDomains bcast_concretization_info;
bcast_concretization_info.build(&fusion);

TORCH_CHECK(
bcast_concretization_info.maybeNonUniquelyConcretized(tv5->axis(1)),
"Failed to detect non-unique concretization of ",
tv5->toString());

TORCH_CHECK(
bcast_concretization_info.isUniquelyConcretized(tv11->axis(1)),
"Failed to detect unique concretization of ",
tv11->toString());

TORCH_CHECK(
!bcast_concretization_info.isConcretized(tv17->axis(1)),
"Failed to detect non-concretization of ",
tv17->toString());
}

TEST_F(NVFuserTest, FusionIssue1430_CUDA) {
// Derived from an expression sorting issue when using loop map, now expr
// sorting uses parallel map.
Expand Down

0 comments on commit 151d95b

Please sign in to comment.