Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from typing import Any, Callable, Dict, Optional, Tuple
from functools import partial
from multiprocessing import Pool
from torch import Tensor

from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset


def _dl_wrap(tarpath, videopath, line):
def _dl_wrap(tarpath: str, videopath: str, line: str) -> None:
download_and_extract_archive(line, tarpath, videopath)


Expand Down Expand Up @@ -90,14 +91,14 @@ def __init__(
frames_per_clip: int,
num_classes: str = "400",
split: str = "train",
frame_rate: Optional[float] = None,
frame_rate: Optional[int] = None,
step_between_clips: int = 1,
transform: Optional[Callable] = None,
extensions: Tuple[str, ...] = ("avi", "mp4"),
download: bool = False,
num_download_workers: int = 1,
num_workers: int = 1,
_precomputed_metadata: Optional[Dict] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
Expand Down Expand Up @@ -187,7 +188,7 @@ def _download_videos(self) -> None:
poolproc = Pool(self.num_download_workers)
poolproc.map(part, lines)

def _make_ds_structure(self):
def _make_ds_structure(self) -> None:
"""move videos from
split_folder/
├── clip1.avi
Expand Down Expand Up @@ -228,13 +229,13 @@ def _make_ds_structure(self):
)

@property
def metadata(self):
def metadata(self) -> Dict[str, Any]:
return self.video_clips.metadata

def __len__(self):
def __len__(self) -> int:
return self.video_clips.num_clips()

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, info, video_idx = self.video_clips.get_clip(idx)
if not self._legacy:
# [T,H,W,C] --> [T,C,H,W]
Expand Down Expand Up @@ -303,7 +304,7 @@ def __init__(
download: Any = None,
num_download_workers: Any = None,
**kwargs: Any
):
) -> None:
warnings.warn(
"Kinetics400 is deprecated and will be removed in a future release."
"It was replaced by Kinetics(..., num_classes=\"400\").")
Expand Down