Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritize user instances when coping from prior frame #1658

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,6 +3019,47 @@ def set_visible_nodes(

return has_missing_nodes

@staticmethod
def find_last_user_instance(
frame_to_copy_from: LabeledFrame,
index: int = -1,
) -> Optional[Instance]:
"""Find last user instance to copy from

Args:
frame_to_copy_from: The last labeled frame from which we obtain the last user instance.

Returns:
The last user instance in the frame_to_copy_from (if present), otherwise null
"""

user_instances = frame_to_copy_from.user_instances
if len(user_instances) > 0:
return user_instances[index]

@staticmethod
def replace_with_user_instance_if_needed(
copy_instance: Optional[Union[Instance, PredictedInstance]],
frame_to_copy_from: LabeledFrame,
):
"""Replace copy_instance with user instance if needed.

Args:
copy_instance: The current copy instance.
frame_to_copy_from: The last labeled frame from which we obtain the last user instance.

Returns:
The current copy_instance or the user_instance it has been replaced with.
"""

if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in frame to copy from, if present
user_instance = AddInstance.find_last_user_instance(frame_to_copy_from)
if user_instance is not None:
return user_instance

return copy_instance

@staticmethod
def find_instance_to_copy_from(
context: CommandContext,
Expand Down Expand Up @@ -3067,25 +3108,44 @@ def find_instance_to_copy_from(
prev_idx = AddInstance.get_previous_frame_index(context)

if prev_idx is not None:
prev_instances = context.labels.find(
prev_frame = context.labels.find(
context.state["video"], prev_idx, return_new=True
)[0].instances
)[0]
prev_instances = prev_frame.instances
if len(prev_instances) > len(context.state["labeled_frame"].instances):
# If more instances in previous frame than current, then use the
# first unmatched instance.
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]
prev_user_instances = prev_frame.user_instances
current_user_instances = context.state[
"labeled_frame"
].user_instances
if len(prev_user_instances) > len(current_user_instances):
copy_instance = prev_user_instances[len(current_user_instances)]
else:
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]

from_prev_frame = True
elif init_method == "best" and (
context.state["labeled_frame"].instances
):
# Otherwise, if there are already instances in current frame,
# copy the points from the last instance added to frame.
copy_instance = context.state["labeled_frame"].instances[-1]

copy_instance = AddInstance.replace_with_user_instance_if_needed(
copy_instance, context.state["labeled_frame"]
)

elif len(prev_instances):
# Otherwise use the last instance added to previous frame.
copy_instance = prev_instances[-1]

copy_instance = AddInstance.replace_with_user_instance_if_needed(
copy_instance, prev_frame
)

from_prev_frame = True

from_predicted = from_predicted if hasattr(from_predicted, "score") else None
Expand Down
168 changes: 166 additions & 2 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SaveProjectAs,
get_new_version_filename,
)
from sleap.instance import Instance, LabeledFrame
from sleap.instance import Instance, LabeledFrame, Point
from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
Expand Down Expand Up @@ -216,7 +216,6 @@ def assert_videos_written(num_videos: int, labels_path: str = None):
context.state["filename"] = None

if csv:

context.state["filename"] = centered_pair_predictions_hdf5_path

params = {"all_videos": True, "csv": csv}
Expand Down Expand Up @@ -617,6 +616,171 @@ def test_CopyInstance(min_tracks_2node_labels: Labels):
assert context.state["clipboard_instance"] == instance


def test_CopyPriorFrame_MoreInstancesInPreviousFrame(centered_pair_predictions: Labels):
"""Test that we copy prior user frame when available.

This case is triggered when there are more user instances in the previous frame than the current frame in which
case we want to copy a user instances from the previous frame.
"""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

# No user instances in current frame, only predicted
assert len(context.state["labeled_frame"].user_instances) == 0
assert len(context.state["labeled_frame"].predicted_instances) == 2

# Modify previous frame to have a user instance
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_instances = prev_frame.instances

# No user instances in previous frame, only predicted
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=prev_frame,
)
prev_instances.insert(0, user_inst)

# Confirm there is one user instance in previous frame
assert len(prev_frame.user_instances) == 1

# Confirm that there are more instances and user instances in previous frame than current frame
assert len(prev_instances) > len(context.state["labeled_frame"].instances)
assert len(prev_frame.user_instances) > len(
context.state["labeled_frame"].user_instances
)

context.newInstance(init_method="prior_frame")

# Confirm that the newly added user instance is the same as the sole user instance in the previous frame+
newly_added_instance = context.state["labeled_frame"].user_instances[0]
assert newly_added_instance.video == user_inst.video
assert newly_added_instance.points == user_inst.points
assert newly_added_instance.track == user_inst.track


def test_CopyPriorFrame_EquivalentInstancesInCurrentAndPreviousFrame(
centered_pair_predictions: Labels,
):
"""Copy user instance when same amount of instances in previous and current frame.

This case is triggered when there are the same amount of instancecs in the previous and current frame
and the copy method is "prior_frame". In this case, we want to take the last available user instance
from the previous frame when available.
"""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

# No instances in current frame
assert len(context.state["labeled_frame"].instances) == 2

# Get previous frame
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_frame.instances.pop()
prev_instances = prev_frame.instances

# Confirm that previous frame has 1 predicted instance
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 1

# Add user instance to previous frame
skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=context.state["labeled_frame"],
)
prev_instances.insert(0, user_inst)

# Confirm addition of user instance in previous frame
assert len(prev_frame.user_instances) == 1

# Confirm same amount of instances in current and previous frame
assert len(prev_frame.instances) == len(context.state["labeled_frame"].instances)

context.newInstance(init_method="prior_frame")

# Confirm that user instance in current frame is the same as the user instance in previous frame
current_user_instance = context.state["labeled_frame"].user_instances[0]
assert current_user_instance.video == user_inst.video
assert current_user_instance.points == user_inst.points
assert current_user_instance.track == user_inst.track


def test_CopyPriorFrame_GetFromCurrentFrame(centered_pair_predictions: Labels):
"""Get user instance from current frame when copy method is "best".

This case is triggered when the current frame has >= the amount of instances as the previous frame and the copy method
is "best". In this case we want to choose the last available user instance in the current frame.
"""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

assert len(context.state["labeled_frame"].predicted_instances) == 2

# Get previous frame
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

# Add user instance to current frame
skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=context.state["labeled_frame"],
)
context.state["labeled_frame"].instances.insert(0, user_inst)

# Confirm addition of user instance in current frame
assert len(context.state["labeled_frame"].user_instances) == 1

# Confirm that current frame has more instances than previous frame
assert len(context.state["labeled_frame"].instances) > len(prev_frame.instances)
Comment on lines +766 to +767
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome


# Remove tracks from all instances so that unused_predictions is empty
for inst in context.state["labeled_frame"].instances:
inst.track = None
inst.from_predicted = inst

context.newInstance(init_method="best")

# Confirm that both user instances in the current frame are the same
previous_user_instance = context.state["labeled_frame"].user_instances[0]
newly_added_instance = context.state["labeled_frame"].user_instances[1]
assert newly_added_instance.video == previous_user_instance.video
assert newly_added_instance.points == previous_user_instance.points
assert newly_added_instance.track == previous_user_instance.track


def test_PasteInstance(min_tracks_2node_labels: Labels):
"""Test that pasting an instance works as expected."""
labels = min_tracks_2node_labels
Expand Down