diff --git a/eta/core/data.py b/eta/core/data.py index 73f0bef07..17487e381 100644 --- a/eta/core/data.py +++ b/eta/core/data.py @@ -326,6 +326,12 @@ def merge_schema(self, schema): '''Merges the given CategoricalAttributeSchema into this schema.''' self.categories.update(schema.categories) + def serialize(self, *args, **kwargs): + d = super(CategoricalAttributeSchema, self).serialize(*args, **kwargs) + if "categories" in d: + d["categories"].sort() + return d + @staticmethod def get_kwargs(d): '''Extracts the relevant keyword arguments for this schema from the diff --git a/eta/core/datasets/labeled_datasets.py b/eta/core/datasets/labeled_datasets.py index e9b3e583a..44b0614ba 100644 --- a/eta/core/datasets/labeled_datasets.py +++ b/eta/core/datasets/labeled_datasets.py @@ -204,8 +204,8 @@ def __iter__(self): Returns: iterator: iterator over (data, labels) pairs, where data is an - object returned by self._read_data() and labels is an object - returned by self._read_labels() from the respective paths + object returned by self.read_data() and labels is an object + returned by self.read_labels() from the respective paths of a data file and corresponding labels file ''' return zip(self.iter_data(), self.iter_labels()) @@ -227,11 +227,11 @@ def iter_data(self): '''Iterates over the data in the dataset. Returns: - iterator: iterator over objects returned by self._read_data() + iterator: iterator over objects returned by self.read_data() from the paths to data files ''' for data_path in self.iter_data_paths(): - yield self._read_data(data_path) + yield self.read_data(data_path) def iter_data_paths(self): '''Iterates over the paths to data files in the dataset. @@ -246,11 +246,11 @@ def iter_labels(self): '''Iterates over the labels in the dataset. Returns: - iterator: iterator over objects returned by self._read_labels() + iterator: iterator over objects returned by self.read_labels() from the paths to labels files ''' for labels_path in self.iter_labels_paths(): - yield self._read_labels(labels_path) + yield self.read_labels(labels_path) def iter_labels_paths(self): '''Iterates over the paths to labels files in the dataset. @@ -441,9 +441,9 @@ def add_file(self, data_path, labels_path, new_data_filename=None, # Update the filename attribute in the labels JSON if necessary if new_data_filename != os.path.basename(data_path): - labels_ = self._read_labels(new_labels_path) + labels_ = self.read_labels(new_labels_path) labels_.filename = new_data_filename - self._write_labels(labels_, new_labels_path) + self.write_labels(labels_, new_labels_path) # First remove any other records with the same data filename self.dataset_index.cull_with_function( @@ -466,9 +466,9 @@ def add_data(self, data, labels, data_filename, labels_filename, Args: data: input data in a format that can be passed to - self._write_data() + self.write_data() labels: input labels in a format that can be passed to - self._write_labels() + self.write_labels() data_filename: filename for the data in the dataset labels_filename: filename for the labels in the dataset error_on_duplicates: whether to raise an error if a data file @@ -488,8 +488,8 @@ def add_data(self, data, labels, data_filename, labels_filename, labels_path = os.path.join( self.dataset_dir, self._LABELS_SUBDIR, labels_filename) - self._write_data(data, data_path) - self._write_labels(labels, labels_path) + self.write_data(data, data_path) + self.write_labels(labels, labels_path) # First remove any other records with the same data filename self.dataset_index.cull_with_function( @@ -700,14 +700,14 @@ def apply_to_data(self, func): Args: func: function that takes in a data element in the format - returned by `self._read_data()` and outputs transformed + returned by `self.read_data()` and outputs transformed data in the same format Returns: self ''' for data, path in zip(self.iter_data(), self.iter_data_paths()): - self._write_data(func(data), path) + self.write_data(func(data), path) return self @@ -734,7 +734,7 @@ def apply_to_labels(self, func): Args: func: function that takes in a labels object in the format - returned by `self._read_labels()` and outputs transformed + returned by `self.read_labels()` and outputs transformed labels in the same format Returns: @@ -742,7 +742,7 @@ def apply_to_labels(self, func): ''' for labels, path in zip( self.iter_labels(), self.iter_labels_paths()): - self._write_labels(func(labels), path) + self.write_labels(func(labels), path) return self @@ -915,7 +915,7 @@ def _build_index_map(self): data_file) self._data_to_labels_map[data_file] = labels_file - def _read_data(self, path): + def read_data(self, path): '''Reads data from a data file at the given path. Subclasses must implement this based on the particular data format for @@ -927,9 +927,9 @@ def _read_data(self, path): Returns: a data object in the particular format for the subclass ''' - raise NotImplementedError("subclasses must implement _read_data()") + raise NotImplementedError("subclasses must implement read_data()") - def _read_labels(self, path): + def read_labels(self, path): '''Reads a labels object from a labels JSON file at the given path. Subclasses must implement this based on the particular labels format @@ -941,33 +941,33 @@ def _read_labels(self, path): Returns: a labels object in the particular format for the subclass ''' - raise NotImplementedError("subclasses must implement _read_labels()") + raise NotImplementedError("subclasses must implement read_labels()") - def _write_data(self, data, path): + def write_data(self, data, path): '''Writes data to a data file at the given path. Subclasses must implement this based on the particular data format for the subclass. The method should accept input `data` of the same type - as output by `self._read_data()`. + as output by `self.read_data()`. Args: data: a data element to be written to a file path: path to write the data ''' - raise NotImplementedError("subclasses must implement _write_data()") + raise NotImplementedError("subclasses must implement write_data()") - def _write_labels(self, labels, path): + def write_labels(self, labels, path): '''Writes a labels object to a labels JSON file at the given path. Subclasses must implement this based on the particular labels format for the subclass. The method should accept input `labels` of the same - type as output by `self._read_labels()`. + type as output by `self.read_labels()`. Args: labels: a labels object to be written to a file path: path to write the labels JSON file ''' - raise NotImplementedError("subclasses must implement _write_labels()") + raise NotImplementedError("subclasses must implement write_labels()") def _build_metadata(self, path): '''Reads metadata from a data file at the given path and builds an @@ -1175,19 +1175,19 @@ def compute_average_video_duration(self): return np.mean(video_durations) - def _read_data(self, path): + def read_data(self, path): return etav.FFmpegVideoReader(path) - def _read_labels(self, path): + def read_labels(self, path): return etav.VideoLabels.from_json(path) - def _write_data(self, data, path): + def write_data(self, data, path): with etav.FFmpegVideoWriter( path, data.frame_rate, data.frame_size) as writer: for img in data: writer.write(img) - def _write_labels(self, labels, path): + def write_labels(self, labels, path): labels.write_json(path) def _build_metadata(self, path): @@ -1273,7 +1273,7 @@ def write_annotated_data(self, output_dir_path, annotation_config=None): img, image_labels, annotation_config=annotation_config) output_path = os.path.join( output_dir_path, os.path.basename(image_path)) - self._write_data(img_annotated, output_path) + self.write_data(img_annotated, output_path) @classmethod def validate_dataset(cls, dataset_path): @@ -1306,16 +1306,16 @@ def validate_dataset(cls, dataset_path): if not os.path.isfile(labels_path): raise LabeledDatasetError("File not found: %s" % labels_path) - def _read_data(self, path): + def read_data(self, path): return etai.read(path) - def _read_labels(self, path): + def read_labels(self, path): return etai.ImageLabels.from_json(path) - def _write_data(self, data, path): + def write_data(self, data, path): etai.write(data, path) - def _write_labels(self, labels, path): + def write_labels(self, labels, path): labels.write_json(path) def _build_metadata(self, path): diff --git a/eta/core/utils.py b/eta/core/utils.py index 48ab764c3..8fbec61b9 100644 --- a/eta/core/utils.py +++ b/eta/core/utils.py @@ -2257,12 +2257,13 @@ def remove_none_values(d): return {k: v for k, v in iteritems(d) if v is not None} -def find_duplicate_files(path_list): +def find_duplicate_files(path_list, verbose=False): '''Returns a list of lists of file paths from the input, that have identical contents to each other. Args: path_list: list of file paths in which to look for duplicate files + verbose: if True, log progress Returns: duplicates: a list of lists, where each list contains a group of @@ -2270,13 +2271,20 @@ def find_duplicate_files(path_list): `path_list` that don't have any duplicates will not appear in the output. ''' - hash_buckets = _get_file_hash_buckets(path_list) + if verbose: + logger.info("Finding duplicates out of %d files..." % len(path_list)) + + hash_buckets = _get_file_hash_buckets(path_list, verbose=verbose) duplicates = [] for file_group in itervalues(hash_buckets): if len(file_group) >= 2: duplicates.extend(_find_duplicates_brute_force(file_group)) + if verbose: + duplicate_count = sum(len(x) for x in duplicates) - len(duplicates) + logger.info("Complete: %d duplicates found" % duplicate_count) + return duplicates @@ -2306,9 +2314,11 @@ def find_matching_file_pairs(path_list1, path_list2): return pairs -def _get_file_hash_buckets(path_list): +def _get_file_hash_buckets(path_list, verbose): hash_buckets = defaultdict(list) - for path in path_list: + for idx, path in enumerate(path_list): + if verbose and idx % 100 == 0: + logger.info("\thashing file %d/%d" % (idx, len(path_list))) if not os.path.isfile(path): logger.warning( "File '%s' is a directory or does not exist. " diff --git a/eta/core/video.py b/eta/core/video.py index 09a7bc4e8..762bb2bf0 100644 --- a/eta/core/video.py +++ b/eta/core/video.py @@ -1616,7 +1616,7 @@ def validate_event(self, event): def attributes(self): '''Returns the list of class attributes that will be serialized. - Args: + Returns: a list of attribute names ''' return ["attrs", "frames", "objects", "events"]