-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[prototype] Minor speed and nit optimizations on Transform Classes #6837
Conversation
@@ -51,7 +51,7 @@ def _check_input( | |||
|
|||
@staticmethod | |||
def _generate_value(left: float, right: float) -> float: | |||
return float(torch.distributions.Uniform(left, right).sample()) | |||
return torch.empty(1).uniform_(left, right).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switching to this random generator we get a performance boost on GPU. Moreover this option is JIT-scriptable (if on the future we decide to add support) and doesn't require to constantly initialize a distribution object as before:
[--------- ColorJitter cpu torch.float32 ---------]
| old random | new random
1 threads: ----------------------------------------
(3, 400, 400) | 17 | 17
6 threads: ----------------------------------------
(3, 400, 400) | 21 | 21
Times are in milliseconds (ms).
[--------- ColorJitter cuda torch.float32 --------]
| old random | new random
1 threads: ----------------------------------------
(3, 400, 400) | 1090 | 883
6 threads: ----------------------------------------
(3, 400, 400) | 1090 | 882
Times are in microseconds (us).
414a1ee
to
7b8be17
Compare
@@ -80,7 +97,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", | |||
|
|||
|
|||
def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox: | |||
bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)} | |||
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use a list here instead of a set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a perf optimization? Otherwise I would prefer the set since
- it makes it more clear we are looking for duplicates
- it is aligned with the other
query_*
function below.
Functionally, the only difference is that with a set
, passing the same bounding box twice in the same sample would not be caught. This indeed seems like a user error, but if we raise there, we should probably check for duplicates everywhere. Not sure if we want to single out bounding boxes here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's both perf and code correctness. What we want here is to ensure there is only 1 bbox in the input. If there were two, even if they were the same exact copy, our transforms that rely on query_bounding_box
would end up working incorrectly. Effectively they would end up modifying the first copy but not the second. This is a different use-case from the rest of the query_*
methods where we are able to handle multiple entries but just want to ensure same size across.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's both perf and code correctness.
Are list
's actually faster than set
's? I'm honestly curious since I never benchmarked this 😇
If there were two, even if they were the same exact copy, our transforms that rely on
query_bounding_box
would end up working incorrectly. Effectively they would end up modifying the first copy but not the second.
Agreed.
This is a different use-case from the rest of the
query_*
methods where we are able to handle multiple entries but just want to ensure same size across.
You are right that, I was not clear enough. What I meant was: we are not checking for duplicate images, videos, masks, ... anywhere. Either we should do that (and that seems like the right thing to do, but that probably has perf implications) or we shouldn't single out bounding boxes here. No strong opinion though. A single mitigation is probably better than none. But we should still discuss if we need to check this for everything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, query_chw
and query_spatial_size
must use sets because what we are interested in is enforcing that all images have a single size. Multiple images/videos are allowed in these cases, we just enforce the size is the same everywhere.
On the other hand the query_bounding_box
is used to extract the single bbox. The transforms that use the specific util, handle only 1 bbox and this is why we need a list. Note that this is identical to the _flatten_and_extract_image_or_video
from AA where we also use a list.
So this is primarily a bug fix to restore the right semantics on the function and also has secondary positive performance implications.
@@ -13,7 +13,7 @@ class DecodeImage(Transform): | |||
_transformed_types = (features.EncodedImage,) | |||
|
|||
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image: | |||
return cast(features.Image, F.decode_image_with_pil(inpt)) | |||
return F.decode_image_with_pil(inpt) # type: ignore[no-any-return] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has to be here, because it seems
@torch.jit.unused | |
def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image: |
doesn't "forward" the type annotations 🙄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In all other places we took the decision to silence with ignore rather than cast, do we really need the cast here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nono, I was just explaining why we need the ignore for future me that is looking confused at the blame why we introduced it in the first place.
@@ -80,7 +97,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", | |||
|
|||
|
|||
def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox: | |||
bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)} | |||
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a perf optimization? Otherwise I would prefer the set since
- it makes it more clear we are looking for duplicates
- it is aligned with the other
query_*
function below.
Functionally, the only difference is that with a set
, passing the same bounding box twice in the same sample would not be caught. This indeed seems like a user error, but if we raise there, we should probably check for duplicates everywhere. Not sure if we want to single out bounding boxes here.
537e650
to
8e6af8d
Compare
if isinstance(fill, defaultdict) and callable(fill.default_factory): | ||
default_value = fill.default_factory() | ||
_check_fill_arg(default_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pmeier Removing the support of defaultdicts would mean an extremely verbose API for users. They would need to define not only types that make sense but also those that dont (such as Labels, HotLabels and BBoxes) even for padding calls. Here I propose an alternative mitigation. This is part 1 to ensure we check things properly during the constructor. This is an attempt to "bug fix"/extend the validation of the existing method for default dicts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, let's move forward with the current implementation and refactor later if needed.
for k, v in fill.items(): | ||
fill[k] = _convert_fill_arg(v) | ||
if isinstance(fill, defaultdict) and callable(fill.default_factory): | ||
default_value = fill.default_factory() | ||
sanitized_default = _convert_fill_arg(default_value) | ||
fill.default_factory = functools.partial(_default_arg, sanitized_default) | ||
return fill # type: ignore[return-value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pmeier Second mitigation. This sanitizes the input of all existing keys and patches the default_factory to have a sanitized value. If we don't like this mitigation, then I think we will have to revert the change of doing the verification on constructor time instead on runtime. Let me know your preference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Vasilis! LGTM if CI is green.
if isinstance(fill, defaultdict) and callable(fill.default_factory): | ||
default_value = fill.default_factory() | ||
_check_fill_arg(default_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, let's move forward with the current implementation and refactor later if needed.
@pmeier Thanks. As mentioned offline, I've tested extensively this approach and know it works but I'm not terribly happy we need to check explicitly for |
…lasses (#6837) Summary: * Change random generator for ColorJitter. * Move `_convert_fill_arg` from runtime to constructor. * Remove unnecessary TypeVars. * Remove unnecessary casts * Update comments. * Minor code-quality changes on Geometical Transforms. * Fixing linter and other minor fixes. * Change mitigation for mypy.` * Fixing the tests. * Fixing the tests. * Fix linter * Restore dict copy. * Handling of defaultdicts * restore int idiom * Update todo Reviewed By: YosuaMichael Differential Revision: D40755989 fbshipit-source-id: d5b475ea9a603c7a137e85db08dcd0db30195e3c
Further speed optimizations, nits and code-quality changes:
torch.distributions.Uniform()
totensor.uniform_()
speeds up CUDA by 20%. It also aligns with the implementation on V1.cc @vfdev-5 @bjuncek @pmeier