diff --git a/setup.py b/setup.py index 20213eb6205cd..b50792f447d7d 100644 --- a/setup.py +++ b/setup.py @@ -735,6 +735,7 @@ def print_box(msg): 'cuda/*.pyi', 'optim/*.pyi', 'autograd/*.pyi', + 'utils/data/*.pyi', 'lib/*.so*', 'lib/*.dylib*', 'lib/*.dll', diff --git a/torch/utils/data/__init__.pyi b/torch/utils/data/__init__.pyi new file mode 100644 index 0000000000000..f4ca405fa2edf --- /dev/null +++ b/torch/utils/data/__init__.pyi @@ -0,0 +1,6 @@ +from .sampler import Sampler as Sampler, SequentialSampler as SequentialSampler, RandomSampler as RandomSampler, \ + SubsetRandomSampler as SubsetRandomSampler, WeightedRandomSampler as WeightedRandomSampler, BatchSampler as BatchSampler +from .distributed import DistributedSampler as DistributedSampler +from .dataset import Dataset as Dataset, TensorDataset as TensorDataset, ConcatDataset as ConcatDataset, \ + Subset as Subset, random_split as random_split +from .dataloader import DataLoader as DataLoader diff --git a/torch/utils/data/dataloader.pyi b/torch/utils/data/dataloader.pyi new file mode 100644 index 0000000000000..6932933b3d987 --- /dev/null +++ b/torch/utils/data/dataloader.pyi @@ -0,0 +1,39 @@ +from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List +from . import Dataset, Sampler + +T_co = TypeVar('T_co', covariant=True) +T = TypeVar('T') +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[List[T]], Any] + +class DataLoader(Generic[T_co]): + dataset: Dataset[T_co] + batch_size: int + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + + @overload + def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=..., sampler: Sampler[int]=..., + num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=..., + drop_last: bool=..., timeout: float=..., worker_init_fn: _worker_init_fn_t=...) -> None: ... + @overload + def __init__(self, dataset: Dataset[T_co], batch_sampler: Sampler[Sequence[int]]=..., num_workers: int=..., + collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=..., + worker_init_fn: _worker_init_fn_t=...) -> None: ... + + def __len__(self) -> int: ... + # We quote '_DataLoaderIter' since it isn't defined yet and the definition can't be moved up since + # '_DataLoaderIter' references 'DataLoader'. Pending updates of PEP 484 will fix this. + def __iter__(self) -> '_DataLoaderIter':... + +class _DataLoaderIter: + def __init__(self, loader: DataLoader) -> None:... + def __len__(self) -> int: ... + def __iter__(self) -> _DataLoaderIter: ... + def __next__(self) -> Any: ... diff --git a/torch/utils/data/dataset.pyi b/torch/utils/data/dataset.pyi new file mode 100644 index 0000000000000..ac85c474eec2c --- /dev/null +++ b/torch/utils/data/dataset.pyi @@ -0,0 +1,28 @@ +from typing import TypeVar, Generic, Iterable, Sequence, List, Tuple +from ... import Tensor + +T_co = TypeVar('T_co', covariant=True) +T = TypeVar('T') +class Dataset(Generic[T_co]): + def __getitem__(self, index: int) -> T_co: ... + def __len__(self) -> int: ... + def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ... + +class TensorDataset(Dataset[Tuple[Tensor, ...]]): + tensors: List[Tensor] + + def __init__(self, *tensors: Tensor) -> None: ... + +class ConcatDataset(Dataset[T_co]): + datasets: List[Dataset[T_co]] + cumulative_sizes: List[int] + + def __init__(self, datasets: Iterable[Dataset]) -> None: ... + +class Subset(Dataset[T_co]): + dataset: Dataset[T_co] + indices: Sequence[int] + + def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: ... + +def random_split(dataset: Dataset[T], lengths: Sequence[int]) -> List[Subset[T]]: ... diff --git a/torch/utils/data/distributed.pyi b/torch/utils/data/distributed.pyi new file mode 100644 index 0000000000000..a9c787ace8737 --- /dev/null +++ b/torch/utils/data/distributed.pyi @@ -0,0 +1,9 @@ +from typing import TypeVar, Optional, Iterable +from . import Sampler, Dataset + +T_co = TypeVar('T_co', covariant=True) +class DistributedSampler(Sampler[T_co]): + def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ... + def __iter__(self) -> Iterable[int]: ... + def __len__(self) -> int: ... + def set_epoch(self, epoch: int) -> None: ... diff --git a/torch/utils/data/sampler.pyi b/torch/utils/data/sampler.pyi new file mode 100644 index 0000000000000..da69876048d41 --- /dev/null +++ b/torch/utils/data/sampler.pyi @@ -0,0 +1,24 @@ +from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized + +T_co = TypeVar('T_co', covariant=True) +class Sampler(Generic[T_co]): + def __init__(self, data_source: Sized) -> None: ... + def __iter__(self) -> Iterator[T_co]: ... + def __len__(self) -> int: ... + +class SequentialSampler(Sampler[int]): + pass + +class RandomSampler(Sampler[int]): + num_samples: int + + def __init__(self, data_source: Sized, replacement: bool=..., num_samples: Optional[int]=...) -> None: ... + +class SubsetRandomSampler(Sampler[int]): + def __init__(self, indices: Sequence[int]) -> None: ... + +class WeightedRandomSampler(Sampler[int]): + def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool=...) -> None: ... + +class BatchSampler(Sampler[List[int]]): + def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: ...