From fd51304135b987c83ad2c1c9609fa68fc1bcada8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 12 Aug 2025 14:27:56 +0800 Subject: [PATCH 1/6] Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index a08b7c34d..5a433cc1a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a08b7c34d4a59f89f4dea252fa1a7e458e298ef0 +Subproject commit 5a433cc1af4a6d859cdf2b62c7c5ab28bf5836ea From 34c7205423c93700c635032294f22290db732550 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 12 Aug 2025 22:18:20 +0800 Subject: [PATCH 2/6] Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling - Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes. - Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management. - Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body. - Removed obsolete code and improved overall code clarity and maintainability. --- src/transform/inject_pipeline.cc | 816 ++++++++++--------------------- tilelang/engine/phase.py | 1 - 2 files changed, 250 insertions(+), 567 deletions(-) diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 3d7a4e692..273a76e7c 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -22,7 +22,6 @@ * \brief Transform annotated loops into pipelined one that parallelize * producers and consumers */ -#include #include #include #include @@ -83,137 +82,34 @@ struct BufferAccessInfo { int use = -1; // the last using stage of the buffer }; -class PipelineOpaqueAccessRewriter { -public: - /*! - * \brief Constructor - * \param buffer_data_to_buffer The map from buffer data to buffer. - * \param buffer_remap The map from original buffer to the buffer with updated - * shape for multi-versioning in the software pipeline. \param pipeline_loop - * The original loop to be software pipelined. \param fragment_info - * Information about tensor core fragment - */ - PipelineOpaqueAccessRewriter( - const Map &buffer_data_to_buffer, - const Map &buffer_remap, const For &pipeline_loop, - const std::unordered_map &fragment_info) - : buffer_data_to_buffer_(buffer_data_to_buffer), - buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop), - fragment_info_(fragment_info) {} - - PrimExpr Rewrite(const Call &call) { - // Intrinsic calls should be handled explicitly here as they are opaque - // accesses to buffer. - static const auto &load_matrix_sync = builtin::tvm_load_matrix_sync(); - static const auto &store_matrix_sync = builtin::tvm_store_matrix_sync(); - static const auto &mma_sync = builtin::tvm_mma_sync(); - static const auto &access_ptr = builtin::tvm_access_ptr(); - static const auto &ptx_ldmatrix = builtin::ptx_ldmatrix(); - static const auto &ptx_mma = builtin::ptx_mma(); - if (call->op.same_as(load_matrix_sync) || - call->op.same_as(store_matrix_sync)) { - const Buffer &buffer = - buffer_data_to_buffer_.at(Downcast(call->args[0])); - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - Array new_args = call->args; - const Buffer &new_buffer = (*it).second; - new_args.Set( - 4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); - return Call(call->dtype, call->op, new_args, call->span); - } - } else if (call->op.same_as(mma_sync)) { - Array new_args = call->args; - for (int i = 0; i < 4; i++) { - const Var &buffer_var = Downcast(call->args[i * 2]); - const PrimExpr &index = call->args[i * 2 + 1]; - const Buffer &buffer = buffer_data_to_buffer_.at(buffer_var); - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - PrimExpr new_index = - RewriteWmmaFragmentIndex(buffer, (*it).second, index); - new_args.Set(i * 2 + 1, new_index); - } - } - return Call(call->dtype, call->op, new_args, call->span); - } else if (call->op.same_as(access_ptr)) { - return RewriteBufferAccess(call, {1}); - } else if (call->op.same_as(ptx_mma)) { - return RewriteBufferAccess(call, {6, 8, 10}); - } else if (call->op.same_as(ptx_ldmatrix)) { - return RewriteBufferAccess(call, {3}); - } - return call; - } - -private: - int GetWmmaFragmentSize(const Buffer &buffer) { - auto it = fragment_info_.find(buffer->data.get()); - ICHECK(it != fragment_info_.end()); - const FragmentInfo &info = (*it).second; - return info.GetSize(); - } - - PrimExpr RewriteWmmaFragmentIndex(const Buffer &old_buffer, - const Buffer &new_buffer, - const PrimExpr &old_index) { - PrimExpr new_buffer_offset = old_index; - - int fragment_size = GetWmmaFragmentSize(old_buffer); - PrimExpr offset = floordiv( - foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), old_buffer->shape), - fragment_size); - new_buffer_offset += - floormod(pipeline_loop_->loop_var - pipeline_loop_->min, - new_buffer->shape[0]) * - offset; - return new_buffer_offset; - } - - PrimExpr RewriteBufferAccess(const Call &call, - const std::vector arg_indices) { - auto product = [](const Array &input) { - return foldl( - [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), input); - }; - Array new_args = call->args; - for (int i : arg_indices) { - const Buffer &buffer = - buffer_data_to_buffer_.at(Downcast(call->args[i])); - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - const Buffer &new_buffer = (*it).second; - const PrimExpr &old_index = call->args[i + 1]; - PrimExpr offset; - if (new_buffer->strides.empty()) { - offset = product(buffer->shape); - } else { - offset = new_buffer->strides[0]; - } - if (buffer.scope() == "m16n8k8.matrixA" || - buffer.scope() == "m16n8k8.matrixB") { - // mma scope size will shrink by warp size - // @see transform_mma_buffer_layout - ICHECK_EQ(Downcast(floormod(offset, 32))->value, 0) - << "mma scope size should be multiple of warp size"; - offset = floordiv(offset, 32); - } - PrimExpr new_index = - old_index + - floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; - new_args.Set(i + 1, new_index); - } +/*! + * \brief Replace IfThenElse nodes with their then_case, preserving attribute + * nodes \param body The statement to process \param condition The condition to + * match in IfThenElse nodes \return The transformed statement + */ +Stmt replace_if_then_else(Stmt body, PrimExpr condition) { + if (const auto *if_node = body.as()) { + // If this is an IfThenElse with the matching condition, replace it with its + // then_case + if (if_node->condition.same_as(condition)) { + return if_node->then_case; } - return Call(call->dtype, call->op, new_args, call->span); + } else if (const auto *attr_node = body.as()) { + // For attribute nodes, preserve the attribute but process its body + AttrStmt attr_stmt = GetRef(attr_node); + attr_stmt.CopyOnWrite()->body = + replace_if_then_else(attr_node->body, condition); + return attr_stmt; + } else if (const auto *block_node = body.as()) { + // For block nodes, process the body + Block block = GetRef(block_node); + block.CopyOnWrite()->body = + replace_if_then_else(block_node->body, condition); + return block; } - - const Map &buffer_data_to_buffer_; - const Map &buffer_remap_; - const For &pipeline_loop_; - const std::unordered_map &fragment_info_; -}; + // For any other node type, return it unchanged + return body; +} /*! * \brief Rewriter for the body of the software pipeline. This pass inserts @@ -231,19 +127,14 @@ class PipelineBodyRewriter : public StmtExprMutator { * Whether all versions the buffers in the software pipeline are accessed. * This will be used to update block access region. In the prologue and * epilogue of a two-stage software pipeline, only one version of these - * buffers are accessed. \param fragment_info Information about tensor core - * fragment + * buffers are accessed. */ - PipelineBodyRewriter( - const Map &buffer_data_to_buffer, - const Map &buffer_remap, For pipeline_loop, - bool access_all_versions, - const std::unordered_map &fragment_info) + PipelineBodyRewriter(const Map &buffer_data_to_buffer, + const Map &buffer_remap, + For pipeline_loop, bool access_all_versions) : buffer_data_to_buffer_(buffer_data_to_buffer), buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop), - access_all_versions_(access_all_versions), - opaque_access_rewriter_(buffer_data_to_buffer_, buffer_remap_, - pipeline_loop_, fragment_info) {} + access_all_versions_(access_all_versions) {} private: BufferRegion @@ -267,6 +158,36 @@ class PipelineBodyRewriter : public StmtExprMutator { return buffer_region; } + PrimExpr RewriteBufferAccess(const Call &call, + const std::vector arg_indices) { + auto product = [](const Array &input) { + return foldl( + [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), input); + }; + Array new_args = call->args; + for (int i : arg_indices) { + const Buffer &buffer = + buffer_data_to_buffer_.at(Downcast(call->args[i])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + const Buffer &new_buffer = (*it).second; + const PrimExpr &old_index = call->args[i + 1]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = product(buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = + old_index + + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; + new_args.Set(i + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } + Stmt VisitStmt_(const BlockNode *op) final { for (const Buffer &alloc_buffer : op->alloc_buffers) { buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); @@ -282,14 +203,14 @@ class PipelineBodyRewriter : public StmtExprMutator { for (const Buffer &alloc_buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(alloc_buffer->data); } - return block; + return std::move(block); } Stmt VisitStmt_(const BufferStoreNode *op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_remap_.find(store->buffer); if (it == buffer_remap_.end()) { - return store; + return std::move(store); } const Buffer &new_buffer = (*it).second; auto *n = store.CopyOnWrite(); @@ -297,14 +218,14 @@ class PipelineBodyRewriter : public StmtExprMutator { PrimExpr version = floormod( (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); n->indices.insert(n->indices.begin(), version); - return store; + return std::move(store); } PrimExpr VisitExpr_(const BufferLoadNode *op) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_remap_.find(load->buffer); if (it == buffer_remap_.end()) { - return load; + return std::move(load); } const Buffer &new_buffer = (*it).second; auto *n = load.CopyOnWrite(); @@ -312,19 +233,21 @@ class PipelineBodyRewriter : public StmtExprMutator { PrimExpr version = floormod( (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); n->indices.insert(n->indices.begin(), version); - return load; + return std::move(load); } PrimExpr VisitExpr_(const CallNode *op) final { Call call = Downcast(StmtExprMutator::VisitExpr_(op)); - return opaque_access_rewriter_.Rewrite(call); + if (call->op.same_as(builtin::tvm_access_ptr())) { + return RewriteBufferAccess(call, {1}); + } + return call; } Map buffer_data_to_buffer_; Map buffer_remap_; For pipeline_loop_; bool access_all_versions_; - PipelineOpaqueAccessRewriter opaque_access_rewriter_; }; /*! @@ -333,35 +256,14 @@ class PipelineBodyRewriter : public StmtExprMutator { */ class PipelineRewriter : public StmtExprMutator { public: - static Stmt Rewrite( - Map buffer_data_to_buffer, - const std::unordered_set - &double_buffers, - const Array pipeline_allocs, const For &pipeline_loop, - const PipelineInfo &pipeline_info, - const std::unordered_map &fragment_info, - const Map preserved_annotations) { - PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, - pipeline_allocs, pipeline_loop, pipeline_info, - fragment_info, preserved_annotations); - return rewriter.BuildPipeline(); - } - -private: - PipelineRewriter( - Map buffer_data_to_buffer, - const std::unordered_set - &double_buffers, - const Array &pipeline_allocs, const For &pipeline_loop, - const PipelineInfo &pipeline_info, - const std::unordered_map &fragment_info, - const Map preserved_annotations) - + PipelineRewriter(Map buffer_data_to_buffer, + const Array &pipeline_allocs, + const For &pipeline_loop, const PipelineInfo &pipeline_info, + PrimExpr predicate_condition = PrimExpr()) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), - double_buffers_(double_buffers), pipeline_allocs_(pipeline_allocs), - pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), - fragment_info_(fragment_info), - preserved_annotations_(preserved_annotations) {} + pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), + pipeline_info_(pipeline_info), + predicate_condition_(predicate_condition) {} Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the @@ -376,36 +278,61 @@ class PipelineRewriter : public StmtExprMutator { } ordered_stmts_.resize(pipeline_info_.size()); - for (const auto &pair : pipeline_info_) { - const Block &block = pair.first; - int order = pair.second.order; - ordered_stmts_.Set(order, block); + for (const auto &[block, anno] : pipeline_info_) { + ordered_stmts_.Set(anno.order, block); } - // Step 2: Emit the pipeline prologue, body and epilogue. - Stmt prologue = - EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); - Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, - pipeline_loop_->min + pipeline_loop_->extent, false); - // introduce extra lowerbound when the loop length is smaller than num - // stages to ensure the epilogue interval do not overlap the prologue - // interval. - PrimExpr epigogue_start = pipeline_loop_->min + pipeline_loop_->extent; - Optional extra_epilogue_lower_bound = std::nullopt; - if (max_stage_ > 1 && - !analyzer_.CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) { - if (is_const_int(epigogue_start)) { - epigogue_start = max(epigogue_start, pipeline_loop_->min + max_stage_); - } else { - // for dynamic case, introduce extra lowerbound as loop predicate - // to ensure the epilogue part unrollable. - extra_epilogue_lower_bound = pipeline_loop_->min + max_stage_; + for (const Block &block : ordered_stmts_) { + int stage = pipeline_info_[block].stage; + if (pipeline_info_[block].async) { + auto &state = async_states[stage]; + state.producer_head = pipeline_loop_->min - 1; + for (auto write_region : block->writes) { + auto buffer = write_region->buffer; + state.dst_buffers.insert(buffer.get()); + if (buffer_remap_.count(buffer)) + state.dst_buffers.insert(buffer_remap_[buffer].get()); + } } } - Stmt epilogue = - EmitImpl(epigogue_start, - pipeline_loop_->min + pipeline_loop_->extent + max_stage_, - true, extra_epilogue_lower_bound); + std::unordered_set consumed; + for (const Block &block : ordered_stmts_) { + int stage = pipeline_info_[block].stage; + if (pipeline_info_[block].async) { + auto &state = async_states[stage]; + if (state.commit_groups.empty() || consumed.count(stage)) { + state.commit_groups.push_back({}); + } + state.commit_groups.back().push_back(pipeline_info_[block].order); + consumed.erase(stage); + for (auto write_region : block->writes) { + auto buffer = buffer_remap_.count(write_region->buffer) + ? buffer_remap_[write_region->buffer] + : write_region->buffer; + state.buffer_to_commit_group_[buffer.get()] = + state.commit_groups.size() - 1; + } + } + + for (auto read_region : block->reads) { + for (const auto &[producer_stage_id, producer_state] : async_states) { + if (producer_stage_id <= stage && + producer_state.writes(read_region->buffer)) { + consumed.insert(producer_stage_id); + } + } + } + } + + // Step 2: Emit the pipeline prologue, body and epilogue. + Stmt prologue = EmitImpl(pipeline_loop_->min, + pipeline_loop_->min + max_stage_, true, true); + Stmt body = + EmitImpl(pipeline_loop_->min + max_stage_, + pipeline_loop_->min + pipeline_loop_->extent, false, false); + Stmt epilogue = EmitImpl( + pipeline_loop_->min + pipeline_loop_->extent, + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true); SeqStmt stmt = SeqStmt({prologue, body, epilogue}); @@ -550,9 +477,6 @@ class PipelineRewriter : public StmtExprMutator { num_versions--; } } - if (num_versions == 1 && double_buffers_.count(buffer)) { - num_versions = 2; - } return num_versions; } @@ -584,15 +508,16 @@ class PipelineRewriter : public StmtExprMutator { // valid, it is the "sum of extents of loops that have been executed" - 1, // e.g. for epilogue it is prologue extent + body extent - 1. This is only // needed to compute wait count for epilogue without async producers. - Optional producer_head{PrimExpr(-1)}; - + PrimExpr producer_head; + std::vector> commit_groups; + std::unordered_map buffer_to_commit_group_; bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } }; // Per-stage states that are local to each of pipeline prologue, body, and // epilogue. struct AsyncStateLocal { - struct { + struct PendingWait { // The index into a list of blocks, where async_wait_queue should be // attached at the beginning. int insert_before; @@ -601,198 +526,76 @@ class PipelineRewriter : public StmtExprMutator { PrimExpr wait_count{nullptr}; bool valid() const { return wait_count.defined(); } - } pending_wait; - - // Destination buffers of async operations that have been encountered so far - // in the loop - // - // for (size_t i = 0; i < new_blocks.size(); ++i) { - // ... - // } - // - // This is for tracking which async operations have been issued at the - // "current" iteration, up until a point where we encounter a consumer of - // async result buffers. This is used to decide if the producer_head of each - // buffer points to a copy written in the current or previous iteration. - std::unordered_set seen; + }; + + std::vector pending_waits; // A symbolic expression representing the index the latest async operation // associated with this stage has written into, at the "current" iteration. Optional producer_head; - // The predicate of BlockRealize containing the async operation of this - // stage. - Optional predicate; - // Indices into a list of blocks, where async_commit_queue scope should be - // attached. If multiple async producers are interleaved with their consumer - // in between, we need separate async_commit_queue for each producer. Thus, - // we need multiple sets of indices. - std::vector> commit_groups; - - // This is set to true when we reach a stage that consumes this async stage. - bool consumed{false}; }; /*! Structure holding intermediate information for pipeline loop rewriting. */ struct RewrittenBlockInfo { int stage; + int order; PrimExpr predicate; Block block; PrimExpr access_index; bool is_async; }; - // Determine where to insert async_wait and the corresponding wait count. - void PopulateWaitCounts( - const std::vector &new_blocks, - arith::Analyzer *ana_normalized, - const std::unordered_map &buffer_to_commit_group, - std::map *async_states_local) { - + void PopulateWaitCounts(const std::vector &new_blocks, + std::map *async_states_local) { for (size_t i = 0; i < new_blocks.size(); ++i) { - if (new_blocks[i].is_async) { - // Record the fact that we have encountered these write buffers. - for (auto write_region : new_blocks[i].block->writes) { - (*async_states_local)[new_blocks[i].stage].seen.insert( - write_region->buffer.get()); - } - } - int producer_stage_idx = -1; for (auto read_region : new_blocks[i].block->reads) { - for (auto kv : async_states) { - if (kv.first <= new_blocks[i].stage && - kv.second.writes(read_region->buffer)) { + for (const auto &[stage, state] : async_states) { + if (stage <= new_blocks[i].stage && + state.writes(read_region->buffer)) { // Found an earlier stage where read_region->buffer was // asynchronously written - ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) + ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage) << "A dependency on multiple async stages is not supported"; - producer_stage_idx = kv.first; + producer_stage_idx = stage; } } } - if (producer_stage_idx == -1) continue; - - // The following logic has become complicated to handle case like this: - // - // for i in range(13): - // # Stage 0 - // async_commit_queue(0): - // async_scope: - // A_shared[(i + 3) % 4] = A[...] - // - // - // # Stage 1 - // async_wait_queue(0, 5): - // compute(A_shared[i], B_shared[i]) - // - // # Stage 0 - // async_commit_queue(0) - // async_scope: - // B_shared[(i + 3) % 4] = B[...] - // - // - // Here, multiple async producers in the same stage are interleaved with - // their consumer in between. Since each buffer is associated with - // different commit groups, the wait_count before the consumer should be - // bigger than the simpler case: - // - // for i in range(13): - // # Stage 0 - // async_commit_queue(0): - // async_scope: - // A_shared[(i + 3) % 4] = A[...] - // B_shared[(i + 3) % 4] = B[...] - // - // # Stage 1 - // async_wait_queue(0, 3): - // compute(A_shared[i], B_shared[i]) - // - // The correct wait_count can be determined by considering each commit - // group separately, and summing "per-commit" wait_counts. - // - // From A_shared's perspective, it allows for (i + 3) - i async commit - // groups to be in flight while from B_shared's perspective, the producer - // head at compute points to the copy done by the previous iteration, so - // its wait_count is calculated as ((i - 1) + 3) - i. The sum of the two - // wait_counts gives 5. - // print async_states_local - + const auto &state = async_states[producer_stage_idx]; auto &dep_local_state = (*async_states_local)[producer_stage_idx]; - const auto num_commit_group = dep_local_state.commit_groups.size(); - std::vector> producer_head_per_commit; - - auto add_unique_producer_head = - [&](const Optional &producer_head) { - // if producer_head already in producer_head_per_commit, return - for (const auto &head : producer_head_per_commit) { - if (StructuralEqual()(head, producer_head)) { - return; - } - } - producer_head_per_commit.push_back(producer_head); - }; - - if (num_commit_group == 0) { - // Epilogue, no async producer. Since "local" producer_head is not - // available, use "global" producer_head. - ICHECK(!dep_local_state.producer_head); - add_unique_producer_head( - async_states[producer_stage_idx].producer_head); - } else { - ICHECK(dep_local_state.producer_head); - std::vector need_wait_count(num_commit_group, true); - - for (auto read_region : new_blocks[i].block->reads) { - if (!async_states[producer_stage_idx].writes(read_region->buffer)) - continue; - auto commit_group_id = - buffer_to_commit_group.at(read_region->buffer.get()); - if (!need_wait_count[commit_group_id]) - continue; - - if (!dep_local_state.seen.count(read_region->buffer.get())) { - // Multiple async producers interleaved: The most recent async write - // is from the previous iteration. This is the B_shared case above. - add_unique_producer_head(dep_local_state.producer_head.value() - 1); - } else { - // Normal case - add_unique_producer_head(dep_local_state.producer_head.value()); - } - - need_wait_count[commit_group_id] = false; + PrimExpr in_flight_cnt = 0; + for (const auto &group : state.commit_groups) { + PrimExpr consumer_head = new_blocks[i].access_index; + PrimExpr producer_head; + if (dep_local_state.producer_head.defined()) { + producer_head = dep_local_state.producer_head.value(); + // if the group is after the wait point, minus by 1 + if (group.front() > new_blocks[i].order) + producer_head -= 1; + } else { + producer_head = state.producer_head; } + in_flight_cnt += producer_head - consumer_head; } - auto wait_count = [=, &ana_normalized]() { - auto sum = PrimExpr(0); - for (const auto &producer_head : producer_head_per_commit) { - if (producer_head && - ana_normalized->CanProve(producer_head.value() >= 0)) { - // Here, new_blocks[i].access_index corresponds to "consumer_head". - // The difference of producer_head and consumer_head is precisely - // the number of async commit groups that can still be in flight - // after this wait. - sum += analyzer_.Simplify(producer_head.value() - - new_blocks[i].access_index); - } else { - // The precise count cannot be determined, give up. - return PrimExpr(0); - } - } - return sum; - }(); - - auto &pending_wait = dep_local_state.pending_wait; - - if (!pending_wait.valid()) { - pending_wait = {static_cast(i), wait_count}; - } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) { - // Coalesce multiple wait_queue if the later one allows fewer in-flight - // ops. - pending_wait = {pending_wait.insert_before, wait_count}; + // We can relax the in-flight-count by the number of independent commit. + std::unordered_set dependent_groups; + for (const auto &read_region : new_blocks[i].block->reads) { + if (state.buffer_to_commit_group_.count(read_region->buffer.get())) + dependent_groups.insert( + state.buffer_to_commit_group_.at(read_region->buffer.get())); + } + for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) { + if (dependent_groups.count(i) == 0) + in_flight_cnt += 1; + else + break; // stop relaxing } + in_flight_cnt = analyzer_.Simplify(in_flight_cnt); + dep_local_state.pending_waits.push_back( + {static_cast(i), in_flight_cnt}); } } @@ -800,85 +603,38 @@ class PipelineRewriter : public StmtExprMutator { // statements with async scopes (if any). Array CompletePipelineLoopStatements( const std::vector &blocks, - const std::map &async_states_local, - arith::Analyzer *ana_normalized) const { + const std::map &async_states_local) const { std::vector new_blocks = blocks; - std::vector commit_group_indices(new_blocks.size(), -1); for (const auto &[stage_id, state] : async_states_local) { - if (!state.commit_groups.empty()) { - for (size_t i = 0; i < state.commit_groups.size(); ++i) { - for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { - ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); - commit_group_indices[state.commit_groups[i][0] + j] = stage_id; - } - } + for (const auto &pw : state.pending_waits) { + auto &block = new_blocks[pw.insert_before].block; + BlockNode *n = block.CopyOnWrite(); + auto zero = make_zero(DataType::Int(32)); + n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, + AttrStmt(zero, tir::attr::async_wait_inflight_count, + pw.wait_count, n->body)); } + } - if (state.pending_wait.valid()) { - auto attach_wait_scope = [&new_blocks](int i, int stage_id, - PrimExpr wait_count) { - auto &block = new_blocks[i].block; - BlockNode *n = block.CopyOnWrite(); - auto zero = make_zero(DataType::Int(32)); - n->body = - AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, - AttrStmt(zero, tir::attr::async_wait_inflight_count, - wait_count, n->body)); - }; - - if (state.predicate && - !ana_normalized->CanProve(state.predicate.value())) { - // If the async operation that this wait_queue is waiting on is - // predicated, and we cannot prove that the predicate is always true, - // the precise wait count is only valid at iterations where the - // predicate is true; - auto wait_count = - Call(DataType::Int(32), builtin::if_then_else(), - {state.predicate.value(), state.pending_wait.wait_count, 0}); - attach_wait_scope(state.pending_wait.insert_before, stage_id, - wait_count); - } else { - attach_wait_scope(state.pending_wait.insert_before, stage_id, - state.pending_wait.wait_count); - } + // mark the last async stmt as commit + std::unordered_set commit_group_indices; + for (const auto &[stage_id, state] : async_states) { + for (size_t i = 0; i < state.commit_groups.size(); ++i) { + commit_group_indices.insert(state.commit_groups[i].back()); } } Array stmts; - for (size_t i = 0; i < new_blocks.size();) { - if (commit_group_indices[i] == -1) { - // A synchrnous block, not part of any commit group - stmts.push_back( - BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); - ++i; - } else { - Array group_bodies; - auto stage_id = commit_group_indices[i]; - auto predicate = new_blocks[i].predicate; - for (; i < commit_group_indices.size() && - commit_group_indices[i] == stage_id; - ++i) { - ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) - << "Predicates in the same stage are expected to be identical"; - group_bodies.push_back(new_blocks[i].block->body); - } - - if (group_bodies.size() > 1) { - auto merged_bodies = SeqStmt(group_bodies); - group_bodies.clear(); - group_bodies.push_back(merged_bodies); - } - - for (auto body : group_bodies) { - auto commit_queue_scope = - AttrStmt(make_zero(DataType::Int(32)), - tir::attr::async_commit_queue_scope, stage_id, body); - auto new_block = - MakeBlock(commit_queue_scope, buffer_data_to_buffer_); - stmts.push_back(BlockRealize({}, predicate, new_block)); - } + for (size_t i = 0; i < new_blocks.size(); i++) { + Block block = new_blocks[i].block; + if (commit_group_indices.count(new_blocks[i].order)) { + auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_commit_queue_scope, + new_blocks[i].stage, block->body); + block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); } + stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block)); } return stmts; @@ -889,21 +645,16 @@ class PipelineRewriter : public StmtExprMutator { * \param start The start of the range * \param end The end of the range * \param unroll_loop Whether the loop should be unrolled. - * \param extra_loop_lower_bound Extra loop lower bound. * \return The result loop. */ Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, - Optional extra_loop_lower_bound = std::nullopt) { + bool need_bound_check) { PrimExpr new_loop_var; PrimExpr extent = end - start; - auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); }; - if (analyzer_.CanProve(extent <= 0)) { - return make_nop(); - } bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); if (is_unit_loop) { new_loop_var = start; // use constants as the loop var for unit loops @@ -912,43 +663,34 @@ class PipelineRewriter : public StmtExprMutator { analyzer_.Bind(Downcast(new_loop_var), Range(start, end)); } - // In contrast to analyzer_ which is bound to [start, end), this one is - // bound to the "normalized" range, [pipeline_loop_->min, extent). - arith::Analyzer ana_normalized; - if (!is_unit_loop) { - ana_normalized.Bind(Downcast(new_loop_var), - Range(pipeline_loop_->min, extent)); - } - std::vector new_blocks; // Async related std::map async_states_local; - std::unordered_map buffer_to_commit_group; + PrimExpr normalized_access_index; for (const Block &block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; + int order = pipeline_info_.at(block).order; + PrimExpr inbound = Bool(true); PrimExpr skewed_loop_var = new_loop_var - stage; - PrimExpr inbound = - analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && - (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); - if (extra_loop_lower_bound.defined()) { - inbound = analyzer_.Simplify( - inbound && new_loop_var >= extra_loop_lower_bound.value()); - } + if (need_bound_check) + inbound = + analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && + (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); if (analyzer_.CanProve(!inbound)) { continue; } - Block new_block = Downcast(PipelineBodyRewriter( - buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, - max_stage_ != 1, fragment_info_)(block)); + Block new_block = Downcast( + PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, + pipeline_loop_, max_stage_ != 1)(block)); PrimExpr delta = start - pipeline_loop_->min; // This variable corresponds to // - "producer_head" if this stage is an async producer // - "consumer_head" if this stage reads from asynchronously written // buffers. - PrimExpr normalized_access_index = + normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; // Adjust the block predicate and the body according to the final loop @@ -958,76 +700,38 @@ class PipelineRewriter : public StmtExprMutator { Var loop_iter = Downcast(new_loop_var); inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); } - new_block = Downcast(Substitute( new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); - + if (predicate_condition_.defined()) { + BlockNode *n = new_block.CopyOnWrite(); + n->body = IfThenElse( + Substitute(predicate_condition_, + {{pipeline_loop_->loop_var, normalized_access_index}}), + n->body); + } if (pipeline_info_[block].async) { auto &local_state = async_states_local[stage]; - - int commit_group_id = -1; - if (local_state.commit_groups.empty() || local_state.consumed) { - // consumed == true means there is already a consumer stage waiting - // for an eariler async operation of this stage. In such cases, we - // make multiple commit_queue for this stage. - commit_group_id = local_state.commit_groups.size(); - local_state.commit_groups.push_back({new_blocks.size()}); - } else { - // This is the case when one commit_queue groups multiple async - // blocks. with commit_queue(stage): - // async_scope: - // A_shared[...] = ... - // async_scope: - // B_shared[...] = ... - - commit_group_id = local_state.commit_groups.size() - 1; - local_state.commit_groups.back().push_back(new_blocks.size()); - } - - for (auto write_region : new_block->writes) { - async_states[stage].dst_buffers.insert(write_region->buffer.get()); - buffer_to_commit_group[write_region->buffer.get()] = commit_group_id; - } - local_state.producer_head = normalized_access_index; - - if (!local_state.predicate || - ana_normalized.CanProve(local_state.predicate.value())) { - local_state.predicate = inbound; - } else if (local_state.predicate) { - local_state.predicate = - ana_normalized.Simplify(local_state.predicate.value() & inbound); - } - BlockNode *n = new_block.CopyOnWrite(); n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body); } - new_blocks.push_back({stage, inbound, new_block, normalized_access_index, + new_blocks.push_back({stage, order, inbound, new_block, + normalized_access_index, pipeline_info_[block].async}); - - for (auto read_region : new_block->reads) { - for (auto kv : async_states) { - int producer_stage_id = kv.first; - if (producer_stage_id <= stage && - kv.second.writes(read_region->buffer)) { - async_states_local[producer_stage_id].consumed = true; - } - } - } } - PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, - &async_states_local); - auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, - &ana_normalized); + PopulateWaitCounts(new_blocks, &async_states_local); + + auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local); Stmt new_loop{nullptr}; if (stmts.empty()) { return make_nop(); } + if (stmts.size() == 1) { new_loop = stmts[0]; } else { @@ -1035,26 +739,22 @@ class PipelineRewriter : public StmtExprMutator { } if (!is_unit_loop) { + Map preserved_annotations; + for (const auto &kv : pipeline_loop_->annotations) { + const String &key = kv.first; + if (kv.first != tir::attr::software_pipeline_stage && + kv.first != tir::attr::software_pipeline_order && + kv.first != tir::attr::software_pipeline_async_stages) { + preserved_annotations.Set(key, kv.second); + } + } new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, - std::move(new_loop), std::nullopt, preserved_annotations_); + std::move(new_loop), std::nullopt, preserved_annotations); } - // Update producer heads in the global async states. - for (const auto &kv : async_states_local) { - const int stage_id = kv.first; - const AsyncStateLocal &state = kv.second; - - if (state.predicate && ana_normalized.CanProve(state.predicate.value()) && - async_states[stage_id].producer_head) { - // Advance the "global" producer head if it is still valid and we know - // exactly how much we can increment - async_states[stage_id].producer_head = - async_states[stage_id].producer_head.value() + extent; - } else { - // Otherwise, invalidate the global producer head - async_states[stage_id].producer_head = std::nullopt; - } + for (const auto &[stage_id, state] : async_states_local) { + async_states[stage_id].producer_head += extent; } return BlockRealize({}, Bool(true), @@ -1063,17 +763,14 @@ class PipelineRewriter : public StmtExprMutator { arith::Analyzer analyzer_; Map buffer_data_to_buffer_; - const std::unordered_set - &double_buffers_; Array pipeline_allocs_; For pipeline_loop_; PipelineInfo pipeline_info_; - const std::unordered_map &fragment_info_; + PrimExpr predicate_condition_; int max_stage_ = -1; Map buffer_remap_; Array ordered_stmts_; std::map async_states; - Map preserved_annotations_; }; /*! @@ -1088,7 +785,8 @@ void BuildDependencyGraph(const Array &blocks, ObjectPtrEqual> *dep_src2dst, std::unordered_map, ObjectPtrHash, ObjectPtrEqual> *dep_dst2src) { - std::unordered_map> buffer_writers; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_writers; for (const Block &block : blocks) { for (const BufferRegion &read : block->reads) { @@ -1119,7 +817,6 @@ class PipelineInjector : private StmtExprMutator { const Buffer &buffer = kv.second; injector.buffer_data_to_buffer_.Set(buffer->data, buffer); } - injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body); return injector(func->body); } @@ -1184,6 +881,7 @@ class PipelineInjector : private StmtExprMutator { // can be direct child of the for-loop. If the for-loop has BlockRealize as // its child, the pipeline body will be the child of the block. Stmt pipeline_body{nullptr}; + PrimExpr predicate_condition{nullptr}; Array pipeline_allocs; if (const auto *realize = for_node->body.as()) { const auto &block = realize->block; @@ -1191,7 +889,15 @@ class PipelineInjector : private StmtExprMutator { ICHECK(buffer->IsInstance()); buffer_data_to_buffer_.Set(buffer->data, buffer); } - pipeline_body = block->body; + if (const auto *if_then_else = block->body.as()) { + ICHECK(!if_then_else->else_case.defined()) + << "Pipeline_Planning: Can't handle the body of the loop because " + "it is not a SeqStmt"; + pipeline_body = if_then_else->then_case; + predicate_condition = if_then_else->condition; + } else { + pipeline_body = block->body; + } pipeline_allocs = block->alloc_buffers; } else { pipeline_body = for_node->body; @@ -1256,16 +962,6 @@ class PipelineInjector : private StmtExprMutator { } } - Map preserved_annotations; - for (const auto &kv : op->annotations) { - const String &key = kv.first; - if (kv.first != tir::attr::software_pipeline_stage && - kv.first != tir::attr::software_pipeline_order && - kv.first != tir::attr::software_pipeline_async_stages) { - preserved_annotations.Set(key, kv.second); - } - } - for (size_t i = 0; i < pipeline_stages.size(); i++) { int stage = static_cast(pipeline_stages[i]->value); bool is_async = @@ -1279,9 +975,10 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = PipelineRewriter::Rewrite( - buffer_data_to_buffer_, double_buffers, pipeline_allocs, - GetRef(op), pipeline_info, fragment_info_, preserved_annotations); + Stmt pipeline = + PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, + GetRef(op), pipeline_info, predicate_condition) + .BuildPipeline(); if (const auto *realize = op->body.as()) { const auto &block = realize->block; @@ -1297,22 +994,12 @@ class PipelineInjector : private StmtExprMutator { buffer_data_to_buffer_.Set(buffer->data, buffer); } - auto it = op->annotations.find(tir::attr::double_buffer_scope); - if (it != op->annotations.end()) { - int buffer_index = Downcast((*it).second).IntValue(); - CHECK(buffer_index >= 0 && - static_cast(buffer_index) < op->writes.size()) - << "ValueError: Index of the buffer exceeds the size of the write " - "regions of the block. (" - << buffer_index << " vs. " << op->writes.size() << ")"; - double_buffers.insert(op->writes[buffer_index]->buffer); - } Block block = Downcast(StmtExprMutator::VisitStmt_(op)); for (const auto &buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); } - return block; + return std::move(block); } bool HasPipelineAnnotation(const ForNode *op) const { @@ -1325,21 +1012,18 @@ class PipelineInjector : private StmtExprMutator { } if (has_stage) { LOG(FATAL) - << "ValueError: Order of the software pipeline is not defined."; + << "ValueError: Stage of the software pipeline is not defined."; } if (has_order) { LOG(FATAL) - << "ValueError: Stage of the software pipeline is not defined."; + << "ValueError: Order of the software pipeline is not defined."; } return false; } Map buffer_data_to_buffer_; - std::unordered_map fragment_info_; - std::unordered_set double_buffers; Optional global_symbol_; }; - } // namespace software_pipeline /*! diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 00a6d05e7..5a53f44d5 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -79,7 +79,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LegalizeVectorizedLoop()(mod) # Add safety checks for memory accesses mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) - # Align dynamic shared memory allocations # Simplify again to clean up any duplicated conditions # that may have been introduced by safety checks # use an enhanced pass to simplify the dynamic symbolics From 6a9b8d94795956624c79f7ff724e4c620cce29d1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 12 Aug 2025 22:29:17 +0800 Subject: [PATCH 3/6] lint fix --- src/transform/inject_pipeline.cc | 64 ++++---------------------------- 1 file changed, 7 insertions(+), 57 deletions(-) diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 273a76e7c..48ea3d6e7 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -82,35 +82,6 @@ struct BufferAccessInfo { int use = -1; // the last using stage of the buffer }; -/*! - * \brief Replace IfThenElse nodes with their then_case, preserving attribute - * nodes \param body The statement to process \param condition The condition to - * match in IfThenElse nodes \return The transformed statement - */ -Stmt replace_if_then_else(Stmt body, PrimExpr condition) { - if (const auto *if_node = body.as()) { - // If this is an IfThenElse with the matching condition, replace it with its - // then_case - if (if_node->condition.same_as(condition)) { - return if_node->then_case; - } - } else if (const auto *attr_node = body.as()) { - // For attribute nodes, preserve the attribute but process its body - AttrStmt attr_stmt = GetRef(attr_node); - attr_stmt.CopyOnWrite()->body = - replace_if_then_else(attr_node->body, condition); - return attr_stmt; - } else if (const auto *block_node = body.as()) { - // For block nodes, process the body - Block block = GetRef(block_node); - block.CopyOnWrite()->body = - replace_if_then_else(block_node->body, condition); - return block; - } - // For any other node type, return it unchanged - return body; -} - /*! * \brief Rewriter for the body of the software pipeline. This pass inserts * `floormod` to indices of the remapped buffer to select the version @@ -258,12 +229,10 @@ class PipelineRewriter : public StmtExprMutator { public: PipelineRewriter(Map buffer_data_to_buffer, const Array &pipeline_allocs, - const For &pipeline_loop, const PipelineInfo &pipeline_info, - PrimExpr predicate_condition = PrimExpr()) + const For &pipeline_loop, const PipelineInfo &pipeline_info) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), - pipeline_info_(pipeline_info), - predicate_condition_(predicate_condition) {} + pipeline_info_(pipeline_info) {} Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the @@ -667,7 +636,6 @@ class PipelineRewriter : public StmtExprMutator { // Async related std::map async_states_local; - PrimExpr normalized_access_index; for (const Block &block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; @@ -690,7 +658,7 @@ class PipelineRewriter : public StmtExprMutator { // - "producer_head" if this stage is an async producer // - "consumer_head" if this stage reads from asynchronously written // buffers. - normalized_access_index = + PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; // Adjust the block predicate and the body according to the final loop @@ -702,13 +670,6 @@ class PipelineRewriter : public StmtExprMutator { } new_block = Downcast(Substitute( new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); - if (predicate_condition_.defined()) { - BlockNode *n = new_block.CopyOnWrite(); - n->body = IfThenElse( - Substitute(predicate_condition_, - {{pipeline_loop_->loop_var, normalized_access_index}}), - n->body); - } if (pipeline_info_[block].async) { auto &local_state = async_states_local[stage]; local_state.producer_head = normalized_access_index; @@ -766,7 +727,6 @@ class PipelineRewriter : public StmtExprMutator { Array pipeline_allocs_; For pipeline_loop_; PipelineInfo pipeline_info_; - PrimExpr predicate_condition_; int max_stage_ = -1; Map buffer_remap_; Array ordered_stmts_; @@ -881,7 +841,6 @@ class PipelineInjector : private StmtExprMutator { // can be direct child of the for-loop. If the for-loop has BlockRealize as // its child, the pipeline body will be the child of the block. Stmt pipeline_body{nullptr}; - PrimExpr predicate_condition{nullptr}; Array pipeline_allocs; if (const auto *realize = for_node->body.as()) { const auto &block = realize->block; @@ -889,15 +848,7 @@ class PipelineInjector : private StmtExprMutator { ICHECK(buffer->IsInstance()); buffer_data_to_buffer_.Set(buffer->data, buffer); } - if (const auto *if_then_else = block->body.as()) { - ICHECK(!if_then_else->else_case.defined()) - << "Pipeline_Planning: Can't handle the body of the loop because " - "it is not a SeqStmt"; - pipeline_body = if_then_else->then_case; - predicate_condition = if_then_else->condition; - } else { - pipeline_body = block->body; - } + pipeline_body = block->body; pipeline_allocs = block->alloc_buffers; } else { pipeline_body = for_node->body; @@ -975,10 +926,9 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = - PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, - GetRef(op), pipeline_info, predicate_condition) - .BuildPipeline(); + Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, + GetRef(op), pipeline_info) + .BuildPipeline(); if (const auto *realize = op->body.as()) { const auto &block = realize->block; From 31f5e2dc9e4f7fa80d2a6c97f8abe71430c51f20 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 12 Aug 2025 23:37:54 +0800 Subject: [PATCH 4/6] Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls - Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves. - Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations. --- src/transform/inject_pipeline.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 48ea3d6e7..0432c7333 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -174,14 +174,14 @@ class PipelineBodyRewriter : public StmtExprMutator { for (const Buffer &alloc_buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(alloc_buffer->data); } - return std::move(block); + return block; } Stmt VisitStmt_(const BufferStoreNode *op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_remap_.find(store->buffer); if (it == buffer_remap_.end()) { - return std::move(store); + return store; } const Buffer &new_buffer = (*it).second; auto *n = store.CopyOnWrite(); @@ -189,14 +189,14 @@ class PipelineBodyRewriter : public StmtExprMutator { PrimExpr version = floormod( (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); n->indices.insert(n->indices.begin(), version); - return std::move(store); + return store; } PrimExpr VisitExpr_(const BufferLoadNode *op) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_remap_.find(load->buffer); if (it == buffer_remap_.end()) { - return std::move(load); + return load; } const Buffer &new_buffer = (*it).second; auto *n = load.CopyOnWrite(); @@ -204,7 +204,7 @@ class PipelineBodyRewriter : public StmtExprMutator { PrimExpr version = floormod( (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); n->indices.insert(n->indices.begin(), version); - return std::move(load); + return load; } PrimExpr VisitExpr_(const CallNode *op) final { @@ -835,7 +835,7 @@ class PipelineInjector : private StmtExprMutator { // Step 1: Recursively rewrite the children first. For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); if (!HasPipelineAnnotation(op)) { - return std::move(for_node); + return for_node; } // Step 2: Find the body and buffer allocations of the pipeline. The body // can be direct child of the for-loop. If the for-loop has BlockRealize as @@ -949,7 +949,7 @@ class PipelineInjector : private StmtExprMutator { for (const auto &buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); } - return std::move(block); + return block; } bool HasPipelineAnnotation(const ForNode *op) const { From 824c66bc8a1b1a81cde5797e6d5c4cf1013aafb8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 13 Aug 2025 12:21:14 +0800 Subject: [PATCH 5/6] test fix --- ...lang_transform_Inject_software_pipeline.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py index f6afca839..81c1007eb 100644 --- a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -9,7 +9,6 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.Simplify()(mod) - print(mod["main"]) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) @@ -40,21 +39,29 @@ def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")): C[tx, i] = B[tx, 0] + T.float32(1) @T.prim_func - def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None: + def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")): for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(""): + with T.block(): T.reads(A[tx, 0]) T.writes(C[tx, 0]) B = T.alloc_buffer((2, 16, 1), scope="shared") - with T.block(""): + with T.block(): T.reads(A[tx, 0]) T.writes(B[0, tx, 0]) B[0, tx, 0] = A[tx, 0] * T.float32(2.0) - with T.block(""): - T.reads() - T.writes() - T.evaluate(0) - with T.block(""): + with T.block(): + T.reads(A[tx, 1:1], B[0:2, tx, 0]) + T.writes(B[1:1, tx, 0], C[tx, 0:0]) + for i in range(0): + with T.block(): + T.reads(A[tx, i + 1]) + T.writes(B[i + 1, tx, 0]) + B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0) + with T.block(): + T.reads(B[i, tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[i, tx, 0] + T.float32(1.0) + with T.block(): T.reads(B[0, tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[0, tx, 0] + T.float32(1.0) From 9dfa6d37af3b004213aea9d0626dd85b84aa348e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 13 Aug 2025 15:01:42 +0800 Subject: [PATCH 6/6] Enhance global read detection in pipeline planning - Updated the handling of global reads to account for condition expressions within IfThenElse nodes, ensuring accurate identification of global memory accesses. - Introduced a new flag to track whether the visitor is within a condition expression, improving the correctness of buffer access analysis. - Refactored the VisitStmt_ method to properly handle the structure of IfThenElse nodes, enhancing the clarity and maintainability of the code. --- src/transform/pipeline_planning.cc | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 8f50765c8..13630b620 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -6,6 +6,7 @@ #include #include "../target/utils.h" +#include "tvm/ir/expr.h" namespace tvm { namespace tl { @@ -81,7 +82,11 @@ class BufferRegionCollector : public StmtExprVisitor { auto load_region = BufferRegion(load_buffer, region); reads_.push_back(load_region); - if (op->buffer.scope() == "global") { + if (op->buffer.scope() == "global" && !within_condition_expr_) { + // skip condition expr of if_then_else node + // shared[i] = T.if_then_else(global[i] < n, register_a[i], register_b[i]) + // is not a global read shared[i] = T.if_then_else(global[i] < n, + // global_a[i], global_b[i]) is a global read is_global_read_ = true; } } @@ -103,11 +108,30 @@ class BufferRegionCollector : public StmtExprVisitor { // because we only care about the buffer itself instead of indices reads_.push_back(buffer_region); } + } else if (op->op.same_as(builtin::if_then_else())) { + within_condition_expr_ = true; + this->VisitExpr(op->args[0]); + within_condition_expr_ = false; + for (auto i = 1; i < op->args.size(); i++) { + this->VisitExpr(op->args[i]); + } } else { StmtExprVisitor::VisitExpr_(op); } } + void VisitStmt_(const IfThenElseNode *op) final { + within_condition_expr_ = true; + this->VisitExpr(op->condition); + within_condition_expr_ = false; + this->VisitStmt(op->then_case); + if (op->else_case.defined()) { + within_condition_expr_ = true; + this->VisitStmt(op->else_case.value()); + within_condition_expr_ = false; + } + } + private: Map buffer_data_to_buffer_; Array reads_; @@ -115,6 +139,7 @@ class BufferRegionCollector : public StmtExprVisitor { bool is_global_read_ = false; bool under_buffer_store_ = false; bool is_global_copy_pattern_ = false; + bool within_condition_expr_ = false; }; class PipelinePlanner : public StmtExprMutator {