Skip to content

Conversation

kshitij12345
Copy link
Collaborator

Fixes #50747

Reference #50006

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 1, 2021

💊 CI failures summary and remediations

As of commit 64883fa (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

test/test_ops.py Outdated
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")

def is_inplace(variant):
return variant.__name__.endswith('_')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to check this?

Or should we trickle variant type and variant from the function that calls this helper?

self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))

self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct way to check this is by testing if the variant is the inplace variant acquired from the OpInfo.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately that won't work as the op.get_inplace() is wrapped by _get_safe_inplace.

pytorch/test/test_ops.py

Lines 68 to 73 in c4c77e2

def _get_safe_inplace(self, inplace_variant):
@wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn

Copy link
Collaborator

@mruberry mruberry Mar 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you access it through the dunder wrapped attribute?

foo.__wrapped__ is op.get_inplace()

(if foo may or may not be wrapped you'll need to check for the attr first, of course)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup it works with that. Thanks!!

def is_inplace(variant):
    if hasattr(variant, "__wrapped__"):
      return variant.__wrapped__ is op.get_inplace()

    return variant is op.get_inplace()

__slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_self']

def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None):
def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_self=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is reasonable but I still think we should filter sample_inputs() by whether the samples will be used for inplace computations or not. However, let me get a few other opinions on this. Maybe this is a better approach.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I neglected to update this, @kshitij12345. General consensus is that it'd be preferable to filter sample_inputs() by what's inplace for now. We may want to revise that decision later, however.

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC,
We will have to update the signature for every sample_inputs_* with

sample_inputs_*(op_info, device, dtype, requires_grad, is_inplace_variant)

?

def sample_inputs(self, device, dtype, requires_grad=False):
"""Returns an iterable of SampleInputs.
These samples should be sufficient to test the function works correctly
with autograd, TorchScript, etc.
"""
return self.sample_inputs_func(self, device, dtype, requires_grad)

And update sample_inputs to

def sample_inputs(self, device, dtype, requires_grad=False, is_inplace_variant=False):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that (and make is_inplace_variant kwarg-only). That would be the most discoverable thing.

But we don't have to. The operations for testing the inplace variant can request sample inputs by passing the for_inplace_variant (we can work on the name) kwarg like this:

try:
  samples = op.sample_inputs(..., for_inplace_variant=True)
except TypeError as te:
  samples = op.sample_inputs(...) 

Then functions which support this option can implement it as:

def sample_inputs_foo(..., *, for_inplace_variant=False)

A nicer (but more disruptive) solution would be to make sample_inputs take **kwargs, and then the functions that want to use "for_inplace_variant" or similar options can query for it from the kwarg dict.

What would you think of an approach like that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the second approach (though more disruptive). Will try that.
Thanks!

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test_variant_consistency_eager,

pytorch/test/test_ops.py

Lines 204 to 242 in 0d81528

@_variant_ops(op_db)
def test_variant_consistency_eager(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
for sample in samples:
# Acquires variants (method variant, inplace variant, aliases)
method = op.get_method()
inplace = op.get_inplace()
# list of all inplace ops: inplace variant + alias inplace variants if exist
inplace_ops = [inplace, ]
aliases = []
for a_op in op.aliases:
aliases.append(a_op.op)
aliases.append(a_op.method_variant)
aliases.append(a_op.inplace_variant)
inplace_ops.append(a_op.inplace_variant)
aliases = tuple(aliases)
inplace_ops = tuple(v for v in inplace_ops if v is not None)
variants = (v for v in (method, inplace) + aliases if v is not None)
# Computes function forward and backward values
sample.input.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None
# TODO: backward consistency only supported for single tensor outputs
# TODO: backward consistency only checked on sample.input, not all
# tensor inputs
# TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output
if (op.supports_autograd and isinstance(expected_forward, torch.Tensor)):
expected_forward.sum().backward()
expected_grad = sample.input.grad
# Test eager consistency
for variant in variants:

We will have to put call for sample_inputs inside the variants loop. Since call to sample_input has to actually materialize all the tensors and perform forward multiple times on the same input sample (for different variant if that sample is valid for multiple variants), this will result in performance regression for this test.

I am feeling a bit sceptical about this now.
Let me know if that sounds acceptable. (or if i missed anything)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point, and that function could probably use a small refactoring. I agree with you but think we can mitigate the damage:

samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)

for sample in samples:
  # Acquires variants (method variant, inplace variant, aliases)
  method = op.get_method()
  inplace = op.get_inplace()

  # list of all inplace ops: inplace variant + alias inplace variants if exist
  inplace_ops = [inplace, ]

  aliases = []
  for a_op in op.aliases:
    aliases.append(a_op.op)
    aliases.append(a_op.method_variant)
    aliases.append(a_op.inplace_variant)
    inplace_ops.append(a_op.inplace_variant)
  
  aliases = tuple(aliases)
  inplace_ops = tuple(v for v in inplace_ops if v is not None)
  variants = (v for v in (method, inplace) + aliases if v is not None)

What's weird about this is that the test is rebuilding its understand of the operator's variants on each sample. This is unnecessary since variants are sample-independent. We should lift that section out of the for loop and acquire the variants once upfront.

Now let's look at the rest of the function:

# Computes function forward and backward values
sample.input.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None

if (op.supports_autograd and isinstance(expected_forward, torch.Tensor)):
  expected_forward.sum().backward()
  expected_grad = sample.input.grad

  for variant in variants:
    sample.input.grad = None
    cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
    variant_forward = variant(cloned, *sample.args, **sample.kwargs)
    self.assertEqual(expected_forward, variant_forward)

    if expected_grad is not None and (variant not in inplace_ops or op.supports_inplace_autograd):
      variant_forward.sum().backward()
      self.assertEqual(expected_grad, sample.input.grad)

I think we can get rid of the for variant in variants loop here by using a itertools.product(sample inputs, variants) after the sample input and variant acquisition at the start of the test. Then we can create a helper function that executes this test body. The helper function takes a variant, a sample input, and whether the operation is inplace or not (so it knows whether to perform the copy. Then the test works like this:

  • the variants are identified
  • sample inputs are acquired
  • if the op has an inplace variant, inplace sample inputs are acquired
  • an itertools product of sample inputs x non-inplace variants invokes the helper function
  • an itertools product of inplace sample inputs x inplace variants invokes the helper function

Does that make sense? Looking forward to hearing your thoughts.

You are correct that this will still redundantly create sample inputs that are common to both the out-of-place and in-place variant. Which is unfortunate from a performance standpoint, and your idea of marking SampleInputs as safe-for-inplace would have handled this problem better. However I think this is an OK penalty (for now, anyway).

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does make sense. I'll try to do that. Will ping here if I stumble into another block.

Thanks!

@kshitij12345 kshitij12345 marked this pull request as ready for review March 31, 2021 06:26
@kshitij12345
Copy link
Collaborator Author

@mruberry, have fixed the merge conflicts. I think this is ready for another round. PTAL :)

Note:
One thing to note is that this PR is volatile, in the sense that if a new samle_input_* is added with current signature (without **kwargs) then this PR will lead to errors as it will pass an invalid argument (so it should be rebased prior to landing and no new sample_input_* should be added in between).

Or should we add a try except (similar to the one mentioned above) temporarily and then remove once this signature becomes the norm?

try:
  samples = op.sample_inputs(..., **kwargs)
except TypeError as te:
  samples = op.sample_inputs(...) 

Thanks!

return variant.__wrapped__ is op.get_inplace()
return variant is op.get_inplace()

samples = op.sample_inputs(device, dtype, requires_grad=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your proposal to wrap this in a try/except so not every sample_input needs to be updated to use **kwargs when this lands is a good idea.


if len(inplace_ops) > 0:
inplace_samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad,
for_inplace_variant=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also wrap this in try/except to minimize logical merge conflicts

yield SampleInput(make_arg((S,)),
args=(torch.randn(S, S, device=device) > 0, make_arg(())))
yield SampleInput(make_arg((S,)),
args=(torch.randn(S, S, device=device) > 0, 10))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change instead of using the bernoulli op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bernoulli produces scalar tensor however we want (S, S) tensor.

def bernoulli_scalar():
return torch.tensor(0, dtype=torch.bool).bernoulli_()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

('masked_fill', (M,), (torch.BoolTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'),

I guess I mean where did this case go? We don't have to worry about it in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, replaced it with ( semantically both are doing the same thing),

SampleInput(make_arg((S,)),
            args=(torch.randn(S, S, device=device) > 0, 10)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kshitij12345! Overall this looks good, as usual.

I made a few comments; I like your suggestion to reduce logical merge conflicts by using a try/except.

I think the OpInfos for masked_fill and masked_scatter can also be updated with this change, right?

SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),

SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),

samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
except TypeError:
samples = self.sample_inputs_func(self, device, dtype, requires_grad)
return samples
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added the try/except here. (we don't need to add try/except in test_ops)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry can you please see if this is ok? Will rebase accordingly then. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this seems fine

inplace_ops = tuple(v for v in inplace_ops if v is not None)
variants = (v for v in (method, inplace) + aliases if v is not None)
inplace_variants = tuple(v for v in inplace_ops if v is not None)
variants = tuple(v for v in (method, inplace) + aliases if v is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fix

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! And thanks for including the fix for the variant generator. If you rebase this now I think we can get it landed, @kshitij12345.

I made one comment, but we don't need to worry about in this PR even if it is an issue.

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@kshitij12345
Copy link
Collaborator Author

With this, we might want to update/add entry to mention this new change at https://github.com/pytorch/pytorch/wiki/Writing-tests-in-PyTorch-1.8 and #54261 (OpInfo porting tracker issue)

@codecov
Copy link

codecov bot commented Apr 7, 2021

Codecov Report

Merging #53014 (64883fa) into master (bc05867) will decrease coverage by 0.19%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #53014      +/-   ##
==========================================
- Coverage   77.42%   77.23%   -0.20%     
==========================================
  Files        1895     1895              
  Lines      187524   187556      +32     
==========================================
- Hits       145194   144859     -335     
- Misses      42330    42697     +367     

@mruberry
Copy link
Collaborator

mruberry commented Apr 7, 2021

Edit: nevermind, ROCm failures are in the base!

Hmm... I was hoping to land this, but the ROCm build failure is worrying. @kshitij12345, would you take a look?

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 17e5ba4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

test_inplace_grad and test_variant_consistency_eager have to be disabled while testing broadcasting semantics in OpInfo based tests

4 participants