-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add xlogy and xlog1py references #77712
Conversation
* Add reference implementations for [bitwise_not, i0, i1, zeta, xlogy, xlog1py] * Add prim operations for [i0, i1, zeta]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 786fa83 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, @rdspring1! It shows a solid understanding of the system, fixes several issues, and adds the needed prims, references, and OpInfos. I made a few inline comments for your review, most of them are very minor
Don't hesitate to reach out on Slack if you have any questions!
@mruberry I updated the logic handling scalar values, but still needed to make them into |
torch/_refs/__init__.py
Outdated
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, | ||
) | ||
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): | ||
assert isinstance(a, TensorLike) or isinstance(b, TensorLike) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change these asserts to be checks and add a comment for why the scalar_tensor
construction is needed
b = scalar_tensor(b, dtype=a.dtype, device=a.device) | ||
elif isinstance(b, TensorLike) and isinstance(a, Number): | ||
a = scalar_tensor(a, dtype=b.dtype, device=b.device) | ||
assert isinstance(a, TensorLike) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these asserts can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lint complains without the asserts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's OK then to add a comment explaining the asserts are to satisfy the linter, or to disable the linter's checks on the line that's complaining (with a note about why the check is disabled)
torch/_refs/__init__.py
Outdated
) | ||
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): | ||
assert isinstance(a, TensorLike) or isinstance(b, TensorLike) | ||
if isinstance(a, TensorLike) and isinstance(b, Number): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an awkward construction because it constructs the scalar_tensor
in the datatype of the tensor, but it's OK to do so because the type promotion decorator has already converted the tensor to the appropriate dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These conditions can be simplified since one of a
and b
is a tensorlike object per the above check:
a = torch.scalar_tensor(a, b.dtype, b.device) if isinstance(a, Number) else a
b = torch.scalar_tensor(b, a.dtype, a.device) if isinstance(a, Number) else b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mypy also complained about this change, since the other argument is Union[TensorLikeType, NumberType]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Darn. I really don't like mypy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall but see a couple cleanup notes!
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 additional jobs have failed, first few of them are: trunk ,trunk / macos-12-py3-arm64-mps / Run MPS tests ,trunk / linux-focal-rocm5.2-py3.7 / test (default, 1, 2, linux.rocm.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot rebase -s |
You don't have permissions to rebase this PR, only people with write permissions may rebase PRs. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @rdspring1. |
* Add reference implementations for `xlogy` and `xlog1py` * Replace `_wrap_scalar` helper function with `scalar_tensor` prim Pull Request resolved: pytorch#77712 Approved by: https://github.com/mruberry
* Add reference implementations for `xlogy` and `xlog1py` * Replace `_wrap_scalar` helper function with `scalar_tensor` prim Pull Request resolved: pytorch#77712 Approved by: https://github.com/mruberry
* Add reference implementations for `xlogy` and `xlog1py` * Replace `_wrap_scalar` helper function with `scalar_tensor` prim Pull Request resolved: pytorch#77712 Approved by: https://github.com/mruberry
xlogy
andxlog1py
_wrap_scalar
helper function withscalar_tensor
primcc @ezyang @mruberry @ngimel @lezcano @fdrocha