Skip to content

Commit

Permalink
Update on "[Inductor][CPP] Add Min/Max with VecMask"
Browse files Browse the repository at this point in the history
**Summary**
Fix issue: #126824 which is missing the support of `min/max` with `VecMask`.

**TestPlan**
```
python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_clamp_max_cpu_bool
python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_clamp_min_cpu_bool
```

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
leslie-fang-intel committed May 23, 2024
1 parent c8ab9d3 commit 90a84cc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 14 deletions.
12 changes: 0 additions & 12 deletions aten/src/ATen/cpu/vec/vec_mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,6 @@ class VecMask {
return result;
}

static VecMask<T, N> minimum(
const VecMask<T, N>& a,
const VecMask<T, N>& b) {
return VecMask<T, N>::blendv(b, a, a < b);
}

static VecMask<T, N> maximum(
const VecMask<T, N>& a,
const VecMask<T, N>& b) {
return VecMask<T, N>::blendv(b, a, a > b);
}

void store(bool* b, int count = size()) {
constexpr int L = (VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1)/ Vectorized<bool>::size();
auto res = this->to<bool, L>();
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ def minimum(a, b):
if a.dtype == torch.bool:
assert b.dtype == torch.bool
a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b))
return f"decltype({a_cast})::minimum({a_cast}, {b_cast})"
return f"{a_cast} & {b_cast}"
else:
return f"at::vec::minimum({a}, {b})"

Expand All @@ -1327,7 +1327,7 @@ def maximum(a, b):
if a.dtype == torch.bool:
assert b.dtype == torch.bool
a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b))
return f"decltype({a_cast})::maximum({a_cast}, {b_cast})"
return f"{a_cast} | {b_cast}"
else:
return f"at::vec::maximum({a}, {b})"

Expand Down

0 comments on commit 90a84cc

Please sign in to comment.