Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] Add a strength reduction pass #944

Open
2 tasks done
xumingkuan opened this issue May 10, 2020 · 14 comments
Open
2 tasks done

[Opt] Add a strength reduction pass #944

xumingkuan opened this issue May 10, 2020 · 14 comments
Labels
enhancement Make existing things or codebases better welcome contribution

Comments

@xumingkuan
Copy link
Collaborator

xumingkuan commented May 10, 2020

Update after #1065:
TODO:

  • Add operators shl and shr for optimization
  • a % pot (power of two constant) -> a & (pot - 1)

Concisely describe the proposed feature
I would like to add a pass like https://en.wikipedia.org/wiki/Strength_reduction.
A list of peephole optimizations in this pass (done in #1065):

(Feel free to add more optimizations to this list)

Describe the solution you'd like (if any)
I would like to place this pass in full_simplify like alg_simp.

Additional comments
Shall we replace <i32x1> a * c where c is a constant power of 2 with a << log2(c) (and also replace <i32x1> a / c with a >> log2(c))? Do we have << and >> operators?

Shall we just write the optimization in the alg_simp pass and add an AlgSimp::alg_is_two function instead of introducing a new pass?

Shall we expose functions AlgSimp::alg_is_one to somewhere so that more passes can use them, and we can split the useless assertion elimination to a new pass?

@xumingkuan xumingkuan added the feature request Suggest an idea on this project label May 10, 2020
@xumingkuan xumingkuan self-assigned this May 10, 2020
@yuanming-hu
Copy link
Member

Thanks for proposing this!

Shall we replace <i32x1> a * c where c is a constant power of 2 with a << log2(c) (and also replace <i32x1> a / c with a >> log2(c))? Do we have << and >> operators?

Yes, that would be helpful. But we need to add shl and shr (and their signed versions) as 4 UnaryOpType first. We will also have to implement the code generators for a few backends. This should be a separate issue.

Shall we just write the optimization in the alg_simp pass and add an AlgSimp::alg_is_two function instead of introducing a new pass?

Sounds good. I don't think there's a need to separate two passes.

Shall we expose functions AlgSimp::alg_is_one to somewhere so that more passes can use them, and we can split the useless assertion elimination to a new pass?

All SGTM.

@archibate
Copy link
Collaborator

<< and >>

Cool! power of 2 is really common. But we don't have that yet, but we can have that in a sep PR for that purpose. eg. bit_shl, bit_shr.

pow(x, 2)

Sounds better than the current sol, thx! I guess we will move the python-side pow opt all into this c++ pass.

@k-ye
Copy link
Member

k-ye commented May 11, 2020

Great! Could you also optimize a % pot (power of two constant) -> a & (pot - 1)?

Currently, % is translated to:

quotient = Expr(taichi_lang_core.expr_floordiv(self.ptr, other.ptr))
multiply = Expr(taichi_lang_core.expr_mul(other.ptr, quotient.ptr))
return Expr(taichi_lang_core.expr_sub(self.ptr, multiply.ptr))

And it's preventing such an optimization.

Right now if I change this

if i % 2 == 1:
to i & 1 == 1, then the example runs faster by ~4% on Metal... (i.e. 24sps -> 25sps)

@archibate
Copy link
Collaborator

And it's preventing such an optimization.

Let's make it expr_python_mod at some point then.

@xumingkuan
Copy link
Collaborator Author

xumingkuan commented May 11, 2020

Great! Could you also optimize a % pot (power of two constant) -> a & (pot - 1)?

Currently, % is translated to:

quotient = Expr(taichi_lang_core.expr_floordiv(self.ptr, other.ptr))
multiply = Expr(taichi_lang_core.expr_mul(other.ptr, quotient.ptr))
return Expr(taichi_lang_core.expr_sub(self.ptr, multiply.ptr))

And it's preventing such an optimization.

Right now if I change this

if i % 2 == 1:

to i & 1 == 1, then the example runs faster by ~4% on Metal... (i.e. 24sps -> 25sps)

I wonder why we didn't just use BinaryOpType::mod...

@xumingkuan
Copy link
Collaborator Author

Test case:

ti.init(print_ir=True)

@ti.kernel
def func(i: ti.i32):
    if i % 2 == 1:
        print(i)

func(1)

Result:

kernel {
  $0 = offloaded  {
    <i32 x1> $1 = arg[0]
    <i32 x1> $2 = const [2]
    <i32 x1> $3 = floordiv $1 $2
    <i32 x1> $4 = mul $2 $3
    <i32 x1> $5 = sub $1 $4
    <i32 x1> $6 = const [1]
    <i32 x1> $7 = cmp_eq $5 $6
    <i32 x1> $8 = bit_and $6 $7
    $9 : if $8 {
      <i32 x1> $10 = arg[0]
      print i, $10
    }
  }
}
[debug] i = 1

@archibate
Copy link
Collaborator

I wonder why we didn't just use BinaryOpType::mod...

BinaryOpType::mod is a C/C++ mod (ti.raw_mod), while the % is python mod (ti.mod).
One day we'll refactor the name to BinaryOpType::raw_mod then.

@xumingkuan
Copy link
Collaborator Author

    <i32 x1> $3 = floordiv $1 $2
    <i32 x1> $4 = mul $2 $3
    <i32 x1> $5 = sub $1 $4

Maybe we can just replace $5 with mod $1 $2 here?

@yuanming-hu
Copy link
Member

yuanming-hu commented Jul 5, 2020

    <i32 x1> $3 = floordiv $1 $2
    <i32 x1> $4 = mul $2 $3
    <i32 x1> $5 = sub $1 $4

Maybe we can just replace $5 with mod $1 $2 here?

Yeah, but before that maybe we need an IR pattern matcher so that we can quickly locate patterns like these in Taichi IR.

@xumingkuan
Copy link
Collaborator Author

Yeah, but before that maybe we need an IR pattern matcher so that we can quickly locate patterns like these in Taichi IR.

I wonder how an "IR pattern matcher" works? binary_op_simplify does hand-written pattern checking which looks very simple to me, and how will it shorten the code or increase the speed if we use a pattern matcher?

@yuanming-hu
Copy link
Member

Halide's source files that starts with Simplify_ are what we can learn from. For example:
https://github.com/halide/Halide/blob/666d4cfa690ae877be7d207743e3a501ea9f4625/src/Simplify_Add.cpp#L51

@xumingkuan
Copy link
Collaborator Author

xumingkuan commented May 11, 2021

Test case:

ti.init(print_ir=True)

@ti.kernel
def func(i: ti.i32):
    if i % 2 == 1:
        print(i)

func(1)

Result:

kernel {
  $0 = offloaded  {
    <i32 x1> $1 = arg[0]
    <i32 x1> $2 = const [2]
    <i32 x1> $3 = floordiv $1 $2
    <i32 x1> $4 = mul $2 $3
    <i32 x1> $5 = sub $1 $4
    <i32 x1> $6 = const [1]
    <i32 x1> $7 = cmp_eq $5 $6
    <i32 x1> $8 = bit_and $6 $7
    $9 : if $8 {
      <i32 x1> $10 = arg[0]
      print i, $10
    }
  }
}
[debug] i = 1

Current result when the data type is u32 after #2332:

kernel {
  $0 = offloaded  
  body {
    <i32> $1 = const [1]
    <u32> $2 = arg[0]
    <u32> $3 = const [1]
    <u32> $4 = bit_sar $2 $3
    <u32> $5 = bit_shl $4 $3
    <u32> $6 = sub $2 $5
    <i32> $7 = cmp_eq $6 $3
    <i32> $8 = bit_and $7 $1
    $9 : if $8 {
      print $2, "\n"
    }
  }
}

@xumingkuan
Copy link
Collaborator Author

I found it hard to optimize a % pot (power of two constant) -> a & (pot - 1) when a is a signed integral type because a % pot is not a & (pot - 1) when a is negative.

@xumingkuan
Copy link
Collaborator Author

Maybe we want to refer to numpy/numpy#17727 for further optimizations.

@ailzhang ailzhang added this to To do in Compiler Frontend & Middle-end via automation Feb 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Make existing things or codebases better welcome contribution
Development

No branches or pull requests

4 participants