Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove vanilla tensors from prototype datasets samples #5018

Merged
merged 1 commit into from Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 19 additions & 30 deletions torchvision/prototype/datasets/_builtin/coco.py
Expand Up @@ -32,23 +32,10 @@
getitem,
path_accessor,
)
from torchvision.prototype.features import BoundingBox, Label
from torchvision.prototype.features._feature import DEFAULT
from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.utils._internal import FrozenMapping


class CocoLabel(Label):
super_category: Optional[str]

@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
super_category: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None), super_category=(super_category, None))


class Coco(Dataset):
def _make_info(self) -> DatasetInfo:
name = "coco"
Expand Down Expand Up @@ -111,27 +98,24 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
categories = [self.info.categories[label] for label in labels]
return dict(
# TODO: create a segmentation feature
segmentations=torch.stack(
[
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
for ann in anns
]
segmentations=Feature(
torch.stack(
[
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
for ann in anns
]
)
),
areas=torch.tensor([ann["area"] for ann in anns]),
crowds=torch.tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
areas=Feature([ann["area"] for ann in anns]),
crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
image_size=image_size,
),
labels=[
CocoLabel(
label,
category=category,
super_category=self.info.extra.category_to_super_category[category],
)
for label, category in zip(labels, categories)
],
labels=Label(labels),
categories=categories,
super_categories=[self.info.extra.category_to_super_category[category] for category in categories],
ann_ids=[ann["id"] for ann in anns],
)

Expand All @@ -141,7 +125,12 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
ann_ids=[ann["id"] for ann in anns],
)

_ANN_DECODERS = OrderedDict([("instances", _decode_instances_anns), ("captions", _decode_captions_ann)])
_ANN_DECODERS = OrderedDict(
[
("instances", _decode_instances_anns),
("captions", _decode_captions_ann),
]
)

_META_FILE_PATTERN = re.compile(
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_feature.py
Expand Up @@ -12,7 +12,7 @@


class Feature(torch.Tensor):
_META_ATTRS: Set[str]
_META_ATTRS: Set[str] = set()
_meta_data: Dict[str, Any]

def __init_subclass__(cls):
Expand Down
8 changes: 7 additions & 1 deletion torchvision/prototype/transforms/_transform.py
Expand Up @@ -360,7 +360,13 @@ def _transform_recursively(self, sample: Any, *, params: Dict[str, Any]) -> Any:
else:
feature_type = type(sample)
if not self.supports(feature_type):
if not issubclass(feature_type, features.Feature) or feature_type in self.NO_OP_FEATURE_TYPES:
if (
not issubclass(feature_type, features.Feature)
# issubclass is not a strict check, but also allows the type checked against. Thus, we need to
# check it separately
or feature_type is features.Feature
or feature_type in self.NO_OP_FEATURE_TYPES
):
return sample

raise TypeError(
Expand Down