Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[DISCUSS] Relax Pass Infrastructure #71

Open
sunggg opened this issue Jan 21, 2022 · 10 comments
Open

[DISCUSS] Relax Pass Infrastructure #71

sunggg opened this issue Jan 21, 2022 · 10 comments

Comments

@sunggg
Copy link
Collaborator

sunggg commented Jan 21, 2022

Like we briefly discussed durning the open development meeting, I think it would be great to start our brainstorming what we want to enable in Relax. I sketched some of my thoughts on existing approaches, so feel free to add comments if you have yours or any questions. I will put them together and bring to our next discussion meetings so that we can address those issues to our initial design.

Motivation

Recent studies demonstrate that feedback from low-level information can substantially help various high-level decisions. For example, TASO searches for the best form of the computation graph by exploring various graph rewriting rules (e.g., layout transformation) with such feedback. Also, Collage finds the most efficient multi-backend execution strategy in a feedback-directed fashion (related discussion: #46). There are various studies regarding flag-level tuning approaches as well.

However, conventional pass infrastructure is designed based on the idea of progressive lowering and it cannot provide seamless integration of these tuning approaches. This prevents from the adoption of various tuning approaches across the different abstraction layers and their joint optimization opportunities. (e.g., TASO+Collage)

Thus, we want to open up new opportunity by offering natural integration of tuning approaches with new pass infrastructure design. Please note that we will separate these optimization passes from the build (related discussion: #49).

Existing Approaches

  • TVM Relay: Only provides heuristics passes in progressive lowering fashion (passes apply one-by-one sequentially in one-direction from high-level to low-level)
    • Various heuristic passes are applied within build. Since there is no separation between build and such optimization passes, it is quite awkward to bring existing tuning approaches as tuning methods need to invoke build to collect low-level feedback. (Build within build problem)
    • Each pass applies in sequential order. Thus, joint-optimization is only possible in greedy manner; each pass makes its own optimization by assuming other passes did or will do their best.
  • Google XTAT: Introduces couple of tuning passes into the XLA. There are several (potential) limitations.
    • Their support is limited to node-configuration-based tuning approaches. (i.e., they cannot support graph rewriting methods like TASO).
    • Their search process seems to be sequential. Although this approach may prune significant search space, it might be suboptimal since later optimization decision may affect the earlier optimization decision.
    • Even though they claim that introducing new tuning pass is simple, it is unclear how actually it is. For example, developers may need to handle fixing illegal combinations and decide the hyperparameter (e.g., neighborhood size)
@comaniac
Copy link
Contributor

We are actually facing a similar issue that may require joint optimization between passes in training. For example, AutoCast (AMP) may choose to insert a new cast op or reuse a already inserted on. Reusing cast ops could maintain the execution order, but it means the output of one cast op is used by more than two subsequent ops, which prevents fusion from happening. In addition, we also have an issue about which backend should be used for each op. As a result, we are thinking an infra for joint optimization, including AutoCast, Fusion, DialectOpLowering, and Rematerialization.

In terms of Google XTAT. Some thoughts about your comments:

  1. I agree that per node configuration seems not flexible enough. As mentioned in Section II.D, using XTAT as a subroutine could workaround this limitation, but it also means the pass developer has to consider a larger scope other than just the pass itself. However, an advantage of node configuration I could think of is to enable partial tuning, which may save lots of time.
  2. I actually don't against sequential pass application, because some passes are implemented with an assumption that the IR has to be XX in advance. XX can be like "fused", "simplified", "canonicalized", "dead code eliminated", etc. IMHO, as long as we have the tuning process (i.e., joint optimization) so that early applied passes could still make good decisions for later passes, fixing the order seems not an issue and could largely simplify the design (for both developing an infra and passes).
  3. This echos previous two points about the developer experience, and this is the most important part to me.

A vague approach in my mind is having a configuration similar to AutoTVM. Specifically, developers could use the tuning APIs to represent tunable parameters in a pass, and the tuning infra is able to collect them and figure out the best combination. Pros: developers can optionally add a dimension of tuning configurations without worrying about anything else. Cons: 1) the tuning infra basically has no idea about what it is tuning, so it's hard to improve the tuning efficiency from the search algorithms. 2) the configuration scope is per pass, making it as the smallest granularity.

Just my two cents. The approach I mentioned is definitely not optimal, but hopefully this could let others chime in for better ideas.

@sunggg
Copy link
Collaborator Author

sunggg commented Jan 21, 2022

@comaniac, thank you for your input and I totally agree with your thoughts. I expect there would be more exciting opportunities in training since there are more interesting operations and data coming into play, like you gave examples. Regarding sequential pass application, I also think it should be one of the fundamental principles in our new infra. However, it will be great if we can allow some opportunity for "true" joint-optimization. I might have an idea, but let me polish it and bring it to our discussion with other's inputs in this thread.

@sunggg
Copy link
Collaborator Author

sunggg commented Feb 28, 2022

Hi, all! Hope you are all doing well.
Since our last discussion, I've worked on identifying more concrete challenges in the current pass infra and drafted an initial design to address them. Let me start with the summary of challenges I found.

Major Challenges in Current Pass Infra

  • C1: Bringing a new tuning method requires significant engineering efforts and tweaks.
    • Since lowering pipeline is one-directional from high-level abstraction layer to low-level (progressive lowering), it is tricky to provide the feedback from low-level to high-level optimization.
    • Currently, optimization passes live within the build pipeline. Thus, if a tuning method is implemented as a pass, it needs to invoke build within the build pipeline to collect low-level feedback. (build with in build problem)
    • Thus, existing tuning methods either build their own standalone framework or tweak the build pipeline. However, this is not scalable and labor-intensive. When there are more than one tuning passes, the complexity grows very quickly.
  • C2: No official support for joint-optimization.
    • This is the outcome of C1.
    • Some promising joint-optimization opportunities are
      • TASO+Collage
      • AutoCast+Fusion+Rematerialization for training as @comaniac mentioned
      • Kernel tuning + flag-level tuning at codegen (e.g., llvm flag)
  • C3: Build process follows a central pipeline with various control logics. It is not easy to support various contexts (e.g., target-dependent passes, conflicting passes) in a flexible and customizable manner.
    • To cover various contexts within a central pipeline, current infra provides various customization knobs, such as diverse if-statements in core pipeline, opt_level, PassContext, and custom keyword arguments like relay.ext.tensorrt.options. However, it is hard to understand which optimization passes apply at which context at glance. This is expected to worsen as we broaden our supports for more diverse contexts, such as dynamic execution, training, quantization, mixed precision, etc.
    • Although some stakeholders want to apply their own passes while leverage existing passes as much as possible, there is no scalable and systematic solution for conflicting passes or target-specific passes. They may maintain their own branch to address such issues, but it is not a desirable solution and it would make future merge/fetch painful after the heavy customization. Some examples are as follows:
      • Accelerators may need its own partitioning/scheduling passes with its own codegen as in the active discussion for UMA
      • Winograd optimization pass should not apply when target backend library provides it.
      • Relay rewriting passes may conflict with TASO-like graph rewriting.
  • C4: Under the hood, a pass infra has many "implicit" behaviors which make customization difficult.
    • Since later optimization may interfere earlier optimization decision, such pass sequence should be explicit and clear. e.g., layout decision in graph-level may be altered by the layout tuning in TIR level.
    • Due to C3, there is a few weird behavior (potentially a bug) in the current pass infra. Currently, each pass needs to register its minimum opt_level to expose easy customizability to users. In the following example, only 1 and 4 apply both optimization pass. Also, 4 applies more than provided two passes for some reason: InferType, InferType, CombineParallelConv2d, InferType. It is not explicit what we should expect from the application of passes without pass instrumentation.
      # We want to apply the following two optimization passes: InferType and CombineParallelConv2D
      # Minimum opt_level - InferType: 0, CombineParallelConv2D: 4
      # 1. Direct invocation
      #    -> Applies both optimization
      mod = relay.transform.InferType()(mod)
      out1 = relay.transform.CombineParallelConv2D(min_num_branches=2)(mod)
      
      # 2. Use transform.Sequential
      #    -> Does not apply CombineParallelConv2D
      seq = tvm.transform.Sequential(
          passes=[relay.transform.InferType(), relay.transform.CombineParallelConv2D(min_num_branches=2)],
      )
      out2 = seq(mod)
      
      # 3. Specify opt_level explicitly in Sequential
      #    -> Still, does not apply CombineParallelConv2D 
      seq = tvm.transform.Sequential(
          opt_level=4,
          passes=[relay.transform.InferType(), relay.transform.CombineParallelConv2D(min_num_branches=2)],
      )
      out3 = seq(mod)
      
      # 4. PassContext
      #    -> Finally applies CombineParallelConv2D
      #       Didn't even need to specify opt_level at Sequential
      seq = tvm.transform.Sequential(
          passes=[relay.transform.InferType(), relay.transform.CombineParallelConv2D(min_num_branches=2)],
      )
      with tvm.transform.PassContext(opt_level=4):
          mod4 = seq(mod)
      
    • Dependency between passes are also implicit and hard-coded in its pass registration.
  • C5 No experimental optimization level for new optimization. If a new optimization may have unexpected side effect, we may want gradual integration.

Thus, we want to address these challenges with new pass infrastructure design.

Termonology

We define two kinds of optimization or analysis passes.

  • Heuristic pass: Pass without any feedback loop
  • Tuning pass: Pass uses feedback loop(s) that consist of three repeating steps: (1) candidate generation (2) evaluation (3) reflection of feedback
    • Any optimization decisions can be adaptively adjusted by using feedback loops. Switching decision (e.g., enable/disable an optimization), ordering of optimizations, mutation decisions (e.g., layout transformation) are representative tuning problems. For example,
      • Graph-rewriting (mutation) problem: TASO
      • Backend placement problem: Collage
      • Flag-level switching/ordering: various studies to identify the best flag sequence/ordering for llvm, gcc
    • Tuning pass may want to apply existing optimization passes (i.e., heuristic pass, tuning pass, or both) for its candidate evaluation. For example, you may want to apply ConstantFolding for your candidate.

Design Goal

New pass infrastructure design aims to provide the flexible and composable optimization pipeline with following goals.

  • G0: Legacy support for existing heuristic passes and build in progress lowering fashion
  • G1: Natural integration of tuning methods
  • G2: Address various interactions between passes
    • Interaction between heuristic passes
    • Interaction between tuning passes
    • Interaction between heuristic pass and tuning pass
  • G3: Unlock true joint-optimization across abstraction layers
  • G4: Explicit handling of
    • Target-dependent optimizations
    • Conflicting passes
    • Dependency between passes
    • New pass with a potential risk
  • G5: Offer decent development experience
    • Clear design principle for developers to clarify wow to design a custom pass and how to integrate new pass with existing passes.
    • First class metaprogramming and customization experience

Design Principles

We propose the following design principles:

  • P1: Essentially, each pass is a IRModule → IRModule transformation. With Relax, this IRModule may contain various abstraction layers.
  • P2: A tuning pass can invoke other tuning/heuristic passes for its evaluation (pass hierarchy). This allows the candidate evaluation with the consideration for later optimizations and unlocks joint-optimization.
  • P3: Decouple pass from the build system. After transformation with a sequence of passes, users generate the optimized executable with universal build. This build can be invoked by tuning pass for the candidate evaluation.
  • P4: Users can explicitly customize and register any sequence of pass pipeline for their contexts. For example, if users wants to manage new partitioning pass that may conflict with existing pass, they can register new sequence of passes to the pass infrastructure without the need to modify or add hooks in build pipeline.
  • P5: Pass infra provides basic common functionalities
    • Interface for pass hierarchy.
    • Pass dependency checker
    • IRModule sanity checker
    • ...

These principles will enable followings. (H: heuristic pass, T: tuning pass, eval_pass: passes for the candidate evaluation)

  • Conventional progressive lowering style of heuristic passes
    seq = [ H1, H2, H3, H4, ... ]
  • Introduction of a tuning pass with existing heuristic passes
    seq = [H1, H2, T1, H3, ...] seq = [H1, H2, T1(eval_passes=[H3, H4]), H5, ...]
  • Joint-optimization
    seq = [H1, H2, T1(eval_passes=[T2, ...]), H4, ...]
  • Target-specific lowering
    • acc1_seq = [H1, H2, H3 ... ]
    • acc2_seq = [H1, H2’, H3', ... ]

Synergy with Relax

New pass infrastructure will have unique synergy with Relax. Since Relax aims to express different abstraction layers in IRModule and offer a universal build for such IRModule, it will allow various opportunities to explore such as,

  • Joint-optimization between graph-level transformations and TIR-level transformations
  • Exploration of lowering decision outside of build pipeline. This will allow easier integration of external functions in PackedFuction.
  • Optimizations for dynamic models (e.g., optimization with symbolic tensor shapes)
  • ...

Class Design

Please note that this is a pseudo code. For actual implementation, we may extend the existing data structures and functionalities if possible. For example, given that MetaSchedule provides a great fundamental APIs, such as builder/runner, database API, cost model, etc. for generic tuning methods, @junrushao1994, @zxybazh and I are planning to discuss how we can extend it for the tuning pass with generic tuning methods beyond kernel-level tuning. Also, currently, there are different functions depending on whether you are handling Function or IRModule . Their handling is omitted for simplicity.

Pass Class

# Base pass class
class Pass():
    # Specify dependent passes
    def __init__(self, required=[]):
        self.required = required

# Pass context class
class PassContext():
    def __init__(
         self,
         # ... include current fields ...
         target, # this is necessary for evaluation in tuning pass.  
     ):
        # ...

# Base class for heuristic pass
# It will look similar to current pass design.
class HeuristicPass(Pass):
    # Actual implementation for optimizations/analysis
    def transform_module(
           self, 
           mod: IRModule, 
           ctx: PassContext
         )->IRModule:
        # ... contents ...
   

# Base class for tuning pass
class TuningPass(Pass):
    def __init__(self, eval_passes, eval_metric, measure_option):
        super().__init__()
        # Passes for evaluation to enable joint-optimization
        self.eval_passes = eval_passes
        # Evaluation criteria for candidates (e.g., execution time)
        self.eval_metric = eval_metric
        # Measurement option
        self.measure_option = measure_option
    
    # Use metaschedule for evaluation
    def evaluate(self, ctx, candidates):
        target = ctx.config["target"]
        # Evaluation
        scoreboard = {}
        for candidate in candidates:
            # Apply pass group before build
            seq = tvm.transform.Sequential(self.eval_passes)     
            candidate = seq(candidate)
            
            # Leverage metaschedule builder/runner to get score
            score = ...
            scoreboard[candidate] = score
        return scoreboard
    
    # Different tuning methods may have different cost model. Can we extend metaschedule?
    @staticmethod
    def query_cost_model(candidates):
        pass

    # Can we extend metaschedule cost model?
    @staticmethod
    def update_database(...):
        pass
    
    @staticmethod
    def select_best_candidate(scoreboard):
        # ... select the best candidate depending on the eval_metric ...
        return best_candidate

     # Actual implementation for optimizations/analysis
     # This will have feedback loops following repeating steps
     # (1) candidate generations
     # (2) candidate evaluation: this will call evaluate() method
     # (3) pick the best candidate and reflect feedback  
     def transform_module(
           self, 
           mod: IRModule, 
           ctx: PassContext
         )->IRModule:
        # ... contents ...
    

# Some useful APIs
# Sanity check for IRModule
def validate(mod: IRModule) -> Bool:
  # ... validataion logic ...

# Dependency check for a sequence of passes
def validate(seq: Array<IRModule>) -> Bool:
  # ... validation logic ...

# Extract certain part of graph in interest
# This will be useful for subgraph benchmarking
def extract_subgraph(mod: Expr) -> Expr:
  # ... extraction mechanics ... 

[TBD] Data structure for the communication with the build system.

This requires a discussion. Look at D1 for details. This will define the exploration space for optimization passes. (e.g., can we explore fusion decisions?)

[TBD] Pass sequence registration interface.

This requires a discussion. Look at D5 for details.

Developer PoV

Developers can design their own custom passes and perform any IRModule -> IRModule transformation. As an example, we can design simple mock tuning passes that decides whether to apply certain heuristic pass based on the low-level feedback.

# This mock tunining pass fuses parallel matmul
@ir.transform.module_pass(opt_level=1)
class TuningCombineParallelMatmul(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)
            
    def transform_module(
                  self, 
                  mod: IRModule, 
                  ctx: PassContext)->IRModule:
        # Candidate generation
        new_mod = transform.CombineParallelMatmul()(mod)
        # Two candiate: Do you want to enable it? or disable it?
        candidate_pool = [mod, new_mod]
        scoreboard = self.evaluate(ctx, candidate_pool)
        best_perf, best_mod = self.select_best_candidate(scoreboard)
        return best_mod
    

# This mock tunining pass makes layout transform decision  
@ir.transform.module_pass(opt_level=1)
class TuningLayout(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)
            
    def transform_module(self, mod: IRModule, ctx: PassContext)->IRModule:
        # Candidate generation
        new_mod = transform.LayoutTransform()(mod)
        # Two candiate: Do you want to enable it? or disable it?
        candidate_pool = [mod, new_mod]
        scoreboard = self.evaluate(ctx, candidate_pool)
        best_perf, best_mod = self.select_best_candidate(scoreboard)
        return best_mod

Depending on what developers want, they can run each pass separately, in sequential, or in joint-optimization fashion.

# Run TuningLayout pass only  
custom_pass = TuningLayout()
optimized_mod = custom_pass(mod)

# Run TuningLayout and TuniningCombineParallelMatmul sequentially
# You can also change the order easily
seq = [ TuningLayout(), TuningCombineParallelMatmul() ]
custom_pipeline = tvm.transform.Sequential(seq)     
optimized_mod = custom_pass(mod)

# Run joint-optimization
seq = [ TuningLayout(eval_passes = [TuningCombineParallelMatmul()]) ]
custom_pipeline = tvm.transform.Sequential(seq)     
optimized_mod = custom_pass(mod)

# Later, you can generate executable by using relax build system.
lib = relax.vm.build(optimized_mod)

Developers can also design a tuning pass for more interesting optimization decisions if the build system supports (e.g., fusion decision, lowering decisions, graph rewriting). Like previous example, they can also easily test out joint-optimization.

# TASO-like tuning pass
@ir.transform.module_pass(opt_level=1)
class TuningGraphRewriting(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)
            
    def transform_module(
         self, 
         mod: IRModule, 
         ctx: PassContext)->IRModule:
	      # Some tuning approaches repeat search within their tuning budget
	      budget = ...
	      while budget>0:
	            # ... some analysis ...
		    # (1) Generate candidates
		    candidates = get_promising_rewritten_graph(expr) 
		    # (2) Evaluate candidates with other passes and minimal build
		    scoreboard = self.evaluate(ctx, candidate_pool)
		    # ...
		    # (3) Reflect the feedback
		    best_perf, best_mod = self.select_best_candidate(scoreboard)
		    # ... generate next promising candidates based on the current feedback ...
               return optimzied_mod
		

# Collage-like tuning pass
@ir.transform.module_pass(opt_level=1)
class TuningBackendPlacement(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)
            
    def transform_module(
         self, 
         mod: IRModule, 
         ctx: PassContext)->IRModule:
              # ... some analysis ...
	      for node in post_order_dfs(mod):
	          # ... some analysis ...
		  # (1) Generate candidates
		  candidates = get_available_backend_candidates(node) 
		  # (2) Evaluate candidates with other passes and minimal build
		  scoreboard = self.evaluate(ctx, candidate_pool)
		  # ...
		  # (3) Reflect the feedback
		 best_perf, best_mod = self.select_best_candidate(scoreboard)
		 annoate(node, best_candidate) 
		 # ...
	return optmized_mod

If you are interested in, you can also play around with its prototype in Relay world here: link

Discussion Points

  • D1: What kind of decisions should we expose to this pass infra and how can the pass infrastructure flows optimization decision to the build system? (related: [DISCUSS] Relax minimum build pipeline  #49)
    • With Relax, pass infra may manage only Relax->Relax passes and TIR->TIR passes, but also lowering decisions and more.
    • Where should we put this new information? Do we need to introduce new data structure AttributeTable (pair<Relax var, attributes>)? Or can we just add new attribute fields in Relax IR? Some other options are extending IRModule, PassContext
      @R.function
              def main(x: Tensor[(32, 32), "float32"], w: Tensor[(32, 32), "float32"]) -> Tensor:
                  with R.dataflow():
                      lv0 = R.call_tir((32, 32), tir_matmul, (x, w))
                      lv1 = R.call_tir((32, 32), tir_relu, (lv0)) 
      
      # Attribute table
      { lv0, attrs= [
                        dev = gpu0, 
                        backend = tvm, 
                        fusion_group = 1,
                        pass = custom_pass, 
                        flag = ["-finline -fCSE ..." ]
                     ...
                     ]
         ...
      } 
      
      # Input distribution for dynamic graph
      {
           ["data", ...]
      }
      
    • Currently, tvm also supports passing a dict to the build module to certain extent
      with tvm.target.cuda() as cuda_tgt:
           s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
           m1 = tvm.lower(s1, [A, B, C], name="test_add1")
           m2 = tvm.lower(s2, [A, B, C], name="test_add2")
           rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
  • D2: How does minimum build naturally integrate custom codegen (BYOC)? (related: [DISCUSS] Relax minimum build pipeline  #49)
    • Think at user PoV
    • Previous works on composable codegen: TVM BYOC, JAX, Functorch, MLIR
  • D3: Search space can be explosive when tuning multiple abstraction layers. It would be great if we can define it in a smart way or prune it with domain knowledge.
    • Also, how can we deal with same optimization across different abstraction layer? e.g., loop unrolling in Relax-level, TIR-level, codegen-level
  • D4: How do we force dependency between passes in intuitive way? Currently, this design allows to specify the dependent passes at the instantiation of a pass and validates whether the given sequence satisfies dependency constraints. Would this be enough? Some may want to apply certain passes with certain conditions (UMA)
  • D5: How do we provide default passe sequence? We may provide default heuristic pipeline, tuning pipeline, and an experimental pipeline for gradual integration. One possible way.
    # User PoV
    seq = tvm.relax.load_pass_sequence("default_heuristic")
    seq = tvm.relax.load_pass_sequence("default_tuning")
    
    # Experiment sequences for gradual integration
    seq = tvm.relax.load_pass_sequence("experiment_heuristic")
    seq = tvm.relax.load_pass_sequence("experiment_tuning")
    
    # Easy to check what's in sequence
    print(seq)
  • D6: How can we offer a decent customization experience? Can we add more interactive features? (e.g., check what kinds of implementations are registered for an operator)
  • D7: Can tuning pass now support generic tuning techniques? Some tuning algorithms conduct algorithmic level tuning. (e.g., FFT Tuner).
    • Some of thoughts: current tuning pass can support generic tuning methods as long as they are not input-dependent. If the best choice depends on the input, it would require run-time support for dynamic adjustment.
  • D8 (Low priority): Project management tool
    • LLVM-style project management might be useful to allow user-side pass development outside TVM codebase link
    • Enable both C++/Python implementation
    # Example for user's project directory
    my_pass
    |- src    : C++ codes live here
    |- python : Python codes live here
    |- build  : build directory
    |- cmake-related files/directories
    |- ...
    # Users can import their pass (living outside of tvm core) in their script.
    from my_pass import cool_optimization
    
    seq = [ 
       ...
       cool_optimization
       ...
    ]
    ...
    with tvm.transform.PassContext(...):
        mod = seq(...)

Any feedback or thoughts would be greatly appreciated! Since this might be a bold design, I would like to face potential issues early.

@sunggg
Copy link
Collaborator Author

sunggg commented Mar 1, 2022

Thanks for valuable inputs today! And definitely I would appreciate more feedback if you have any.
Feel free to leave more :)

Some of feedback during the meeting

  1. Use program feature for pass application - link
  2. Dependency checker
  3. Some tuning passes may require their own parameters - how can we systemically pass them?
  4. Better design to explore the order of phases - can we make it more intuitive and organic?
  5. Investigate the potentially related work for tuning pass - MLIR? XLA?

@ZihengJiang
Copy link
Contributor

ZihengJiang commented Mar 3, 2022

Hi @sunggg , thanks for the great proposal. I have a few questions:

  • For the tuning pass with eval passes usage T1(eval_passes=[H3, H4]), will the H3 and H4 happen before or after T1?
  • If we want to do joint-optimization with several tunable passes, each tunable pass has its own eval passes, how can we represent this with current API?

@sunggg
Copy link
Collaborator Author

sunggg commented Mar 3, 2022

Hi, @ZihengJiang

  • For the tuning pass with eval passes usage T1(eval_passes=[H3, H4]), will the H3 and H4 happen before or after T1?

eval_passes will apply for candidate evaluation. Each candidate will be evaluated after applying the given passes. example

  • If we want to do joint-optimization with several tunable passes, each tunable pass has its own eval passes, how can we represent this with current API?

Currently, I'm thinking to allow their own eval passes and trying to find potential issues. If you find any corner case, that would be very helpful. Current design specifies eval pass on pass invocation. If this is not what you asked for, would you elaborate a little further?

class TuningPass(Pass):
   def __init__(self, eval_passes, eval_metric, measure_option):
        super().__init__()
        # Passes for evaluation to enable joint-optimization
        self.eval_passes = eval_passes

@hypercubestart
Copy link
Collaborator

hi @sunggg , thanks for the great work! Some random ideas/thoughts about pass order infrastructure:

  • In Relay, quantization workflow is split into 3 separate passes: QuantizeAnnotate, QuantizeCalibrate, QuantizeRealize. For QuantizeCalibrate, we need to represent required pre-passes (annotate) and post-passes (realize), and in addition the passes must be run one after the other. This might challenging to represent in the current architecture unless all three quantize passes are combined into a single pass
  • also follow-up to @ZihengJiang 's question about joint-optimization. If I had two passes LayoutTransform (TIR)/LayoutRewrite(Graph), and wanted to joint-optimize, does this require rewriting a completely new tunable pass? If so, how would this interact with passes that may expect to require LayoutTransform but not the new pass?

@sunggg
Copy link
Collaborator Author

sunggg commented Mar 3, 2022

@hypercubestart, thank you for your input!

  • In Relay, quantization workflow is split into 3 separate passes: QuantizeAnnotate, QuantizeCalibrate, QuantizeRealize. For QuantizeCalibrate, we need to represent required pre-passes (annotate) and post-passes (realize), and in addition the passes must be run one after the other. This might challenging to represent in the current architecture unless all three quantize passes are combined into a single pass

Yes, great point. I think I may miss in the design in the above, but I'm planning to add required_passes in base Pass class and make sure the given sequence satisfies such constraints. Like we briefly discussed during meeting, I'm considering to provide this information on pass instantiation like we do for eval_pass for more flexible configuration, rather than hard-coding like we do in the current infra. (e.g., the same optimization may need different analysis passes depending on whether it is static model or dynamic model). I will make this clear in the next draft.

  • also follow-up to @ZihengJiang 's question about joint-optimization. If I had two passes LayoutTransform (TIR)/LayoutRewrite(Graph), and wanted to joint-optimize, does this require rewriting a completely new tunable pass? If so, how would this interact with passes that may expect to require LayoutTransform but not the new pass?

Theoretically, if you want to joint-optimize two independent passes, you would just include one pass in the eval pass of another one. e.g., seq = [T1(eval_passes=T2)]
If you are 100% sure that those two passes don't need to be separated, you may consider designing a pass that combines both passes. However, although you introduce this new tuning pass, you don't necessarily need to remove the existing passes because other passes may depend on them like you described. And of course, since this may introduce some overlaps between passes, I think we need to carefully think what passes we want to provide by default while providing all customization support as an infrastructure.

And I think your example brought up excellent point. If two passes may do similar jobs in the different abstraction layers like LayoutTransform (TIR)/LayoutRewrite(Graph), later one may revert earlier one's decision. I believe this is an open question, so I added it as one of the discussion point D3 in my proposal.

@slyubomirsky
Copy link
Collaborator

A brief thought I brought up in our meeting but wanted to leave in writing: For the management of passes and when they are applicable, we may want some automatic way of specifying which program features passes are designed to handle and automatically check for them like the seldom-used feature flags in Relay. Feature flags like those (but coupled with automatic enforcement) could allow for detecting certain kinds of bugs and incompatibilities in advance. (Alternatively, we could have the norm that all passes are expected to support all program features, which would also require testing them against all program features to be certain of that.)

@YuchenJin YuchenJin changed the title [DISCUSSION] Relax Pass Infrastructure [DISCUSS] Relax Pass Infrastructure Mar 23, 2022
@sunggg
Copy link
Collaborator Author

sunggg commented Apr 20, 2022

Hi, all. I'd like to discuss the first formal version of Tuning Pass design. Once we reach on the agreement, I'd like to start working on the implementation.

Backgrounds

Today’s tuning methods

What you tune

  • Graph (e.g., TASO, Collage)
  • Subgraph if you want
  • Kernel/tir-level (e.g., AutoTVM, AutoScheduler, MetaSchedule)
  • Codegen options (e.g,. cutlass)
  • Heuristic decisions, phase ordering (e.g., conventional auto-tuning methods)

How you tune

Depending on what you accept as input and how you generate candidates

  • Input: Graph, Candidate:Graph (e.g, TASO)
  • Input: Graph, Candidate:Subgraph (e.g., Collage)
  • Input: Subgraph, Candidate: Kernel
  • Input: Kernel, Candidate:Kernel (e.g., AutoTVM, AutoScheduler, MetaSchedule)

Search methods - different tuning methods may favor different strategy

Depending on whether you stop tuning anytime or not, we can categorize the search into two kinds.

  • K1: You can stop tuning anytime
    • e.g., Evolutionary search, simulated annealing, multi-armed bandit, random, reinforcement leanring
  • K2: You cannot stop tuning until its completion
    • e.g., DP, ILP

Goal

  • G1: Provide essential tuning primitives, common supports and interfaces
    • Natural integration with MetaSchedule
  • G2: Unlock easy customization of joint-optimizations
  • G3: Offer intuitive development experience (e.g, debugging)
  • G4: Find a synergistic use of both tuning pass and heuristic pass
    • Each tuning pass and heuristic pass have trade-offs
    • We want to evolve both passes together within the positive feedback loop

Design Overview

Fundamentally, tuning is a feedback-directed search that repeats three primitives: (1) candidate generation (2) candidate evaluation (3) update of search state. Therefore, as a bare minimum, a tuning pass must allow developers convenient API for these three tuning primitives. Note that a tuning pass often wants to account the effect from other optimization passes for (2), so API design should consider this. Besides, there can be optional primitives such as cost model or data-driven prediction model (takes an IRModule and predicts the optimized IRModule, potentially in a data-driven manner. ApplyHistoryBest is a representative example. AutoTVM has a scoring heuristic to find the closest workload in the database.)

heu_tuning_tradeoff

To ease the development complexity, each tuning pass would generate its search space based on its input IRModule - each pass does not need to worry what other passes are doing as long as each pass is guaranteed to produce a valid IRModule . Later, a user can customize how to apply these tuning passes when defining a sequence of a pass pipeline.

# Say search space of TuningPass1/2 is s1/s2 respectively.

# 1. Two sequential tuning passes: search space grows additively
Seqential([TuningPass1(), TuningPass2()])
# -> Search space: s1+s2

# 2. Joint optimization: search space grows combinatorially 
Seqential([TuningPass1(eval_pass=TuningPass2())])
# -> Search space: s1*s2

# Since joint-optimization expands search space rapidly, 
# it is highly recommended to apply two tuning passes in sequential
# if they are orthogonal to each other

As a tuning pass can nest other tuning passes for joint-optimization, the trace of a sequence of transformation may not be easy to track. This would cause huge difficulty in understanding its behavior and debugging. Thus, inspired by MetaSchedule, we introduce Instruction and Trace

  • Instruction: A knob for a code transformation. It defines a set of choices. A single tuning pass can apply more than one Instruction.
def convert_conv2d_NHWC(mod):
    new_mod = ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(new_mod)
    return new_mod

def convert_conv2d_NCHW(mod):
    new_mod = ConvertLayout({"nn.conv2d": ["NCHW", "default"]})(new_mod)
    return new_mod

def noapply(mod):
    return mod

choices = {
    "convert_conv2d_NHWC": Choice(convert_conv2d_NHWC),
    "convert_conv2d_NCHW": Choice(convert_conv2d_NCHW),
    "NoApply": Choice(noapply),
}
knob = Instruction("LayoutTransform", choices)
  • Trace: Instruction + decision (i.e., one of its choices)
Trace length: 2
[1] MockTuningInst1: choice3
[2] MockTuningInst2: choice1

I want to clarify that the goal of the tuning pass is NOT to replace the heuristic passes. Each approach has its own unique strength and they can help each other to evolve.

heu_tuning_relation

Conceptually, tuning pass and heuristic pass only have difference in the decision making process while considering the same set of transformation (e.g., once fusion decision is made by either tuning or heuristic methods, the process of applying the decision would be same.) Thus, tuning pass and heuristic pass would be likely share many API functions. Our design also aims to provide such common functionalities by maximizing the code reuse.

pass_design_concept

Tuning API Design

# Tuning API provides important primitives to implement a tuning method
# Current design integrates MetaSchedule builder/runner and database
# By default, it would use the database for a pass pipeline
# However, it can also manage its own database if necessary
# Users can define eval_passes to consider the interaction with other passes

# Classes
# A choice defines a valid transformation
# Instruction will consider each choice for its candidate generation
# To reduce the search space, each choice may be considered in a probabilistic manner
class Choice:
     def __init__(self, func: Callable, constr=None, args=None):
         self.func = func       # transformation func 
                                # it allows feedback loop to cover finer 
                                # granuality of candidate tuning 
                                # (e.g., subgraph tuning, see Collage example below)
         self.constr = constr   # constraints e.g., condition on tensor shape
         self.args = args       # arguments for func

class Instruction:
     def __init__(
         self, name: str, choices: Union[List[Choice], Dict[str, Choice], Dict[int, Choice]]
     ):
         self.name = name
         self.choices = choices

     # Check if a decision is valid
     def verify(self, decision: Union[str, int]) -> Boolean:
         if isinstance(self.choices, dict):
             return decision in self.choices
         elif isinstance(self.choices, List):
             return decision < len(self.choices)
         else:
             raise Exception("Invalid type for choices")

     # Get a choice for a decision
     def get_choice(self, decision: Union[str, int]) -> Choice:
         assert self.verify(decision)
         return self.choices[decision]
   
     # Apply a decision to an input IRModule
     def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule:
         assert self.verify(decision)
         return self.choices[decision].func(mod)

     def __str__(self) -> str:
         msg = f"{self.name} (# of choices: {len(self.choices)})\n"
         if isinstance(self.choices, dict):
             for name, choice in self.choices.items():
                 msg += f"  - {name}: {choice}\n"
         elif isinstance(self.choices, List):
             for idx, choice in enumerate(self.choices):
                 msg += f"  - {idx}: {choice}\n"
         else:
             raise Exception("Invalid type for choices")
         return msg

 # Trace maintains a sequence of instructions and their decisions.
 # It maintains the input/output IRModule and its performance
 class Trace:
     def __init__(
                    self, 
                    in_mod: IRModule, 
                    trace: List[Tuple[Instruction, Union[str, int]]] = []
                  ):
         self.in_mod = in_mod
         self.trace = trace
         self.out_mod = self.apply(in_mod, trace)
         self.perf = None

     def verify(self):
         for (knob, decision) in self.trace:
             if not knob.verify(decision):
                 return False
         return True

     # Apply certain trace to input IRModule
     def apply(self, in_mod: IRModule, trace: Trace) -> IRModule:
         out_mod = copy.deepcopy(in_mod)
         for knob, decision in trace:
             if not knob.verify(decision):
                 raise Exception("Illegal decision in the trace")
             out_mod = knob.apply(in_mod, decision)
         self.perf = None
         return out_mod

     # Add a pair of intruction and its decision to the current trace
     def add(self, knob: Instruction, decision: Union[str, int]) -> None:
         self.out_mod = knob.apply(self.out_mod, decision)
         self.trace.append((knob, decision))
         self.perf = None

     def __str__(self) -> str:
         msg = f"Trace length: {len(self.trace)}\n"
         for idx, (knob, decision) in enumerate(self.trace):
             msg += f"[{idx+1}] {knob.name}: {decision}\n"
         return msg

# Helper functions
# Generate the search space for a given trace by using registered choices
# To reduce the search space, it may expand each choice in a probablistic manner
# A developer can introduce a smart search strategies like multi-armed bandit
def generate_candidates(inst, trace: Trace, ctx: PassContext, eval_passes: List[Pass] = None) -> List[Trace]:
   candidates = list()
   for decision in inst.choices.keys():
       choice = inst.choices[decision]
       # Generate new candidate when this condition satisfies
       if choice.constr:
           new_trace = copy.deepcopy(trace)
           new_trace.add(inst, decision)
           candidates.append(new_trace)
   # Expand candidates by using eval passes if available
   if eval_passes:
       candidates = consider_eval_passes(candidates, ctx, eval_passes)
   return candidates

# Expands traces generated by current tuning pass with its eval passes
def consider_eval_passes(
   seeds: List[Trace], ctx: PassContext, eval_passes: List[Pass] = None
) -> List[Trace]:
   candidates = list(seeds)
   num = len(candidates)
   for i in range(num):
       trace = candidates.pop(0)
       for eval_pass in eval_passes:
           # For heuristic pass, we create an know with single choice for tracking
           if isinstance(eval_pass, HeuristicPass):
               knob = Instruction(f"{eval_pass.name}", [Choice(eval_pass)])
               trace.add(knob, 0)
           # Tuning pass expands candidates by visiting its evaluation passes in dfs
           else:
               trace = eval_pass()(trace, ctx)

       candidates.append(trace)
   return candidates

# Evaluates each candidate with MetaSchedule Runner/Builder
# Its performance can be stored in MetaSchedule Database
 def evaluate(ctx, candidates: List[Trace], eval_config, database):
    # These targets will be retrieved from the ctx
    target_str, target_host, device_id = (
	       ctx.config["target"],
	       ctx.config["target_host"],
	       ctx.config["device_id"],
    )
    target = tvm.target.Target(target_str)
    device = tvm.device(target_str, device_id)
    
    num_evals = 0
    # Evaluation
    for candidate in candidates:
	if candidate.perf is not None:
	     continue
	 num_evals += 1
	 mod = candidate.out_mod
	 # Evaluate candidates
	 # Build candidate
	 builder = LocalBuilder(f_build= ... )
	 (builder_result,) = builder.build([BuilderInput(mod, target)])

	 assert builder_result.artifact_path is not None
	 assert builder_result.error_msg is None

	 runner_input = RunnerInput(
	     builder_result.artifact_path,
	     target_str,
	     [],  # ArgInfo
	 )

	 runner = LocalRunner(
	     timeout_sec=100,
	     evaluator_config=eval_config,
	     f_run_evaluator=eval_func,
	 )

	 (runner_future,) = runner.run([runner_input])
	 runner_result = runner_future.result()

	 assert runner_result.error_msg is None
	 perfs = []
	 for result in runner_result.run_secs:
	     if isinstance(result, tvm.tir.FloatImm):
	         result = result.value
	     assert isinstance(result, float)
	     assert result >= 0.0
	     perfs.append(result)

	 # ...
	 candidate.perf = tuple([np.mean(perfs), np.std(perfs)])

	 if database is not None:
	     
	     workload = database.commit_workload(mod)
	     record = TuningRecord(
	         trace,
	         perfs,
	         workload,
	         target,
	         [],  
	     )
	     database.commit_tuning_record(record)
	
# Choose the best trace
def select_best_candidate(traces):
    best_perf, best_trace = sys.maxsize, None
    for candidate in traces:
        (avg, std) = candidate.perf
         # Select best one
         if best_perf > avg:
             best_perf = avg
             best_trace = candidate
    return best_trace
	

# Return trace wrapper if necessary
def get_trace(in_):
     if isinstance(in_, Trace):
         return in_
     if isinstance(in_, IRModule):
         return Trace(in_)
     elif isinstance(in_, Expr):
         return Trace(tvm.IRModule.from_expr(in_))
     #...
     else:
         raise Exception("Invalid input type for pass")
   
# Extracts matching subgraph for subgraph-level tuning
def extract_subgraph(mod, pattern):
    # ...
	
# [Optional] a cost model that estimates the performance of a trace
def query_cost_model(cost_model, trace:Trace)->float:
     assert 0, "Need to implement"
	
# [Optional] a prediction model that predicts the optimized IRModule
# This can be done by heuristic like AutoTVM 
# or data-driven approach like ApplyHistoryBest in MetaSchedule
def predict(mod: IRModule, ctx) -> IRModule:
     assert 0, "Need to implement"

Example

Setup

from TuningAPI import (
        Choice, 
        Trace, 
        Instruction, 
        generate_candidates, 
        consider_eval_passes,
        evaluate, 
        select_best_candidate
)

Simple switching decision

class TuningParallelConv2dPass(Pass):
     def __init__(self, eval_passes=[], required=[], database=None):
         super().__init__(
             "TuneCombineParallelConv2D",
             required=required,
         )
				 self.eval_passes=eval_passes,
         self.database=database

     def tune(self, trace, ctx):
         def apply(mod):
						 new_mod = InferType()(mod)
             new_mod = CombineParallelConv2D(min_num_branches=2)(new_mod)
             return new_mod

         def noapply(mod):
             return mod

         choices = {"On": Choice(apply), "Off": Choice(noapply)}
         # Tuning pass manages a set of transformation functions
         inst = Instruction("InstructionTuningParallelConv2D", choices)
         candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
         evaluate(ctx, candidates, self.database)
         best_trace = select_best_candidate(candidates)
         return best_trace

     def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:  
         best_trace = self.tune(get_trace(mod), ctx)
         return best_trace.out_mod

Layout Transformation

class TuningLayoutPass(Pass):
     def __init__(self, eval_passes=[], required=[], database=None):
         super().__init__(
             "TuneLayout",
             required=required,
         )
         self.eval_passes=eval_passes
         self.database=database
         self.num_evals = 0

     def tune(self, trace, ctx):
        def convert_conv2d_NHWC(mod):
				   new_mod = ConvertLayout({"nn.conv2d": ["NHWC", ...]})(new_mod)
           return new_mod

        def convert_conv2d_NCHW(mod):
           new_mod = ConvertLayout({"nn.conv2d": ["NCHW", ...]})(new_mod)
           return new_mod

        def noapply(mod):
           return mod

        choices = {
           "convert_conv2d_NHWC": Choice(convert_conv2d_NHWC),
           "convert_conv2d_NCHW": Choice(convert_conv2d_NCHW),
           "NoApply": Choice(noapply),
         }
         inst = Instruction("InstructionTuningLayout", choices)
         candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
         evaluate(ctx, candidates, self.database)
         best_trace = select_best_candidate(candidates)
         return best_trace

     def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
         best_trace = self.tune(get_traced(mod), ctx)
         return best_trace.out_mod

TASO

taso

class TASO(Pass):
	 def __init__(self, eval_passes={}, required=[], database=None):
	     super().__init__(
	         "TASO",
	         required=required,
	     )
	     self.eval_passes=eval_passes
	     self.database=database
	     self.num_evals = 0
	
	 def tune(self, trace, ctx):
	   q = PriorityQueue()
	   q.push(trace.in_mod)
	   best_trace, best_perf = None, 1e100
	   while not q.empty():
	      g = q.pop()
	      choices = []
	      for s in get_available_substitutions():
	         for l in get_available_layouts(g, s):
	              choice = Choice(
	                          f"{s.name}_{l.name}", 
	                          get_rewriting_func(s,l)
	                      )
	              choices.append(choice)
	      inst = Instruction("tune_rewriting", choices)
	      candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
	      evaluate(ctx, candidates, self.database)
	      best_cand = select_best_candidate(candidates)
	      if best_cand.perf < best_perf:
	          best_trace, best_perf = best_cand best_cand.perf
	      next_population = get_top_alpha(candidates, best_perf, alpha)
	      q.push(new_population)
	          
	    return best_trace
	
	 def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
	     best_trace = self.tune(get_traced(mod), ctx)
	     return best_trace.out_mod

Collage

Screen Shot 2022-04-21 at 4 40 49 PM

class Collage(Pass):
    def __init__(self, eval_passes={}, required=[], database=None):
          super().__init__(
	         "Collage",
	         required=required,
	     )
	  self.eval_passes=eval_passes
	  self.database=database
	  self.num_evals = 0
	
    def tune(self, trace, ctx):
        def func(mod):
            g = build_graph(trace.mod)
	    q = FrontierQueue() # priority queue sorted by node depth
	    q.push(g.get_root())
	    while not q.empty():
	        f = q.pop()
	        expr = f.get_expr()
	        choices = []
	        new_frontiers = []
	        for backend, pattern in get_available_backend():
	           if pattern.match(expr):
	              choice = Choice(
	                        f"{backend.name}", 
	                        get_extract_and_annotate_func(pattern, backend)
	               )
	              choices.append(choice)
	              new_frontiers.append(get_new_frontiers(f, pattern))
	        inst = Instruction("tune_backend", choices)
	        new_trace = Trace(tvm.IRModule.from_expr(expr))
	        candidates = generate_candidates(inst, new_trace)
                # You can manually expand candidates by using eval_pass for more control
                new_candidates = []
                for candidate in candidates:
                    backend_name = candidate.trace[-1][0].name
                    # Depending on the backend, apply different passes
                    eval_pass = self.eval_passes[backend_name]
	            cands = consider_eval_passes([candidate], ctx, eval_passes)
                    new_candidates.extend(cands)
	       evaluate(ctx, candidates, self.database)
	       best_trace = select_best_candidate(candidates)
	       update_placement(expr, best_trace)
	       q.push(new_frontiers)
            return apply_best_placement(mod)
      inst = Instruction("Collage", [Choice("Subgraph-tuning": func)])
      return best_trace
	
    def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
	 best_trace = self.tune(get_traced(mod), ctx)
	 return best_trace.out_mod

Pipeline customization

# 1. Apply single tuning pass
custom_pipeline = TuningParallelConv2dPass()     
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 2

custom_pipeline = TuningLayoutPass()     
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 3

# Heuristic pass won't increase the search space
custom_pipeline = TuningLayoutPass(eval_passes=[MyHeuristicPass()])     
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 3

# 2. Apply two tuning passes in sequential
# This is useful when we know two tuning passes are orthogonal to each other
# (we don't always want combinatorial search space with joint-optimization)
custom_pipeline = Sequential([TuningParallelConv2dPass(), TuningLayoutPass()])
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 2+3

custom_pipeline = Sequential([TuningLayoutPass(), TuningParallelConv2dPass()])
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 3+2

# 3. Joint-optimization
custom_pipeline = TuningParallelConv2dPass(eval_passes=[TuningLayoutPass()])  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 2*3

custom_pipeline = TuningLayoutPass(eval_passes=[TuningParallelConv2dPass()])  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*2

# Say we have a MockTuningPass with search space of 5
custom_pipeline = TuningLayoutPass(
     eval_passes=[TuningParallelConv2dPass(
                      eval_passes=[MockTuningPass()]
                  )]
)  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*2*5

custom_pipeline = TuningLayoutPass(
     eval_passes=[TuningParallelConv2dPass(), MockTuningPass()]
)  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*(2+5)

C++ implementation

namespace transform {
// Since both heuristic and tuning methods live in the same file,
// code sharing is natural

Pass MockHeurstic() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        ConstantFolder folder(m);
        return Downcast<Function>(folder(f));
      };
  return CreateFunctionPass(pass_func, 0, "MockHeuristic", {});
}

TVM_REGISTER_GLOBAL("relax.transform.MockHeuristic").set_body_typed(MockHeuristic);

Pass MockTune(){
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> tune_pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        // Implement your tuner
        new_f = MockTuner().tune(f);
        return Downcast<Function>(new_f);
      };
  return CreateFunctionPass(tune_pass_func, 0, "MockTuner", required={"InferType"});
}
TVM_REGISTER_GLOBAL("relax.transform.MockTune").set_body_typed(MockTune);

Comparison with Prior Work

  • Relay passes

    • Existing tuning passes: AutoTVM, AutoScheduler, MetaSchedule, Collage
    • Kernel tunings, such as AutoTVM, are done offline outside of pass pipeline. (we just load the best config inside the pipeline)
    • Collage RFC aims to joint-optimize backend placement and kernel tuning, but the extensibility is not clear - it may cause headache to add another tuning layer for layout and graph transformation
  • XTAT in Google XLA

    • Strength
      • Unlock customizable joint-optimization framework
      • Single global search method → easy control over the tuning process
    • Weakness
      • They split the search space generation of a single pass into multiple steps. So, its joint-optimization looks close to a fine-grained sequential approach
        • Their joint optimization: [A,B,C,A,B,C ...]
        • Their sequential optimization: [A,A,A,B,B,B...]
      • Some critical user-side decisions seem hard to make
        • How do we split each pass? How do we decide the pass sequence? (e.g., [A,B,C,A,B,C...]? [A,A,B,B,C,B,C]?)
      • The interaction between tuning/heuristic passes is unclear and may be over-simplified and suboptimal
        • When generating candidates within each pass, XTAT fixes configurations of latter optimizations by using heuristic decision or the best node decision so far
        • Unclear how good enough it is
      • Only support per-node decisions → cannot support graph-rewriting like TASO
      • Unclear what developer should do to introduce a tuning pass and how easy the process is
      • Single global search method → limits search methods and tuning method designs

Opportunity

  • Extensive support for tuning methods and their joint-optimization
    • Every year, we see new compiler tuning techniques. However, it is unclear how to bring those innovations into the framework and deploy them.
    • This design can provide one of the first solutions. I believe this is an interesting step forward towards Compiler 2.0
      [slides: Compiler2.0 - CGO keynote - 04-22.pdf] [video]

Limitations

  • Cannot support input-dependent optimization choices. However, I believe it is okay since we have not observed such tuning practices yet.
  • Still, we start from the fixed model imported from frontends and we are allowed to transform the graph within the computational equivalence.

Discussion

  • Intuitive interface, Debuggability - Can we make sth like TensorBoard?
  • How can we use the tuning time wisely in the presence of multiple tuning methods?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants