-
Notifications
You must be signed in to change notification settings - Fork 58
[DISCUSS] Relax Pass Infrastructure #71
Comments
We are actually facing a similar issue that may require joint optimization between passes in training. For example, AutoCast (AMP) may choose to insert a new cast op or reuse a already inserted on. Reusing cast ops could maintain the execution order, but it means the output of one cast op is used by more than two subsequent ops, which prevents fusion from happening. In addition, we also have an issue about which backend should be used for each op. As a result, we are thinking an infra for joint optimization, including AutoCast, Fusion, DialectOpLowering, and Rematerialization. In terms of Google XTAT. Some thoughts about your comments:
A vague approach in my mind is having a configuration similar to AutoTVM. Specifically, developers could use the tuning APIs to represent tunable parameters in a pass, and the tuning infra is able to collect them and figure out the best combination. Pros: developers can optionally add a dimension of tuning configurations without worrying about anything else. Cons: 1) the tuning infra basically has no idea about what it is tuning, so it's hard to improve the tuning efficiency from the search algorithms. 2) the configuration scope is per pass, making it as the smallest granularity. Just my two cents. The approach I mentioned is definitely not optimal, but hopefully this could let others chime in for better ideas. |
@comaniac, thank you for your input and I totally agree with your thoughts. I expect there would be more exciting opportunities in training since there are more interesting operations and data coming into play, like you gave examples. Regarding sequential pass application, I also think it should be one of the fundamental principles in our new infra. However, it will be great if we can allow some opportunity for "true" joint-optimization. I might have an idea, but let me polish it and bring it to our discussion with other's inputs in this thread. |
Hi, all! Hope you are all doing well. Major Challenges in Current Pass Infra
Thus, we want to address these challenges with new pass infrastructure design. TermonologyWe define two kinds of optimization or analysis passes.
Design GoalNew pass infrastructure design aims to provide the flexible and composable optimization pipeline with following goals.
Design PrinciplesWe propose the following design principles:
These principles will enable followings. (H: heuristic pass, T: tuning pass, eval_pass: passes for the candidate evaluation)
Synergy with RelaxNew pass infrastructure will have unique synergy with Relax. Since Relax aims to express different abstraction layers in IRModule and offer a universal build for such IRModule, it will allow various opportunities to explore such as,
Class DesignPlease note that this is a pseudo code. For actual implementation, we may extend the existing data structures and functionalities if possible. For example, given that MetaSchedule provides a great fundamental APIs, such as builder/runner, database API, cost model, etc. for generic tuning methods, @junrushao1994, @zxybazh and I are planning to discuss how we can extend it for the tuning pass with generic tuning methods beyond kernel-level tuning. Also, currently, there are different functions depending on whether you are handling Pass Class# Base pass class
class Pass():
# Specify dependent passes
def __init__(self, required=[]):
self.required = required
# Pass context class
class PassContext():
def __init__(
self,
# ... include current fields ...
target, # this is necessary for evaluation in tuning pass.
):
# ...
# Base class for heuristic pass
# It will look similar to current pass design.
class HeuristicPass(Pass):
# Actual implementation for optimizations/analysis
def transform_module(
self,
mod: IRModule,
ctx: PassContext
)->IRModule:
# ... contents ...
# Base class for tuning pass
class TuningPass(Pass):
def __init__(self, eval_passes, eval_metric, measure_option):
super().__init__()
# Passes for evaluation to enable joint-optimization
self.eval_passes = eval_passes
# Evaluation criteria for candidates (e.g., execution time)
self.eval_metric = eval_metric
# Measurement option
self.measure_option = measure_option
# Use metaschedule for evaluation
def evaluate(self, ctx, candidates):
target = ctx.config["target"]
# Evaluation
scoreboard = {}
for candidate in candidates:
# Apply pass group before build
seq = tvm.transform.Sequential(self.eval_passes)
candidate = seq(candidate)
# Leverage metaschedule builder/runner to get score
score = ...
scoreboard[candidate] = score
return scoreboard
# Different tuning methods may have different cost model. Can we extend metaschedule?
@staticmethod
def query_cost_model(candidates):
pass
# Can we extend metaschedule cost model?
@staticmethod
def update_database(...):
pass
@staticmethod
def select_best_candidate(scoreboard):
# ... select the best candidate depending on the eval_metric ...
return best_candidate
# Actual implementation for optimizations/analysis
# This will have feedback loops following repeating steps
# (1) candidate generations
# (2) candidate evaluation: this will call evaluate() method
# (3) pick the best candidate and reflect feedback
def transform_module(
self,
mod: IRModule,
ctx: PassContext
)->IRModule:
# ... contents ...
# Some useful APIs
# Sanity check for IRModule
def validate(mod: IRModule) -> Bool:
# ... validataion logic ...
# Dependency check for a sequence of passes
def validate(seq: Array<IRModule>) -> Bool:
# ... validation logic ...
# Extract certain part of graph in interest
# This will be useful for subgraph benchmarking
def extract_subgraph(mod: Expr) -> Expr:
# ... extraction mechanics ... [TBD] Data structure for the communication with the build system.This requires a discussion. Look at D1 for details. This will define the exploration space for optimization passes. (e.g., can we explore fusion decisions?) [TBD] Pass sequence registration interface.This requires a discussion. Look at D5 for details. Developer PoVDevelopers can design their own custom passes and perform any # This mock tunining pass fuses parallel matmul
@ir.transform.module_pass(opt_level=1)
class TuningCombineParallelMatmul(TuningPass):
def __init__(self, eval_passes = []):
super().__init__(eval_passes)
def transform_module(
self,
mod: IRModule,
ctx: PassContext)->IRModule:
# Candidate generation
new_mod = transform.CombineParallelMatmul()(mod)
# Two candiate: Do you want to enable it? or disable it?
candidate_pool = [mod, new_mod]
scoreboard = self.evaluate(ctx, candidate_pool)
best_perf, best_mod = self.select_best_candidate(scoreboard)
return best_mod
# This mock tunining pass makes layout transform decision
@ir.transform.module_pass(opt_level=1)
class TuningLayout(TuningPass):
def __init__(self, eval_passes = []):
super().__init__(eval_passes)
def transform_module(self, mod: IRModule, ctx: PassContext)->IRModule:
# Candidate generation
new_mod = transform.LayoutTransform()(mod)
# Two candiate: Do you want to enable it? or disable it?
candidate_pool = [mod, new_mod]
scoreboard = self.evaluate(ctx, candidate_pool)
best_perf, best_mod = self.select_best_candidate(scoreboard)
return best_mod Depending on what developers want, they can run each pass separately, in sequential, or in joint-optimization fashion. # Run TuningLayout pass only
custom_pass = TuningLayout()
optimized_mod = custom_pass(mod)
# Run TuningLayout and TuniningCombineParallelMatmul sequentially
# You can also change the order easily
seq = [ TuningLayout(), TuningCombineParallelMatmul() ]
custom_pipeline = tvm.transform.Sequential(seq)
optimized_mod = custom_pass(mod)
# Run joint-optimization
seq = [ TuningLayout(eval_passes = [TuningCombineParallelMatmul()]) ]
custom_pipeline = tvm.transform.Sequential(seq)
optimized_mod = custom_pass(mod)
# Later, you can generate executable by using relax build system.
lib = relax.vm.build(optimized_mod) Developers can also design a tuning pass for more interesting optimization decisions if the build system supports (e.g., fusion decision, lowering decisions, graph rewriting). Like previous example, they can also easily test out joint-optimization. # TASO-like tuning pass
@ir.transform.module_pass(opt_level=1)
class TuningGraphRewriting(TuningPass):
def __init__(self, eval_passes = []):
super().__init__(eval_passes)
def transform_module(
self,
mod: IRModule,
ctx: PassContext)->IRModule:
# Some tuning approaches repeat search within their tuning budget
budget = ...
while budget>0:
# ... some analysis ...
# (1) Generate candidates
candidates = get_promising_rewritten_graph(expr)
# (2) Evaluate candidates with other passes and minimal build
scoreboard = self.evaluate(ctx, candidate_pool)
# ...
# (3) Reflect the feedback
best_perf, best_mod = self.select_best_candidate(scoreboard)
# ... generate next promising candidates based on the current feedback ...
return optimzied_mod
# Collage-like tuning pass
@ir.transform.module_pass(opt_level=1)
class TuningBackendPlacement(TuningPass):
def __init__(self, eval_passes = []):
super().__init__(eval_passes)
def transform_module(
self,
mod: IRModule,
ctx: PassContext)->IRModule:
# ... some analysis ...
for node in post_order_dfs(mod):
# ... some analysis ...
# (1) Generate candidates
candidates = get_available_backend_candidates(node)
# (2) Evaluate candidates with other passes and minimal build
scoreboard = self.evaluate(ctx, candidate_pool)
# ...
# (3) Reflect the feedback
best_perf, best_mod = self.select_best_candidate(scoreboard)
annoate(node, best_candidate)
# ...
return optmized_mod If you are interested in, you can also play around with its prototype in Relay world here: link Discussion Points
Any feedback or thoughts would be greatly appreciated! Since this might be a bold design, I would like to face potential issues early. |
Thanks for valuable inputs today! And definitely I would appreciate more feedback if you have any. Some of feedback during the meeting
|
Hi @sunggg , thanks for the great proposal. I have a few questions:
|
Hi, @ZihengJiang
Currently, I'm thinking to allow their own eval passes and trying to find potential issues. If you find any corner case, that would be very helpful. Current design specifies eval pass on pass invocation. If this is not what you asked for, would you elaborate a little further? class TuningPass(Pass):
def __init__(self, eval_passes, eval_metric, measure_option):
super().__init__()
# Passes for evaluation to enable joint-optimization
self.eval_passes = eval_passes |
hi @sunggg , thanks for the great work! Some random ideas/thoughts about pass order infrastructure:
|
@hypercubestart, thank you for your input!
Yes, great point. I think I may miss in the design in the above, but I'm planning to add
Theoretically, if you want to joint-optimize two independent passes, you would just include one pass in the eval pass of another one. e.g., And I think your example brought up excellent point. If two passes may do similar jobs in the different abstraction layers like LayoutTransform (TIR)/LayoutRewrite(Graph), later one may revert earlier one's decision. I believe this is an open question, so I added it as one of the discussion point D3 in my proposal. |
A brief thought I brought up in our meeting but wanted to leave in writing: For the management of passes and when they are applicable, we may want some automatic way of specifying which program features passes are designed to handle and automatically check for them like the seldom-used feature flags in Relay. Feature flags like those (but coupled with automatic enforcement) could allow for detecting certain kinds of bugs and incompatibilities in advance. (Alternatively, we could have the norm that all passes are expected to support all program features, which would also require testing them against all program features to be certain of that.) |
Hi, all. I'd like to discuss the first formal version of Tuning Pass design. Once we reach on the agreement, I'd like to start working on the implementation. BackgroundsToday’s tuning methodsWhat you tune
How you tuneDepending on what you accept as input and how you generate candidates
Search methods - different tuning methods may favor different strategyDepending on whether you stop tuning anytime or not, we can categorize the search into two kinds.
Goal
Design OverviewFundamentally, tuning is a feedback-directed search that repeats three primitives: (1) candidate generation (2) candidate evaluation (3) update of search state. Therefore, as a bare minimum, a tuning pass must allow developers convenient API for these three tuning primitives. Note that a tuning pass often wants to account the effect from other optimization passes for (2), so API design should consider this. Besides, there can be optional primitives such as cost model or data-driven prediction model (takes an To ease the development complexity, each tuning pass would generate its search space based on its input # Say search space of TuningPass1/2 is s1/s2 respectively.
# 1. Two sequential tuning passes: search space grows additively
Seqential([TuningPass1(), TuningPass2()])
# -> Search space: s1+s2
# 2. Joint optimization: search space grows combinatorially
Seqential([TuningPass1(eval_pass=TuningPass2())])
# -> Search space: s1*s2
# Since joint-optimization expands search space rapidly,
# it is highly recommended to apply two tuning passes in sequential
# if they are orthogonal to each other As a tuning pass can nest other tuning passes for joint-optimization, the trace of a sequence of transformation may not be easy to track. This would cause huge difficulty in understanding its behavior and debugging. Thus, inspired by MetaSchedule, we introduce
def convert_conv2d_NHWC(mod):
new_mod = ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(new_mod)
return new_mod
def convert_conv2d_NCHW(mod):
new_mod = ConvertLayout({"nn.conv2d": ["NCHW", "default"]})(new_mod)
return new_mod
def noapply(mod):
return mod
choices = {
"convert_conv2d_NHWC": Choice(convert_conv2d_NHWC),
"convert_conv2d_NCHW": Choice(convert_conv2d_NCHW),
"NoApply": Choice(noapply),
}
knob = Instruction("LayoutTransform", choices)
Trace length: 2
[1] MockTuningInst1: choice3
[2] MockTuningInst2: choice1 I want to clarify that the goal of the tuning pass is NOT to replace the heuristic passes. Each approach has its own unique strength and they can help each other to evolve. Conceptually, tuning pass and heuristic pass only have difference in the decision making process while considering the same set of transformation (e.g., once fusion decision is made by either tuning or heuristic methods, the process of applying the decision would be same.) Thus, tuning pass and heuristic pass would be likely share many API functions. Our design also aims to provide such common functionalities by maximizing the code reuse. Tuning API Design# Tuning API provides important primitives to implement a tuning method
# Current design integrates MetaSchedule builder/runner and database
# By default, it would use the database for a pass pipeline
# However, it can also manage its own database if necessary
# Users can define eval_passes to consider the interaction with other passes
# Classes
# A choice defines a valid transformation
# Instruction will consider each choice for its candidate generation
# To reduce the search space, each choice may be considered in a probabilistic manner
class Choice:
def __init__(self, func: Callable, constr=None, args=None):
self.func = func # transformation func
# it allows feedback loop to cover finer
# granuality of candidate tuning
# (e.g., subgraph tuning, see Collage example below)
self.constr = constr # constraints e.g., condition on tensor shape
self.args = args # arguments for func
class Instruction:
def __init__(
self, name: str, choices: Union[List[Choice], Dict[str, Choice], Dict[int, Choice]]
):
self.name = name
self.choices = choices
# Check if a decision is valid
def verify(self, decision: Union[str, int]) -> Boolean:
if isinstance(self.choices, dict):
return decision in self.choices
elif isinstance(self.choices, List):
return decision < len(self.choices)
else:
raise Exception("Invalid type for choices")
# Get a choice for a decision
def get_choice(self, decision: Union[str, int]) -> Choice:
assert self.verify(decision)
return self.choices[decision]
# Apply a decision to an input IRModule
def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule:
assert self.verify(decision)
return self.choices[decision].func(mod)
def __str__(self) -> str:
msg = f"{self.name} (# of choices: {len(self.choices)})\n"
if isinstance(self.choices, dict):
for name, choice in self.choices.items():
msg += f" - {name}: {choice}\n"
elif isinstance(self.choices, List):
for idx, choice in enumerate(self.choices):
msg += f" - {idx}: {choice}\n"
else:
raise Exception("Invalid type for choices")
return msg
# Trace maintains a sequence of instructions and their decisions.
# It maintains the input/output IRModule and its performance
class Trace:
def __init__(
self,
in_mod: IRModule,
trace: List[Tuple[Instruction, Union[str, int]]] = []
):
self.in_mod = in_mod
self.trace = trace
self.out_mod = self.apply(in_mod, trace)
self.perf = None
def verify(self):
for (knob, decision) in self.trace:
if not knob.verify(decision):
return False
return True
# Apply certain trace to input IRModule
def apply(self, in_mod: IRModule, trace: Trace) -> IRModule:
out_mod = copy.deepcopy(in_mod)
for knob, decision in trace:
if not knob.verify(decision):
raise Exception("Illegal decision in the trace")
out_mod = knob.apply(in_mod, decision)
self.perf = None
return out_mod
# Add a pair of intruction and its decision to the current trace
def add(self, knob: Instruction, decision: Union[str, int]) -> None:
self.out_mod = knob.apply(self.out_mod, decision)
self.trace.append((knob, decision))
self.perf = None
def __str__(self) -> str:
msg = f"Trace length: {len(self.trace)}\n"
for idx, (knob, decision) in enumerate(self.trace):
msg += f"[{idx+1}] {knob.name}: {decision}\n"
return msg
# Helper functions
# Generate the search space for a given trace by using registered choices
# To reduce the search space, it may expand each choice in a probablistic manner
# A developer can introduce a smart search strategies like multi-armed bandit
def generate_candidates(inst, trace: Trace, ctx: PassContext, eval_passes: List[Pass] = None) -> List[Trace]:
candidates = list()
for decision in inst.choices.keys():
choice = inst.choices[decision]
# Generate new candidate when this condition satisfies
if choice.constr:
new_trace = copy.deepcopy(trace)
new_trace.add(inst, decision)
candidates.append(new_trace)
# Expand candidates by using eval passes if available
if eval_passes:
candidates = consider_eval_passes(candidates, ctx, eval_passes)
return candidates
# Expands traces generated by current tuning pass with its eval passes
def consider_eval_passes(
seeds: List[Trace], ctx: PassContext, eval_passes: List[Pass] = None
) -> List[Trace]:
candidates = list(seeds)
num = len(candidates)
for i in range(num):
trace = candidates.pop(0)
for eval_pass in eval_passes:
# For heuristic pass, we create an know with single choice for tracking
if isinstance(eval_pass, HeuristicPass):
knob = Instruction(f"{eval_pass.name}", [Choice(eval_pass)])
trace.add(knob, 0)
# Tuning pass expands candidates by visiting its evaluation passes in dfs
else:
trace = eval_pass()(trace, ctx)
candidates.append(trace)
return candidates
# Evaluates each candidate with MetaSchedule Runner/Builder
# Its performance can be stored in MetaSchedule Database
def evaluate(ctx, candidates: List[Trace], eval_config, database):
# These targets will be retrieved from the ctx
target_str, target_host, device_id = (
ctx.config["target"],
ctx.config["target_host"],
ctx.config["device_id"],
)
target = tvm.target.Target(target_str)
device = tvm.device(target_str, device_id)
num_evals = 0
# Evaluation
for candidate in candidates:
if candidate.perf is not None:
continue
num_evals += 1
mod = candidate.out_mod
# Evaluate candidates
# Build candidate
builder = LocalBuilder(f_build= ... )
(builder_result,) = builder.build([BuilderInput(mod, target)])
assert builder_result.artifact_path is not None
assert builder_result.error_msg is None
runner_input = RunnerInput(
builder_result.artifact_path,
target_str,
[], # ArgInfo
)
runner = LocalRunner(
timeout_sec=100,
evaluator_config=eval_config,
f_run_evaluator=eval_func,
)
(runner_future,) = runner.run([runner_input])
runner_result = runner_future.result()
assert runner_result.error_msg is None
perfs = []
for result in runner_result.run_secs:
if isinstance(result, tvm.tir.FloatImm):
result = result.value
assert isinstance(result, float)
assert result >= 0.0
perfs.append(result)
# ...
candidate.perf = tuple([np.mean(perfs), np.std(perfs)])
if database is not None:
workload = database.commit_workload(mod)
record = TuningRecord(
trace,
perfs,
workload,
target,
[],
)
database.commit_tuning_record(record)
# Choose the best trace
def select_best_candidate(traces):
best_perf, best_trace = sys.maxsize, None
for candidate in traces:
(avg, std) = candidate.perf
# Select best one
if best_perf > avg:
best_perf = avg
best_trace = candidate
return best_trace
# Return trace wrapper if necessary
def get_trace(in_):
if isinstance(in_, Trace):
return in_
if isinstance(in_, IRModule):
return Trace(in_)
elif isinstance(in_, Expr):
return Trace(tvm.IRModule.from_expr(in_))
#...
else:
raise Exception("Invalid input type for pass")
# Extracts matching subgraph for subgraph-level tuning
def extract_subgraph(mod, pattern):
# ...
# [Optional] a cost model that estimates the performance of a trace
def query_cost_model(cost_model, trace:Trace)->float:
assert 0, "Need to implement"
# [Optional] a prediction model that predicts the optimized IRModule
# This can be done by heuristic like AutoTVM
# or data-driven approach like ApplyHistoryBest in MetaSchedule
def predict(mod: IRModule, ctx) -> IRModule:
assert 0, "Need to implement" ExampleSetupfrom TuningAPI import (
Choice,
Trace,
Instruction,
generate_candidates,
consider_eval_passes,
evaluate,
select_best_candidate
) Simple switching decisionclass TuningParallelConv2dPass(Pass):
def __init__(self, eval_passes=[], required=[], database=None):
super().__init__(
"TuneCombineParallelConv2D",
required=required,
)
self.eval_passes=eval_passes,
self.database=database
def tune(self, trace, ctx):
def apply(mod):
new_mod = InferType()(mod)
new_mod = CombineParallelConv2D(min_num_branches=2)(new_mod)
return new_mod
def noapply(mod):
return mod
choices = {"On": Choice(apply), "Off": Choice(noapply)}
# Tuning pass manages a set of transformation functions
inst = Instruction("InstructionTuningParallelConv2D", choices)
candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
evaluate(ctx, candidates, self.database)
best_trace = select_best_candidate(candidates)
return best_trace
def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
best_trace = self.tune(get_trace(mod), ctx)
return best_trace.out_mod Layout Transformationclass TuningLayoutPass(Pass):
def __init__(self, eval_passes=[], required=[], database=None):
super().__init__(
"TuneLayout",
required=required,
)
self.eval_passes=eval_passes
self.database=database
self.num_evals = 0
def tune(self, trace, ctx):
def convert_conv2d_NHWC(mod):
new_mod = ConvertLayout({"nn.conv2d": ["NHWC", ...]})(new_mod)
return new_mod
def convert_conv2d_NCHW(mod):
new_mod = ConvertLayout({"nn.conv2d": ["NCHW", ...]})(new_mod)
return new_mod
def noapply(mod):
return mod
choices = {
"convert_conv2d_NHWC": Choice(convert_conv2d_NHWC),
"convert_conv2d_NCHW": Choice(convert_conv2d_NCHW),
"NoApply": Choice(noapply),
}
inst = Instruction("InstructionTuningLayout", choices)
candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
evaluate(ctx, candidates, self.database)
best_trace = select_best_candidate(candidates)
return best_trace
def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
best_trace = self.tune(get_traced(mod), ctx)
return best_trace.out_mod TASOclass TASO(Pass):
def __init__(self, eval_passes={}, required=[], database=None):
super().__init__(
"TASO",
required=required,
)
self.eval_passes=eval_passes
self.database=database
self.num_evals = 0
def tune(self, trace, ctx):
q = PriorityQueue()
q.push(trace.in_mod)
best_trace, best_perf = None, 1e100
while not q.empty():
g = q.pop()
choices = []
for s in get_available_substitutions():
for l in get_available_layouts(g, s):
choice = Choice(
f"{s.name}_{l.name}",
get_rewriting_func(s,l)
)
choices.append(choice)
inst = Instruction("tune_rewriting", choices)
candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
evaluate(ctx, candidates, self.database)
best_cand = select_best_candidate(candidates)
if best_cand.perf < best_perf:
best_trace, best_perf = best_cand best_cand.perf
next_population = get_top_alpha(candidates, best_perf, alpha)
q.push(new_population)
return best_trace
def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
best_trace = self.tune(get_traced(mod), ctx)
return best_trace.out_mod Collageclass Collage(Pass):
def __init__(self, eval_passes={}, required=[], database=None):
super().__init__(
"Collage",
required=required,
)
self.eval_passes=eval_passes
self.database=database
self.num_evals = 0
def tune(self, trace, ctx):
def func(mod):
g = build_graph(trace.mod)
q = FrontierQueue() # priority queue sorted by node depth
q.push(g.get_root())
while not q.empty():
f = q.pop()
expr = f.get_expr()
choices = []
new_frontiers = []
for backend, pattern in get_available_backend():
if pattern.match(expr):
choice = Choice(
f"{backend.name}",
get_extract_and_annotate_func(pattern, backend)
)
choices.append(choice)
new_frontiers.append(get_new_frontiers(f, pattern))
inst = Instruction("tune_backend", choices)
new_trace = Trace(tvm.IRModule.from_expr(expr))
candidates = generate_candidates(inst, new_trace)
# You can manually expand candidates by using eval_pass for more control
new_candidates = []
for candidate in candidates:
backend_name = candidate.trace[-1][0].name
# Depending on the backend, apply different passes
eval_pass = self.eval_passes[backend_name]
cands = consider_eval_passes([candidate], ctx, eval_passes)
new_candidates.extend(cands)
evaluate(ctx, candidates, self.database)
best_trace = select_best_candidate(candidates)
update_placement(expr, best_trace)
q.push(new_frontiers)
return apply_best_placement(mod)
inst = Instruction("Collage", [Choice("Subgraph-tuning": func)])
return best_trace
def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
best_trace = self.tune(get_traced(mod), ctx)
return best_trace.out_mod Pipeline customization# 1. Apply single tuning pass
custom_pipeline = TuningParallelConv2dPass()
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 2
custom_pipeline = TuningLayoutPass()
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3
# Heuristic pass won't increase the search space
custom_pipeline = TuningLayoutPass(eval_passes=[MyHeuristicPass()])
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3
# 2. Apply two tuning passes in sequential
# This is useful when we know two tuning passes are orthogonal to each other
# (we don't always want combinatorial search space with joint-optimization)
custom_pipeline = Sequential([TuningParallelConv2dPass(), TuningLayoutPass()])
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 2+3
custom_pipeline = Sequential([TuningLayoutPass(), TuningParallelConv2dPass()])
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3+2
# 3. Joint-optimization
custom_pipeline = TuningParallelConv2dPass(eval_passes=[TuningLayoutPass()])
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 2*3
custom_pipeline = TuningLayoutPass(eval_passes=[TuningParallelConv2dPass()])
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*2
# Say we have a MockTuningPass with search space of 5
custom_pipeline = TuningLayoutPass(
eval_passes=[TuningParallelConv2dPass(
eval_passes=[MockTuningPass()]
)]
)
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*2*5
custom_pipeline = TuningLayoutPass(
eval_passes=[TuningParallelConv2dPass(), MockTuningPass()]
)
with PassContext(config=config):
mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*(2+5) C++ implementationnamespace transform {
// Since both heuristic and tuning methods live in the same file,
// code sharing is natural
Pass MockHeurstic() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
ConstantFolder folder(m);
return Downcast<Function>(folder(f));
};
return CreateFunctionPass(pass_func, 0, "MockHeuristic", {});
}
TVM_REGISTER_GLOBAL("relax.transform.MockHeuristic").set_body_typed(MockHeuristic);
Pass MockTune(){
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> tune_pass_func =
[=](Function f, IRModule m, PassContext pc) {
// Implement your tuner
new_f = MockTuner().tune(f);
return Downcast<Function>(new_f);
};
return CreateFunctionPass(tune_pass_func, 0, "MockTuner", required={"InferType"});
}
TVM_REGISTER_GLOBAL("relax.transform.MockTune").set_body_typed(MockTune); Comparison with Prior Work
Opportunity
Limitations
Discussion
|
Like we briefly discussed durning the open development meeting, I think it would be great to start our brainstorming what we want to enable in Relax. I sketched some of my thoughts on existing approaches, so feel free to add comments if you have yours or any questions. I will put them together and bring to our next discussion meetings so that we can address those issues to our initial design.
Motivation
Recent studies demonstrate that feedback from low-level information can substantially help various high-level decisions. For example, TASO searches for the best form of the computation graph by exploring various graph rewriting rules (e.g., layout transformation) with such feedback. Also, Collage finds the most efficient multi-backend execution strategy in a feedback-directed fashion (related discussion: #46). There are various studies regarding flag-level tuning approaches as well.
However, conventional pass infrastructure is designed based on the idea of progressive lowering and it cannot provide seamless integration of these tuning approaches. This prevents from the adoption of various tuning approaches across the different abstraction layers and their joint optimization opportunities. (e.g., TASO+Collage)
Thus, we want to open up new opportunity by offering natural integration of tuning approaches with new pass infrastructure design. Please note that we will separate these optimization passes from the build (related discussion: #49).
Existing Approaches
The text was updated successfully, but these errors were encountered: