-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[primTorch] Adds broadcast_to, column_stack references #78416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 9864de3 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) | ||
|
|
||
|
|
||
| def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
torch/_refs/__init__.py
Outdated
| @out_wrapper | ||
| def column_stack(tensors: TensorSequenceType) -> TensorLikeType: | ||
| aligned_tensors = [ | ||
| x if x.ndim > 1 else torch._prims.expand_dims(x, list(range(x.ndim, 2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just prims. is fine (as with broadcast_to impl)
torch/_refs/__init__.py
Outdated
| x if x.ndim > 1 else torch._prims.expand_dims(x, list(range(x.ndim, 2))) | ||
| for x in tensors | ||
| ] | ||
| return prims.cat(aligned_tensors, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer using the reference to the prim (so just cat(...))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason that we prefer to use references inside _refs instead of prim? Using prims in references looks like a cleaner implementation.
torch/_refs/__init__.py
Outdated
|
|
||
| @out_wrapper | ||
| def column_stack(tensors: TensorSequenceType) -> TensorLikeType: | ||
| aligned_tensors = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nit: tuples are better than lists if not modifying the container
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @jjsjann123! See a few inline notes, then merge this when you like using pytorchbot
|
@pytorchbot merge this |
|
Hey @jjsjann123. |
Summary: 1. Added references for the two ops; 2. Inherited original operators' OpInfo tests; TODO for future PR: adding primTorch references for `dsplit` and `dstack`. <- Those two should use `atleast_3d` which is in a different packet right now. Pull Request resolved: #78416 Approved by: https://github.com/mruberry Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7ea9c6edc2f0dd2187db5f42359df2ddc7b503fe Reviewed By: seemethere Differential Revision: D36815590 Pulled By: seemethere fbshipit-source-id: 57feb0b546e198b4675c346d15be7c7cfe287cc7
TODO for future PR:
adding primTorch references for
dsplitanddstack. <- Those two should useatleast_3dwhich is in a different packet right now.