Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritize user instances when coping from prior frame #1658

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,6 +3019,23 @@ def set_visible_nodes(

return has_missing_nodes

@staticmethod
def find_last_user_instance(
prev_frame: LabeledFrame,
) -> Optional[Instance]:
"""Find last user instance to copy from

Args:
prev_frame: The last labeled frame from which we obtain the last user instance.

Returns:
The last user instance in the previous frame (if present), otherwise null
"""

user_instances = prev_frame.user_instances
if len(user_instances) > 0:
return user_instances[-1]

@staticmethod
def find_instance_to_copy_from(
context: CommandContext,
Expand Down Expand Up @@ -3067,25 +3084,48 @@ def find_instance_to_copy_from(
prev_idx = AddInstance.get_previous_frame_index(context)

if prev_idx is not None:
prev_instances = context.labels.find(
prev_frame = context.labels.find(
context.state["video"], prev_idx, return_new=True
)[0].instances
)[0]
prev_instances = prev_frame.instances
if len(prev_instances) > len(context.state["labeled_frame"].instances):
# If more instances in previous frame than current, then use the
# first unmatched instance.
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]

if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in previous frame, if present
user_instance = AddInstance.find_last_user_instance(prev_frame)
if user_instance is not None:
copy_instance = user_instance

from_prev_frame = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of the three ways am instance is copied from the previous frame, this is the only one that doesn't seem to match the intended functionality. This case tries to find an instance that exists in the previous frame, but not in the current frame and add it to the current frame. We can use that same logic, but just prioritize the user instance. Maybe something like:

Suggested change
if len(prev_instances) > len(context.state["labeled_frame"].instances):
# If more instances in previous frame than current, then use the
# first unmatched instance.
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]
if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in previous frame, if present
user_instance = AddInstance.find_last_user_instance(prev_frame)
if user_instance is not None:
copy_instance = user_instance
from_prev_frame = True
if len(prev_instances) > len(context.state["labeled_frame"].instances):
# If more instances in previous frame than current, then use the
# first unmatched instance.
prev_user_instances = prev_instances.user_instances
current_user_instances = context.state["labeled_frame"].instances
if len(prev_instances.user_instances) > len(context.state["labeled_frame"].user_instances):
user_instance = prev_user_instances[len(current_user_instances)]
else:
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]
from_prev_frame = True

I liked that you added an AddInstance.find_last_user method, maybe we can edit that function a bit to find_user_instance(index: int = -1, frame: LabeledFrame) where we use a default argument for the index to always get the last instance if the index is not provided?

elif init_method == "best" and (
context.state["labeled_frame"].instances
):
# Otherwise, if there are already instances in current frame,
# copy the points from the last instance added to frame.
copy_instance = context.state["labeled_frame"].instances[-1]
if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in current frame
user_instance = AddInstance.find_last_user_instance(
context.state["labeled_frame"]
)
if user_instance is not None:
copy_instance = user_instance

elif len(prev_instances):
# Otherwise use the last instance added to previous frame.
copy_instance = prev_instances[-1]

if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in previous frame, if present
user_instance = AddInstance.find_last_user_instance(prev_frame)
if user_instance is not None:
copy_instance = user_instance

Copy link

@coderabbitai coderabbitai bot Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method find_instance_to_copy_from has been updated to include calls to find_last_user_instance. This change ensures that when copying instances from a prior frame, the last user instance is prioritized over predicted instances. The logic appears to be correct, and the use of isinstance checks ensures that the method is robust to different instance types. However, there is a repeated pattern that could be refactored to reduce code duplication.

- if isinstance(copy_instance, PredictedInstance):
-     # Set copy instance to last user instance in previous frame, if present
-     user_instance = AddInstance.find_last_user_instance(prev_frame)
-     if user_instance is not None:
-         copy_instance = user_instance
+ # Refactor repeated logic into a method
+ copy_instance = self._replace_with_user_instance_if_needed(copy_instance, prev_frame)

Committable suggestion

IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
prev_frame = context.labels.find(
context.state["video"], prev_idx, return_new=True
)[0].instances
)[0]
prev_instances = prev_frame.instances
if len(prev_instances) > len(context.state["labeled_frame"].instances):
# If more instances in previous frame than current, then use the
# first unmatched instance.
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]
if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in previous frame, if present
user_instance = AddInstance.find_last_user_instance(prev_frame)
if user_instance is not None:
copy_instance = user_instance
from_prev_frame = True
elif init_method == "best" and (
context.state["labeled_frame"].instances
):
# Otherwise, if there are already instances in current frame,
# copy the points from the last instance added to frame.
copy_instance = context.state["labeled_frame"].instances[-1]
if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in current frame
user_instance = AddInstance.find_last_user_instance(
context.state["labeled_frame"]
)
if user_instance is not None:
copy_instance = user_instance
elif len(prev_instances):
# Otherwise use the last instance added to previous frame.
copy_instance = prev_instances[-1]
if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in previous frame, if present
user_instance = AddInstance.find_last_user_instance(prev_frame)
if user_instance is not None:
copy_instance = user_instance
prev_frame = context.labels.find(
context.state["video"], prev_idx, return_new=True
)[0]
prev_instances = prev_frame.instances
if len(prev_instances) > len(context.state["labeled_frame"].instances):
# If more instances in previous frame than current, then use the
# first unmatched instance.
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]
copy_instance = self._replace_with_user_instance_if_needed(copy_instance, prev_frame)
from_prev_frame = True
elif init_method == "best" and (
context.state["labeled_frame"].instances
):
# Otherwise, if there are already instances in current frame,
# copy the points from the last instance added to frame.
copy_instance = context.state["labeled_frame"].instances[-1]
copy_instance = self._replace_with_user_instance_if_needed(copy_instance, context.state["labeled_frame"])
elif len(prev_instances):
# Otherwise use the last instance added to previous frame.
copy_instance = prev_instances[-1]
copy_instance = self._replace_with_user_instance_if_needed(copy_instance, prev_frame)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the same pattern is used over an over, maybe we can just include that into the new method you created.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roomrys Absolutely, the repeated logic can be encapsulated within the find_last_user_instance method to streamline the code. This will ensure that the method not only finds the last user instance but also replaces the predicted instance with the user instance when appropriate. I'll adjust the method accordingly.

from_prev_frame = True

from_predicted = from_predicted if hasattr(from_predicted, "score") else None
Expand Down
146 changes: 144 additions & 2 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SaveProjectAs,
get_new_version_filename,
)
from sleap.instance import Instance, LabeledFrame
from sleap.instance import Instance, LabeledFrame, Point
from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
Expand Down Expand Up @@ -216,7 +216,6 @@ def assert_videos_written(num_videos: int, labels_path: str = None):
context.state["filename"] = None

if csv:

context.state["filename"] = centered_pair_predictions_hdf5_path

params = {"all_videos": True, "csv": csv}
Expand Down Expand Up @@ -617,6 +616,149 @@ def test_CopyInstance(min_tracks_2node_labels: Labels):
assert context.state["clipboard_instance"] == instance


def test_CopyPriorFramePreviousUser(centered_pair_predictions: Labels):
"""Test that we copy prior user frame when available."""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

# No user instances in current frame, only predicted
assert len(context.state["labeled_frame"].user_instances) == 0
assert len(context.state["labeled_frame"].predicted_instances) == 2

# Modify previous frame to have a user instance
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_instances = prev_frame.instances

# No user instances in previous frame, only predicted
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=prev_frame,
)
prev_instances.append(user_inst)

# Confirm there is one user instance in previous frame
assert len(prev_frame.user_instances) == 1

context.newInstance(init_method="prior_frame")

# Confirm that the newly added user instance is the same as the sole user instance in the previous frame+
newly_added_instance = context.state["labeled_frame"].user_instances[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a better check would be to insert the user_inst (instead of append) since the original logic that we wanted to change would have just taken the last instance in the prev_instances list in any case. I would also add assertion statements here to ensure that we are in the case that you want to test (and to help me, the reviewer), i.e. for the first case:

assert len(prev_instances) > len(context.state["labeled_frame"].instances):

assert newly_added_instance.video == user_inst.video
assert newly_added_instance.points == user_inst.points
assert newly_added_instance.track == user_inst.track


def test_CopyPriorFramePreviousUser2(centered_pair_predictions: Labels):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a better name that describes what the 2 means, or you could bundle it into the previous test if you want and section with assert statements, i.e. for the second case we would make some changes to our labels to enter the second case and:

assert len(prev_instances) <= len(context.state["labeled_frame"].instances)  # not the first case
assert context.state["labeled_frame"].instances  # but yes to the second case

"""Test that we copy user instance in previous frame when the current frame has no instances."""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["labeled_frame"].instances = []
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

# No instances in current frame
assert len(context.state["labeled_frame"].instances) == 0

# Get previous frame
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_instances = prev_frame.instances

# Confirm that previous frame has 2 predicted instances
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

# Add user instance to previous frame
skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=context.state["labeled_frame"],
)
prev_instances.insert(0, user_inst)

# Confirm addition of user instance in previous frame
assert len(prev_frame.user_instances) == 1

context.newInstance(init_method="best")

# Confirm that user instance in current frame is the same as the user instance in previous frame
current_user_instance = context.state["labeled_frame"].user_instances[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! yes the insert! The only comment here would be to add the assert statement for which case we are in:

assert len(prev_instances) <= len(context.state["labeled_frame"].instances)  # no to the first case
assert context.state["labeled_frame"].instances  # but yes to the second case

assert current_user_instance.video == user_inst.video
assert current_user_instance.points == user_inst.points
assert current_user_instance.track == user_inst.track


def test_CopyPriorFrameCurrentUser(centered_pair_predictions: Labels):
"""Test that we copy user instance in current frame when the current frame has >= amount of instances
as the previous frame."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this is a bit too long of a summary doc string (see google style guide), but I am super pleased that you added some docstrings in the first place!

context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

assert len(context.state["labeled_frame"].predicted_instances) == 2

# Get previous frame
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_instances = prev_frame.instances
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

# Add user instance to current frame
skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=context.state["labeled_frame"],
)
context.state["labeled_frame"].instances.insert(0, user_inst)

# Confirm addition of user instance in current frame
assert len(context.state["labeled_frame"].user_instances) == 1

# Confirm that current frame has more instances than previous frame
assert len(context.state["labeled_frame"].instances) > len(prev_frame.instances)
Comment on lines +766 to +767
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome


# Remove tracks from all instances so that unused_predictions is empty
for inst in context.state["labeled_frame"].instances:
inst.track = None
inst.from_predicted = inst

context.newInstance(init_method="best")

# Confirm that both user instances in the current frame are the same+
previous_user_instance = context.state["labeled_frame"].user_instances[0]
newly_added_instance = context.state["labeled_frame"].user_instances[1]
assert newly_added_instance.video == previous_user_instance.video
assert newly_added_instance.points == previous_user_instance.points
assert newly_added_instance.track == previous_user_instance.track


def test_PasteInstance(min_tracks_2node_labels: Labels):
"""Test that pasting an instance works as expected."""
labels = min_tracks_2node_labels
Expand Down
Loading