Skip to content

Commit

Permalink
webdataset prototype - LoadFilesFromDiskIterableDataset (#48955)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #48955

Test Plan: Imported from OSS

Reviewed By: izdeby

Differential Revision: D25541393

Pulled By: glaringlee

fbshipit-source-id: dea6ad64a7ba40abe45612d99f078b14d1da8bbf
  • Loading branch information
lixinyu authored and facebook-github-bot committed Dec 16, 2020
1 parent 6786b2b commit 001ff3a
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
21 changes: 16 additions & 5 deletions test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

from torch.testing._internal.common_utils import (TestCase, run_tests)

from torch.utils.data.datasets import (ListDirFilesIterableDataset)
from torch.utils.data.datasets import (ListDirFilesIterableDataset, LoadFilesFromDiskIterableDataset)

def create_temp_dir_and_files():
temp_dir = tempfile.TemporaryDirectory()
# The temp dir and files within it will be released and deleted in tearDown().
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
temp_dir_path = temp_dir.name
temp_file1 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False)
temp_file2 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False)
temp_file3 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False)
temp_file1 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
temp_file2 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201
temp_file3 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201

return (temp_dir, temp_file1.name, temp_file2.name, temp_file3.name)

Expand All @@ -34,5 +36,14 @@ def test_listdirfiles_iterable_dataset(self):
for pathname in dataset:
self.assertTrue(pathname in self.temp_files)

def test_loadfilesfromdisk_iterable_dataset(self):
temp_dir = self.temp_dir.name
dataset1 = ListDirFilesIterableDataset(temp_dir, '')
dataset2 = LoadFilesFromDiskIterableDataset(dataset1)

for rec in dataset2:
self.assertTrue(rec[0] in self.temp_files)
self.assertTrue(rec[1].read() == open(rec[0], 'rb').read())

if __name__ == '__main__':
run_tests()
3 changes: 2 additions & 1 deletion torch/utils/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .listdirfilesdataset import ListDirFilesIterableDataset
from .loadfilesfromdiskdataset import LoadFilesFromDiskIterableDataset

__all__ = ['ListDirFilesIterableDataset']
__all__ = ['ListDirFilesIterableDataset', 'LoadFilesFromDiskIterableDataset']
16 changes: 16 additions & 0 deletions torch/utils/data/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from typing import List, Union, Iterable


def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
# empty mask matches any input name
if not masks:
Expand All @@ -16,6 +17,7 @@ def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
return True
return False


def get_file_pathnames_from_root(
root: str,
masks: Union[str, List[str]],
Expand All @@ -35,3 +37,17 @@ def onerror(err : OSError):
yield os.path.join(path, f)
if not recursive:
break


def get_file_binaries_from_pathnames(pathnames : Iterable):

if not isinstance(pathnames, Iterable):
warnings.warn("get_file_binaries_from_pathnames needs the input be an Iterable")
raise TypeError

for pathname in pathnames:
if not isinstance(pathname, str):
warnings.warn("file pathname must be string type, but got {}".format(type(pathname)))
raise TypeError

yield (pathname, open(pathname, 'rb'))
30 changes: 30 additions & 0 deletions torch/utils/data/datasets/loadfilesfromdiskdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from torch.utils.data.dataset import IterableDataset
from torch.utils.data.datasets.common import get_file_binaries_from_pathnames

from typing import Iterable, Iterator

class LoadFilesFromDiskIterableDataset(IterableDataset):
r""" :class:`LoadFilesFromDiskIterableDataset`.
IterableDataset to load file binary streams from given pathnames,
yield pathname and binary stream in a tuple.
args:
dataset: Iterable dataset that provides pathnames
length: a nominal length of the dataset
"""

def __init__(
self,
dataset : Iterable,
length : int = -1):
super().__init__()
self.dataset : Iterable = dataset
self.length : int = length

def __iter__(self) -> Iterator[tuple] :
yield from get_file_binaries_from_pathnames(self.dataset)

def __len__(self):
if self.length == -1:
raise NotImplementedError
return self.length

0 comments on commit 001ff3a

Please sign in to comment.