-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fix vectorized ops.masked #130130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vectorized ops.masked #130130
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130130
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ec91d11 with merge base 6f275ae ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix LGTM. Let's wait for @jgong5 to approve the
torch/_inductor/codegen/cpp.py
Outdated
assert isinstance(var, CppCSEVariable) | ||
assert var.dtype is not None | ||
if not var.is_vec: | ||
if not (var.is_vec or (mask and mask.is_vec)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the existing test cover the this - either var
or mask
is vec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new test does cover when var
is not vec, but mask
is.
torch/_inductor/codegen/cpp.py
Outdated
|
||
var_dtype = var.dtype | ||
lower_scalar = lower | ||
upper_scalar = upper | ||
var_scalar = var | ||
if not var.is_vec: | ||
var = f"{self._get_vec_type(var_dtype)}({var})" | ||
if mask and not mask.is_vec: | ||
mask = f"{self._get_vec_type(var_dtype)}({mask})" | ||
if lower: | ||
lower = f"{self._get_vec_type(var.dtype)}({lower})" | ||
lower = f"{self._get_vec_type(var_dtype)}({lower})" | ||
if upper: | ||
upper = f"{self._get_vec_type(var.dtype)}({upper})" | ||
upper = f"{self._get_vec_type(var_dtype)}({upper})" | ||
if lower and upper: | ||
cond = f"({lower} <= {var}) & ({var} < {upper})" | ||
cond_print = f"{lower_scalar} <= {var} < {upper_scalar}" | ||
cond_print = f"{lower_scalar} <= {var_scalar} < {upper_scalar}" | ||
elif lower: | ||
cond = f"{lower} <= {var}" | ||
cond_print = f"{lower_scalar} <= {var}" | ||
cond_print = f"{lower_scalar} <= {var_scalar}" | ||
else: | ||
assert upper | ||
cond = f"{var} < {upper}" | ||
cond_print = f"{var} < {upper_scalar}" | ||
cond = f"({self._get_mask_type(var.dtype)}({cond})).all_masked()" | ||
cond_print = f"{var_scalar} < {upper_scalar}" | ||
cond = f"{self._get_mask_type(var_dtype)}({cond})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant something like
lower_scalar = lower
upper_scalar = upper
if var.is_vec:
if lower:
lower = f"{self._get_vec_type(var.dtype)}({lower})"
if upper:
upper = f"{self._get_vec_type(var.dtype)}({upper})"
if lower and upper:
cond = f"({lower} <= {var}) & ({var} < {upper})"
cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
elif lower:
cond = f"{lower} <= {var}"
cond_print = f"{lower_scalar} <= {var}"
else:
assert upper
cond = f"{var} < {upper}"
cond_print = f"{var} < {upper_scalar}"
if not var.is_vec:
cond = f"{self._get_mask_type(var.dtype)}({cond})"
if mask and not mask.is_vec:
mask = f"{self._get_mask_type(var.dtype)}({mask})"
Doing it this way, we don't vectorize the whoel comparison unnecessarily, we perform the scalar comparison and then we vectorize the result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining and the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much cleaner, thanks.
torch/_inductor/codegen/cpp.py
Outdated
mask = f"({mask}).all_masked()" | ||
return super().indirect_assert(var, lower, upper, mask) | ||
|
||
var_dtype = var.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit delete
assert isinstance(var, CppCSEVariable) | ||
assert var.dtype is not None | ||
if not var.is_vec: | ||
if isinstance(mask, CppCSEVariable) and mask.is_vec: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this equivalent to if mask and mask.is_vec
, or can mask
be of a different type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature is
def indirect_assert(
self,
var: Union[CSEVariable, str],
lower: Optional[str],
upper: Optional[str],
mask: Optional[str] = None,
)
in torch._inductor.codegen.common.Kernel
Should have been Optional[CSEVariable, str]
just like var
. I was guarding against str
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #130299 Approved by: https://github.com/lezcano ghstack dependencies: #130130
Pull Request resolved: #130130 Approved by: https://github.com/jgong5, https://github.com/lezcano
Pull Request resolved: pytorch#130130 Approved by: https://github.com/jgong5, https://github.com/lezcano
Pull Request resolved: pytorch#130299 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#130130
Pull Request resolved: pytorch#130130 Approved by: https://github.com/jgong5, https://github.com/lezcano
Pull Request resolved: pytorch#130299 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#130130
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang