Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Add simple HLO if conversion pass #22974

Merged
merged 2 commits into from Mar 18, 2019

Commits on Dec 15, 2018

  1. [XLA] Add simple HLO if conversion pass

    kConditional operations are currently generally disallowed in parallel contexts
    (e.g. in mapped computations). The julia XLA frontend was running into this limitation
    quite a bit, because existing julia code tends to use the terniary operator for select,
    e.g. to describe the derivative of a `max` call (and thus a `relu`) - see the
    definitions of the derivatives of `max` at
    https://github.com/JuliaDiff/DiffRules.jl/blob/master/src/rules.jl#L94
    
    To support these sorts of patterns, add a simple if conversion pass that converts
    conditionals in parallel context by equivalent select calls (which are well supported),
    i.e. a computation like:
    
    ```
    if {
     %pif = () parameter(0)
     ROOT %cif = f32[] constant(0)
    }
    
    else {
     %pelse = () parameter(0)
     ROOT %celse = f32[] constant(1)
    }
    
    mapped {
     %a = f32[] parameter(0)
     %b = f32[] parameter(1)
     %lt = pred[] less-than(%a, %b)
     %t = () tuple()
     ROOT %conditional = f32[] conditional(%lt, %t, %t), true_computation=if, false_computation=else
    }
    
    ENTRY comp {
     %p1 = f32[1000]{0} parameter(0)
     %p2 = f32[1000]{0} parameter(1)
     ROOT %mapped = f32[1000]{0} map(%p1, %p2), dimensions={0}, to_apply=mapped
    }
    ```
    
    gets rewritten to
    
    ```
    mapped {
     %a = f32[] parameter(0)
     %b = f32[] parameter(1)
     %cif = f32[] constant(0)
     %celse = f32[] constant(1)
     %lt = pred[] less-than(%a, %b)
     ROOT %select = f32[] select(%lt, %cif, %celse)
    }
    
    ENTRY comp {
     %p1 = f32[1000]{0} parameter(0)
     %p2 = f32[1000]{0} parameter(1)
     ROOT %mapped = f32[1000]{0} map(%p1, %p2) dimensions={0} to_apply=mapped
    }
    ```
    
    To keep things simple, this is accomplished by first rewriting the conditional
    to two calls and a select and then inlining the individual calls. Naturally,
    the transformation is only applied if the called computation do not
    have side effects (which they generally don't if they're in parallel
    context). In the future, it would be good to let MapInliner further
    simplify this to an implicitly mapped select.
    Keno committed Dec 15, 2018
    Configuration menu
    Copy the full SHA
    27b9374 View commit details
    Browse the repository at this point in the history

Commits on Mar 16, 2019

  1. Configuration menu
    Copy the full SHA
    7026001 View commit details
    Browse the repository at this point in the history