Skip to content

Commit

Permalink
fixed copy prior frame issue
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhavtrip29 committed Jan 12, 2024
1 parent 14b5b78 commit 5ac817b
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 4 deletions.
44 changes: 42 additions & 2 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,6 +3019,23 @@ def set_visible_nodes(

return has_missing_nodes

@staticmethod
def find_last_user_instance(
prev_frame: LabeledFrame,
) -> Optional[Instance]:
"""Find last user instance to copy from
Args:
prev_frame: The last labeled frame from which we obtain the last user instance.
Returns:
The last user instance in the previous frame (if present), otherwise null
"""

user_instances = prev_frame.user_instances
if len(user_instances) > 0:
return user_instances[-1]

@staticmethod
def find_instance_to_copy_from(
context: CommandContext,
Expand Down Expand Up @@ -3067,25 +3084,48 @@ def find_instance_to_copy_from(
prev_idx = AddInstance.get_previous_frame_index(context)

if prev_idx is not None:
prev_instances = context.labels.find(
prev_frame = context.labels.find(
context.state["video"], prev_idx, return_new=True
)[0].instances
)[0]
prev_instances = prev_frame.instances
if len(prev_instances) > len(context.state["labeled_frame"].instances):
# If more instances in previous frame than current, then use the
# first unmatched instance.
copy_instance = prev_instances[
len(context.state["labeled_frame"].instances)
]

if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in previous frame, if present
user_instance = AddInstance.find_last_user_instance(prev_frame)
if user_instance is not None:
copy_instance = user_instance

from_prev_frame = True
elif init_method == "best" and (
context.state["labeled_frame"].instances
):
# Otherwise, if there are already instances in current frame,
# copy the points from the last instance added to frame.
copy_instance = context.state["labeled_frame"].instances[-1]
if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in current frame
user_instance = AddInstance.find_last_user_instance(
context.state["labeled_frame"]
)
if user_instance is not None:
copy_instance = user_instance

elif len(prev_instances):
# Otherwise use the last instance added to previous frame.
copy_instance = prev_instances[-1]

if isinstance(copy_instance, PredictedInstance):
# Set copy instance to last user instance in previous frame, if present
user_instance = AddInstance.find_last_user_instance(prev_frame)
if user_instance is not None:
copy_instance = user_instance

from_prev_frame = True

from_predicted = from_predicted if hasattr(from_predicted, "score") else None
Expand Down
146 changes: 144 additions & 2 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SaveProjectAs,
get_new_version_filename,
)
from sleap.instance import Instance, LabeledFrame
from sleap.instance import Instance, LabeledFrame, Point
from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
Expand Down Expand Up @@ -216,7 +216,6 @@ def assert_videos_written(num_videos: int, labels_path: str = None):
context.state["filename"] = None

if csv:

context.state["filename"] = centered_pair_predictions_hdf5_path

params = {"all_videos": True, "csv": csv}
Expand Down Expand Up @@ -617,6 +616,149 @@ def test_CopyInstance(min_tracks_2node_labels: Labels):
assert context.state["clipboard_instance"] == instance


def test_CopyPriorFramePreviousUser(centered_pair_predictions: Labels):
"""Test that we copy prior user frame when available."""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

# No user instances in current frame, only predicted
assert len(context.state["labeled_frame"].user_instances) == 0
assert len(context.state["labeled_frame"].predicted_instances) == 2

# Modify previous frame to have a user instance
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_instances = prev_frame.instances

# No user instances in previous frame, only predicted
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=prev_frame,
)
prev_instances.append(user_inst)

# Confirm there is one user instance in previous frame
assert len(prev_frame.user_instances) == 1

context.newInstance(init_method="prior_frame")

# Confirm that the newly added user instance is the same as the sole user instance in the previous frame+
newly_added_instance = context.state["labeled_frame"].user_instances[0]
assert newly_added_instance.video == user_inst.video
assert newly_added_instance.points == user_inst.points
assert newly_added_instance.track == user_inst.track


def test_CopyPriorFramePreviousUser2(centered_pair_predictions: Labels):
"""Test that we copy user instance in previous frame when the current frame has no instances."""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["labeled_frame"].instances = []
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

# No instances in current frame
assert len(context.state["labeled_frame"].instances) == 0

# Get previous frame
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_instances = prev_frame.instances

# Confirm that previous frame has 2 predicted instances
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

# Add user instance to previous frame
skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=context.state["labeled_frame"],
)
prev_instances.insert(0, user_inst)

# Confirm addition of user instance in previous frame
assert len(prev_frame.user_instances) == 1

context.newInstance(init_method="best")

# Confirm that user instance in current frame is the same as the user instance in previous frame
current_user_instance = context.state["labeled_frame"].user_instances[0]
assert current_user_instance.video == user_inst.video
assert current_user_instance.points == user_inst.points
assert current_user_instance.track == user_inst.track


def test_CopyPriorFrameCurrentUser(centered_pair_predictions: Labels):
"""Test that we copy user instance in current frame when the current frame has >= amount of instances
as the previous frame."""
context = CommandContext.from_labels(centered_pair_predictions)
context.state["labeled_frame"] = centered_pair_predictions.find(
centered_pair_predictions.videos[0], frame_idx=124
)[0]
context.state["video"] = centered_pair_predictions.videos[0]
context.state["frame_idx"] = 124
context.state["skeleton"] = centered_pair_predictions.skeleton

assert len(context.state["labeled_frame"].predicted_instances) == 2

# Get previous frame
prev_idx = 123
prev_frame = context.labels.find(context.state["video"], prev_idx, return_new=True)[
0
]
prev_instances = prev_frame.instances
assert len(prev_frame.user_instances) == 0
assert len(prev_frame.predicted_instances) == 2

# Add user instance to current frame
skeleton = centered_pair_predictions.skeleton
user_inst = Instance(
skeleton=skeleton,
points={node: Point(1, 1) for node in skeleton.nodes},
frame=context.state["labeled_frame"],
)
context.state["labeled_frame"].instances.insert(0, user_inst)

# Confirm addition of user instance in current frame
assert len(context.state["labeled_frame"].user_instances) == 1

# Confirm that current frame has more instances than previous frame
assert len(context.state["labeled_frame"].instances) > len(prev_frame.instances)

# Remove tracks from all instances so that unused_predictions is empty
for inst in context.state["labeled_frame"].instances:
inst.track = None
inst.from_predicted = inst

context.newInstance(init_method="best")

# Confirm that both user instances in the current frame are the same+
previous_user_instance = context.state["labeled_frame"].user_instances[0]
newly_added_instance = context.state["labeled_frame"].user_instances[1]
assert newly_added_instance.video == previous_user_instance.video
assert newly_added_instance.points == previous_user_instance.points
assert newly_added_instance.track == previous_user_instance.track


def test_PasteInstance(min_tracks_2node_labels: Labels):
"""Test that pasting an instance works as expected."""
labels = min_tracks_2node_labels
Expand Down

0 comments on commit 5ac817b

Please sign in to comment.