Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader] __getitems__ added to description of Dataset API and better supported within Subset #100375

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 16 additions & 1 deletion torch/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ class Dataset(Generic[T_co]):
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.

.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
Expand All @@ -52,6 +55,10 @@ class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
raise NotImplementedError

# def __getitems__(self, indices: List) -> List[T_co]:
# Not implemented to prevent false-positives in fetcher check in
# torch.utils.data._utils.fetch._MapDatasetFetcher

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])

Expand Down Expand Up @@ -296,6 +303,14 @@ def __getitem__(self, idx):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]

def __getitems__(self, indices: List[int]) -> List[T_co]:
# add batched sampling support when parent dataset supports it.
# see torch.utils.data._utils.fetch._MapDatasetFetcher
if callable(getattr(self.dataset, "__getitems__", None)):
return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
else:
return [self.dataset[self.indices[idx]] for idx in indices]

def __len__(self):
return len(self.indices)

Expand Down