-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[primTorch] Enforces stride metadata #77542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
❌ 1 New FailuresAs of commit 5166070 (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
# Test doesn't support non-tensor inputs | ||
DecorateInfo(unittest.expectedFailure, | ||
'TestMathBits', | ||
'test_neg_view'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
join @zou3519 and I in advocating these skips should be automatically generated and saved ;) It's very time consuming to manually track all of these down
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 -> #74642
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but not in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you may already be past the tipping point where it will be faster to sit down and add this infrastructure than to play popcorn with the CI for the next week
torch/_prims/wrappers.py
Outdated
# the kernel is invoked on cpu, so it makes strides contiguous | ||
if a.device.type == "cpu": | ||
return prims.convert_element_type(a, dtype) | ||
return prims._to_dtype(a, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is very reasonable, but I still do not understand why cpu gets special cased in the condition here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm tweaking this now to see if I can get CPU strides validated
device = inferred_device if device is None else device | ||
|
||
if isinstance(device, str): | ||
device = torch.device(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is impossible according to the type signature. Relax the type signature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point - fixed
if a.device != b.device: | ||
msg = "Devices {0} and {1} are not equal!".format(a.device, b.device) | ||
raise AssertionError(msg) | ||
# Handles special cuda:0 vs cuda case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow we're getting both values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you query the device of a tensor, it should always have an index. If there is no index then there is some invariant violation when we are creating the tensors in the first place (we can probably force an index in TensorMeta's constructor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you and it would be interesting to hunt it down, this PR is already a little sprawling, though
torch/_prims/utils.py
Outdated
|
||
# NOTE: currently we are only validating strides on CUDA, because | ||
# we are using opmath on both CPU and CUDA, which causes | ||
# divergance stride behavior vs. the CPU, which does not use opmath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the "we" here is ambiguous; I assume you're talking about refs, compared to the reference CPU implementations? But it's also surprising that CPU TensorIterator doesn't preserve strides because it "lost" the information when doing a dtype conversion for type promotion. Isn't that a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not so nice that strides can only be validated on CUDA; this means that if you're working on strides it's mandatory to be on a CUDA machine (for me at least, my default dev env is non-CUDA)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it may be a bug but I'm trying to get the tests to pass at the moment by modeling the CPU behavior
# NOTE: Based on the implementation in TensorIterator.cpp, but note that | ||
# the note [Computing output strides] is incorrect, because it | ||
# says that strides will be preserved even if they are not | ||
# "non overlapping and dense", but this is incorrect. The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overlapping/sparse strides get preserved in the sense that they implicitly define some permutation, and that permutation is preserved in the (contiguous) output strides
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then the note in C++ should say that instead of what it does
if ndim == 0: | ||
return () | ||
if ndim == 1: | ||
return (1,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH, I'm not sure PrimTorch should be in the business of defining these short circuits, if the general algorithm works for these cases too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment to review removing them if they're unnecessary
torch/_prims/utils.py
Outdated
return 0 | ||
|
||
perm = tuple(range(ndim)) | ||
perm = sorted(perm, key=cmp_to_key(_cmp), reverse=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not define perm as a list and then .sort()
it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that would have worked, too
perm = tuple(range(ndim)) | ||
perm = sorted(perm, key=cmp_to_key(_cmp), reverse=True) | ||
|
||
permuted_shape = [-1] * ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initializing these with None is safer, because -1 is a valid index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1 is a valid index for a dimension but not a valid dimension length, and initializing with None would change the type
torch/_prims/utils.py
Outdated
relevant_pairs.append((x, y)) | ||
|
||
expected = 1 | ||
for x, y in sorted(relevant_pairs, key=lambda p: p[1]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a tracing perspective, this sort is terrifying
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Luckily the final version of the PR didn't include this function, and the sort in the stride comparison function shouldn't define any validity conditions, although it is an example of how we may have to just run the meta functions for our ops to understand what the intermediate metadata values of certain tensors are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, you are team "no symbolic strides"? Our current default assumption is that strides are symbolic, because from a design perspective that is easier. To make them not symbolic we will have to work (because strides are computed from symbolic quantities aka shapes).
torch/_prims/__init__.py
Outdated
|
||
# NOTE: _to_dtype | ||
# This private op casts the input to the desired type while preserving its stride | ||
# permutation, unlike .to(dtype) which will create a tensor with contiguous strides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to() is supposed to preserve strides (that's why its memory format defaults to preserve_format). File a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call -- #77600
torch/_prims/__init__.py
Outdated
try: | ||
requires_grad = a.requires_grad | ||
except Exception as e: | ||
requires_grad = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison if we fully replace TensorMeta with FakeTensor I think it will fix this
torch/_prims/__init__.py
Outdated
|
||
result = empty_like(a, device=a.device, dtype=dtype, requires_grad=requires_grad) | ||
|
||
# TODO: review if the no_grad context is the best way to model this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The entire autograd story here is a bit wishy washy. But my default assumption was that each prim in primtorch would have an autograd formula explicitly defined for it. So then no_grad
here doesn't matter, because a use of _to_dtype
should only ever be in a context where there's going to be an explicit autograd formula.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're correct, but not a direction we've been focused on modeling yet
doc="", | ||
) | ||
|
||
# TODO: layout, pin_memory, memory_format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhat surprised the meta tests aren't complaining loudly at you on this ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The samples don't set these options
""" | ||
|
||
empty = _make_prim( | ||
schema="empty(int[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should empty really have a requires_grad
argument in PrimTorch? From the perspective of a backend implementer requires_grad
ought to have been long erased; there's nothing they're going to be usefully able to do with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's certainly interesting to consider; we can always make it exclusive to the ref later when we get into autograd
impl_aten=_empty_like_aten, | ||
return_type=RETURN_TYPE.NEW, | ||
doc=_empty_like_doc, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come this is a prim? It doesn't seem very primitive to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full_like is a jax.lax operation and to make this (torch.empty_like) non-prim in the current system we'd have to do empty+as_strided, and as_strided is an operation we generally don't want to call
per the below thinking, full_like can be made a ref by combining empty_like + fill
Edit: clarified what "this" was referring to and updated per comment below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty_strided
is probably better as a prim, as it is more powerful than empty_like, and empty_like can easily be expressed with it?
impl_aten=_full_aten, | ||
return_type=RETURN_TYPE.NEW, | ||
doc=_full_doc, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto here, in primitives I'd expect an empty allocation and then an inplace fill afterwards (you do have inplace in primtorch, right?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have fill at this time, full and full_like are jax.lax operators and they're kind of natural prims, but yes we'll likely model them as references in the future
if _tensor_requires_grad(a): | ||
return True | ||
if isinstance(x, torch.Tensor) and x.requires_grad: | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use tree_map
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, that would work, too
def _to_tensormeta(x): | ||
if isinstance(x, torch.Tensor): | ||
return prims.utils.TensorMeta(x) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops lol
for ei in error_inputs: | ||
si = ei.sample_input | ||
meta_sample = si.transform(_to_tensormeta) | ||
# TODO: match strings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expect tests would be very helpful here, then you wouldn't have to manually type in the correct strings everywhere
device: torch.device, | ||
requires_grad: bool, | ||
) -> Tensor: | ||
# Note that Mypy thinks torch.full can't accept a complex fill_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That just means full's pyi annotation is incorrect, need to be generalized a little then
torch/_refs/__init__.py
Outdated
type_promotion_kind, | ||
use_opmath, | ||
CPU_use_opmath=None, | ||
CUDA_use_opmath=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just lower case here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that'd be reasonable -- they're capitalized in a lot of the test suite today so I suppose I was thinking of it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CR from me is non-substantive, assuming you can get this to pass tests, merge this whenever the tests are passing. The longer we wait the harder it will be to enforce this.
impl_aten=_empty_like_aten, | ||
return_type=RETURN_TYPE.NEW, | ||
doc=_empty_like_doc, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty_strided
is probably better as a prim, as it is more powerful than empty_like, and empty_like can easily be expressed with it?
for idx, x in enumerate(perm): | ||
permuted_shape[idx] = shape[x] | ||
|
||
new_strides = make_contiguous_strides_for(permuted_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd expect permuted_strides correspond to permuted_shape, and what you are returning are output strides.
torch/_refs/__init__.py
Outdated
prims.abs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT | ||
prims.abs, | ||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, | ||
use_opmath=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
@pytorchbot merge on green |
Merge failed due to Refusing to merge as mandatory check(s) Lint are not yet run for rule superuser |
@pytorchmergebot merge this |
check_same_shape(*tensors, allow_cpu_scalar_tensors=True) | ||
|
||
# Filters the tensors to actual tensors | ||
all_tensors = all(isinstance(a, TensorLike) for a in tensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all_tensors
is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doh! you're right -- thanks @jjsjann123! I'll get the cleaned up
Summary: This PR... **Filed the Following Issues** - #77553 - #77526 - #77600 **Testing** - Updates test_dtypes to longer attempt to test the backward of sample inputs where no inputs require grad - Adds a new test_python_reference_errors; it ensures the meta operations for references throw errors as expected - Updates compare_tensor_meta to better handle CUDA devices, and (temporarily) restricts stride checking to the CUDA device type - Elementwise unary and elementwise binary operators now have arbitrarily strided reference inputs - Reference inputs for _like functions are added - An OpInfo for torch.empty is added - Reference inputs for torch.clone are added - A NumPy reference for clone is added - Adds OpInfos for refs.empty and refs.empty_like **Prims** - Renames the "max" and "min" prims have been renamed to "maximum" and "minimum," respectively, to better conform to their ATen names - Adds the empty, empty_like, full, and full_like prims - Fixes the elementwise meta function's stride propagation - Fixes clone's meta function's stride propagation - Fixes convert_element_type's meta's stride propagation - Adds a (temporary) _to_dtype pprivate prim that casts a tensor while preserving its stride permutation - Removes the _set prim comment - Adds utils.compute_elementwise_output_strides, which computes the correct output strides for elementwise operations - Corrects an issue where utils.make_contiguous_strides_for was creating the incorrect strides for tensors with no elements **References** - Adds the empty, empty_like, full, full_like, and ones_like refs - Extends make_elementwise_unary_reference to accept an additional callable to perform extra input validation - Adds an extra validation function to handle refs.neg(BoolTensor) - Updates the isfinite ref to call ones_like when appropriate - Models Python scalar handling for elementwise binary operations - Added a 64 dim check for the amin and amax references - opmath is now a flag that can be set separately for cpu and CUDA Pull Request resolved: #77542 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/580a053832cea61affce5fdb61c737036c8954af Reviewed By: seemethere Differential Revision: D36494082 Pulled By: mruberry fbshipit-source-id: 1f833e53bbd1f50d8658d41dfed8cced99d0ea93
This PR...
Filed the Following Issues
Testing
Prims
References