Skip to content

Conversation

@mruberry
Copy link
Collaborator

@mruberry mruberry commented May 3, 2022

This PR...

Adds the following prims:

  • slice
  • slice_in_dim
  • transpose

Adds the following refs:

  • cat
  • permute
  • transpose
  • swap_axes (alias for transpose)
  • tensor_split

Makes the following test improvements:

  • adds reference inputs for torch.permute
  • adds a NumPy reference for torch.permute
  • adds reference inputs for torch.cat

Fixes the following bugs:

  • adds support for scalars to the min and max prims

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 3, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit be6a6e7 (more details on the Dr. CI page):

Expand to see more

None of the CI failures appear to be your fault 💚



🚧 2 ongoing upstream failures:

These were probably caused by upstream breakages that are not fixed yet.


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mruberry mruberry changed the title shape ops [primTorch] slice and transpose & etc. May 3, 2022
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
) -> TensorLikeType:
if isinstance(a, TensorLike) and isinstance(b, Number):
b = utils.wrap_scalar(b, dtype=a.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks wrong for cuda inputs (they accept scalar tensors on cpu, but here device checks on prim would error out)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens after those checks, though

)
raise ValueError(msg)

if a.ndim != len(start_indices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the same condition as in line 916, so you've errored out already?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is our general stance on where error checking like this should happen? E.g. for reductions I do it in ref, and prim assumes dims it receives is valid. You've already canonicalized indices before sending here, hence no negative dims, doesn't it follow that you've also error-checked them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, this should be limit_indices here

I think the prim should do its own error checking and in the future we should create an error context so the name of the called function appears in the error message, and try to minimize redundant error checking in the ref

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(from a trace perspective only prim error checking matters)

raise ValueError(msg)
if x > y:
msg = (
"Attempting to slice a tensor but a stop index in {0} is greater than the length of "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it stop index or end index?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Python slice docs it's start, stop and step: https://docs.python.org/3/c-api/slice.html

msg = (
"Attempting to slice a tensor but a stop index in {0} is greater than the length of "
" its corresponding dimension in shape {1}".format(
limit_indices, a.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why full lists here unlike other error messages?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk


new_shape = []
for x, y in zip(start_indices, limit_indices):
new_shape.append(y - x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be divided by stride?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! Good catch

raise ValueError(msg)

start_indices = [0] * a.ndim
limit_indices = [0] * a.ndim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should limit_indices be a.shape?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! -- that's weird I thought I fixed that

axis: int = 0,
) -> Tensor:
start_indices = [0] * a.ndim
limit_indices = list(a.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep they should be

"""
return torch.tensor(a, dtype=type_to_dtype(type(a)))
if dtype is None:
return torch.tensor(a, dtype=type_to_dtype(type(a)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's confusing that now type promotion is split between ref (that converts tensor) and here (converts scalar). Maybe wrapping indeed should be done in ref.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm going to look at scalar handling more generally and ensure we're preserving scalars properly

msg = "cat expects at least one tensor, but received zero!"
raise ValueError(msg)

_dim = utils.canonicalize_dims(tensors[0].ndim, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not canonicalize_idx?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, might as well


if out is not None:
out = _maybe_resize_out(out, result.shape)
return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it weird that all prims are not type-promoting except copy_to that is? Would it be better to instead check out dtype here, and not rely on copy_to to error out? If refs are documentation in code, then making error conditions clear is one of the goals.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's reconsider this when we apply out handling as a wrapper?

@mruberry
Copy link
Collaborator Author

mruberry commented May 4, 2022

@pytorchbot merge this please

facebook-github-bot pushed a commit that referenced this pull request May 5, 2022
Summary:
This PR...

Adds the following prims:
- slice
- slice_in_dim
- transpose

Adds the following refs:
- cat
- permute
- transpose
- swap_axes (alias for transpose)
- tensor_split

Makes the following test improvements:
- adds reference inputs for torch.permute
- adds a NumPy reference for torch.permute
- adds reference inputs for torch.cat

Fixes the following bugs:
- adds support for scalars to the min and max prims

Pull Request resolved: #76727
Approved by: https://github.com/ngimel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ef9f56eb0b86a2224652f5c5ac1f82dcb4f444c6

Reviewed By: malfet

Differential Revision: D36134094

Pulled By: malfet

fbshipit-source-id: 89e44eff92cf1b4bbc8ce123e45ee96ec580435e
@mruberry mruberry deleted the shape_ops branch May 19, 2022 18:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants