Skip to content

Conversation

mruberry
Copy link
Collaborator

@mruberry mruberry commented May 16, 2022

This PR...

Filed the Following Issues

Testing

  • Updates test_dtypes to longer attempt to test the backward of sample inputs where no inputs require grad
  • Adds a new test_python_reference_errors; it ensures the meta operations for references throw errors as expected
  • Updates compare_tensor_meta to better handle CUDA devices, and (temporarily) restricts stride checking to the CUDA device type
  • Elementwise unary and elementwise binary operators now have arbitrarily strided reference inputs
  • Reference inputs for _like functions are added
  • An OpInfo for torch.empty is added
  • Reference inputs for torch.clone are added
  • A NumPy reference for clone is added
  • Adds OpInfos for refs.empty and refs.empty_like

Prims

  • Renames the "max" and "min" prims have been renamed to "maximum" and "minimum," respectively, to better conform to their ATen names
  • Adds the empty, empty_like, full, and full_like prims
  • Fixes the elementwise meta function's stride propagation
  • Fixes clone's meta function's stride propagation
  • Fixes convert_element_type's meta's stride propagation
  • Adds a (temporary) _to_dtype pprivate prim that casts a tensor while preserving its stride permutation
  • Removes the _set prim comment
  • Adds utils.compute_elementwise_output_strides, which computes the correct output strides for elementwise operations
  • Corrects an issue where utils.make_contiguous_strides_for was creating the incorrect strides for tensors with no elements

References

  • Adds the empty, empty_like, full, full_like, and ones_like refs
  • Extends make_elementwise_unary_reference to accept an additional callable to perform extra input validation
  • Adds an extra validation function to handle refs.neg(BoolTensor)
  • Updates the isfinite ref to call ones_like when appropriate
  • Models Python scalar handling for elementwise binary operations
  • Added a 64 dim check for the amin and amax references
  • opmath is now a flag that can be set separately for cpu and CUDA

@mruberry mruberry requested a review from ezyang May 16, 2022 13:25
@mruberry mruberry requested a review from ngimel as a code owner May 16, 2022 13:25
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 16, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit 5166070 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-18T10:09:48.3150056Z ##[error]Process completed with exit code 1.
2022-05-18T10:09:48.2940528Z �[1A�[K�[31m�[1mFAILED:�[0m Build did NOT complete successfully (34 packages loaded, 3468 targets \
2022-05-18T10:09:48.2941203Z configured)
2022-05-18T10:09:48.2954904Z �[31m�[1mERROR: �[0mCouldn't start the build. Unable to run tests
2022-05-18T10:09:48.2997532Z 
2022-05-18T10:09:48.2997963Z �[1A�[K
2022-05-18T10:09:48.3001491Z �[1A�[K�[31m�[1mFAILED:�[0m Build did NOT complete successfully (34 packages loaded, 3468 targets \
2022-05-18T10:09:48.3001956Z configured)
2022-05-18T10:09:48.3093802Z �[0m+ cleanup
2022-05-18T10:09:48.3094117Z + retcode=1
2022-05-18T10:09:48.3094396Z + set +x
2022-05-18T10:09:48.3150056Z ##[error]Process completed with exit code 1.
2022-05-18T10:09:48.3202248Z Prepare all required actions
2022-05-18T10:09:48.3218739Z ##[group]Run ./.github/actions/chown-workspace
2022-05-18T10:09:48.3218956Z env:
2022-05-18T10:09:48.3219100Z   IN_CI: 1
2022-05-18T10:09:48.3219265Z   IS_GHA: 1
2022-05-18T10:09:48.3219451Z   GIT_DEFAULT_BRANCH: master
2022-05-18T10:09:48.3219628Z ##[endgroup]
2022-05-18T10:09:48.3233897Z ##[group]Run docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
2022-05-18T10:09:48.3234244Z �[36;1mdocker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .�[0m
2022-05-18T10:09:48.3245726Z shell: /usr/bin/bash --noprofile --norc -e -o pipefail {0}

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mruberry mruberry changed the title stride prop [primTorch] Enforces stride metadata May 16, 2022
# Test doesn't support non-tensor inputs
DecorateInfo(unittest.expectedFailure,
'TestMathBits',
'test_neg_view'),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

join @zou3519 and I in advocating these skips should be automatically generated and saved ;) It's very time consuming to manually track all of these down

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 -> #74642

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but not in this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may already be past the tipping point where it will be faster to sit down and add this infrastructure than to play popcorn with the CI for the next week

# the kernel is invoked on cpu, so it makes strides contiguous
if a.device.type == "cpu":
return prims.convert_element_type(a, dtype)
return prims._to_dtype(a, dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is very reasonable, but I still do not understand why cpu gets special cased in the condition here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm tweaking this now to see if I can get CPU strides validated

device = inferred_device if device is None else device

if isinstance(device, str):
device = torch.device(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is impossible according to the type signature. Relax the type signature?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - fixed

if a.device != b.device:
msg = "Devices {0} and {1} are not equal!".format(a.device, b.device)
raise AssertionError(msg)
# Handles special cuda:0 vs cuda case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow we're getting both values

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you query the device of a tensor, it should always have an index. If there is no index then there is some invariant violation when we are creating the tensors in the first place (we can probably force an index in TensorMeta's constructor)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you and it would be interesting to hunt it down, this PR is already a little sprawling, though


# NOTE: currently we are only validating strides on CUDA, because
# we are using opmath on both CPU and CUDA, which causes
# divergance stride behavior vs. the CPU, which does not use opmath
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the "we" here is ambiguous; I assume you're talking about refs, compared to the reference CPU implementations? But it's also surprising that CPU TensorIterator doesn't preserve strides because it "lost" the information when doing a dtype conversion for type promotion. Isn't that a bug?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not so nice that strides can only be validated on CUDA; this means that if you're working on strides it's mandatory to be on a CUDA machine (for me at least, my default dev env is non-CUDA)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be a bug but I'm trying to get the tests to pass at the moment by modeling the CPU behavior

# NOTE: Based on the implementation in TensorIterator.cpp, but note that
# the note [Computing output strides] is incorrect, because it
# says that strides will be preserved even if they are not
# "non overlapping and dense", but this is incorrect. The
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overlapping/sparse strides get preserved in the sense that they implicitly define some permutation, and that permutation is preserved in the (contiguous) output strides

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the note in C++ should say that instead of what it does

if ndim == 0:
return ()
if ndim == 1:
return (1,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I'm not sure PrimTorch should be in the business of defining these short circuits, if the general algorithm works for these cases too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment to review removing them if they're unnecessary

return 0

perm = tuple(range(ndim))
perm = sorted(perm, key=cmp_to_key(_cmp), reverse=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not define perm as a list and then .sort() it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would have worked, too

perm = tuple(range(ndim))
perm = sorted(perm, key=cmp_to_key(_cmp), reverse=True)

permuted_shape = [-1] * ndim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initializing these with None is safer, because -1 is a valid index

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 is a valid index for a dimension but not a valid dimension length, and initializing with None would change the type

relevant_pairs.append((x, y))

expected = 1
for x, y in sorted(relevant_pairs, key=lambda p: p[1]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a tracing perspective, this sort is terrifying

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Luckily the final version of the PR didn't include this function, and the sort in the stride comparison function shouldn't define any validity conditions, although it is an example of how we may have to just run the meta functions for our ops to understand what the intermediate metadata values of certain tensors are

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, you are team "no symbolic strides"? Our current default assumption is that strides are symbolic, because from a design perspective that is easier. To make them not symbolic we will have to work (because strides are computed from symbolic quantities aka shapes).


# NOTE: _to_dtype
# This private op casts the input to the desired type while preserving its stride
# permutation, unlike .to(dtype) which will create a tensor with contiguous strides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to() is supposed to preserve strides (that's why its memory format defaults to preserve_format). File a bug?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call -- #77600

try:
requires_grad = a.requires_grad
except Exception as e:
requires_grad = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison if we fully replace TensorMeta with FakeTensor I think it will fix this


result = empty_like(a, device=a.device, dtype=dtype, requires_grad=requires_grad)

# TODO: review if the no_grad context is the best way to model this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire autograd story here is a bit wishy washy. But my default assumption was that each prim in primtorch would have an autograd formula explicitly defined for it. So then no_grad here doesn't matter, because a use of _to_dtype should only ever be in a context where there's going to be an explicit autograd formula.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct, but not a direction we've been focused on modeling yet

doc="",
)

# TODO: layout, pin_memory, memory_format
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat surprised the meta tests aren't complaining loudly at you on this ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The samples don't set these options

"""

empty = _make_prim(
schema="empty(int[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should empty really have a requires_grad argument in PrimTorch? From the perspective of a backend implementer requires_grad ought to have been long erased; there's nothing they're going to be usefully able to do with it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's certainly interesting to consider; we can always make it exclusive to the ref later when we get into autograd

impl_aten=_empty_like_aten,
return_type=RETURN_TYPE.NEW,
doc=_empty_like_doc,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this is a prim? It doesn't seem very primitive to me.

Copy link
Collaborator Author

@mruberry mruberry May 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full_like is a jax.lax operation and to make this (torch.empty_like) non-prim in the current system we'd have to do empty+as_strided, and as_strided is an operation we generally don't want to call

per the below thinking, full_like can be made a ref by combining empty_like + fill

Edit: clarified what "this" was referring to and updated per comment below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty_strided is probably better as a prim, as it is more powerful than empty_like, and empty_like can easily be expressed with it?

impl_aten=_full_aten,
return_type=RETURN_TYPE.NEW,
doc=_full_doc,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here, in primitives I'd expect an empty allocation and then an inplace fill afterwards (you do have inplace in primtorch, right?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have fill at this time, full and full_like are jax.lax operators and they're kind of natural prims, but yes we'll likely model them as references in the future

if _tensor_requires_grad(a):
return True
if isinstance(x, torch.Tensor) and x.requires_grad:
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use tree_map here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that would work, too

def _to_tensormeta(x):
if isinstance(x, torch.Tensor):
return prims.utils.TensorMeta(x)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops lol

for ei in error_inputs:
si = ei.sample_input
meta_sample = si.transform(_to_tensormeta)
# TODO: match strings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expect tests would be very helpful here, then you wouldn't have to manually type in the correct strings everywhere

device: torch.device,
requires_grad: bool,
) -> Tensor:
# Note that Mypy thinks torch.full can't accept a complex fill_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That just means full's pyi annotation is incorrect, need to be generalized a little then

type_promotion_kind,
use_opmath,
CPU_use_opmath=None,
CUDA_use_opmath=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just lower case here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that'd be reasonable -- they're capitalized in a lot of the test suite today so I suppose I was thinking of it

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CR from me is non-substantive, assuming you can get this to pass tests, merge this whenever the tests are passing. The longer we wait the harder it will be to enforce this.

impl_aten=_empty_like_aten,
return_type=RETURN_TYPE.NEW,
doc=_empty_like_doc,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty_strided is probably better as a prim, as it is more powerful than empty_like, and empty_like can easily be expressed with it?

for idx, x in enumerate(perm):
permuted_shape[idx] = shape[x]

new_strides = make_contiguous_strides_for(permuted_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd expect permuted_strides correspond to permuted_shape, and what you are returning are output strides.

prims.abs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
prims.abs,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
use_opmath=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

@mruberry
Copy link
Collaborator Author

@pytorchbot merge on green

@pytorchmergebot
Copy link
Collaborator

Merge failed due to Refusing to merge as mandatory check(s) Lint are not yet run for rule superuser
Raised by https://github.com/pytorch/pytorch/actions/runs/2344451181

@mruberry
Copy link
Collaborator Author

@pytorchmergebot merge this

check_same_shape(*tensors, allow_cpu_scalar_tensors=True)

# Filters the tensors to actual tensors
all_tensors = all(isinstance(a, TensorLike) for a in tensors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_tensors is not used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doh! you're right -- thanks @jjsjann123! I'll get the cleaned up

@mruberry mruberry deleted the primtorch_strides_meta branch May 19, 2022 18:50
facebook-github-bot pushed a commit that referenced this pull request May 20, 2022
Summary:
This PR...

**Filed the Following Issues**
- #77553
- #77526
- #77600

**Testing**
- Updates test_dtypes to longer attempt to test the backward of sample inputs where no inputs require grad
- Adds a new test_python_reference_errors; it ensures the meta operations for references throw errors as expected
- Updates compare_tensor_meta to better handle CUDA devices, and (temporarily) restricts stride checking to the CUDA device type
- Elementwise unary and elementwise binary operators now have arbitrarily strided reference inputs
- Reference inputs for _like functions are added
- An OpInfo for torch.empty is added
- Reference inputs for torch.clone are added
- A NumPy reference for clone is added
- Adds OpInfos for refs.empty and refs.empty_like

**Prims**
- Renames the "max" and "min" prims have been renamed to "maximum" and "minimum," respectively, to better conform to their ATen names
- Adds the empty, empty_like, full, and full_like prims
- Fixes the elementwise meta function's stride propagation
- Fixes clone's meta function's stride propagation
- Fixes convert_element_type's meta's stride propagation
- Adds a (temporary) _to_dtype pprivate prim that casts a tensor while preserving its stride permutation
- Removes the _set prim comment
- Adds utils.compute_elementwise_output_strides, which computes the correct output strides for elementwise operations
- Corrects an issue where utils.make_contiguous_strides_for was creating the incorrect strides for tensors with no elements

**References**
- Adds the empty, empty_like, full, full_like, and ones_like refs
- Extends make_elementwise_unary_reference to accept an additional callable to perform extra input validation
- Adds an extra validation function to handle refs.neg(BoolTensor)
- Updates the isfinite ref to call ones_like when appropriate
- Models Python scalar handling for elementwise binary operations
- Added a 64 dim check for the amin and amax references
- opmath is now a flag that can be set separately for cpu and CUDA

Pull Request resolved: #77542
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/580a053832cea61affce5fdb61c737036c8954af

Reviewed By: seemethere

Differential Revision: D36494082

Pulled By: mruberry

fbshipit-source-id: 1f833e53bbd1f50d8658d41dfed8cced99d0ea93
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants