-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[primTorch] slice and transpose & etc. #76727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit be6a6e7 (more details on the Dr. CI page): Expand to see more✅ None of the CI failures appear to be your fault 💚
🚧 2 ongoing upstream failures:These were probably caused by upstream breakages that are not fixed yet.
This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] | ||
| ) -> TensorLikeType: | ||
| if isinstance(a, TensorLike) and isinstance(b, Number): | ||
| b = utils.wrap_scalar(b, dtype=a.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks wrong for cuda inputs (they accept scalar tensors on cpu, but here device checks on prim would error out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This happens after those checks, though
torch/_prims/__init__.py
Outdated
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| if a.ndim != len(start_indices): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the same condition as in line 916, so you've errored out already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is our general stance on where error checking like this should happen? E.g. for reductions I do it in ref, and prim assumes dims it receives is valid. You've already canonicalized indices before sending here, hence no negative dims, doesn't it follow that you've also error-checked them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, this should be limit_indices here
I think the prim should do its own error checking and in the future we should create an error context so the name of the called function appears in the error message, and try to minimize redundant error checking in the ref
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(from a trace perspective only prim error checking matters)
| raise ValueError(msg) | ||
| if x > y: | ||
| msg = ( | ||
| "Attempting to slice a tensor but a stop index in {0} is greater than the length of " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it stop index or end index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the Python slice docs it's start, stop and step: https://docs.python.org/3/c-api/slice.html
| msg = ( | ||
| "Attempting to slice a tensor but a stop index in {0} is greater than the length of " | ||
| " its corresponding dimension in shape {1}".format( | ||
| limit_indices, a.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why full lists here unlike other error messages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idk
torch/_prims/__init__.py
Outdated
|
|
||
| new_shape = [] | ||
| for x, y in zip(start_indices, limit_indices): | ||
| new_shape.append(y - x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be divided by stride?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! Good catch
torch/_prims/__init__.py
Outdated
| raise ValueError(msg) | ||
|
|
||
| start_indices = [0] * a.ndim | ||
| limit_indices = [0] * a.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should limit_indices be a.shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! -- that's weird I thought I fixed that
| axis: int = 0, | ||
| ) -> Tensor: | ||
| start_indices = [0] * a.ndim | ||
| limit_indices = list(a.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep they should be
| """ | ||
| return torch.tensor(a, dtype=type_to_dtype(type(a))) | ||
| if dtype is None: | ||
| return torch.tensor(a, dtype=type_to_dtype(type(a))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's confusing that now type promotion is split between ref (that converts tensor) and here (converts scalar). Maybe wrapping indeed should be done in ref.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm going to look at scalar handling more generally and ensure we're preserving scalars properly
| msg = "cat expects at least one tensor, but received zero!" | ||
| raise ValueError(msg) | ||
|
|
||
| _dim = utils.canonicalize_dims(tensors[0].ndim, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not canonicalize_idx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, might as well
|
|
||
| if out is not None: | ||
| out = _maybe_resize_out(out, result.shape) | ||
| return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it weird that all prims are not type-promoting except copy_to that is? Would it be better to instead check out dtype here, and not rely on copy_to to error out? If refs are documentation in code, then making error conditions clear is one of the goals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's reconsider this when we apply out handling as a wrapper?
|
@pytorchbot merge this please |
Summary: This PR... Adds the following prims: - slice - slice_in_dim - transpose Adds the following refs: - cat - permute - transpose - swap_axes (alias for transpose) - tensor_split Makes the following test improvements: - adds reference inputs for torch.permute - adds a NumPy reference for torch.permute - adds reference inputs for torch.cat Fixes the following bugs: - adds support for scalars to the min and max prims Pull Request resolved: #76727 Approved by: https://github.com/ngimel Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ef9f56eb0b86a2224652f5c5ac1f82dcb4f444c6 Reviewed By: malfet Differential Revision: D36134094 Pulled By: malfet fbshipit-source-id: 89e44eff92cf1b4bbc8ce123e45ee96ec580435e
This PR...
Adds the following prims:
Adds the following refs:
Makes the following test improvements:
Fixes the following bugs: