-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[primTorch] refs: hsplit, vsplit #78418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
❌ 1 New FailuresAs of commit 564a0b9 (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
Args: | ||
input (Tensor): tensor to split. | ||
indices_or_sections (Tensor, int or list or tuple of ints): See argument in :func:`torch.tensor_split`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not accept tensor
pytorch/aten/src/ATen/native/native_functions.yaml
Lines 4407 to 4411 in 8ad305f
- func: hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] | |
variants: function, method | |
- func: hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] | |
variants: function, method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch
torch/_refs/__init__.py
Outdated
dim = 0 if a.ndim == 1 else 1 | ||
if isinstance(indices_or_sections, int): | ||
split_size = indices_or_sections | ||
if not (split_size != 0 and a.shape[dim] % split_size == 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double negation here is a little hard to read -- maybe propagate the not
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this error message redundant with error checking done by tensor_split?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this msg is better than the one that tensor_split
throws
tensor_split: number of sections must be greater than 0, but was 0
vs
torch.hsplit attempted to split along dimension 1, but the size of the dimension 5 is not divisible by the split_size 0!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
torch/_refs/__init__.py
Outdated
+ str(split_size) | ||
+ "!" | ||
) | ||
raise RuntimeError(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
msg is here so that's cool but same questions as for hsplit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, @kshitij12345! A couple inline comments for your review
torch/_refs/__init__.py
Outdated
"hsplit(): received an invalid combination of arguments. " | ||
"Expected indices_or_sections to be of type int, list of ints or tuple of ints " | ||
f"but got type {type(indices_or_sections)}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice if check
can also take the type of error to raise. This would be changed to
check(isinstance(indices_or_sections, (list, tuple)), msg, error_type=TypeError)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ezyang -- we're already seeing variants of check
pop up for different error types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional third argument sounds good to me
torch/_refs/__init__.py
Outdated
dim = 0 if a.ndim == 1 else 1 | ||
if isinstance(indices_or_sections, int): | ||
split_size = indices_or_sections | ||
msg = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing final checks -- what I don't like about this pattern is that we're constructing the message even when we don't error, cc @ezyang -- I think we have to check for the condition to construct the message only in the error state, then if we want to use check
just check(False, ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(the other thing to do is we could define a function that returns the string, which would also delay string construction and which check
does accept. I'm not sure how costly defining a function is, however -- I'd prefer we hide as much as possible behind the error checks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is why check takes a lambda; check(b, lambda: string_construction)
. Taking a string directly is wrong and won't actually work. I checked and dynamo is able to trace past the lambda construction.
torch/_refs/__init__.py
Outdated
def hsplit( | ||
a: TensorLikeType, indices_or_sections: DimsType | ||
) -> Tuple[TensorLikeType, ...]: | ||
msg = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue here about guarding message construction on the error
…lop/refs/h-v-split
…lop/refs/h-v-split
…lop/refs/h-v-split
Gentle Ping @mruberry |
return tuple(splits) | ||
|
||
|
||
def hsplit( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I feel like a common helper could have implemented these but nbd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
This just needs a rebase to resolve the merge conflict, @kshitij12345 |
Onnx failure looks unrelated to the PR. See #78844 @pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @kshitij12345. |
Summary: As per title TODO: * [x] Add error inputs (already exist) Pull Request resolved: #78418 Approved by: https://github.com/mruberry Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/c461d8a9777ad54ab5a5587d8c67bf2d19bca894 Reviewed By: osalpekar Differential Revision: D36959164 Pulled By: osalpekar fbshipit-source-id: 6f0a6258742777a583cdc82d4e62b3eb1314d5bf
As per title
TODO: