Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Minor speed and nit optimizations on Transform Classes #6837

Merged
merged 17 commits into from
Oct 27, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 25, 2022

Further speed optimizations, nits and code-quality changes:

  • Switching from torch.distributions.Uniform() to tensor.uniform_() speeds up CUDA by 20%. It also aligns with the implementation on V1.
  • Moves the input sanitisation from the runtime to construction time for as many params as possible.
  • Contains a few non performance refactorings to avoid unnecessary casting and moving methods around.

cc @vfdev-5 @bjuncek @pmeier

@datumbox datumbox marked this pull request as draft October 25, 2022 17:01
@@ -51,7 +51,7 @@ def _check_input(

@staticmethod
def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample())
return torch.empty(1).uniform_(left, right).item()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to this random generator we get a performance boost on GPU. Moreover this option is JIT-scriptable (if on the future we decide to add support) and doesn't require to constantly initialize a distribution object as before:

[--------- ColorJitter cpu torch.float32 ---------]
                     |   old random  |   new random
1 threads: ----------------------------------------
      (3, 400, 400)  |       17      |       17    
6 threads: ----------------------------------------
      (3, 400, 400)  |       21      |       21    

Times are in milliseconds (ms).

[--------- ColorJitter cuda torch.float32 --------]
                     |   old random  |   new random
1 threads: ----------------------------------------
      (3, 400, 400)  |      1090     |      883    
6 threads: ----------------------------------------
      (3, 400, 400)  |      1090     |      882    

Times are in microseconds (us).

@datumbox datumbox added module: transforms Perf For performance improvements prototype labels Oct 25, 2022
torchvision/prototype/transforms/_geometry.py Outdated Show resolved Hide resolved
@@ -80,7 +97,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect",


def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)}
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use a list here instead of a set.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a perf optimization? Otherwise I would prefer the set since

  1. it makes it more clear we are looking for duplicates
  2. it is aligned with the other query_* function below.

Functionally, the only difference is that with a set, passing the same bounding box twice in the same sample would not be caught. This indeed seems like a user error, but if we raise there, we should probably check for duplicates everywhere. Not sure if we want to single out bounding boxes here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's both perf and code correctness. What we want here is to ensure there is only 1 bbox in the input. If there were two, even if they were the same exact copy, our transforms that rely on query_bounding_box would end up working incorrectly. Effectively they would end up modifying the first copy but not the second. This is a different use-case from the rest of the query_* methods where we are able to handle multiple entries but just want to ensure same size across.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's both perf and code correctness.

Are list's actually faster than set's? I'm honestly curious since I never benchmarked this 😇

If there were two, even if they were the same exact copy, our transforms that rely on query_bounding_box would end up working incorrectly. Effectively they would end up modifying the first copy but not the second.

Agreed.

This is a different use-case from the rest of the query_* methods where we are able to handle multiple entries but just want to ensure same size across.

You are right that, I was not clear enough. What I meant was: we are not checking for duplicate images, videos, masks, ... anywhere. Either we should do that (and that seems like the right thing to do, but that probably has perf implications) or we shouldn't single out bounding boxes here. No strong opinion though. A single mitigation is probably better than none. But we should still discuss if we need to check this for everything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, query_chw and query_spatial_size must use sets because what we are interested in is enforcing that all images have a single size. Multiple images/videos are allowed in these cases, we just enforce the size is the same everywhere.

On the other hand the query_bounding_box is used to extract the single bbox. The transforms that use the specific util, handle only 1 bbox and this is why we need a list. Note that this is identical to the _flatten_and_extract_image_or_video from AA where we also use a list.

So this is primarily a bug fix to restore the right semantics on the function and also has secondary positive performance implications.

@datumbox datumbox changed the title [WIP] [prototype] Speed up more Transform Classes [prototype] Speed up more Transform Classes Oct 26, 2022
@datumbox datumbox marked this pull request as ready for review October 26, 2022 12:54
torchvision/prototype/transforms/_geometry.py Outdated Show resolved Hide resolved
@@ -13,7 +13,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)

def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return cast(features.Image, F.decode_image_with_pil(inpt))
return F.decode_image_with_pil(inpt) # type: ignore[no-any-return]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be here, because it seems

@torch.jit.unused
def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:

doesn't "forward" the type annotations 🙄

Copy link
Contributor Author

@datumbox datumbox Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all other places we took the decision to silence with ignore rather than cast, do we really need the cast here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nono, I was just explaining why we need the ignore for future me that is looking confused at the blame why we introduced it in the first place.

@@ -80,7 +97,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect",


def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)}
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a perf optimization? Otherwise I would prefer the set since

  1. it makes it more clear we are looking for duplicates
  2. it is aligned with the other query_* function below.

Functionally, the only difference is that with a set, passing the same bounding box twice in the same sample would not be caught. This indeed seems like a user error, but if we raise there, we should probably check for duplicates everywhere. Not sure if we want to single out bounding boxes here.

Comment on lines +40 to +42
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier Removing the support of defaultdicts would mean an extremely verbose API for users. They would need to define not only types that make sense but also those that dont (such as Labels, HotLabels and BBoxes) even for padding calls. Here I propose an alternative mitigation. This is part 1 to ensure we check things properly during the constructor. This is an attempt to "bug fix"/extend the validation of the existing method for default dicts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, let's move forward with the current implementation and refactor later if needed.

Comment on lines +79 to +85
for k, v in fill.items():
fill[k] = _convert_fill_arg(v)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
sanitized_default = _convert_fill_arg(default_value)
fill.default_factory = functools.partial(_default_arg, sanitized_default)
return fill # type: ignore[return-value]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier Second mitigation. This sanitizes the input of all existing keys and patches the default_factory to have a sanitized value. If we don't like this mitigation, then I think we will have to revert the change of doing the verification on constructor time instead on runtime. Let me know your preference.

@datumbox datumbox requested a review from pmeier October 26, 2022 15:48
@datumbox datumbox changed the title [prototype] Speed up more Transform Classes [prototype] Minor speed and nit optimizations on Transform Classes Oct 26, 2022
Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Vasilis! LGTM if CI is green.

Comment on lines +40 to +42
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, let's move forward with the current implementation and refactor later if needed.

@datumbox
Copy link
Contributor Author

@pmeier Thanks. As mentioned offline, I've tested extensively this approach and know it works but I'm not terribly happy we need to check explicitly for defaultdict. I would be happy if we revert this change and move back to runtime input validation on a follow up PR. I'll merge to unblock further work but happy to discuss any changes.

@datumbox datumbox merged commit e1f464b into pytorch:main Oct 27, 2022
@datumbox datumbox deleted the prototype/class_speedup2 branch October 27, 2022 09:56
facebook-github-bot pushed a commit that referenced this pull request Oct 27, 2022
…lasses (#6837)

Summary:
* Change random generator for ColorJitter.

* Move `_convert_fill_arg` from runtime to constructor.

* Remove unnecessary TypeVars.

* Remove unnecessary casts

* Update comments.

* Minor code-quality changes on Geometical Transforms.

* Fixing linter and other minor fixes.

* Change mitigation for mypy.`

* Fixing the tests.

* Fixing the tests.

* Fix linter

* Restore dict copy.

* Handling of defaultdicts

* restore int idiom

* Update todo

Reviewed By: YosuaMichael

Differential Revision: D40755989

fbshipit-source-id: d5b475ea9a603c7a137e85db08dcd0db30195e3c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants