Skip to content

Commit

Permalink
[TIR] Add software pipelining (apache#10066)
Browse files Browse the repository at this point in the history
* [TIR] Add software pipelining

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* fix

* fix

* lint

* fix

* format

* doc

* remove print

* lint

* lint

* doc

* Apply suggestions from code review

Co-authored-by: Junru Shao <junrushao1994@gmail.com>

* address comments

* address comments

* refactor FragmentInfo::GetSize

* remove unused

* refactor

* address comments

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
  • Loading branch information
7 people authored and pfk-beta committed Apr 11, 2022
1 parent f8d2c81 commit a6b99ab
Show file tree
Hide file tree
Showing 8 changed files with 1,765 additions and 17 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,12 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
*/
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*! \brief Mark the stage of a statement in the software pipeline */
constexpr const char* software_pipeline_stage = "software_pipeline_stage";

/*! \brief Mark the order of a statement in the software pipeline */
constexpr const char* software_pipeline_order = "software_pipeline_order";

/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

Expand Down
101 changes: 101 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,107 @@ TVM_DLL Pass ConvertForLoopsToSerial();
*/
TVM_DLL Pass UnifiedStaticMemoryPlanner();

/*!
* \brief This pass transforms annotated loops into pipelined ones where producers and consumers
* are overlapped with the information provided in loop annotations, which enables optimization
* techniques like prefetching and pipeline parallelism.
*
* The pipeline scope consists of the direct children of the annotated loop (ignoring BlockRealize,
* Block, SeqStmt), and the number of children is denoted by `n` in the documentation.
*
* The following annotations are used to guide the loop transformation:
*
* 1) Loop annotation `software_pipeline_stage` defines the pipeline stage.
* An array of `n` integers, and each element should be in range [0, max_stage],
* where max_stage is the maximum (inclusive) stage.
* 2) Loop annotation `software_pipeline_order` defines the pipeline order.
* An array of `n` integers, a permutation of [0, 1, ..., num_components - 1];
* 3) Block annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of
* read/write dependency. It's an integer index of the write regions of the block.
*
* Every annotated loop is transformed into a loop with three blocks as its direct children:
*
* 1) Prologue block, where components whose stage is less than `max_stage` is executed;
*
* 2) Body block, where all the components are executed;
*
* 3) Epilogue block, where only components whose stage is greater than 0 will be executed.
* The execution order is controlled by the annotation `software_pipeline_order`,
* and thus could be different than the original order.
*
* Note: For nested software pipelines, the inner software pipeline will be generated first,
* which may affect the number of the direct children of the outer loop.
* In this case, the annotations for the outer software
* pipeline should include the result of the inner software pipeline,
* which is the three blocks as discussed above.
* Example:
*
* Before this pass, the TIR is:
*
* \code{.py}
* @T.prim_func
* def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* for i in T.serial(0, 16,
* annotations={"software_pipeline_stage": [0, 1],
* "software_pipeline_order": [0, 1]}
* ):
* with T.block():
* T.reads(A[tx, i])
* T.writes(C[tx, i])
* B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
* with T.block("B"):
* T.reads(A[tx, i])
* T.writes(B[tx, 0])
* B[tx, 0] = A[tx, i] * T.float32(2)
* with T.block("C"):
* T.reads(B[tx, 0])
* T.writes(C[tx, i])
* C[tx, i] = B[tx, 0] + T.float32(1)
* \endcode
*
* The TIR above annotates the loop as a two-stage pipeline with no reordering.
* After applying this pass, the TIR is transformed into:
*
* \code{.py}
* @T.prim_func
* def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* with T.block():
* T.reads([A[tx, 0:16]])
* T.writes([C[tx, 0:16]])
* B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
* with T.block("prologue"):
* T.reads([A[tx, 0]])
* T.writes([B[0, tx, 0]])
* B[0, tx, 0] = A[tx, 0] * T.float32(2)
* with T.block("body"):
* T.reads([A[tx, 1:16], B[0:2, tx, 0]])
* T.writes([B[0:2, tx, 0], C[tx, 0:15]])
* for i in T.serial(0, 15):
* with T.block("B"):
* T.reads([A[tx, i + 1]])
* T.writes([B[(i + 1) % 2, tx, 0]])
* B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
* with T.block("C"):
* T.reads([B[i % 2, tx, 0]])
* T.writes([C[tx, i]])
* C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
* with T.block("epilogue"):
* T.reads([B[1, tx, 0]])
* T.writes([C[tx, 15]])
* C[tx, 15] = B[1, tx, 0] + T.float32(1)
* \endcode
*
* The original loop has two blocks, B and C, as its direct children. The loop annotations indicate
* that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B
* should be executed in advance of block C by one iteration. The order 0 and 1 specifies the order
* of block B and C inside the body block inside the result TIR.
*
* \return The IR transform pass.
*/
TVM_DLL Pass InjectSoftwarePipeline();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,14 @@ def ConvertForLoopsToSerial():
The result pass
"""
return _ffi_api.ConvertForLoopsToSerial() # type: ignore


def InjectSoftwarePipeline():
"""Transform annotated loops into pipelined one that parallelize producers and consumers
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectSoftwarePipeline() # type: ignore
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
Expand Down
Loading

0 comments on commit a6b99ab

Please sign in to comment.