Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small improvements with a vague description #435

Merged
merged 5 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions eta/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def merge_schema(self, schema):
'''Merges the given CategoricalAttributeSchema into this schema.'''
self.categories.update(schema.categories)

def serialize(self, *args, **kwargs):
d = super(CategoricalAttributeSchema, self).serialize(*args, **kwargs)
if "categories" in d:
d["categories"].sort()
return d

@staticmethod
def get_kwargs(d):
'''Extracts the relevant keyword arguments for this schema from the
Expand Down
70 changes: 35 additions & 35 deletions eta/core/datasets/labeled_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def __iter__(self):

Returns:
iterator: iterator over (data, labels) pairs, where data is an
object returned by self._read_data() and labels is an object
returned by self._read_labels() from the respective paths
object returned by self.read_data() and labels is an object
returned by self.read_labels() from the respective paths
of a data file and corresponding labels file
'''
return zip(self.iter_data(), self.iter_labels())
Expand All @@ -227,11 +227,11 @@ def iter_data(self):
'''Iterates over the data in the dataset.

Returns:
iterator: iterator over objects returned by self._read_data()
iterator: iterator over objects returned by self.read_data()
from the paths to data files
'''
for data_path in self.iter_data_paths():
yield self._read_data(data_path)
yield self.read_data(data_path)

def iter_data_paths(self):
'''Iterates over the paths to data files in the dataset.
Expand All @@ -246,11 +246,11 @@ def iter_labels(self):
'''Iterates over the labels in the dataset.

Returns:
iterator: iterator over objects returned by self._read_labels()
iterator: iterator over objects returned by self.read_labels()
from the paths to labels files
'''
for labels_path in self.iter_labels_paths():
yield self._read_labels(labels_path)
yield self.read_labels(labels_path)

def iter_labels_paths(self):
'''Iterates over the paths to labels files in the dataset.
Expand Down Expand Up @@ -441,9 +441,9 @@ def add_file(self, data_path, labels_path, new_data_filename=None,

# Update the filename attribute in the labels JSON if necessary
if new_data_filename != os.path.basename(data_path):
labels_ = self._read_labels(new_labels_path)
labels_ = self.read_labels(new_labels_path)
labels_.filename = new_data_filename
self._write_labels(labels_, new_labels_path)
self.write_labels(labels_, new_labels_path)

# First remove any other records with the same data filename
self.dataset_index.cull_with_function(
Expand All @@ -466,9 +466,9 @@ def add_data(self, data, labels, data_filename, labels_filename,

Args:
data: input data in a format that can be passed to
self._write_data()
self.write_data()
labels: input labels in a format that can be passed to
self._write_labels()
self.write_labels()
data_filename: filename for the data in the dataset
labels_filename: filename for the labels in the dataset
error_on_duplicates: whether to raise an error if a data file
Expand All @@ -488,8 +488,8 @@ def add_data(self, data, labels, data_filename, labels_filename,
labels_path = os.path.join(
self.dataset_dir, self._LABELS_SUBDIR, labels_filename)

self._write_data(data, data_path)
self._write_labels(labels, labels_path)
self.write_data(data, data_path)
self.write_labels(labels, labels_path)

# First remove any other records with the same data filename
self.dataset_index.cull_with_function(
Expand Down Expand Up @@ -700,14 +700,14 @@ def apply_to_data(self, func):

Args:
func: function that takes in a data element in the format
returned by `self._read_data()` and outputs transformed
returned by `self.read_data()` and outputs transformed
data in the same format

Returns:
self
'''
for data, path in zip(self.iter_data(), self.iter_data_paths()):
self._write_data(func(data), path)
self.write_data(func(data), path)

return self

Expand All @@ -734,15 +734,15 @@ def apply_to_labels(self, func):

Args:
func: function that takes in a labels object in the format
returned by `self._read_labels()` and outputs transformed
returned by `self.read_labels()` and outputs transformed
labels in the same format

Returns:
self
'''
for labels, path in zip(
self.iter_labels(), self.iter_labels_paths()):
self._write_labels(func(labels), path)
self.write_labels(func(labels), path)

return self

Expand Down Expand Up @@ -915,7 +915,7 @@ def _build_index_map(self):
data_file)
self._data_to_labels_map[data_file] = labels_file

def _read_data(self, path):
def read_data(self, path):
'''Reads data from a data file at the given path.

Subclasses must implement this based on the particular data format for
Expand All @@ -927,9 +927,9 @@ def _read_data(self, path):
Returns:
a data object in the particular format for the subclass
'''
raise NotImplementedError("subclasses must implement _read_data()")
raise NotImplementedError("subclasses must implement read_data()")

def _read_labels(self, path):
def read_labels(self, path):
'''Reads a labels object from a labels JSON file at the given path.

Subclasses must implement this based on the particular labels format
Expand All @@ -941,33 +941,33 @@ def _read_labels(self, path):
Returns:
a labels object in the particular format for the subclass
'''
raise NotImplementedError("subclasses must implement _read_labels()")
raise NotImplementedError("subclasses must implement read_labels()")

def _write_data(self, data, path):
def write_data(self, data, path):
'''Writes data to a data file at the given path.

Subclasses must implement this based on the particular data format for
the subclass. The method should accept input `data` of the same type
as output by `self._read_data()`.
as output by `self.read_data()`.

Args:
data: a data element to be written to a file
path: path to write the data
'''
raise NotImplementedError("subclasses must implement _write_data()")
raise NotImplementedError("subclasses must implement write_data()")

def _write_labels(self, labels, path):
def write_labels(self, labels, path):
'''Writes a labels object to a labels JSON file at the given path.

Subclasses must implement this based on the particular labels format
for the subclass. The method should accept input `labels` of the same
type as output by `self._read_labels()`.
type as output by `self.read_labels()`.

Args:
labels: a labels object to be written to a file
path: path to write the labels JSON file
'''
raise NotImplementedError("subclasses must implement _write_labels()")
raise NotImplementedError("subclasses must implement write_labels()")

def _build_metadata(self, path):
'''Reads metadata from a data file at the given path and builds an
Expand Down Expand Up @@ -1175,19 +1175,19 @@ def compute_average_video_duration(self):

return np.mean(video_durations)

def _read_data(self, path):
def read_data(self, path):
return etav.FFmpegVideoReader(path)

def _read_labels(self, path):
def read_labels(self, path):
return etav.VideoLabels.from_json(path)

def _write_data(self, data, path):
def write_data(self, data, path):
with etav.FFmpegVideoWriter(
path, data.frame_rate, data.frame_size) as writer:
for img in data:
writer.write(img)

def _write_labels(self, labels, path):
def write_labels(self, labels, path):
labels.write_json(path)

def _build_metadata(self, path):
Expand Down Expand Up @@ -1273,7 +1273,7 @@ def write_annotated_data(self, output_dir_path, annotation_config=None):
img, image_labels, annotation_config=annotation_config)
output_path = os.path.join(
output_dir_path, os.path.basename(image_path))
self._write_data(img_annotated, output_path)
self.write_data(img_annotated, output_path)

@classmethod
def validate_dataset(cls, dataset_path):
Expand Down Expand Up @@ -1306,16 +1306,16 @@ def validate_dataset(cls, dataset_path):
if not os.path.isfile(labels_path):
raise LabeledDatasetError("File not found: %s" % labels_path)

def _read_data(self, path):
def read_data(self, path):
return etai.read(path)

def _read_labels(self, path):
def read_labels(self, path):
return etai.ImageLabels.from_json(path)

def _write_data(self, data, path):
def write_data(self, data, path):
etai.write(data, path)

def _write_labels(self, labels, path):
def write_labels(self, labels, path):
labels.write_json(path)

def _build_metadata(self, path):
Expand Down
18 changes: 14 additions & 4 deletions eta/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2257,26 +2257,34 @@ def remove_none_values(d):
return {k: v for k, v in iteritems(d) if v is not None}


def find_duplicate_files(path_list):
def find_duplicate_files(path_list, verbose=False):
'''Returns a list of lists of file paths from the input, that have
identical contents to each other.

Args:
path_list: list of file paths in which to look for duplicate files
verbose: if True, log progress

Returns:
duplicates: a list of lists, where each list contains a group of
file paths that all have identical content. File paths in
`path_list` that don't have any duplicates will not appear in
the output.
'''
hash_buckets = _get_file_hash_buckets(path_list)
if verbose:
logger.info("Finding duplicates out of %d files..." % len(path_list))

hash_buckets = _get_file_hash_buckets(path_list, verbose=verbose)

duplicates = []
for file_group in itervalues(hash_buckets):
if len(file_group) >= 2:
duplicates.extend(_find_duplicates_brute_force(file_group))

if verbose:
duplicate_count = sum(len(x) for x in duplicates) - len(duplicates)
logger.info("Complete: %d duplicates found" % duplicate_count)

return duplicates


Expand Down Expand Up @@ -2306,9 +2314,11 @@ def find_matching_file_pairs(path_list1, path_list2):
return pairs


def _get_file_hash_buckets(path_list):
def _get_file_hash_buckets(path_list, verbose):
hash_buckets = defaultdict(list)
for path in path_list:
for idx, path in enumerate(path_list):
if verbose and idx % 100 == 0:
logger.info("\thashing file %d/%d" % (idx, len(path_list)))
if not os.path.isfile(path):
logger.warning(
"File '%s' is a directory or does not exist. "
Expand Down
2 changes: 1 addition & 1 deletion eta/core/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,7 @@ def validate_event(self, event):
def attributes(self):
'''Returns the list of class attributes that will be serialized.

Args:
Returns:
a list of attribute names
'''
return ["attrs", "frames", "objects", "events"]
Expand Down