Skip to content

Commit

Permalink
Fix scheduling with polymorphic broadcast (#1714)
Browse files Browse the repository at this point in the history
* Add a routine to query if a broadcast domain may be concretized to
multiple domains

* Don't group operations together that may have a broadcast that's being broadcasted to more than one size.

Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>
Co-authored-by: jjsjann123 <jiej@nvidia.com>
  • Loading branch information
3 people committed May 19, 2022
1 parent 4ab5ef7 commit f68b830
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 16 deletions.
45 changes: 41 additions & 4 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,16 @@ def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1):
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
o = op(*args)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
if check_stride:
self.assertEqual(o.stride(), jit_o.stride())

if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]

for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True)

def _run_training_helper(self, jit_op, op, grads, *args):
Expand Down Expand Up @@ -4771,6 +4777,37 @@ def t_cpu(x):

self.assertGraphContainsExactly(t_cpu_jit.graph_for(x), FUSION_GUARD, 0)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scheduler_with_polymorphic_broadcast(self):
device = "cuda"
x0 = torch.randn(1024, 204800, device=device)
x1 = torch.rand_like(x0)
x2 = torch.randn(1024, device=device)

def t(x0, x1, x2):
x3 = x2.unsqueeze(-1)
x4 = x3 + x0
x5 = x3 + x1
x6 = x5.sum(0)
return x4, x6

t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)

x2 = torch.randn(204800, device=device)

def t2(x0, x1, x2):
x3 = x2.unsqueeze(0)
x4 = x3 + x0
x5 = x3 + x1
x6 = x5.sum(1)
return x4, x6

t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True)


class TestPassManagerCudaFuser(JitTestCase):
def setUp(self):
Expand Down
32 changes: 24 additions & 8 deletions torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,19 @@ void ConcretizedBroadcastDomains::build(Fusion* fusion) {
}

bool ConcretizedBroadcastDomains::isConcretized(IterDomain* id) const {
auto it = concretized_domains_.find(id);
return it != concretized_domains_.end();
auto it = broadcast_to_concrete_map_.find(id);
return it != broadcast_to_concrete_map_.end();
}

bool ConcretizedBroadcastDomains::isUniquelyConcretized(IterDomain* id) const {
auto it = broadcast_to_concrete_map_.find(id);
return it != broadcast_to_concrete_map_.end() && it->second.size() == 1;
}

bool ConcretizedBroadcastDomains::maybeNonUniquelyConcretized(
IterDomain* id) const {
auto it = broadcast_to_concrete_map_.find(id);
return it != broadcast_to_concrete_map_.end() && it->second.size() > 1;
}

void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) {
Expand Down Expand Up @@ -67,7 +78,10 @@ void ConcretizedBroadcastDomains::handle(Expr* expr) {
for (const auto& kv : p2c_map) {
auto p_id = kv.first;
auto c_id = kv.second;
const bool is_concretized = !c_id->isBroadcast();
// If the consumer ID is a reduction (i.e., a trivial
// reduction), do not consider it's concretized.
const bool is_concretized =
!c_id->isBroadcast() && !c_id->isReduction();
auto it = broadcast_origin_map_.find(p_id);
TORCH_INTERNAL_ASSERT(
it != broadcast_origin_map_.end(),
Expand All @@ -79,8 +93,7 @@ void ConcretizedBroadcastDomains::handle(Expr* expr) {
if (is_concretized) {
// Keep track of all the origin domains as concretized
for (auto origin : producer_origins) {
// concretized_root_domains_.insert(origin);
markAsConcretized(origin);
markAsConcretized(origin, c_id);
}
} else {
// Not concretized yet. Propagate forward the origin info.
Expand All @@ -95,12 +108,15 @@ void ConcretizedBroadcastDomains::handle(Expr* expr) {
}
}

void ConcretizedBroadcastDomains::markAsConcretized(IterDomain* root_domain) {
std::deque<IterDomain*> child_domains({root_domain});
void ConcretizedBroadcastDomains::markAsConcretized(
IterDomain* broadcast_root_domain,
IterDomain* concrete_root_domain) {
std::deque<IterDomain*> child_domains({broadcast_root_domain});
while (!child_domains.empty()) {
auto child = child_domains.front();
child_domains.pop_front();
if (!concretized_domains_.emplace(child).second) {
auto& concrete_ids = broadcast_to_concrete_map_[child];
if (!concrete_ids.emplace(concrete_root_domain).second) {
continue;
}
const auto& child_uses = child->uses();
Expand Down
18 changes: 14 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,35 @@ class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor {
public:
void build(Fusion* fusion);

//! Is a domain concretized?
bool isConcretized(IterDomain* id) const;

//! Is a domain concretized to a unique concrete domain?
bool isUniquelyConcretized(IterDomain* id) const;

//! Is a domain concretized to multiple concrete domains?
bool maybeNonUniquelyConcretized(IterDomain* id) const;

private:
using IterVisitor::handle;

void handle(BroadcastOp* bop) final;

void handle(Expr* expr) final;

void markAsConcretized(IterDomain* root_domain);
void markAsConcretized(
IterDomain* broadcast_root_domain,
IterDomain* concrete_root_domain);

private:
//! Maps each broadcast domain to its original broadcast
//! Maps each root broadcast domain to its original root broadcast
//! domains. Their can be multiple original domains due to, e.g.,
//! binary ops with broadcast domains in both inputs.
std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
broadcast_origin_map_;
//! Set of all concretized original domains
std::unordered_set<IterDomain*> concretized_domains_;
//! Map all broadcast domains to concrete root domains
std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
broadcast_to_concrete_map_;
};

} // namespace cuda
Expand Down
41 changes: 41 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,26 @@ static bool checkPatternEquivalence(
return it0 == out_root0.end() && it1 == out_root1.end();
}

// Reusing some code from lowering specifically in lower_trivial_broadcast.cpp
// ConcretizedBroadcastDomains::maybeNonUniquelyConcretized this checks if
// there's a broadcast iteration domain that's being broadcasted to seemingly
// different extents, meaning we don't know in the kernel if the dimension is
// being broadcasted to one size multiple times or different sizes. This is a
// hard to optimize problem and likely indicates we shouldn't be fusing.
bool hasNonUniqueBcast(Fusion* fusion) {
ConcretizedBroadcastDomains concretize_info;
concretize_info.build(fusion);

for (auto tv : ir_utils::allTvs(fusion)) {
for (auto id : tv->getRootDomain()) {
if (concretize_info.maybeNonUniquelyConcretized(id)) {
return true;
}
}
}
return false;
}

//! Scheduler interface:
//! Each of the scheduler needs to provide 3 interface functions:
//!
Expand Down Expand Up @@ -851,6 +871,13 @@ class ReductionScheduler : public SchedulerEntry {
return false;
}

if (hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Reduction,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}

// Make sure reduction axes are consistent through the fusion
auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);
Expand Down Expand Up @@ -981,6 +1008,13 @@ class PointWiseScheduler : public SchedulerEntry {
return false;
}

if (hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::PointWise,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}

return true;
}

Expand Down Expand Up @@ -1048,6 +1082,13 @@ class PersistentKernelScheduler : public SchedulerEntry {
return false;
}

if (hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Persistent,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}

auto reduction_tvs =
scheduler_utils::getReductionTvs(fusion, false /* ignore_trivial */);

Expand Down
47 changes: 47 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23011,6 +23011,53 @@ TEST_F(NVFuserTest, FusionContigPredicate_CUDA) {
testValidate(fe.kernel(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
}

// Repro of an issue of the reduction scheduler with a broadcast
// domain concretized to multiple domains that are not proven to have
// the same extent
TEST_F(NVFuserTest, FusionRepro1713_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv0 = makeSymbolicTensor(2);
auto tv1 = makeSymbolicTensor(2);
auto tv2 = makeSymbolicTensor(1);
fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);
auto tv3 = broadcast(tv2, {false, true});

auto tv4 = add(tv3, tv0);

auto tv5 = add(tv3, tv1);
auto tv6 = sum(tv5, {0});
fusion->addOutput(tv4);
fusion->addOutput(tv6);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1024, 204800}, options);
// Original repro had the same shape as t0, but this should work
// with a different extent at the second axis
at::Tensor t1 = at::randn({1024, 123}, options);
at::Tensor t2 = at::randn({1024}, options);
std::vector<IValue> aten_inputs({t0, t1, t2});

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

auto t3 = t2.unsqueeze(-1);
auto t4 = t3 + t0;
auto t5 = t3 + t1;
auto t6 = sum(t5, {0});

testValidate(
executor_cache.fusion(),
cg_outputs,
{t0, t1, t2},
{t4, t6},
__LINE__,
__FILE__);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)

0 comments on commit f68b830

Please sign in to comment.