From 38428be16f8c01b05f566d70779d9df28d068110 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 24 Feb 2023 17:23:23 +0100 Subject: [PATCH 1/7] add docstring for dataset wrapper --- docs/source/datasets.rst | 5 ++ torchvision/datapoints/_dataset_wrapper.py | 65 +++++++++++++++++----- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 68c72e7af8c..d1e01311d9e 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -169,3 +169,8 @@ Base classes for custom datasets DatasetFolder ImageFolder VisionDataset + +Transforms v2 +------------- + +.. autofunction:: wrap_dataset_for_transforms_v2 diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index e358c83d9d1..de532b49672 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -16,6 +16,51 @@ # TODO: naming! def wrap_dataset_for_transforms_v2(dataset): + """Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. + + .. note:: + + So far we only provide wrappers for the most popular datasets. Furthermore, the wrappers only support dataset + configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you + to raise an issue to ``torchvision`` for a dataset or configuration that you need, please act on it. + + The dataset samples are wrapped according to the description below. + + Special + + * :class:`~torchvision.datasets.CocoDetection`: Instead returning the target as list of dicts, now returns it as + dict of lists. In addition, the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` are added which + wrap the data in the corresponding ``torchvision.datapoints``. + * :class:`~torchvision.datasets.VOCDetection` + * :class:`~torchvision.datasets.SBDataset` + * :class:`~torchvision.datasets.CelebA` + * :class:`~torchvision.datasets.Kitti` + * :class:`~torchvision.datasets.OxfordIIITPet` + * :class:`~torchvision.datasets.Cityscapes` + * :class:`~torchvision.datasets.WIDERFace` + + Image classification datasets + + This wrapper is a no-op for image classification datasets, since they were already fully supported by + :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`. + + Segmentation datasets + + Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of + :class:`PIL.Image.Image`'s. This wrapper leaves the image, i.e. the first item, as is, while wrapping the + segmentation mask, i.e. the second item, into a :class:`~torchvision.datapoints.Mask`. + + Video classification datasets + + Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple contained a + :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a + :class:`~torchvision.datapoints.Video` while leaving the other items as is. + + .. note:: + + Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative + ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`. + """ return VisionDatasetDatapointWrapper(dataset) @@ -103,10 +148,6 @@ def raise_not_supported(description): ) -def identity(item): - return item - - def identity_wrapper_factory(dataset): def wrapper(idx, sample): return sample @@ -114,10 +155,6 @@ def wrapper(idx, sample): return wrapper -def pil_image_to_mask(pil_image): - return datapoints.Mask(pil_image) - - def list_of_dicts_to_dict_of_lists(list_of_dicts): dict_of_lists = defaultdict(list) for dct in list_of_dicts: @@ -131,7 +168,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): target = [target] wrapped_target = tuple( - type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target) + type_wrappers.get(target_type, lambda x: x)(item) for target_type, item in zip(target_types, target) ) if len(wrapped_target) == 1: @@ -161,7 +198,7 @@ def classification_wrapper_factory(dataset): def segmentation_wrapper_factory(dataset): def wrapper(idx, sample): image, mask = sample - return image, pil_image_to_mask(mask) + return image, datapoints.Mask(mask) return wrapper @@ -307,7 +344,7 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.SBDataset) -def sbd_wrapper(dataset): +def sbdataset_wrapper(dataset): if dataset.mode == "boundaries": raise_not_supported("SBDataset with mode='boundaries'") @@ -374,7 +411,7 @@ def wrapper(idx, sample): target, target_types=dataset._target_types, type_wrappers={ - "segmentation": pil_image_to_mask, + "segmentation": datapoints.Mask, }, ) @@ -390,7 +427,7 @@ def cityscapes_wrapper_factory(dataset): def instance_segmentation_wrapper(mask): # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21 - data = pil_image_to_mask(mask) + data = datapoints.Mask(mask) masks = [] labels = [] for id in data.unique(): @@ -409,7 +446,7 @@ def wrapper(idx, sample): target_types=dataset.target_type, type_wrappers={ "instance": instance_segmentation_wrapper, - "semantic": pil_image_to_mask, + "semantic": datapoints.Mask, }, ) From 17cb38b4222be815775de9fc93d105f2e0e786a5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 24 Feb 2023 17:10:54 +0000 Subject: [PATCH 2/7] Some minor changes --- docs/source/datasets.rst | 6 ++++- torchvision/datapoints/_dataset_wrapper.py | 30 ++++++++++++++-------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index d1e01311d9e..35e5eaf2a9f 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -173,4 +173,8 @@ Base classes for custom datasets Transforms v2 ------------- -.. autofunction:: wrap_dataset_for_transforms_v2 +.. autosummary:: + :toctree: generated/ + :template: function.rst + + wrap_dataset_for_transforms_v2 diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index de532b49672..a938018bba1 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -14,23 +14,28 @@ __all__ = ["wrap_dataset_for_transforms_v2"] -# TODO: naming! def wrap_dataset_for_transforms_v2(dataset): - """Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. + """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. + + .. v2betastatus:: wrap_dataset_for_transforms_v2 function + + Example: + >>> coco = torchvision.datasets.CocoDetection() + >>> coco = wrap_dataset_for_transforms_v2(coco) .. note:: - So far we only provide wrappers for the most popular datasets. Furthermore, the wrappers only support dataset + For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you - to raise an issue to ``torchvision`` for a dataset or configuration that you need, please act on it. + to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so. The dataset samples are wrapped according to the description below. - Special + Special cases: - * :class:`~torchvision.datasets.CocoDetection`: Instead returning the target as list of dicts, now returns it as - dict of lists. In addition, the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` are added which - wrap the data in the corresponding ``torchvision.datapoints``. + * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper returns a + dict of lists. In addition, the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` are added and + wrap the data in the corresponding ``torchvision.datapoints``. The original keys are preserved. * :class:`~torchvision.datasets.VOCDetection` * :class:`~torchvision.datasets.SBDataset` * :class:`~torchvision.datasets.CelebA` @@ -47,12 +52,12 @@ def wrap_dataset_for_transforms_v2(dataset): Segmentation datasets Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of - :class:`PIL.Image.Image`'s. This wrapper leaves the image, i.e. the first item, as is, while wrapping the - segmentation mask, i.e. the second item, into a :class:`~torchvision.datapoints.Mask`. + :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the + segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item). Video classification datasets - Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple contained a + Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a :class:`~torchvision.datapoints.Video` while leaving the other items as is. @@ -60,6 +65,9 @@ def wrap_dataset_for_transforms_v2(dataset): Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`. + + Args: + dataset: the dataset instance to wrap for compatibility with transforms v2. """ return VisionDatasetDatapointWrapper(dataset) From eb43a5a03ef76c4368f1f7a67861d547e9e59a8b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 27 Feb 2023 09:02:39 +0100 Subject: [PATCH 3/7] add remaining descriptions for special datasets --- torchvision/datapoints/_dataset_wrapper.py | 26 ++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index de532b49672..8fe52e3f5f8 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -29,15 +29,23 @@ def wrap_dataset_for_transforms_v2(dataset): Special * :class:`~torchvision.datasets.CocoDetection`: Instead returning the target as list of dicts, now returns it as - dict of lists. In addition, the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` are added which - wrap the data in the corresponding ``torchvision.datapoints``. - * :class:`~torchvision.datasets.VOCDetection` - * :class:`~torchvision.datasets.SBDataset` - * :class:`~torchvision.datasets.CelebA` - * :class:`~torchvision.datasets.Kitti` - * :class:`~torchvision.datasets.OxfordIIITPet` - * :class:`~torchvision.datasets.Cityscapes` - * :class:`~torchvision.datasets.WIDERFace` + dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), ``"masks"``, + and ``"labels"`` are added, which wrap the data in the corresponding ``torchvision.datapoints``. + * :class:`~torchvision.datasets.VOCDetection`: the key-value-pairs ``"boxes"`` and ``"labels"`` are added to + the target, which wrap the data in the corresponding ``torchvision.datapoints``. + * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to ``XYXY`` + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. + * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, now returns it as + dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added, which wrap the data + in the corresponding ``torchvision.datapoints``. + * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a + :class:`~torchvision.datapoints.Mask` datapoint. + * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a + :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by + a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and + ``"labels"``. + * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. Image classification datasets From de13ea4447b7fa9273ddb9ca6d136cbe9c0a86f9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 27 Feb 2023 09:09:12 +0100 Subject: [PATCH 4/7] cleanup --- torchvision/datapoints/_dataset_wrapper.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index d43c75a5002..f36f4db7e52 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -20,8 +20,8 @@ def wrap_dataset_for_transforms_v2(dataset): .. v2betastatus:: wrap_dataset_for_transforms_v2 function Example: - >>> coco = torchvision.datasets.CocoDetection() - >>> coco = wrap_dataset_for_transforms_v2(coco) + >>> dataset = torchvision.datasets.CocoDetection(...) + >>> dataset = wrap_dataset_for_transforms_v2(dataset) .. note:: @@ -36,14 +36,15 @@ def wrap_dataset_for_transforms_v2(dataset): * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. - * :class:`~torchvision.datasets.VOCDetection`: the key-value-pairs ``"boxes"`` and ``"labels"`` are added to - the target, which wrap the data in the corresponding ``torchvision.datapoints``. The original keys are + The original keys are preserved. + * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to + the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are preserved. - * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to ``XYXY`` + * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. - * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, now returns it as - dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added, which wrap the data - in the corresponding ``torchvision.datapoints``. + * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict + of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data + in the corresponding ``torchvision.datapoints``. The original keys are preserved. * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a :class:`~torchvision.datapoints.Mask` datapoint. * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a From 9138e001b9d98fa3b59b8577e9fbeb43c9eccf34 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 27 Feb 2023 11:33:21 +0000 Subject: [PATCH 5/7] Fix indent for proper formatting --- torchvision/datapoints/_dataset_wrapper.py | 26 +++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index f36f4db7e52..e1f94847366 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -34,25 +34,25 @@ def wrap_dataset_for_transforms_v2(dataset): Special cases: * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper - returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), - ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. - The original keys are preserved. + returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), + ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. + The original keys are preserved. * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to - the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are - preserved. + the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are + preserved. * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` - coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict - of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data - in the corresponding ``torchvision.datapoints``. The original keys are preserved. + of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data + in the corresponding ``torchvision.datapoints``. The original keys are preserved. * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a - :class:`~torchvision.datapoints.Mask` datapoint. + :class:`~torchvision.datapoints.Mask` datapoint. * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a - :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by - a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and - ``"labels"``. + :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by + a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and + ``"labels"``. * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` - coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. Image classification datasets From ca4b1df2ce95ee19b72ff0d1c2a5471975951a04 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 27 Feb 2023 13:25:15 +0100 Subject: [PATCH 6/7] revert cleanup changes --- torchvision/datapoints/_dataset_wrapper.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index f36f4db7e52..d1c1ac7794d 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -14,6 +14,7 @@ __all__ = ["wrap_dataset_for_transforms_v2"] +# TODO: naming! def wrap_dataset_for_transforms_v2(dataset): """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. @@ -166,6 +167,10 @@ def raise_not_supported(description): ) +def identity(item): + return item + + def identity_wrapper_factory(dataset): def wrapper(idx, sample): return sample @@ -173,6 +178,10 @@ def wrapper(idx, sample): return wrapper +def pil_image_to_mask(pil_image): + return datapoints.Mask(pil_image) + + def list_of_dicts_to_dict_of_lists(list_of_dicts): dict_of_lists = defaultdict(list) for dct in list_of_dicts: @@ -186,7 +195,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): target = [target] wrapped_target = tuple( - type_wrappers.get(target_type, lambda x: x)(item) for target_type, item in zip(target_types, target) + type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target) ) if len(wrapped_target) == 1: @@ -216,7 +225,7 @@ def classification_wrapper_factory(dataset): def segmentation_wrapper_factory(dataset): def wrapper(idx, sample): image, mask = sample - return image, datapoints.Mask(mask) + return image, pil_image_to_mask(mask) return wrapper @@ -362,7 +371,7 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.SBDataset) -def sbdataset_wrapper(dataset): +def sbd_wrapper(dataset): if dataset.mode == "boundaries": raise_not_supported("SBDataset with mode='boundaries'") @@ -429,7 +438,7 @@ def wrapper(idx, sample): target, target_types=dataset._target_types, type_wrappers={ - "segmentation": datapoints.Mask, + "segmentation": pil_image_to_mask, }, ) @@ -445,7 +454,7 @@ def cityscapes_wrapper_factory(dataset): def instance_segmentation_wrapper(mask): # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21 - data = datapoints.Mask(mask) + data = pil_image_to_mask(mask) masks = [] labels = [] for id in data.unique(): @@ -464,7 +473,7 @@ def wrapper(idx, sample): target_types=dataset.target_type, type_wrappers={ "instance": instance_segmentation_wrapper, - "semantic": datapoints.Mask, + "semantic": pil_image_to_mask, }, ) From 1405991f5b8380ec7b4b41ab47c3d727dc795e30 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 27 Feb 2023 12:28:42 +0000 Subject: [PATCH 7/7] Remove comment --- torchvision/datapoints/_dataset_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index 026ffdb03a6..87ce3ba93a1 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -14,7 +14,6 @@ __all__ = ["wrap_dataset_for_transforms_v2"] -# TODO: naming! def wrap_dataset_for_transforms_v2(dataset): """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.