Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/murthylab/sleap into nn-…
Browse files Browse the repository at this point in the history
…interface
  • Loading branch information
ntabris committed Jan 23, 2020
2 parents cbec811 + ee84530 commit 4bb4cb3
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 43 deletions.
Empty file added sleap/config/path_prefixes.yaml
Empty file.
5 changes: 3 additions & 2 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,11 +1242,12 @@ def showLearningDialog(self, mode: str):
self._child_windows[mode].frame_selection = self._frames_for_prediction()
self._child_windows[mode].open()

def learningFinished(self):
def learningFinished(self, new_count: int):
"""Called when inference finishes."""
# we ran inference so update display/ui
self.on_data_update([UpdateTopic.all])
self.commands.changestack_push("new predictions")
if new_count:
self.commands.changestack_push("new predictions")

def visualizeOutputs(self):
"""Gui for adding overlay with live visualization of predictions."""
Expand Down
15 changes: 13 additions & 2 deletions sleap/gui/dataviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from PySide2 import QtCore, QtWidgets, QtGui

import numpy as np
import os

from operator import itemgetter
Expand Down Expand Up @@ -140,7 +141,11 @@ def headerData(

return None

def sort(self, column_idx: int, order: QtCore.Qt.SortOrder):
def sort(
self,
column_idx: int,
order: QtCore.Qt.SortOrder = QtCore.Qt.SortOrder.AscendingOrder,
):
"""Sorts table by given column and order."""
prop = self.properties[column_idx]
reverse = order == QtCore.Qt.SortOrder.DescendingOrder
Expand All @@ -150,8 +155,14 @@ def sort(self, column_idx: int, order: QtCore.Qt.SortOrder):
if "video" in self.properties and "frame" in self.properties:
sort_function = itemgetter("video", "frame")

def string_safe_sort(x):
try:
return float(sort_function(x))
except ValueError:
return -np.inf

self.beginResetModel()
self._data.sort(key=sort_function, reverse=reverse)
self._data.sort(key=string_safe_sort, reverse=reverse)
self.endResetModel()

def get_from_idx(self, index: QtCore.QModelIndex):
Expand Down
4 changes: 2 additions & 2 deletions sleap/gui/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class InferenceDialog(QtWidgets.QDialog):
mode: String which specified mode ("learning", "expert", or "inference").
"""

learningFinished = QtCore.Signal()
learningFinished = QtCore.Signal(int)

def __init__(
self,
Expand Down Expand Up @@ -415,7 +415,7 @@ def run(self):
frames_to_predict=frames_to_predict,
)

self.learningFinished.emit()
self.learningFinished.emit(new_counts)

if new_counts >= 0:
QtWidgets.QMessageBox(
Expand Down
10 changes: 3 additions & 7 deletions sleap/gui/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Dict, List, Union

from sleap.util import get_config_file
from sleap import util


class ShortcutDialog(QtWidgets.QDialog):
Expand Down Expand Up @@ -141,9 +141,7 @@ class Shortcuts(object):
)

def __init__(self):
shortcut_yaml = get_config_file("shortcuts.yaml")
with open(shortcut_yaml, "r") as f:
shortcuts = yaml.load(f, Loader=yaml.SafeLoader)
shortcuts = util.get_config_yaml("shortcuts.yaml")

for action in shortcuts:
key_string = shortcuts.get(action, None)
Expand All @@ -158,9 +156,7 @@ def __init__(self):

def save(self):
"""Saves all shortcuts to shortcut file."""
shortcut_yaml = get_config_file("shortcuts.yaml")
with open(shortcut_yaml, "w") as f:
yaml.dump(self._shortcuts, f)
util.save_config_yaml("shortcuts.yaml", self._shortcuts)

def __getitem__(self, idx: Union[slice, int, str]) -> Union[str, Dict[str, str]]:
"""
Expand Down
56 changes: 28 additions & 28 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ def _update_from_labels(self, merge: bool = False):
# Add any videos that are present in the labels but
# missing from the video list
if merge or len(self.videos) == 0:
# find videos in labeled frames that aren't yet in top level videos
new_videos = {label.video for label in self.labels} - set(self.videos)
# find videos in labeled frames or suggestions
# that aren't yet in top level videos
lf_videos = {label.video for label in self.labels}
suggestion_videos = {sug.video for sug in self.suggestions}
new_videos = lf_videos.union(suggestion_videos) - set(self.videos)
# just add the new videos so we don't re-order current list
if len(new_videos):
self.videos.extend(list(new_videos))
Expand Down Expand Up @@ -1331,31 +1334,24 @@ def make_video_callback(cls, search_paths: Optional[List] = None) -> Callable:
search_paths = search_paths or []

def video_callback(video_list, new_paths=search_paths):
# Check each video
for video_item in video_list:
if "backend" in video_item and "filename" in video_item["backend"]:
current_filename = video_item["backend"]["filename"]
# check if we can find video
if not os.path.exists(current_filename):

current_basename = os.path.basename(current_filename)
# handle unix, windows, or mixed paths
if current_basename.find("/") > -1:
current_basename = current_basename.split("/")[-1]
if current_basename.find("\\") > -1:
current_basename = current_basename.split("\\")[-1]

# First see if we can find the file in another directory,
# and if not, prompt the user to find the file.

# We'll check in the current working directory, and if the user has
# already found any missing videos, check in the directory of those.
for path_dir in new_paths:
check_path = os.path.join(path_dir, current_basename)
if os.path.exists(check_path):
# we found the file in a different directory
video_item["backend"]["filename"] = check_path
break
filenames = [item["backend"]["filename"] for item in video_list]
missing = pathutils.list_file_missing(filenames)

# Try changing the prefix using saved patterns
if sum(missing):
pathutils.fix_paths_with_saved_prefix(filenames, missing)

# Check for file in search_path directories
if sum(missing) and new_paths:
for i, filename in enumerate(filenames):
fixed_path = find_path_using_paths(filename, new_paths)
if fixed_path != filename:
filenames[i] = fixed_path
missing[i] = False

# Replace the video filenames with changes by user
for i, item in enumerate(video_list):
item["backend"]["filename"] = filenames[i]

return video_callback

Expand Down Expand Up @@ -1383,7 +1379,11 @@ def gui_video_callback(video_list, new_paths=search_paths):
filenames = [item["backend"]["filename"] for item in video_list]
missing = pathutils.list_file_missing(filenames)

# First check for file in search_path directories
# Try changing the prefix using saved patterns
if sum(missing):
pathutils.fix_paths_with_saved_prefix(filenames, missing)

# Check for file in search_path directories
if sum(missing) and new_paths:
for i, filename in enumerate(filenames):
fixed_path = find_path_using_paths(filename, new_paths)
Expand Down
39 changes: 38 additions & 1 deletion sleap/io/pathutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

import os
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple

from sleap import util


def list_file_missing(filenames):
Expand Down Expand Up @@ -64,6 +66,10 @@ def filenames_prefix_change(
filenames[i] = try_filename
check[i] = False

# Save prefix change in config file so that it can be used
# automatically in the future
save_path_prefix_replacement(old_prefix, new_prefix)


def fix_path_separator(path: str):
return path.replace("\\", "/")
Expand Down Expand Up @@ -99,3 +105,34 @@ def find_changed_subpath(old_path: str, new_path: str) -> Tuple[str, str]:
new_inital = new_path[: new_char_idx + 1] if new_char_idx < -1 else new_path

return (old_initial, new_inital)


def fix_paths_with_saved_prefix(filenames, missing: Optional[List[bool]] = None):
path_prefix_conversions = util.get_config_yaml("path_prefixes.yaml")

if path_prefix_conversions is None:
return

for i, filename in enumerate(filenames):
if missing is not None:
if not missing[i]:
continue
elif os.path.exists(filename):
continue

for old_prefix, new_prefix in path_prefix_conversions.items():
if filename.startswith(old_prefix):
try_filename = filename.replace(old_prefix, new_prefix)
try_filename = fix_path_separator(try_filename)

if os.path.exists(try_filename):
filenames[i] = try_filename
if missing is not None:
missing[i] = False
continue


def save_path_prefix_replacement(old_prefix: str, new_prefix: str):
data = util.get_config_yaml("path_prefixes.yaml") or dict()
data[old_prefix] = new_prefix
util.save_config_yaml("path_prefixes.yaml", data)
17 changes: 16 additions & 1 deletion sleap/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import psutil
import json
import rapidjson
import yaml

from typing import Any, Dict, Hashable, Iterable, List, Optional

from sleap.io import pathutils


def json_loads(json_str: str) -> Dict:
"""
Expand Down Expand Up @@ -267,7 +270,7 @@ def get_config_file(shortname: str) -> str:
if not os.path.exists(desired_path):
package_path = get_package_file(f"sleap/config/{shortname}")
if not os.path.exists(package_path):
return FileNotFoundError(
raise FileNotFoundError(
f"Cannot locate {shortname} config file at {desired_path} or {package_path}."
)
# Make sure there's a ~/.sleap/ directory to store user version of the
Expand All @@ -283,6 +286,18 @@ def get_config_file(shortname: str) -> str:
return desired_path


def get_config_yaml(shortname: str) -> dict:
config_path = get_config_file(shortname)
with open(config_path, "r") as f:
return yaml.load(f, Loader=yaml.SafeLoader)


def save_config_yaml(shortname: str, data: Any) -> dict:
yaml_path = get_config_file(shortname)
with open(yaml_path, "w") as f:
yaml.dump(data, f)


def make_scoped_dictionary(
flat_dict: Dict[str, Any], exclude_nones: bool = True
) -> Dict[str, Dict[str, Any]]:
Expand Down
12 changes: 12 additions & 0 deletions tests/gui/test_dataviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ def test_skeleton_nodes(qtbot, centered_pair_predictions):
)
table.selectRow(1)
assert table.model().data(table.currentIndex()) == "21/24"


def test_table_sort_string(qtbot):
table_model = GenericTableModel(
items=[dict(a=1, b=2), dict(a=2, b="")], properties=["a", "b"]
)

table = GenericTableView(is_sortable=True, model=table_model)

# Make sure we can sort with both numbers and strings (i.e., "")
table.model().sort(0)
table.model().sort(1)
13 changes: 13 additions & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,3 +821,16 @@ def test_path_fix(tmpdir):
# Make sure we got the actual video path by searching that directory
assert len(labels.videos) == 1
assert labels.videos[0].filename == "tests/data/videos/small_robot.mp4"


def test_local_path_save(tmpdir, monkeypatch):

filename = "test.h5"

# Set current working directory (monkeypatch isolates other tests)
monkeypatch.chdir(tmpdir)

# Try saving with relative path
Labels.save_file(filename=filename, labels=Labels())

assert os.path.exists(os.path.join(tmpdir, filename))

0 comments on commit 4bb4cb3

Please sign in to comment.