Skip to content

Commit

Permalink
[DataLoader] Add context to NotImplementedErrors in dataset.py (#100667)
Browse files Browse the repository at this point in the history
Add helpful context message to `NotImplementedError`'s thrown by Dataset and IterableDataset, reminding users that they must implement `__getitem__`/`__iter__` in subclasses. Currently, users are presented with a bare `NotImplementedError` without describing the remedy.
Pull Request resolved: #100667
Approved by: https://github.com/NivekT
  • Loading branch information
cshimmin authored and pytorchmergebot committed May 5, 2023
1 parent a3989b2 commit 2f41bc5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Dataset(Generic[T_co]):
"""

def __getitem__(self, index) -> T_co:
raise NotImplementedError
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
Expand Down Expand Up @@ -169,7 +169,7 @@ class IterableDataset(Dataset[T_co]):
[3, 4, 5, 6]
"""
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.")

def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
Expand Down

0 comments on commit 2f41bc5

Please sign in to comment.