Skip to content

Commit

Permalink
Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#…
Browse files Browse the repository at this point in the history
…1710)

* unify segmented and single fusion path
  • Loading branch information
shmsong committed May 24, 2022
1 parent b3d1c3f commit dd23252
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 213 deletions.
6 changes: 1 addition & 5 deletions benchmarks/cpp/nvfuser/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ void runBenchmarkIterations(
std::vector<c10::IValue>& aten_inputs) {
fusion_executor_cache->runFusionWithInputs(aten_inputs);
bool segmented =
fusion_executor_cache->getMostRecentKernelRuntime()->isSegmented() &&
fusion_executor_cache->getMostRecentKernelRuntime()
->fusionSegments()
->groups()
.size() > 1;
fusion_executor_cache->getMostRecentKernelRuntime()->isSegmented();

if (!segmented) {
fusion_executor_cache->profile(true);
Expand Down
26 changes: 26 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,32 @@ std::string toString(const SegmentedEdge* edge) {
return ss.str();
}

std::unique_ptr<SegmentedFusion> SegmentedFusion::fromCompleteFusion(
std::unique_ptr<Fusion> fusion_ptr,
ScheduleHeuristic heuristic) {
auto fusion = fusion_ptr.get();

auto segmented_fusion_ptr =
std::make_unique<SegmentedFusion>(std::move(fusion_ptr));

// Make a group for the single fusion
auto single_group = segmented_fusion_ptr->newGroup();

// Add input and output vals
single_group->input_vals = fusion->inputs();
single_group->output_vals = fusion->outputs();

// Get ordered expression list
single_group->resetExprList();

// Assign heuristics and id for the complete fusion
// to share the runtime path of segmented fusion.
single_group->setHeuristic(heuristic);
single_group->setID(0);

return segmented_fusion_ptr;
}

SegmentedFusion::SegmentedFusion(std::unique_ptr<Fusion> fusion)
: impl_(this), complete_fusion_(std::move(fusion)) {
segmented_fusion_name_ = segmentedFusionName();
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ class TORCH_CUDA_CU_API SegmentedFusion {
public:
explicit SegmentedFusion(std::unique_ptr<Fusion> fusion);

//! Factory function for the un-segmented case, directly
//! constructs a "SegmentedFusion", with the given Fusion
//! as the only group.
static std::unique_ptr<SegmentedFusion> fromCompleteFusion(
std::unique_ptr<Fusion> fusion,
ScheduleHeuristic heuristic);

//! Is the fusion segmented?
bool isSegmented() const {
return !groups_.empty();
Expand Down
Loading

0 comments on commit dd23252

Please sign in to comment.