From cdbb6b9fde5680322d5c12f00698f5d5bbe1f96a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sat, 18 Jul 2020 20:28:15 +0200 Subject: [PATCH] vision --- torchvision/datasets/vision.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/torchvision/datasets/vision.py b/torchvision/datasets/vision.py index 7ee5a84dfcc..7aa25f79eb8 100644 --- a/torchvision/datasets/vision.py +++ b/torchvision/datasets/vision.py @@ -1,12 +1,19 @@ import os import torch import torch.utils.data as data +from typing import Any, Callable, List, Optional, Tuple class VisionDataset(data.Dataset): _repr_indent = 4 - def __init__(self, root, transforms=None, transform=None, target_transform=None): + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: if isinstance(root, torch._six.string_classes): root = os.path.expanduser(root) self.root = root @@ -25,13 +32,13 @@ def __init__(self, root, transforms=None, transform=None, target_transform=None) transforms = StandardTransform(transform, target_transform) self.transforms = transforms - def __getitem__(self, index): + def __getitem__(self, index: int) -> Any: raise NotImplementedError - def __len__(self): + def __len__(self) -> int: raise NotImplementedError - def __repr__(self): + def __repr__(self) -> str: head = "Dataset " + self.__class__.__name__ body = ["Number of datapoints: {}".format(self.__len__())] if self.root is not None: @@ -42,33 +49,33 @@ def __repr__(self): lines = [head] + [" " * self._repr_indent + line for line in body] return '\n'.join(lines) - def _format_transform_repr(self, transform, head): + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() return (["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) - def extra_repr(self): + def extra_repr(self) -> str: return "" class StandardTransform(object): - def __init__(self, transform=None, target_transform=None): + def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: self.transform = transform self.target_transform = target_transform - def __call__(self, input, target): + def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: if self.transform is not None: input = self.transform(input) if self.target_transform is not None: target = self.target_transform(target) return input, target - def _format_transform_repr(self, transform, head): + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() return (["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) - def __repr__(self): + def __repr__(self) -> str: body = [self.__class__.__name__] if self.transform is not None: body += self._format_transform_repr(self.transform,