Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to analysis export for exporting predictions for all frames including those with no predictions #1624

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,16 +496,30 @@ def add_submenu_choices(menu, title, options, key):
add_menu_item(
export_csv_menu,
"export_csv_current",
"Current Video...",
"Current Video (only tracked frames)...",
self.commands.exportCSVFile,
)
add_menu_item(
export_csv_menu,
"export_csv_all",
"All Videos...",
"All Videos (only tracked frames)...",
lambda: self.commands.exportCSVFile(all_videos=True),
)

export_csv_menu.addSeparator()
add_menu_item(
export_csv_menu,
"export_csv_current_all_frames",
"Current Video (all frames)...",
lambda: self.commands.exportCSVFile(all_frames=True),
)
add_menu_item(
export_csv_menu,
"export_csv_all_all_frames",
"All Videos (all frames)...",
lambda: self.commands.exportCSVFile(all_videos=True, all_frames=True),
)

add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB)

fileMenu.addSeparator()
Expand Down
31 changes: 21 additions & 10 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,13 @@ def saveProjectAs(self):
"""Show gui to save project as a new file."""
self.execute(SaveProjectAs)

def exportAnalysisFile(self, all_videos: bool = False):
def exportAnalysisFile(self, all_videos: bool = False, all_frames: bool = False):
"""Shows gui for exporting analysis h5 file."""
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=False)
self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=False)

def exportCSVFile(self, all_videos: bool = False):
def exportCSVFile(self, all_videos: bool = False, all_frames: bool = False):
"""Shows gui for exporting analysis csv file."""
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=True)
self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=True)

def exportNWB(self):
"""Show gui for exporting nwb file."""
Expand Down Expand Up @@ -1142,12 +1142,23 @@ def do_action(cls, context: CommandContext, params: dict):
adaptor = NixAdaptor
else:
adaptor = SleapAnalysisAdaptor
adaptor.write(
filename=output_path,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)

if 'all_frames' in params and params['all_frames']:
adaptor.write(
filename=output_path,
all_frames=True,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)
else:
adaptor.write(
filename=output_path,
all_frames=False,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
Expand Down
15 changes: 11 additions & 4 deletions sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def write_occupancy_file(
print(f"Saved as {output_path}")


def write_csv_file(output_path, data_dict):
def write_csv_file(output_path, data_dict, all_frames):

"""Write CSV file with data from given dictionary.

Expand Down Expand Up @@ -348,14 +348,21 @@ def write_csv_file(output_path, data_dict):
tracks.append(detection)

tracks = pd.DataFrame(tracks)
tracks.to_csv(output_path, index=False)

if all_frames:
tracks = tracks.set_index('frame_idx')
tracks = tracks.reindex(range(0, len(data_dict['track_occupancy'])), fill_value=np.nan)
tracks = tracks.reset_index(drop=False)
tracks.to_csv(output_path, index=False)
else:
tracks.to_csv(output_path, index=False)


def main(
labels: Labels,
output_path: str,
labels_path: str = None,
all_frames: bool = True,
all_frames: bool = False,
video: Video = None,
csv: bool = False,
):
Comment on lines 348 to 368
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [361-447]

Consider adding error handling in the main function to manage potential exceptions that could be raised during the execution of the script. This would improve the robustness and user experience by providing more informative error messages and handling edge cases gracefully.

Expand Down Expand Up @@ -435,7 +442,7 @@ def main(
)

if csv:
write_csv_file(output_path, data_dict)
write_csv_file(output_path, data_dict, all_frames=all_frames)
else:
write_occupancy_file(output_path, data_dict, transpose=True)

Expand Down
8 changes: 5 additions & 3 deletions sleap/io/format/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def write(
filename: str,
source_object: Labels,
source_path: str = None,
all_frames: bool = False,
video: Video = None,
):
"""Writes csv file for :py:class:`Labels` `source_object`.
Expand All @@ -53,18 +54,19 @@ def write(
filename: The filename for the output file.
source_object: The :py:class:`Labels` from which to get data from.
source_path: Path for the labels object
video: The :py:class:`Video` from which toget data from. If no `video` is
all_frames: A boolean flag to determine whether to include all frames or
only those with tracking data in the export.
video: The :py:class:`Video` from which to get data from. If no `video` is
specified, then the first video in `source_object` videos list will be
used. If there are no :py:class:`Labeled Frame`s in the `video`, then no
analysis file will be written.
"""
from sleap.info.write_tracking_h5 import main as write_analysis

write_analysis(
labels=source_object,
output_path=filename,
labels_path=source_path,
all_frames=True,
all_frames=all_frames,
video=video,
csv=True,
)
3 changes: 2 additions & 1 deletion sleap/io/format/nix.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def write(
filename: str,
source_object: object,
source_path: Optional[str] = None,
all_frames: bool = False,
video: Optional[Video] = None,
):
"""Writes the object to a file."""
Expand Down Expand Up @@ -460,4 +461,4 @@ def write_data(block, source: Labels, video: Video):
print(f"\n\tWriting failed with following error:\n{e}!")
finally:
if nix_file is not None:
nix_file.close()
nix_file.close()
3 changes: 2 additions & 1 deletion sleap/io/format/sleap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def write(
filename: str,
source_object: Labels,
source_path: str = None,
all_frames: bool = False,
video: Video = None,
Comment on lines +132 to 133
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The write function has been updated to include a new all_frames parameter, which is consistent with the PR's objective to allow exporting data with rows for all frames. However, the docstring for the write function has not been updated to include the new parameter. It is important to maintain accurate documentation, so the docstring should be updated to describe the all_frames parameter and its effect on the function's behavior.

):
"""Writes analysis file for :py:class:`Labels` `source_object`.
vtsai881 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -147,6 +148,6 @@ def write(
labels=source_object,
output_path=filename,
labels_path=source_path,
all_frames=True,
all_frames=all_frames,
video=video,
)
60 changes: 60 additions & 0 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,66 @@ def assert_videos_written(num_videos: int, labels_path: str = None):
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Test with all_videos True and all_frames True
params = {"all_videos": True, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Test with all_videos False and all_frames True
params = {"all_videos": False, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=1, labels_path=context.state["filename"])

# Test with all_videos False and all_frames False
params = {"all_videos": False, "all_frames": False, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=1, labels_path=context.state["filename"])

# Add labels path and test with all_videos True and all_frames True (single video)
context.state["filename"] = str(tmpdir.with_name("path.to.labels"))
params = {"all_videos": True, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Add a video (no labels) and test with all_videos True and all_frames True
labels.add_video(small_robot_mp4_vid)

params = {"all_videos": True, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Test with videos with the same filename
(tmpdir / "session1").mkdir()
(tmpdir / "session2").mkdir()
shutil.copy(
centered_pair_predictions.video.backend.filename,
tmpdir / "session1" / "video.mp4",
)
shutil.copy(small_robot_mp4_vid.backend.filename, tmpdir / "session2" / "video.mp4")
labels.videos[0].backend.filename = str(tmpdir / "session1" / "video.mp4")
labels.videos[1].backend.filename = str(tmpdir / "session2" / "video.mp4")
params = {"all_videos": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Remove all videos and test
all_videos = list(labels.videos)
for video in all_videos:
labels.remove_video(labels.videos[-1])

params = {"all_videos": True, "all_frames": True, "csv": csv}
# Test with videos with the same filename
(tmpdir / "session1").mkdir()
(tmpdir / "session2").mkdir()
Expand Down
Loading