diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 2543b6c514d..68f7470e6ab 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Optional, Tuple from functools import partial from multiprocessing import Pool +from torch import Tensor from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity from .folder import find_classes, make_dataset @@ -15,7 +16,7 @@ from .vision import VisionDataset -def _dl_wrap(tarpath, videopath, line): +def _dl_wrap(tarpath: str, videopath: str, line: str) -> None: download_and_extract_archive(line, tarpath, videopath) @@ -90,14 +91,14 @@ def __init__( frames_per_clip: int, num_classes: str = "400", split: str = "train", - frame_rate: Optional[float] = None, + frame_rate: Optional[int] = None, step_between_clips: int = 1, transform: Optional[Callable] = None, extensions: Tuple[str, ...] = ("avi", "mp4"), download: bool = False, num_download_workers: int = 1, num_workers: int = 1, - _precomputed_metadata: Optional[Dict] = None, + _precomputed_metadata: Optional[Dict[str, Any]] = None, _video_width: int = 0, _video_height: int = 0, _video_min_dimension: int = 0, @@ -187,7 +188,7 @@ def _download_videos(self) -> None: poolproc = Pool(self.num_download_workers) poolproc.map(part, lines) - def _make_ds_structure(self): + def _make_ds_structure(self) -> None: """move videos from split_folder/ ├── clip1.avi @@ -228,13 +229,13 @@ def _make_ds_structure(self): ) @property - def metadata(self): + def metadata(self) -> Dict[str, Any]: return self.video_clips.metadata - def __len__(self): + def __len__(self) -> int: return self.video_clips.num_clips() - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: video, audio, info, video_idx = self.video_clips.get_clip(idx) if not self._legacy: # [T,H,W,C] --> [T,C,H,W] @@ -303,7 +304,7 @@ def __init__( download: Any = None, num_download_workers: Any = None, **kwargs: Any - ): + ) -> None: warnings.warn( "Kinetics400 is deprecated and will be removed in a future release." "It was replaced by Kinetics(..., num_classes=\"400\").")