Skip to content

Conversation

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jul 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130130

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ec91d11 with merge base 6f275ae (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@isuruf isuruf changed the title Check that _load_mask is not set in indirect_assert Fix vectorized ops.masked Jul 5, 2024
[ghstack-poisoned]
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add test?

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jul 9, 2024
ghstack-source-id: 826dd85
Pull Request resolved: #130130
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix LGTM. Let's wait for @jgong5 to approve the

[ghstack-poisoned]
@lezcano lezcano requested a review from jgong5 July 12, 2024 07:39
assert isinstance(var, CppCSEVariable)
assert var.dtype is not None
if not var.is_vec:
if not (var.is_vec or (mask and mask.is_vec)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the existing test cover the this - either var or mask is vec?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new test does cover when var is not vec, but mask is.

[ghstack-poisoned]
Comment on lines 2594 to 2617

var_dtype = var.dtype
lower_scalar = lower
upper_scalar = upper
var_scalar = var
if not var.is_vec:
var = f"{self._get_vec_type(var_dtype)}({var})"
if mask and not mask.is_vec:
mask = f"{self._get_vec_type(var_dtype)}({mask})"
if lower:
lower = f"{self._get_vec_type(var.dtype)}({lower})"
lower = f"{self._get_vec_type(var_dtype)}({lower})"
if upper:
upper = f"{self._get_vec_type(var.dtype)}({upper})"
upper = f"{self._get_vec_type(var_dtype)}({upper})"
if lower and upper:
cond = f"({lower} <= {var}) & ({var} < {upper})"
cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
cond_print = f"{lower_scalar} <= {var_scalar} < {upper_scalar}"
elif lower:
cond = f"{lower} <= {var}"
cond_print = f"{lower_scalar} <= {var}"
cond_print = f"{lower_scalar} <= {var_scalar}"
else:
assert upper
cond = f"{var} < {upper}"
cond_print = f"{var} < {upper_scalar}"
cond = f"({self._get_mask_type(var.dtype)}({cond})).all_masked()"
cond_print = f"{var_scalar} < {upper_scalar}"
cond = f"{self._get_mask_type(var_dtype)}({cond})"
Copy link
Collaborator

@lezcano lezcano Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant something like

        lower_scalar = lower
        upper_scalar = upper
        if var.is_vec:
            if lower:
                lower = f"{self._get_vec_type(var.dtype)}({lower})"
            if upper:
                upper = f"{self._get_vec_type(var.dtype)}({upper})"
        if lower and upper:
            cond = f"({lower} <= {var}) & ({var} < {upper})"
            cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
        elif lower:
            cond = f"{lower} <= {var}"
            cond_print = f"{lower_scalar} <= {var}"
        else:
            assert upper
            cond = f"{var} < {upper}"
            cond_print = f"{var} < {upper_scalar}"
        if not var.is_vec:
            cond = f"{self._get_mask_type(var.dtype)}({cond})"
        if mask and not mask.is_vec:
            mask = f"{self._get_mask_type(var.dtype)}({mask})"

Doing it this way, we don't vectorize the whoel comparison unnecessarily, we perform the scalar comparison and then we vectorize the result.

Copy link
Collaborator Author

@isuruf isuruf Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining and the code

isuruf added 3 commits July 12, 2024 22:53
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much cleaner, thanks.

mask = f"({mask}).all_masked()"
return super().indirect_assert(var, lower, upper, mask)

var_dtype = var.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit delete

assert isinstance(var, CppCSEVariable)
assert var.dtype is not None
if not var.is_vec:
if isinstance(mask, CppCSEVariable) and mask.is_vec:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this equivalent to if mask and mask.is_vec, or can mask be of a different type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature is

def indirect_assert(
        self,
        var: Union[CSEVariable, str],
        lower: Optional[str],
        upper: Optional[str],
        mask: Optional[str] = None,
    )

in torch._inductor.codegen.common.Kernel

Should have been Optional[CSEVariable, str] just like var. I was guarding against str

@amjames amjames requested a review from jgong5 July 16, 2024 14:56
isuruf added 3 commits July 16, 2024 16:46
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@lezcano
Copy link
Collaborator

lezcano commented Jul 17, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 17, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@isuruf
Copy link
Collaborator Author

isuruf commented Jul 17, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 17, 2024
mlazos pushed a commit that referenced this pull request Jul 18, 2024
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
@github-actions github-actions bot deleted the gh/isuruf/63/head branch August 17, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants