-
Notifications
You must be signed in to change notification settings - Fork 7.2k
add prototype for CIFAR datasets #4511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Conflicts: torchvision/prototype/datasets/_builtin/__init__.py torchvision/prototype/datasets/utils/_internal.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Making a high-level comment here, I think we should benchmark our datasets while we add them here to ensure we don't introduce performance regressions (which is probably the case in CIFAR).
Can you add this to the TODO list of what to add for datasets? Probably running it for a few epochs to take into account tar extraction etc as well.
(_, category_idx), (_, image_array_flat) = data | ||
|
||
image_array = image_array_flat.reshape((3, 32, 32)).transpose(1, 2, 0) | ||
image_buffer = image_buffer_from_array(image_array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this keeps the APIs consistent, it is worth noting that this will bring significant overheads for the dataset, making it much slower than the current CIFAR datasets in torchvision, as we will be encoding + decoding for every image, while the CIFAR datasets already store the images in decoded format.
I want us to keep this in mind for the future, maybe defaulting to not re-encoding + decoding for speed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed the default behavior so that we don't need to encode and decode in every step. This only applies to datasets where this possible (IIRC only CIFAR and MNIST). "Normal" image datasets are untouched by this.
def __iter__(self) -> Iterator[D]: | ||
for sequence in self.datapipe: | ||
yield from iter(sequence) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the equivalent of flattening the sequence? If yes, do we have something already available in datapipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be doable with https://github.com/pytorch/pytorch/blob/2a5116e1599be7d7fa1be9572f47c316716b74c3/torch/utils/data/datapipes/iter/grouping.py#L102. Otherwise we can always send these datapipes upstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid writing new datapipes if they are redundant with upstream blocks already
Good point. I'll design a benchmark utility, so we can do this for every dataset. |
archive_dp = TarArchiveReader(archive_dp) | ||
archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) | ||
archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ejguan any idea how to appease mypy
here, without slapping : IterDataPipe
everywhere? Otherwise I'm inclined to blanket ignore var-annotated
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The easiest way should be adding annotation to the variable at the beginning:
archive_dp = TarArchiveReader(archive_dp) | |
archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) | |
archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle) | |
archive_dp: IterDataPipe | |
archive_dp = resource_dps[0] | |
archive_dp = TarArchiveReader(archive_dp) | |
archive_dp = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) | |
archive_dp = Mapper(archive_dp, self._unpickle) |
Summary: * add prototype for CIFAR datasets Reviewed By: NicolasHug Differential Revision: D31505571 fbshipit-source-id: 33a656a3d3c176752491f400aeb0e9327856493e
* add prototype for CIFAR datasets Conflicts: torchvision/prototype/datasets/_builtin/__init__.py torchvision/prototype/datasets/utils/_internal.py * fix mypy * cleanup * more cleanup * revert unrelated changes * fix code format * avoid decoding twice by default * revert unrelated change * cleanup
* add prototype for CIFAR datasets Conflicts: torchvision/prototype/datasets/_builtin/__init__.py torchvision/prototype/datasets/utils/_internal.py * fix mypy * cleanup * more cleanup * revert unrelated changes * fix code format * avoid decoding twice by default * revert unrelated change * cleanup
This overlaps with #4510 in some parts, so this will have to be rebased after the other PR is merged.
cc @pmeier @mthrok @bjuncek