Skip to content

Commit

Permalink
Unify Python and C++ TIR lower API (apache#8110)
Browse files Browse the repository at this point in the history
  • Loading branch information
CircleSpin authored and Trevor Morris committed Jun 17, 2021
1 parent ba0dd07 commit 4e459cf
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 1,371 deletions.
60 changes: 56 additions & 4 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/support/with.h>
#include <tvm/target/target.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/function.h>

#include <string>
#include <unordered_map>
Expand All @@ -42,17 +43,68 @@
#include <vector>

namespace tvm {

/*!
* \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList)
* \param mod The IRmodule to lower
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);

/*!
* \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list
* defined in CreatePassList)
* \param func The PrimFunc to lower
* \param name The name of the lowered function.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
bool simple_mode = false);

/*!
* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \brief Build an IRModule given a TE schedule, args and binds. This function also applies
* the lowering passes defined in CreatePassList.
* \param sch The TE schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);

TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);

/*!
* \brief Build an IRModule given a TE schedule, args and binds. This function also applies
* the lowering passes defined in CreatePassList.
* \param sch The TE schedule to lower.
* \param args The arguments to the function (Array of Tensor, Buffer and Vars)
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);

/*!
* \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want
* to apply lowering passes as well, use LowerSchedule.
* \param sch The schedule
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);
/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
12 changes: 0 additions & 12 deletions include/tvm/te/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,6 @@ bool VerifyCompactBuffer(const Stmt& stmt);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);

/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True):
"""Do lower while keeping all axes in IR
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
"""
binds, _ = build_module.get_binds(args, binds)
binds, _ = build_module.get_binds(args, compact=False, binds=binds)
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/driver/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.driver"""
import tvm._ffi

tvm._ffi._init_api("driver", __name__)
155 changes: 20 additions & 135 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,96 +37,58 @@
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var

from . import _ffi_api as ffi


def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = tvm.tir.decl_buffer(
x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, tvm.tir.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
binds, arg_list = ffi.get_binds(args, compact, binds)
return binds, arg_list


def form_irmodule(sch, args, name, binds):
def schedule_to_module(
sch: schedule.Schedule,
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
) -> IRModule:
"""According to the given schedule, form a function.
Parameters
----------
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str
The name of result function.
The name of result function, default name is "main"
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
pass_ctx = PassContext.current()
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)

compact = schedule.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)

stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

func = func.with_attr("global_symbol", name)

if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
return tvm.IRModule({name: func})
return ffi.schedule_to_module(sch, args, name, binds)


def lower(
inputs: Union[schedule.Schedule, PrimFunc, IRModule],
inp: Union[schedule.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
Expand All @@ -136,7 +98,7 @@ def lower(
Parameters
----------
input : Union[schedule.Schedule, PrimFunc, IRModule]
inputs : Union[schedule.Schedule, PrimFunc, IRModule]
The TE schedule or TensorIR PrimFunc/IRModule to be built
args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
Expand All @@ -160,90 +122,13 @@ def lower(
m : IRModule
The result IRModule
"""
# config setup
pass_ctx = PassContext.current()
instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False))
add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])

lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

# Phase 0
pass_list = lower_phase0
is_legacy_te_schedule: bool = False

if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for lowering from TE schedule")
mod = form_irmodule(inputs, args, name, binds)
is_legacy_te_schedule = True
elif isinstance(inputs, PrimFunc):
func = inputs.with_attr("global_symbol", name)
if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
elif isinstance(inputs, IRModule):
mod = inputs
else:
raise TypeError(
f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}"
)

# Phase 1
if is_legacy_te_schedule:
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
]
else:
pass_list += [
tvm.tir.transform.LowerInitBlock(),
tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
tvm.tir.transform.ConvertBlocksToOpaque(),
tvm.tir.transform.CompactBufferAllocation(),
tvm.tir.transform.FlattenBuffer(),
]
pass_list += [
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
]

pass_list += lower_phase1

# Phase 2
if not simple_mode:
pass_list += [(tvm.tir.transform.LoopPartition())]

pass_list += [
tvm.tir.transform.VectorizeLoop(not disable_vectorize),
tvm.tir.transform.InjectVirtualThread(),
tvm.tir.transform.InjectDoubleBuffer(),
tvm.tir.transform.StorageRewrite(),
tvm.tir.transform.UnrollLoop(),
]
pass_list += lower_phase2

# Phase 3
pass_list += [
tvm.tir.transform.Simplify(),
tvm.tir.transform.RemoveNoOp(),
]

pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [tvm.tir.transform.HoistIfThenElse()]
pass_list += lower_phase3

# Instrument BoundCheckers
if instrument_bound_checkers:
pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]

optimize = tvm.transform.Sequential(pass_list)
mod = optimize(mod)
return mod
if isinstance(inp, IRModule):
return ffi.lower_module(inp, simple_mode)
if isinstance(inp, PrimFunc):
return ffi.lower_primfunc(inp, name, simple_mode)
if isinstance(inp, schedule.Schedule):
return ffi.lower_schedule(inp, args, name, binds, simple_mode)
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp))


def _build_for_device(input_mod, target, target_host):
Expand Down
39 changes: 0 additions & 39 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,6 @@
from tvm.target import Target


@tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
"""Backend function for lowering.
Parameters
----------
sch : tvm.te.Schedule
The schedule.
inputs : List[tvm.te.Tensor]
The inputs to the function.
func_name : str
The name of the function.
source-func : tvm.relay.Function
The source function to be lowered.
Returns
-------
mod : tvm.IRModule
The result of lowering.
"""
# pylint: disable=broad-except, import-outside-toplevel
import traceback

try:
f = tvm.driver.lower(sch, inputs, name=func_name)
# logging.debug("lower function %s", func_name)
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile function\n"
msg += "-----------------------------\n"
msg += source_func.astext()
raise RuntimeError(msg)
return f


@tvm._ffi.register_func("relay.backend.build")
def build(mod, target, target_host=None):
"""Backend build function.
Expand Down
Loading

0 comments on commit 4e459cf

Please sign in to comment.