-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Type annotations for
util.data
. (#18963)
Summary: I haven't had a chance to rigorously try these out yet so don't merge yet. Closes #18725. Pull Request resolved: #18963 Differential Revision: D14832897 Pulled By: ezyang fbshipit-source-id: 4780e7a34126bc66ddbfd9d808dfc9e0edd77e68
- Loading branch information
1 parent
a2ac260
commit 0565141
Showing
6 changed files
with
107 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .sampler import Sampler as Sampler, SequentialSampler as SequentialSampler, RandomSampler as RandomSampler, \ | ||
SubsetRandomSampler as SubsetRandomSampler, WeightedRandomSampler as WeightedRandomSampler, BatchSampler as BatchSampler | ||
from .distributed import DistributedSampler as DistributedSampler | ||
from .dataset import Dataset as Dataset, TensorDataset as TensorDataset, ConcatDataset as ConcatDataset, \ | ||
Subset as Subset, random_split as random_split | ||
from .dataloader import DataLoader as DataLoader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List | ||
from . import Dataset, Sampler | ||
|
||
T_co = TypeVar('T_co', covariant=True) | ||
T = TypeVar('T') | ||
_worker_init_fn_t = Callable[[int], None] | ||
|
||
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that | ||
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. | ||
# See https://github.com/python/mypy/issues/3737. | ||
_collate_fn_t = Callable[[List[T]], Any] | ||
|
||
class DataLoader(Generic[T_co]): | ||
dataset: Dataset[T_co] | ||
batch_size: int | ||
num_workers: int | ||
pin_memory: bool | ||
drop_last: bool | ||
timeout: float | ||
|
||
@overload | ||
def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=..., sampler: Sampler[int]=..., | ||
num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=..., | ||
drop_last: bool=..., timeout: float=..., worker_init_fn: _worker_init_fn_t=...) -> None: ... | ||
@overload | ||
def __init__(self, dataset: Dataset[T_co], batch_sampler: Sampler[Sequence[int]]=..., num_workers: int=..., | ||
collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=..., | ||
worker_init_fn: _worker_init_fn_t=...) -> None: ... | ||
|
||
def __len__(self) -> int: ... | ||
# We quote '_DataLoaderIter' since it isn't defined yet and the definition can't be moved up since | ||
# '_DataLoaderIter' references 'DataLoader'. Pending updates of PEP 484 will fix this. | ||
def __iter__(self) -> '_DataLoaderIter':... | ||
|
||
class _DataLoaderIter: | ||
def __init__(self, loader: DataLoader) -> None:... | ||
def __len__(self) -> int: ... | ||
def __iter__(self) -> _DataLoaderIter: ... | ||
def __next__(self) -> Any: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import TypeVar, Generic, Iterable, Sequence, List, Tuple | ||
from ... import Tensor | ||
|
||
T_co = TypeVar('T_co', covariant=True) | ||
T = TypeVar('T') | ||
class Dataset(Generic[T_co]): | ||
def __getitem__(self, index: int) -> T_co: ... | ||
def __len__(self) -> int: ... | ||
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ... | ||
|
||
class TensorDataset(Dataset[Tuple[Tensor, ...]]): | ||
tensors: List[Tensor] | ||
|
||
def __init__(self, *tensors: Tensor) -> None: ... | ||
|
||
class ConcatDataset(Dataset[T_co]): | ||
datasets: List[Dataset[T_co]] | ||
cumulative_sizes: List[int] | ||
|
||
def __init__(self, datasets: Iterable[Dataset]) -> None: ... | ||
|
||
class Subset(Dataset[T_co]): | ||
dataset: Dataset[T_co] | ||
indices: Sequence[int] | ||
|
||
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: ... | ||
|
||
def random_split(dataset: Dataset[T], lengths: Sequence[int]) -> List[Subset[T]]: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from typing import TypeVar, Optional, Iterable | ||
from . import Sampler, Dataset | ||
|
||
T_co = TypeVar('T_co', covariant=True) | ||
class DistributedSampler(Sampler[T_co]): | ||
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ... | ||
def __iter__(self) -> Iterable[int]: ... | ||
def __len__(self) -> int: ... | ||
def set_epoch(self, epoch: int) -> None: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized | ||
|
||
T_co = TypeVar('T_co', covariant=True) | ||
class Sampler(Generic[T_co]): | ||
def __init__(self, data_source: Sized) -> None: ... | ||
def __iter__(self) -> Iterator[T_co]: ... | ||
def __len__(self) -> int: ... | ||
|
||
class SequentialSampler(Sampler[int]): | ||
pass | ||
|
||
class RandomSampler(Sampler[int]): | ||
num_samples: int | ||
|
||
def __init__(self, data_source: Sized, replacement: bool=..., num_samples: Optional[int]=...) -> None: ... | ||
|
||
class SubsetRandomSampler(Sampler[int]): | ||
def __init__(self, indices: Sequence[int]) -> None: ... | ||
|
||
class WeightedRandomSampler(Sampler[int]): | ||
def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool=...) -> None: ... | ||
|
||
class BatchSampler(Sampler[List[int]]): | ||
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: ... |