Skip to content

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Apr 26, 2022

This PR refactors our prototype transforms functional tests. They are currently located in test/test_prototype_transforms_functional.py. To ease the reviewing process I've added a new test/test_prototype_transforms_kernels.py module that contains the refactored tests from this PR. In the end that should replace most parts of the old file, but doing it in a separate module avoids GH diff hell.

Status quo

Our current implementation was the first attempt to automate our tests. I took some inspiration from the OpInfo framework from PyTorch core. The basic idea is to define a FunctionalInfo's for each functional that stores some metadata about it.

The most important info (and for now the only metadata we store) is the sample_inputs_fn. It yields call arguments

With that we can write common tests that can be @pytest.mark.parametrize'd over the kernel-call-args combinations. For example, a test that checks the torch.jit.script'ed output against its eager counterpart looks like this

@pytest.mark.parametrize(
("functional_info", "sample_input"),
[
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_eager_vs_scripted(functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
torch.testing.assert_close(eager, scripted)

Pros / Cons

This architecture has two main benefits over manually writing these tests:

  1. It is trivial to add a new common test for all functionals.
  2. Writing comprehensive tests for a new functional is reduced to defining call args.

Plus, and in contrast to the OpInfo's from PyTorch core, the test are easier to debug, since we use @pytest.mark.parametrize over for loops in the test body to iterate over the call args. If one of our tests fails, we can reproduce the parametrization from the log, whereas in PyTorch core one needs manually find which call args are responsible for the failure.

However, there are also downsides:

  1. Using @pytest.mark.parametrize instantiates everything upfront. Especially with tensor inputs that can quickly become a big chunk of memory. Right now test/test_prototype_transforms_functional.py includes ~23k tests that instantiate tensors during collection (pytest --co test/test_prototype_transforms_functional.py::test_eager_vs_scripted). They all come from a common test, namely eager vs scripted. If we add more tests or more ops, this number will grow fast. This is the reason why PyTorch core does not rely on parametrization for their sample inputs, but rather fall back to for loops inside the test.
  2. Having only one kind of sample inputs. Although each test might have different needs with regard to the sample inputs, we only have one generating function. This means that we need to support all needs with it and in turn unnecessary blow up the time to run all tests. For example, to test whether the output of the scripted affine_image_tensor kernel matches its eager counterpart, we only need to test a single set of affine parameters (as long as there is no branching based on them). However, for reference testing, we should test multiple parameter sets to make sure the kernel actually behaves like its reference.

Design goal

This PR sets out to solve the problems detailed above while retaining all the positive aspects of the current implementation.

  1. Introduce the TensorLoader class: it wraps another callable that in the end will create the tensor, but it knows the shape, dtype and possible other feature metadata ahead of time. The device will only be passed at runtime to allow us to parametrize over different devices. With this we can continue to rely on the tensor attributes during sample input generation, e.g.

    height, width = image.shape[-2:]

    height, width = bounding_box.image_size

    without actually instantiating the tensors.

    At test time, the tensor can simply instantiated with TensorLoader(...).load(device). For convenience, ArgsKwargs was made aware of TensorLoader and got a .load(device) method as well. With these, the common tests will look somewhat like this:

    @pytest.mark.parametrize(
        ("info", "args_kwargs"), [
            (info, args_kwargs) 
            for info in KERNEL_INFOS 
            for args_kwargs in info.sample_inputs
        ]
    )
    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    def test_smoke(info, args_kwargs, device):
        args, kwargs = args_kwargs.load(device)
    
        result = info.kernel(*args, **kwargs)

    This approach of "lazy loading" is similar to the concept of lazy tensors although stripping everything we don't need. To avoid confusion, I preferred to use the term "load" over "lazy" here.

  2. Introduce a reference_inputs_fn alongside the sample_inputs_fn. As the name implies, the former will only be used for reference tests and should be comprehensive with respect to the tested values. In contrast, the sample_inputs_fn should only cover all valid code paths. This is on par with PyTorch core does with their OpInfo framework although they have even more diverse sample inputs functions, like the error_inputs_func.

Limitations

There are two things that are not included in the current design:

  1. Reference test against fixed inputs. The reference tests explained above work by using one function that generates sample inputs and passing them to the kernel as well as a reference function. Thus, we only need to parametrize over one set of sample inputs and the reference outputs only get computed at runtime. To support fixed references, we would need a map of sample inputs to their fixed reference and thus defeating the whole part of the design to nto instantiate tensors at collection. Fortunately, we don't need to force everything into the framework proposed here and can simply have separate "manual" tests for this. Of course this also applies to testing error / warning inputs although in the future we could also integrate that into the framework.
  2. Tests for high level functionals aka dispatchers. As the name implies, they only dispatch, so it would be quite a waste of resources to test them again with sample inputs if we already tested the kernels. I think we should implement something like "if I put in a PIL Image, the PIL kernel gets called and its output gets returned". However, this is out of scope for this PR.

Todo

This PR mostly introduces the new framework while adding some kernels as examples. There are three ways to add the remaining ones:

  1. Finish everything in this PR.
  2. Merge this PR and add follow-up PRs for the remaining kernels. This is possible without ripping large gaps in our CI since I've added the new tests in a new module. The last PR of this series could be a clean up to remove all the then duplicated tests in the old module.
  3. Use this PR as feature branch and add more PRs against this until we are finished.

My preference is 3. -> 2. -> 1. but I'll leave that up to the reviewers. Here is the list of what kernels are done or missing:

  • clamp_bounding_box
  • convert_bounding_box_format
  • convert_color_space
    • image
  • adjust_brightness
    • image
  • adjust_contrast
    • image
  • adjust_gamma
    • image
  • adjust_hue
    • image
  • adjust_saturation
    • image
  • adjust_sharpness
    • image
  • autocontrast
    • image
  • equalize
    • image
  • invert
    • image
  • posterize
    • image
  • solarize
    • image
  • affine
    • bounding_box
    • image
    • segmentation_mask
  • center_crop
    • bounding_box
    • image
    • segmentation_mask
  • crop
    • bounding_box
    • image
    • segmentation_mask
  • elastic
    • bounding_box
    • image
    • segmentation_mask
  • five_crop
    • image
  • horizontal_flip
    • bounding_box
    • image
    • segmentation_mask
  • pad
    • bounding_box
    • image
    • segmentation_mask
  • perspective
    • bounding_box
    • image
    • segmentation_mask
  • resize
    • bounding_box
    • image
    • segmentation_mask
  • resized_crop
    • bounding_box
    • image
    • segmentation_mask
  • rotate
    • bounding_box
    • image
    • segmentation_mask
  • ten_crop
    • bounding_box
    • image
    • segmentation_mask
  • vertical_flip
    • bounding_box
    • image
    • segmentation_mask

__all__ = ["assert_close"]


class PILImagePair(TensorLikePair):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a superset of what the old ImagePair did. It includes the options to only test the aggregated difference or check the percentage of differing pixels. That is on par with what we are currently doing in our stable functional tests:

vision/test/common_utils.py

Lines 172 to 174 in a67cc87

def _assert_approx_equal_tensor_to_pil(
tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
):

Comment on lines 166 to 180
def from_loader(loader_fn):
def wrapper(*args, **kwargs):
loader = loader_fn(*args, **kwargs)
return loader.load(kwargs.get("device", "cpu"))

return wrapper


def from_loaders(loaders_fn):
def wrapper(*args, **kwargs):
loaders = loaders_fn(*args, **kwargs)
for loader in loaders:
yield loader.load(kwargs.get("device", "cpu"))

return wrapper
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions are mostly for "BC" with our current tests. With them we can turn make_*_loader{s} back into make_*. For example, make_images = from_loaders(make_image_loaders). This makes the transition period easer, since we don't need to touch the old files.

In the future, most tests should use the loader architecture. Those that don't could simply invoke TensorLoader(...).load(device) manually.

return self.kernel.__name__


def pil_reference_wrapper(pil_kernel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference defined in the KernelInfo will be passed the same inputs as the kernel. Since we use the PIL kernel as reference for its tensor counterpart, this is simple wrapper to avoid defining the same kind of reference function over and over.


def sample_inputs_horizontal_flip_image_tensor():
for image_loader in make_image_loaders(dtypes=[torch.float32]):
yield ArgsKwargs(image_loader.unwrap())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to make_image_loaders, these functions don't need to yield, but it makes the definition less verbose.

@pmeier
Copy link
Collaborator Author

pmeier commented Sep 12, 2022

@vfdev-5 The failure comes from a test you wanted to fix:

=================================== FAILURES ===================================
_ test_correctness_elastic_image_or_mask_tensor[elastic_segmentation_mask-make_segmentation_masks-cpu] _
Traceback (most recent call last):
  File "/home/runner/work/vision/vision/test/test_prototype_transforms_functional.py", line 1643, in test_correctness_elastic_image_or_mask_tensor
    sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c)
  File "/home/runner/work/vision/vision/torchvision/prototype/features/_feature.py", line 87, in __torch_function__
    output = func(*args, **kwargs)
IndexError: index 34 is out of bounds for dimension 1 with size 25
----------------------------- Captured stdout call -----------------------------
torch.Size([1, 64, 76])
torch.Size([4, 1, 64, 76])
torch.Size([0, 64, 76])
torch.Size([4, 0, 64, 76])
torch.Size([9, 64, 76])
torch.Size([4, 4, 64, 76])
torch.Size([1, 25, 18])

Given that this PR does not touch this test at all, my best guess is that it was flaky before and depended on a random seed. Plus, we seem to have missed a debug statement:

@pmeier pmeier marked this pull request as ready for review September 12, 2022 16:41
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 12, 2022

The failure comes from a test you wanted to fix:

@pmeier it was fixed. The issue is with input image size that is too small for given bounding boxes. Probably, you changed that again somewhere.

@pmeier
Copy link
Collaborator Author

pmeier commented Sep 13, 2022

@vfdev-5 You were right and I fixed that in a49f0db. However, now we get this failure:

=================================== FAILURES ===================================
_ test_correctness_perspective_segmentation_mask[startpoints0-endpoints0-cpu] __
Traceback (most recent call last):
  File "/Users/runner/work/vision/vision/test/test_prototype_transforms_functional.py", line 1496, in test_correctness_perspective_segmentation_mask
    torch.testing.assert_close(output_mask, expected_masks)
  File "/Users/runner/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/testing/_comparison.py", line 1359, in assert_close
    msg=msg,
  File "/Users/runner/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/testing/_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 1 / 558 (0.2%)
Greatest absolute difference: 1 at index (0, 2, 17)
Greatest relative difference: inf at index (0, 2, 17)

In CI this is only failing on macOS, but this also fails for me locally. Given that we have only a single mismatched element, the test is probably flaky.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 13, 2022

I do not see from your commit where the size was fixed. More general question, how to find the code of failing test ? It is totally obfuscated to me :)

@pmeier
Copy link
Collaborator Author

pmeier commented Sep 13, 2022

I do not see from your commit where the size was fixed.

The issue was that while refactoring the make_*s functions I didn't made sure that the "data dimensions", i.e. everything that is not a batch dimension, are constant. For example, before the last commit in this PR, setting size parameter in make_segmentation_masks only applied to a part of the generated samples. This is why the failing test popped back up: you set the size and depended on it being at least that, but some samples were smaller.

More general question, how to find the code of failing test ? It is totally obfuscated to me :)

Not sure what you mean. Could you elaborate? Right now there is no failing test in the new tests, so I'll construct one to show what it looks like. Imagine horizontal_flip_bounding_box is not torch.jit.script'able. The failing test will look like this

________________ TestCommon.test_scripted_vs_eager[cpu-horizontal_flip_bounding_box35] _________________
Traceback

Traceback (most recent call last):
  File "/home/philip/git/pytorch/torchvision/test/test_prototype_transforms_kernels.py", line 343, in test_scripted_vs_eager
    kernel_scripted = torch.jit.script(kernel_eager)
  File "/home/philip/.local/opt/mambaforge/envs/torchvision-dev/lib/python3.7/site-packages/torch/jit/_script.py", line 1344, in script
    qualified_name, ast, _rcb, get_default_args(obj)
RuntimeError: 
Ellipses followed by tensor indexing is currently not supported:
  File "/home/philip/git/pytorch/torchvision/torchvision/prototype/transforms/functional/_geometry.py", line 42
    )

    bounding_box[..., [0, 2]] = image_size[1] - bounding_box[..., [2, 0]]
                                                ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

    return convert_bounding_box_format(


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/philip/git/pytorch/torchvision/test/test_prototype_transforms_kernels.py", line 345, in test_scripted_vs_eager
    raise AssertionError("Trying to `torch.jit.script` the kernel raised the error above.") from error
AssertionError: Trying to `torch.jit.script` the kernel raised the error above.

From the test name and parametrization you should find everything:

  1. The test is defined in the class TestCommon and is named test_scripted_vs_eager
  2. The test was run on the CPU
  3. It checked the kernel horizontal_flip_bounding_box and specifically the 35th sample input. (In this example all sample inputs fail since the scriptability of the kernel is not dependent on the input).

To debug, you can run this exact test with

pytest 'test/test_prototype_transforms_kernels.py::TestCommon::test_scripted_vs_eager[cpu-horizontal_flip_bounding_box35]'

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 13, 2022

I was talking about test_correctness_perspective_segmentation_mask[startpoints0-endpoints0-cpu] from failed macosx job.

@pmeier
Copy link
Collaborator Author

pmeier commented Sep 13, 2022

Top-most error in the traceback:

File "/Users/runner/work/vision/vision/test/test_prototype_transforms_functional.py", line 1496, in test_correctness_perspective_segmentation_mask
    torch.testing.assert_close(output_mask, expected_masks)

To reproduce:

pytest 'test/test_prototype_transforms_functional.py::test_correctness_perspective_segmentation_mask[startpoints0-endpoints0-cpu]'

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to me



def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
def make_bounding_box_loader(*, extra_dims=(), format, image_size=None, dtype=torch.float32):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why bounding box does not have num_objects arg ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we defined a bounding box as (*, 4), i.e. features.BoundingBox([0, 0, 10, 10], ...) is valid although it only has a single dimension. Thus, if you want to have multiple boxes, set extra_dims=(num_objects,)

@pmeier pmeier merged commit 0b5ebae into pytorch:main Sep 15, 2022
@pmeier pmeier deleted the prototype-functional-test branch September 15, 2022 09:30
facebook-github-bot pushed a commit that referenced this pull request Sep 15, 2022
Reviewed By: jdsgomes

Differential Revision: D39543278

fbshipit-source-id: 413bc5160188c5423d39d9f73387e9a5f25d8af7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants