New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DataLoader] __getitems__
added to description of Dataset API and better supported within Subset
#100375
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100375
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 954fbd6: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/utils/data/dataset.py
Outdated
@@ -291,6 +298,11 @@ def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: | |||
self.dataset = dataset | |||
self.indices = indices | |||
|
|||
# add batched sampling support when parent dataset supports it. | |||
# see torch.utils.data._utils.fetch._MapDatasetFetcher | |||
if getattr(dataset, "__getitems__", None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit if you don't want to rely on the falsy evaluation of None, you could also make this:
if getattr(dataset, "__getitems__", None): | |
if getattr(dataset, "__getitems__", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Python does have a proper way to do it:
if getattr(dataset, "__getitems__", None): | |
if hasattr(dataset, "__getitems__"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ejguan This is a bit different semantically, the original would still return false if a.__getitems__
= False etc. while hasattr would return true in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code that was originally in this PR was hasattr(dataset, "__getitems__") and dataset.__getitems__
which caused mypy to complain unlike this version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, what is the case that __getitems__
becomes an attribute rather than a method? Then, why not hasattr(dataset, "__getitems__") and callable(dataset.__getitems__)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ejguan That's a double lookup on __getitems__
better to do callable(getattr(dataset, "__getitems__", None))
which is more efficient. I am not sure what case __getitems__
could become an attribute but better to code defensively in this case.
if getattr(dataset, "__getitems__", None): | |
if callable(getattr(dataset, "__getitems__", None)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I kept the same behavior as torch.utils.data._utils.fetch._MapDatasetFetcher
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both call sites should probably changed to be honest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably better to do the same for fetcher, but it's out of scope of this PR.
torch/utils/data/dataset.py
Outdated
@@ -299,6 +311,9 @@ def __getitem__(self, idx): | |||
def __len__(self): | |||
return len(self.indices) | |||
|
|||
def _getitems(self, idx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just add __getitems__
and raise Error if self.dataset
doesn't have __getitems__
implemented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because torch.utils.data._utils.fetch._MapDatasetFetcher will do false positive check
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, we can do the similar thing for Subset.__getitems__
by calling __getitems__
if it's available. Otherwise, calling self.dataset[idx]
by iterating idx
from possibly_batched_index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good to me now. I am importing to internal to see if this PR breaks any internal system.
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
__getitems__
added to description of Dataset API and better supported within Subset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@pytorchbot merge -r |
@pytorchbot successfully started a rebase job. Check the current status here |
Subset dataset now supports __getitems__.
Successfully rebased |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: Meta Internal-Only Changes Check Details for Dev Infra teamRaised by workflow job |
@ejguan, could you check tests pipeline? |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: Meta Internal-Only Changes Check Details for Dev Infra teamRaised by workflow job |
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
It seems something wrong with the CI. I will land it from internal. Thanks |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…etter supported within `Subset` (pytorch#100375) DataLoader supports batched loading from Mapped Datasets. This is the fetcher's implementation of auto-detection of batch loading support. torch.utils.data._utils.fetch._MapDatasetFetcher ``` class _MapDatasetFetcher(_BaseDatasetFetcher): def fetch(self, possibly_batched_index): if self.auto_collation: if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: data = self.dataset.__getitems__(possibly_batched_index) else: data = [self.dataset[idx] for idx in possibly_batched_index] ``` Description of Dataset API now shows this feature. Additionally, Subset dataset now supports `__getitems__` if parent dataset supports it. Pull Request resolved: pytorch#100375 Approved by: https://github.com/ejguan, https://github.com/NivekT
DataLoader supports batched loading from Mapped Datasets.
This is the fetcher's implementation of auto-detection of batch loading support.
torch.utils.data._utils.fetch._MapDatasetFetcher
Description of Dataset API now shows this feature.
Additionally, Subset dataset now supports
__getitems__
if parent dataset supports it.