diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index d59a23efb4d..a211f2f4f51 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -26,7 +26,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, - ): + ) -> None: super().__init__(root, transforms, transform, target_transform) from pycocotools.coco import COCO @@ -37,7 +37,7 @@ def _load_image(self, id: int) -> Image.Image: path = self.coco.loadImgs(id)[0]["file_name"] return Image.open(os.path.join(self.root, path)).convert("RGB") - def _load_target(self, id) -> List[Any]: + def _load_target(self, id: int) -> List[Any]: return self.coco.loadAnns(self.coco.getAnnIds(id)) def __getitem__(self, index: int) -> Tuple[Any, Any]: @@ -95,5 +95,5 @@ class CocoCaptions(CocoDetection): """ - def _load_target(self, id) -> List[str]: + def _load_target(self, id: int) -> List[str]: return [ann["caption"] for ann in super()._load_target(id)]