Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_prototype_datasets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import FrozenMapping, FrozenBunch
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch


def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions test/test_prototype_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)


def make_image(**kwargs):
data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist()))
return features.Image(data, **kwargs)


def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
Expand Down Expand Up @@ -42,6 +47,7 @@ def make_bounding_box(*, format="xyxy", image_size=(10, 10)):


MAKE_DATA_MAP = {
features.Image: make_image,
features.BoundingBox: make_bounding_box,
}

Expand Down
51 changes: 51 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from torchvision.prototype import transforms, features
from torchvision.prototype.utils._internal import sequence_to_str


FEATURE_TYPES = {
feature_type
for name, feature_type in features.__dict__.items()
if not name.startswith("_")
and isinstance(feature_type, type)
and issubclass(feature_type, features.Feature)
and feature_type is not features.Feature
}

TRANSFORM_TYPES = tuple(
transform_type
for name, transform_type in transforms.__dict__.items()
if not name.startswith("_")
and isinstance(transform_type, type)
and issubclass(transform_type, transforms.Transform)
and transform_type is not transforms.Transform
)


def test_feature_type_support():
missing_feature_types = FEATURE_TYPES - set(transforms.Transform._BUILTIN_FEATURE_TYPES)
if missing_feature_types:
names = sorted([feature_type.__name__ for feature_type in missing_feature_types])
raise AssertionError(
f"The feature(s) {sequence_to_str(names, separate_last='and ')} is/are exposed at "
f"`torchvision.prototype.features`, but are missing in Transform._BUILTIN_FEATURE_TYPES. "
f"Please add it/them to the collection."
)


@pytest.mark.parametrize(
"transform_type",
[transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity],
ids=lambda transform_type: transform_type.__name__,
)
def test_no_op(transform_type):
unsupported_features = (
FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES)
)
if unsupported_features:
names = sorted([feature_type.__name__ for feature_type in unsupported_features])
raise AssertionError(
f"The feature(s) {sequence_to_str(names, separate_last='and ')} are neither supported nor declared as "
f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, "
f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection."
)
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
Enumerator,
getitem,
read_mat,
FrozenMapping,
)
from torchvision.prototype.utils._internal import FrozenMapping


class ImageNet(Dataset):
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str

from .._home import use_sharded_dataset
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource


Expand Down
85 changes: 0 additions & 85 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import csv
import enum
import gzip
import io
Expand All @@ -7,7 +6,6 @@
import os.path
import pathlib
import pickle
import textwrap
from typing import (
Sequence,
Callable,
Expand All @@ -18,10 +16,7 @@
Iterator,
Dict,
Optional,
NoReturn,
IO,
Iterable,
Mapping,
Sized,
)
from typing import cast
Expand All @@ -38,10 +33,6 @@
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"make_repr",
"FrozenMapping",
"FrozenBunch",
"create_categories_file",
"read_mat",
"image_buffer_from_array",
"SequenceIterator",
Expand All @@ -62,82 +53,6 @@
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"


def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])

prefix = f"{name}("
postfix = ")"
body = to_str(", ")

line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix

body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"


class FrozenMapping(Mapping[K, D]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))

def __getitem__(self, item: K) -> D:
return cast(Mapping[K, D], self.__dict__["__data__"])[item]

def __iter__(self) -> Iterator[K]:
return iter(self.__dict__["__data__"].keys())

def __len__(self) -> int:
return len(self.__dict__["__data__"])

def __setitem__(self, key: K, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __delitem__(self, key: K) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __hash__(self) -> int:
return cast(int, self.__dict__["__final_hash__"])

def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
return NotImplemented

return hash(self) == hash(other)

def __repr__(self) -> str:
return repr(self.__dict__["__data__"])


class FrozenBunch(FrozenMapping):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error

def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")

def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())


def create_categories_file(
root: Union[str, pathlib.Path], name: str, categories: Sequence[Union[str, Sequence[str]]], **fmtparams: Any
) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w", newline="") as file:
csv.writer(file, **fmtparams).writerows(categories)


def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try:
import scipy.io as sio
Expand Down
41 changes: 41 additions & 0 deletions torchvision/prototype/datasets/utils/_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import collections.abc
from typing import Any, Callable, Iterator, Optional, Tuple, TypeVar, cast

from torchvision.prototype.features import BoundingBox, Image

T = TypeVar("T")


class SampleQuery:
def __init__(self, sample: Any) -> None:
self.sample = sample

@staticmethod
def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]:
if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)):
for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample:
yield from SampleQuery._query_recursively(item, fn)
else:
result = fn(sample)
if result is not None:
yield result

def query(self, fn: Callable[[Any], Optional[T]]) -> T:
results = set(self._query_recursively(self.sample, fn))
if not results:
raise RuntimeError("Query turned up empty.")
elif len(results) > 1:
raise RuntimeError(f"Found more than one result: {results}")

return results.pop()

def image_size(self) -> Tuple[int, int]:
def fn(sample: Any) -> Optional[Tuple[int, int]]:
if isinstance(sample, Image):
return cast(Tuple[int, int], sample.shape[-2:])
elif isinstance(sample, BoundingBox):
return sample.image_size
else:
return None

return self.query(fn)
9 changes: 9 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ class Image(Feature):
color_spaces = ColorSpace
color_space: ColorSpace

@classmethod
def _to_tensor(cls, data, *, dtype, device):
tensor = torch.as_tensor(data, dtype=dtype, device=device)
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 3:
raise ValueError("Only single images with 2 or 3 dimensions are allowed.")
return tensor

@classmethod
def _parse_meta_data(
cls,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import Any, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.alexnet import AlexNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Any, Optional, Tuple

import torch.nn as nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode

from ...models.densenet import DenseNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Any, Optional, Union

from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.faster_rcnn import (
Expand All @@ -12,7 +13,6 @@
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval

from ....models.detection.keypoint_rcnn import (
_resnet_fpn_extractor,
_validate_trainable_layers,
KeypointRCNN,
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from ..resnet import ResNet50Weights, resnet50
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.mask_rcnn import (
Expand All @@ -10,7 +11,6 @@
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights, resnet50
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.retinanet import (
Expand All @@ -11,7 +12,6 @@
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights, resnet50
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.ssd import (
Expand All @@ -9,7 +10,6 @@
DefaultBoxGenerator,
SSD,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..vgg import VGG16Weights, vgg16
Expand Down
Loading