diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 5587122fb2f..f4f8c44f8ee 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,3 +1,4 @@ +import functools import io import pathlib import re @@ -132,7 +133,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) @@ -185,7 +186,7 @@ def _make_datapipe( dp = Filter(dp, self._is_not_rogue_file) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index fb58080c9ae..b3c50c07943 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,4 +1,5 @@ import csv +import functools import io from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence @@ -26,7 +27,6 @@ hint_shuffling, ) - csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -155,7 +155,7 @@ def _make_datapipe( splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) - splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) + splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) splits_dp = hint_sharding(splits_dp) splits_dp = hint_shuffling(splits_dp) @@ -181,4 +181,4 @@ def _make_datapipe( keep_key=True, ) dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 940e41adc6d..68147ba0f9e 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -89,7 +89,7 @@ def _make_datapipe( dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index eb606d11e16..e400a1db07d 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,3 +1,4 @@ +import functools import io import pathlib import re @@ -183,12 +184,16 @@ def _make_datapipe( if config.annotations is None: dp = hint_sharding(images_dp) dp = hint_shuffling(dp) - return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder)) meta_dp = Filter( meta_dp, - self._filter_meta_files, - fn_kwargs=dict(split=config.split, year=config.year, annotations=config.annotations), + functools.partial( + self._filter_meta_files, + split=config.split, + year=config.year, + annotations=config.annotations, + ), ) meta_dp = JsonParser(meta_dp) meta_dp = Mapper(meta_dp, getitem(1)) @@ -226,7 +231,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) return Mapper( - dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder) + dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder) ) def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: @@ -235,7 +240,8 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: dp = resources[1].load(pathlib.Path(root) / self.name) dp = Filter( - dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances") + dp, + functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"), ) dp = JsonParser(dp) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 7ef2237c3dc..9ea70296427 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,3 +1,4 @@ +import functools import io import pathlib import re @@ -165,7 +166,7 @@ def _make_datapipe( dp = hint_shuffling(dp) dp = Mapper(dp, self._collate_test_data) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 4cfbb05255f..8f49f1ce72a 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -136,7 +136,7 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder)) class MNIST(_MNISTBase): diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index c3d94a411f7..82fdb2adf8b 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,3 +1,4 @@ +import functools import io import pathlib import re @@ -152,7 +153,7 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 8adfb3071bb..9df12c98b9b 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,3 +1,4 @@ +import functools import io from typing import Any, Callable, Dict, List, Optional, Tuple @@ -65,5 +66,5 @@ def _make_datapipe( dp = CSVParser(dp, delimiter=" ") dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) return dp diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index acfe691fe5e..66905fac3bd 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -127,7 +127,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) - split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task])) + split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task])) split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_sharding(split_dp) @@ -142,4 +142,4 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 0bf05a6ef91..efffaa80f99 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -1,3 +1,4 @@ +import functools import io import os import os.path @@ -50,12 +51,12 @@ def from_data_folder( categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" dp = FileLister(str(root), recursive=recursive, masks=masks) - dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) + dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) dp = hint_sharding(dp) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileLoader(dp) return ( - Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), + Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)), categories, )