Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset checks I #413

Merged
merged 13 commits into from
Feb 3, 2020
9 changes: 9 additions & 0 deletions eta/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ def merge_schema(self, schema):
'''Merges the given CategoricalAttributeSchema into this schema.'''
self.categories.update(schema.categories)

def iter_name_values(self):
'''Iterate over all pairs of (attr.name, attr.value)

Returns:
generator that yields (attr.name, attr.value) tuples
'''
for value in self.categories:
yield self.name, value

@staticmethod
def get_kwargs(d):
'''Extracts the relevant keyword arguments for this schema from the
Expand Down
1 change: 1 addition & 0 deletions eta/core/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
'''

from .builders import *
from .standardize import *
from .labeled_datasets import *
from .split_methods import *
from .transformers import *
Expand Down
70 changes: 35 additions & 35 deletions eta/core/datasets/labeled_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def __iter__(self):

Returns:
iterator: iterator over (data, labels) pairs, where data is an
object returned by self._read_data() and labels is an object
returned by self._read_labels() from the respective paths
object returned by self.read_data() and labels is an object
returned by self.read_labels() from the respective paths
of a data file and corresponding labels file
'''
return zip(self.iter_data(), self.iter_labels())
Expand All @@ -227,11 +227,11 @@ def iter_data(self):
'''Iterates over the data in the dataset.

Returns:
iterator: iterator over objects returned by self._read_data()
iterator: iterator over objects returned by self.read_data()
from the paths to data files
'''
for data_path in self.iter_data_paths():
yield self._read_data(data_path)
yield self.read_data(data_path)

def iter_data_paths(self):
'''Iterates over the paths to data files in the dataset.
Expand All @@ -246,11 +246,11 @@ def iter_labels(self):
'''Iterates over the labels in the dataset.

Returns:
iterator: iterator over objects returned by self._read_labels()
iterator: iterator over objects returned by self.read_labels()
from the paths to labels files
'''
for labels_path in self.iter_labels_paths():
yield self._read_labels(labels_path)
yield self.read_labels(labels_path)

def iter_labels_paths(self):
'''Iterates over the paths to labels files in the dataset.
Expand Down Expand Up @@ -441,9 +441,9 @@ def add_file(self, data_path, labels_path, new_data_filename=None,

# Update the filename attribute in the labels JSON if necessary
if new_data_filename != os.path.basename(data_path):
labels_ = self._read_labels(new_labels_path)
labels_ = self.read_labels(new_labels_path)
labels_.filename = new_data_filename
self._write_labels(labels_, new_labels_path)
self.write_labels(labels_, new_labels_path)

# First remove any other records with the same data filename
self.dataset_index.cull_with_function(
Expand All @@ -466,9 +466,9 @@ def add_data(self, data, labels, data_filename, labels_filename,

Args:
data: input data in a format that can be passed to
self._write_data()
self.write_data()
labels: input labels in a format that can be passed to
self._write_labels()
self.write_labels()
data_filename: filename for the data in the dataset
labels_filename: filename for the labels in the dataset
error_on_duplicates: whether to raise an error if a data file
Expand All @@ -488,8 +488,8 @@ def add_data(self, data, labels, data_filename, labels_filename,
labels_path = os.path.join(
self.dataset_dir, self._LABELS_SUBDIR, labels_filename)

self._write_data(data, data_path)
self._write_labels(labels, labels_path)
self.write_data(data, data_path)
self.write_labels(labels, labels_path)

# First remove any other records with the same data filename
self.dataset_index.cull_with_function(
Expand Down Expand Up @@ -700,14 +700,14 @@ def apply_to_data(self, func):

Args:
func: function that takes in a data element in the format
returned by `self._read_data()` and outputs transformed
returned by `self.read_data()` and outputs transformed
data in the same format

Returns:
self
'''
for data, path in zip(self.iter_data(), self.iter_data_paths()):
self._write_data(func(data), path)
self.write_data(func(data), path)

return self

Expand All @@ -734,15 +734,15 @@ def apply_to_labels(self, func):

Args:
func: function that takes in a labels object in the format
returned by `self._read_labels()` and outputs transformed
returned by `self.read_labels()` and outputs transformed
labels in the same format

Returns:
self
'''
for labels, path in zip(
self.iter_labels(), self.iter_labels_paths()):
self._write_labels(func(labels), path)
self.write_labels(func(labels), path)

return self

Expand Down Expand Up @@ -905,7 +905,7 @@ def _build_index_map(self):
data_file)
self._data_to_labels_map[data_file] = labels_file

def _read_data(self, path):
def read_data(self, path):
'''Reads data from a data file at the given path.

Subclasses must implement this based on the particular data format for
Expand All @@ -917,9 +917,9 @@ def _read_data(self, path):
Returns:
a data object in the particular format for the subclass
'''
raise NotImplementedError("subclasses must implement _read_data()")
raise NotImplementedError("subclasses must implement read_data()")

def _read_labels(self, path):
def read_labels(self, path):
'''Reads a labels object from a labels JSON file at the given path.

Subclasses must implement this based on the particular labels format
Expand All @@ -931,33 +931,33 @@ def _read_labels(self, path):
Returns:
a labels object in the particular format for the subclass
'''
raise NotImplementedError("subclasses must implement _read_labels()")
raise NotImplementedError("subclasses must implement read_labels()")

def _write_data(self, data, path):
def write_data(self, data, path):
'''Writes data to a data file at the given path.

Subclasses must implement this based on the particular data format for
the subclass. The method should accept input `data` of the same type
as output by `self._read_data()`.
as output by `self.read_data()`.

Args:
data: a data element to be written to a file
path: path to write the data
'''
raise NotImplementedError("subclasses must implement _write_data()")
raise NotImplementedError("subclasses must implement write_data()")

def _write_labels(self, labels, path):
def write_labels(self, labels, path):
'''Writes a labels object to a labels JSON file at the given path.

Subclasses must implement this based on the particular labels format
for the subclass. The method should accept input `labels` of the same
type as output by `self._read_labels()`.
type as output by `self.read_labels()`.

Args:
labels: a labels object to be written to a file
path: path to write the labels JSON file
'''
raise NotImplementedError("subclasses must implement _write_labels()")
raise NotImplementedError("subclasses must implement write_labels()")

def _build_metadata(self, path):
'''Reads metadata from a data file at the given path and builds an
Expand Down Expand Up @@ -1151,19 +1151,19 @@ def compute_average_video_duration(self):

return np.mean(video_durations)

def _read_data(self, path):
def read_data(self, path):
return etav.FFmpegVideoReader(path)

def _read_labels(self, path):
def read_labels(self, path):
return etav.VideoLabels.from_json(path)

def _write_data(self, data, path):
def write_data(self, data, path):
with etav.FFmpegVideoWriter(
path, data.frame_rate, data.frame_size) as writer:
for img in data:
writer.write(img)

def _write_labels(self, labels, path):
def write_labels(self, labels, path):
labels.write_json(path)

def _build_metadata(self, path):
Expand Down Expand Up @@ -1235,7 +1235,7 @@ def write_annotated_data(self, output_dir_path, annotation_config=None):
img, image_labels, annotation_config=annotation_config)
output_path = os.path.join(
output_dir_path, os.path.basename(image_path))
self._write_data(img_annotated, output_path)
self.write_data(img_annotated, output_path)

@classmethod
def validate_dataset(cls, dataset_path):
Expand Down Expand Up @@ -1268,16 +1268,16 @@ def validate_dataset(cls, dataset_path):
if not os.path.isfile(labels_path):
raise LabeledDatasetError("File not found: %s" % labels_path)

def _read_data(self, path):
def read_data(self, path):
return etai.read(path)

def _read_labels(self, path):
def read_labels(self, path):
return etai.ImageLabels.from_json(path)

def _write_data(self, data, path):
def write_data(self, data, path):
etai.write(data, path)

def _write_labels(self, labels, path):
def write_labels(self, labels, path):
labels.write_json(path)

def _build_metadata(self, path):
Expand Down
135 changes: 135 additions & 0 deletions eta/core/datasets/standardize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
'''
Tools for standardizing datasets by automated means

Copyright 2017-2020 Voxel51, Inc.
voxel51.com

Tyler Ganter, tyler@voxel51.com
'''
# pragma pylint: disable=redefined-builtin
# pragma pylint: disable=unused-wildcard-import
# pragma pylint: disable=wildcard-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from builtins import *
# pragma pylint: enable=redefined-builtin
# pragma pylint: enable=unused-wildcard-import
# pragma pylint: enable=wildcard-import

import logging
import os

import eta.core.image as etai
import eta.core.utils as etau
import eta.core.video as etav

from .labeled_datasets import LabeledDatasetError, LabeledImageDataset, \
LabeledVideoDataset


logger = logging.getLogger(__name__)


def ensure_labels_filename_property(dataset, audit_only=True):
'''Audit labels.filename's for each record in a dataset and optionally
populate this field.

Args:
dataset: a `LabeledDataset` instance
audit_only: If False, modifies the labels in place to populate the
filename attribute

Returns:
a tuple of:
missing_count: integer count of labels files without a
labels.filename field
mismatch_count: integer count of labels files with a labels.filename
field inconsistent with the data record filename

Raises:
LabeledDatasetError if audit_only==False and a mismatching filename is
found.
'''
logger.info("Checking labels.filename's for labeled dataset...")

missing_count = 0
mismatch_count = 0

for idx, (data_path, labels_path) in enumerate(dataset.iter_paths()):
if idx % 20 == 0:
logger.info("%4d/%4d" % (idx, len(dataset)))

data_filename = os.path.basename(data_path)
labels = dataset.read_labels(labels_path)

if labels.filename is None:
missing_count += 1

if not audit_only:
labels.filename = data_filename
dataset.write_labels(labels, labels_path)

elif labels.filename != data_filename:
mismatch_count += 1

if not audit_only:
raise LabeledDatasetError(
"Filename: '%s' in labels file does not match data"
" filename: '%s'." % (labels.filename, data_filename)
)

logger.info("Complete: %d missing filenames and %d mismatched filenames"
% (missing_count, mismatch_count))

return missing_count, mismatch_count


def check_dataset_syntax(dataset, target_schema, audit_only=True):
'''Audit labels.filename's for each record in a dataset and optionally
populate this field.

Args:
dataset: a `LabeledDataset` instance
target_schema: an `ImageLabelsSchema` or `VideoLabelsSchema` matching
the dataset type
audit_only: If False, modifies the labels in place to fix syntax

Returns:
a tuple of:
fixable_schema: schema of values that can be (or were) substituted
with target_schema syntax
unfixable_schema: schema of values that cannot be mapped to
the target_schema
'''
logger.info("Checking consistent syntax for labeled dataset...")

if isinstance(dataset, LabeledImageDataset):
checker = etai.ImageLabelsSyntaxChecker(target_schema)
elif isinstance(dataset, LabeledVideoDataset):
checker = etav.VideoLabelsSyntaxChecker(target_schema)
else:
raise ValueError(
"Unexpected input type: `%s`" % etau.get_class_name(dataset))

modified_count = 0

for idx, labels_path in enumerate(dataset.iter_labels_paths()):
if idx % 20 == 0:
logger.info("%4d/%4d" % (idx, len(dataset)))

labels = dataset.read_labels(labels_path)

was_modified = checker.check(labels)

modified_count += int(was_modified)
if not audit_only and was_modified:
labels.write_json(labels_path)

logger.info(
"Complete: %d/%d files %supdated"
% (modified_count, len(dataset), "can be " if audit_only else "")
)

return checker.fixable_schema, checker.unfixable_schema
Loading