diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 00da4ac23..4b68e61c3 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1617,6 +1617,12 @@ def append_unique(old, new): } instance_type_to_idx = {Instance: 0, PredictedInstance: 1} + # Each instance we create will have and index in the dataset, keep track of + # these so we can quickly add from_predicted links on a second pass. + instance_to_idx = {} + instances_with_from_predicted = [] + instances_from_predicted = [] + # If we are appending, we need look inside to see what frame, instance, and point # ids we need to start from. This gives us offsets to use. if append and "points" in f: @@ -1633,9 +1639,7 @@ def append_unique(old, new): point_id = 0 pred_point_id = 0 instance_id = 0 - frame_id = 0 - all_from_predicted = [] - from_predicted_id = 0 + for frame_id, label in enumerate(labels): frames[frame_id] = ( frame_id + frame_id_offset, @@ -1645,6 +1649,11 @@ def append_unique(old, new): instance_id + instance_id_offset + len(label.instances), ) for instance in label.instances: + + # Add this instance to our lookup structure we will need for from_predicted + # links + instance_to_idx[instance] = instance_id + parray = instance.get_points_array(copy=False, full=True) instance_type = type(instance) @@ -1659,8 +1668,8 @@ def append_unique(old, new): # Keep track of any from_predicted instance links, we will insert the # correct instance_id in the dataset after we are done. if instance.from_predicted: - all_from_predicted.append(instance.from_predicted) - from_predicted_id = from_predicted_id + 1 + instances_with_from_predicted.append(instance_id) + instances_from_predicted.append(instance.from_predicted) # Copy all the data instances[instance_id] = ( @@ -1688,6 +1697,21 @@ def append_unique(old, new): instance_id = instance_id + 1 + # Add from_predicted links + for instance_id, from_predicted in zip( + instances_with_from_predicted, instances_from_predicted + ): + try: + instances[instance_id]["from_predicted"] = instance_to_idx[ + from_predicted + ] + except KeyError: + # If we haven't encountered the from_predicted instance yet then don't save the link. + # It’s possible for a user to create a regular instance from a predicted instance and then + # delete all predicted instances from the file, but in this case I don’t think there’s any reason + # to remember which predicted instance the regular instance came from. + pass + # We pre-allocated our points array with max possible size considering the max # skeleton size, drop any unused points. points = points[0:point_id] @@ -1785,6 +1809,10 @@ def load_hdf5( tracks = labels.tracks.copy() tracks.extend([None]) + # A dict to keep track of instances that have a from_predicted link. The key is the + # instance and the value is the index of the instance. + from_predicted_lookup = {} + # Create the instances instances = [] for i in instances_dset: @@ -1806,6 +1834,13 @@ def load_hdf5( ) instances.append(instance) + if i["from_predicted"] != -1: + from_predicted_lookup[instance] = i["from_predicted"] + + # Make a second pass to add any from_predicted links + for instance, from_predicted_idx in from_predicted_lookup.items(): + instance.from_predicted = instances[from_predicted_idx] + # Create the labeled frames frames = [ LabeledFrame( diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 0695030a3..f141ae6fc 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -692,3 +692,24 @@ def test_labels_append_hdf5(multi_skel_vid_labels, tmpdir): loaded_labels = Labels.load_hdf5(filename=filename) _check_labels_match(labels, loaded_labels) + + +def test_hdf5_from_predicted(multi_skel_vid_labels, tmpdir): + labels = multi_skel_vid_labels + filename = os.path.join(tmpdir, "test.h5") + + # Add some predicted instances to create from_predicted links + for frame_num, frame in enumerate(labels): + if frame_num % 20 == 0: + frame.instances[0].from_predicted = PredictedInstance.from_instance( + frame.instances[0], float(frame_num) + ) + frame.instances.append(frame.instances[0].from_predicted) + + # Save and load, compare the results + Labels.save_hdf5(filename=filename, labels=labels) + loaded_labels = Labels.load_hdf5(filename=filename) + + for frame_num, frame in enumerate(loaded_labels): + if frame_num % 20 == 0: + assert frame.instances[0].from_predicted.score == float(frame_num)