New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Alias for polygamma
#59691
Alias for polygamma
#59691
Conversation
💊 CI failures summary and remediationsAs of commit 6195f5a (more details on the Dr. CI page and at hud.pytorch.org/pr/59691):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_linux_xenial_py3_clang5_asan_test1 (1/1)Step: "Run tests" (full log | diagnosis details | 🔁 rerun)
|
…ant consistency checks I need to add comments.
SkipInfo('TestCommon', 'test_variant_consistency_jit'),), | ||
SkipInfo('TestCommon', 'test_variant_consistency_jit'), | ||
SkipInfo('TestCommon', 'test_jit_alias_remapping'), | ||
SkipInfo('TestCommon', 'test_variant_consistency_eager')), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test fails because of the ordering of args in polygamma
/special_polygamma
: (int, tensor)
but JIT tests expect first input to be a tensor.
======================================================================
ERROR: test_variant_consistency_eager_polygamma_polygamma_n_0_cpu_float32 (__main__.TestCommonCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/krshrimali/git/krshrimali/pytorch/torch/testing/_internal/common_device_type.py", line 292, in instantiated_test
result = test_fn(self, *args)
File "/home/krshrimali/git/krshrimali/pytorch/torch/testing/_internal/common_device_type.py", line 266, in test_wrapper
return test(*args, **kwargs)
File "/home/krshrimali/git/krshrimali/pytorch/test/test_ops.py", line 345, in test_variant_consistency_eager
_test_consistency_helper(samples, variants)
File "/home/krshrimali/git/krshrimali/pytorch/test/test_ops.py", line 334, in _test_consistency_helper
variant_forward = variant(cloned,
TypeError: special_polygamma(): argument 'n' (position 1) must be int, not Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not related to JIT but we use lambda to reorder x, n
while calling polygamma
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs)
This is because we only verify gradients in consistency_eager for input=Tensor (but not the ones in arg).
This test now fails because, we don't reorder the arguments for the alias. (thus it expects first argument to be int and not a Tensor)
To mitigate should we allow specifying lambda wrappers around aliases? Not sure but might need to investigate.
Relevant lines of test
-
Here we acquire the actual operator of the mentioned alias.
Lines 288 to 292 in cf38b20
for a_op in op.aliases: variants.append(a_op.op) variants.append(a_op.method_variant) variants.append(a_op.inplace_variant) inplace_ops.append(a_op.inplace_variant) -
Here we call the alias variant
Line 354 in cf38b20
variant_forward = variant(cloned, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After an offline discussion with @krshrimali and thinking a bit more. We might as well just add a new entry for special.polygamma
in OpInfo and maybe add a handcoded alias test for polygamma?
Reason:
- Even if we go through the hassle of making sure
test_variant_consistency_eager
works, we still don't get thealias_remapping
test due to the use of lambda. - I don't think there are other operators which might need this
op_aliases
. (It is particularly due to the peculiar signature of polygamma that we need this)
@mruberry what is your opinion?
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the issue. Sure, implementing a new OpInfo for it (with a comment) sounds like a fine workaround.
The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. Thanks @krshrimali
Have one minor nit regarding the signature of special::polygamma_out.
However, I am concerned about disabling the consistency_eager test. Have put some pointers below as to why the test is failing. Would be nice if we can have a workaround without disabling it. Can you please have a look at that.
Thanks!
return torch::special_polygamma(n, self); | ||
} | ||
|
||
inline Tensor& polygamma_out(int64_t n, Tensor& result, const Tensor& self) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should stick to the signature of polygamma_out
which is Tensor& polygamma_out(Tensor& result, int64_t n, const Tensor& self)
SkipInfo('TestCommon', 'test_variant_consistency_jit'),), | ||
SkipInfo('TestCommon', 'test_variant_consistency_jit'), | ||
SkipInfo('TestCommon', 'test_jit_alias_remapping'), | ||
SkipInfo('TestCommon', 'test_variant_consistency_eager')), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not related to JIT but we use lambda to reorder x, n
while calling polygamma
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs)
This is because we only verify gradients in consistency_eager for input=Tensor (but not the ones in arg).
This test now fails because, we don't reorder the arguments for the alias. (thus it expects first argument to be int and not a Tensor)
To mitigate should we allow specifying lambda wrappers around aliases? Not sure but might need to investigate.
Relevant lines of test
-
Here we acquire the actual operator of the mentioned alias.
Lines 288 to 292 in cf38b20
for a_op in op.aliases: variants.append(a_op.op) variants.append(a_op.method_variant) variants.append(a_op.inplace_variant) inplace_ops.append(a_op.inplace_variant) -
Here we call the alias variant
Line 354 in cf38b20
variant_forward = variant(cloned, |
… into dev/special_polygamma
Hey @krshrimali! Just checking in here. Looks like some tests are failing -- everything going OK on this PR? Is there something I'm supposed to do for it? |
Thanks, @mruberry for the question. If you could take a look at P.S: If we think that |
Got it. No, I don't think logsumexp will take too long. If the jit team doesn't get back to us soon I'll ping them |
self.name = alias_name | ||
self.op = _getattr_qual(torch, alias_name) | ||
self.op = alias_op if alias_op else _getattr_qual(torch, alias_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice extension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, earlier we decided to keep this to avoid extra OpInfo entry for special
alias. Later, we decided to instead have an extra OpInfo entry since it's only required for ops like polygamma which need re-ordering of the arguments. (context: #59691 (comment))
This has been removed now.
Unfortunately it looks like the jit test is still failing: test_variant_consistency_jit_special_polygamma_special_polygamma_n_0_cpu_float32 |
Thanks @mruberry for taking a look. The skipped test had the wrong class mentioned, hence the errors. This should be fixed now. Should be ready for review once all the tests pass. :) Thank you! |
Gentle ping, PTAL @mruberry - whenever you find time. The failing test seems unrelated to the PR. |
@@ -7029,6 +7029,34 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): | |||
# ~~~~~~~~~~~~~~~ <--- HERE | |||
SkipInfo('TestJit', 'test_variant_consistency_jit'),), | |||
sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})), | |||
# A separate OpInfo entry for special.polygamma is needed to reorder the arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just testing this for the n=0 case makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Thanks @krshrimali
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
See #50345
cc: @mruberry @kshitij12345