Skip to content

Commit

Permalink
Fix missing cooperative launch (#1726)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed May 24, 2022
1 parent dc670a2 commit b3d1c3f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class KernelIrScanner : private IrVisitor {
summary_.has_grid_reductions = true;
const auto dom = ir_utils::getTvOutput(grid_reduction)->domain();
updateGridReductionInLoop(dom);
if (grid_reduction->isAllreduce()) {
summary_.has_cooperative_grid_reduction = true;
}
}

void handle(GridBroadcast* grid_broadcast) final {
Expand Down
9 changes: 7 additions & 2 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,8 +1033,9 @@ TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) {
auto tv8 = add(tv7, tv6);
fusion.addOutput(tv8);

const int tidx = 512;
groupReductions({tv1, tv4});
tv1->split(1, 128);
tv1->split(1, tidx);
TransformPropagator::from(tv1);

tv0->computeAt(tv8, -1, ComputeAtMode::MostInlined);
Expand All @@ -1044,7 +1045,11 @@ TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) {
tv1->axis(2)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));

std::vector<int64_t> shape({99, 999});
std::vector<int64_t> shape({10, 999});

if (shape.at(0) * ceilDiv(shape.at(1), tidx) > deviceSMCount()) {
GTEST_SKIP() << "Not enough SMs to run this test";
}

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

Expand Down

0 comments on commit b3d1c3f

Please sign in to comment.