From b72c0a20d1ade3e7e6e5c00f2904a9d6d34f684c Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 6 Sep 2019 16:26:24 -0400 Subject: [PATCH 001/176] include pid in tmp dir name for json.zip this prevents conflicts when multiple processes are using the same json.zip file, as when training models in parallel. --- sleap/io/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 42686a268..05177a32e 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -907,7 +907,7 @@ def load_json(cls, filename: str, # Make a tmpdir, located in the directory that the file exists, to unzip # its contents. tmp_dir = os.path.join(os.path.dirname(filename), - f"tmp_{os.path.basename(filename)}") + f"tmp_{os.getpid()}_{os.path.basename(filename)}") if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir, ignore_errors=True) try: From 229802285877f5b2acc09b9dcd2eca51ee533b74 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 6 Sep 2019 16:52:40 -0400 Subject: [PATCH 002/176] more tests --- tests/io/test_video.py | 6 +++++- tests/test_util.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/io/test_video.py b/tests/io/test_video.py index f694cc19a..e7d18bc1d 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -3,7 +3,7 @@ import numpy as np -from sleap.io.video import Video +from sleap.io.video import Video, HDF5Video, MediaVideo from tests.fixtures.videos import TEST_H5_FILE, TEST_SMALL_ROBOT_MP4_FILE # FIXME: @@ -11,6 +11,10 @@ # of redundant test code here. # See: https://github.com/pytest-dev/pytest/issues/349 +def test_from_filename(): + assert type(Video.from_filename(TEST_H5_FILE).backend) == HDF5Video + assert type(Video.from_filename(TEST_SMALL_ROBOT_MP4_FILE).backend) == MediaVideo + def test_hdf5_get_shape(hdf5_vid): assert(hdf5_vid.shape == (42, 512, 512, 1)) diff --git a/tests/test_util.py b/tests/test_util.py index 50ce8da83..cbc6e72b7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,7 @@ from typing import List, Dict -from sleap.util import attr_to_dtype +from sleap.util import attr_to_dtype, frame_list def test_attr_to_dtype(): """ @@ -41,3 +41,6 @@ class TestAttr3: with pytest.raises(TypeError): attr_to_dtype(TestAttr3) +def test_frame_list(): + assert frame_list("3-5") == [3,4,5] + assert frame_list("7,10") == [7,10] \ No newline at end of file From f80bdeeaed70ce56a8ac06a893e91d31c676d152 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 6 Sep 2019 17:10:00 -0400 Subject: [PATCH 003/176] more tests --- tests/io/test_dataset.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 70f2d37ff..351d0fafd 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -155,6 +155,23 @@ def test_label_accessors(centered_pair_labels): assert len(labels.find(video)) == 70 assert labels[video] == labels.find(video) + f = labels.frames(video, from_frame_idx=1) + assert next(f).frame_idx == 15 + assert next(f).frame_idx == 31 + + f = labels.frames(video, from_frame_idx=31, reverse=True) + assert next(f).frame_idx == 15 + + f = labels.frames(video, from_frame_idx=0, reverse=True) + assert next(f).frame_idx == 1092 + next(f) + next(f) + # test that iterator now has fewer items left + assert len(list(f)) == 70-3 + + assert labels.instance_count(video, 15) == 2 + assert labels.instance_count(video, 7) == 0 + assert labels[0].video == video assert labels[0].frame_idx == 0 @@ -166,6 +183,7 @@ def test_label_accessors(centered_pair_labels): assert labels.find(video, 954)[0] == labels[61] assert labels.find_first(video) == labels[0] assert labels.find_first(video, 954) == labels[61] + assert labels.find_last(video) == labels[69] assert labels[video, 954] == labels[61] assert labels[video, 0] == labels[0] assert labels[video] == labels.labels From 96b1a4aa2a6d541cfb79743391530d64c982e5fc Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 6 Sep 2019 17:58:32 -0400 Subject: [PATCH 004/176] bug fix to pass new tests --- sleap/io/dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 05177a32e..3197759e5 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -262,10 +262,13 @@ def frames(self, video: Video, from_frame_idx: int = -1, reverse=False): if video not in self._frame_idx_map: return None # Get sorted list of frame indexes for this video - frame_idxs = sorted(self._frame_idx_map[video].keys(), reverse=reverse) + frame_idxs = sorted(self._frame_idx_map[video].keys()) - # Find the next frame index after the specified frame - next_frame_idx = min(filter(lambda x: x > from_frame_idx, frame_idxs), default=frame_idxs[0]) + # Find the next frame index after (before) the specified frame + if not reverse: + next_frame_idx = min(filter(lambda x: x > from_frame_idx, frame_idxs), default=frame_idxs[0]) + else: + next_frame_idx = max(filter(lambda x: x < from_frame_idx, frame_idxs), default=frame_idxs[-1]) cut_list_idx = frame_idxs.index(next_frame_idx) # Shift list of frame indices to start with specified frame From 7e68b3d230af9652b944739fb43c18e0b8e1d3bb Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 9 Sep 2019 08:47:29 -0400 Subject: [PATCH 005/176] don't autoconvert to grayscale --- sleap/nn/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index ee72d6dcb..6782b1a74 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -149,7 +149,7 @@ def predict(self, if isinstance(input_video, dict): vid = Video.cattr().structure(input_video, Video) elif isinstance(input_video, str): - vid = Video.from_filename(input_video) + vid = Video.from_filename(input_video, grayscale=False) else: raise AttributeError(f"Unable to load input video: {input_video}") From 3685c809a698c1ccbf6b3addd4917bc306f34da5 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 9 Sep 2019 09:40:50 -0400 Subject: [PATCH 006/176] re-arranged methods in Labels class --- sleap/io/dataset.py | 268 +++++++++++++++++++++------------------ tests/io/test_dataset.py | 18 +++ 2 files changed, 162 insertions(+), 124 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 3197759e5..fb52a262e 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -74,6 +74,7 @@ def json_dumps(d: Dict, filename: str = None): """ LABELS_JSON_FILE_VERSION = "2.0.0" + @attr.s(auto_attribs=True) class Labels(MutableSequence): """ @@ -91,6 +92,8 @@ class Labels(MutableSequence): skeletons: A list of skeletons that these labels may or may not reference. tracks: A list of tracks that instances can belong to. suggestions: A dict with a list for each video of suggested frames to label. + negative_anchors: A dict with list of anchor coordinates + for negative training samples for each video. """ labeled_frames: List[LabeledFrame] = attr.ib(default=attr.Factory(list)) @@ -182,10 +185,6 @@ def labels(self): """ Alias for labeled_frames """ return self.labeled_frames - @property - def user_labeled_frames(self): - return [lf for lf in self.labeled_frames if lf.has_user_instances] - def __len__(self): return len(self.labeled_frames) @@ -227,6 +226,59 @@ def __getitem__(self, key): else: raise KeyError("Invalid label indexing arguments.") + def __setitem__(self, index, value: LabeledFrame): + # TODO: Maybe we should remove this method altogether? + self.labeled_frames.__setitem__(index, value) + self._update_containers(value) + + def _update_containers(self, new_label: LabeledFrame): + """ Ensure that top-level containers are kept updated with new + instances of objects that come along with new labels. """ + + if new_label.video not in self.videos: + self.videos.append(new_label.video) + + for skeleton in {instance.skeleton for instance in new_label}: + if skeleton not in self.skeletons: + self.skeletons.append(skeleton) + for node in skeleton.nodes: + if node not in self.nodes: + self.nodes.append(node) + + # Add any new Tracks as well + for instance in new_label.instances: + if instance.track and instance.track not in self.tracks: + self.tracks.append(instance.track) + + # Sort the tracks again + self.tracks.sort(key=lambda t: (t.spawned_on, t.name)) + + # Update cache datastructures + if new_label.video not in self._lf_by_video: + self._lf_by_video[new_label.video] = [] + if new_label.video not in self._frame_idx_map: + self._frame_idx_map[new_label.video] = dict() + self._lf_by_video[new_label.video].append(new_label) + self._frame_idx_map[new_label.video][new_label.frame_idx] = new_label + + def insert(self, index, value: LabeledFrame): + if value in self or (value.video, value.frame_idx) in self: + return + + self.labeled_frames.insert(index, value) + self._update_containers(value) + + def append(self, value: LabeledFrame): + self.insert(len(self) + 1, value) + + def __delitem__(self, key): + self.labeled_frames.remove(self.labeled_frames[key]) + + def remove(self, value: LabeledFrame): + self.labeled_frames.remove(value) + self._lf_by_video[new_label.video].remove(value) + del self._frame_idx_map[new_label.video][value.frame_idx] + def find(self, video: Video, frame_idx: Union[int, range] = None, return_new: bool=False) -> List[LabeledFrame]: """ Search for labeled frames given video and/or frame index. @@ -310,6 +362,12 @@ def find_last(self, video: Video, frame_idx: int = None) -> LabeledFrame: if label.video == video and (frame_idx is None or (label.frame_idx == frame_idx)): return label + @property + def user_labeled_frames(self): + return [lf for lf in self.labeled_frames if lf.has_user_instances] + + # Methods for instances + def instance_count(self, video: Video, frame_idx: int) -> int: count = 0 labeled_frame = self.find_first(video, frame_idx) @@ -317,6 +375,33 @@ def instance_count(self, video: Video, frame_idx: int) -> int: count = len([inst for inst in labeled_frame.instances if type(inst)==Instance]) return count + + @property + def all_instances(self): + return list(self.instances()) + + @property + def user_instances(self): + return [inst for inst in self.all_instances if type(inst) == Instance] + + def instances(self, video: Video = None, skeleton: Skeleton = None): + """ Iterate through all instances in the labels, optionally with filters. + + Args: + video: Only iterate through instances in this video + skeleton: Only iterate through instances with this skeleton + + Yields: + Instance: The next labeled instance + """ + for label in self.labels: + if video is None or label.video == video: + for instance in label.instances: + if skeleton is None or instance.skeleton == skeleton: + yield instance + + # Methods for tracks + def get_track_occupany(self, video: Video): try: return self._track_occupancy[video] @@ -440,82 +525,66 @@ def does_track_match(inst, tr, labeled_frame): def find_track_instances(self, *args, **kwargs) -> List[Instance]: return [inst for lf, inst in self.find_track_occupancy(*args, **kwargs)] - @property - def all_instances(self): - return list(self.instances()) - - @property - def user_instances(self): - return [inst for inst in self.all_instances if type(inst) == Instance] - - def instances(self, video: Video = None, skeleton: Skeleton = None): - """ Iterate through all instances in the labels, optionally with filters. - - Args: - video: Only iterate through instances in this video - skeleton: Only iterate through instances with this skeleton - - Yields: - Instance: The next labeled instance + # Methods for suggestions + + def get_video_suggestions(self, video:Video) -> list: """ - for label in self.labels: - if video is None or label.video == video: - for instance in label.instances: - if skeleton is None or instance.skeleton == skeleton: - yield instance - - def _update_containers(self, new_label: LabeledFrame): - """ Ensure that top-level containers are kept updated with new - instances of objects that come along with new labels. """ - - if new_label.video not in self.videos: - self.videos.append(new_label.video) - - for skeleton in {instance.skeleton for instance in new_label}: - if skeleton not in self.skeletons: - self.skeletons.append(skeleton) - for node in skeleton.nodes: - if node not in self.nodes: - self.nodes.append(node) - - # Add any new Tracks as well - for instance in new_label.instances: - if instance.track and instance.track not in self.tracks: - self.tracks.append(instance.track) + Returns the list of suggested frames for the specified video + or suggestions for all videos (if no video specified). + """ + return self.suggestions.get(video, list()) - # Sort the tracks again - self.tracks.sort(key=lambda t: (t.spawned_on, t.name)) + def get_suggestions(self) -> list: + """Return all suggestions as a list of (video, frame) tuples.""" + suggestion_list = [(video, frame_idx) + for video in self.videos + for frame_idx in self.get_video_suggestions(video) + ] + return suggestion_list - # Update cache datastructures - if new_label.video not in self._lf_by_video: - self._lf_by_video[new_label.video] = [] - if new_label.video not in self._frame_idx_map: - self._frame_idx_map[new_label.video] = dict() - self._lf_by_video[new_label.video].append(new_label) - self._frame_idx_map[new_label.video][new_label.frame_idx] = new_label + def get_next_suggestion(self, video, frame_idx, seek_direction=1) -> list: + """Returns a (video, frame_idx) tuple.""" + # make sure we have valid seek_direction + if seek_direction not in (-1, 1): return (None, None) + # make sure the video belongs to this Labels object + if video not in self.videos: return (None, None) - def __setitem__(self, index, value: LabeledFrame): - # TODO: Maybe we should remove this method altogether? - self.labeled_frames.__setitem__(index, value) - self._update_containers(value) + all_suggestions = self.get_suggestions() - def insert(self, index, value: LabeledFrame): - if value in self or (value.video, value.frame_idx) in self: - return + # If we're currently on a suggestion, then follow order of list + if (video, frame_idx) in all_suggestions: + suggestion_idx = all_suggestions.index((video, frame_idx)) + new_idx = (suggestion_idx+seek_direction)%len(all_suggestions) + video, frame_suggestion = all_suggestions[new_idx] - self.labeled_frames.insert(index, value) - self._update_containers(value) + # Otherwise, find the prev/next suggestion sorted by frame order + else: + # look for next (or previous) suggestion in current video + if seek_direction == 1: + frame_suggestion = min((i for i in self.get_video_suggestions(video) if i > frame_idx), default=None) + else: + frame_suggestion = max((i for i in self.get_video_suggestions(video) if i < frame_idx), default=None) + if frame_suggestion is not None: return (video, frame_suggestion) + # if we didn't find suggestion in current video, + # then we want earliest frame in next video with suggestions + next_video_idx = (self.videos.index(video) + seek_direction) % len(self.videos) + video = self.videos[next_video_idx] + if seek_direction == 1: + frame_suggestion = min((i for i in self.get_video_suggestions(video)), default=None) + else: + frame_suggestion = max((i for i in self.get_video_suggestions(video)), default=None) + return (video, frame_suggestion) - def append(self, value: LabeledFrame): - self.insert(len(self) + 1, value) + def set_suggestions(self, suggestions:Dict[Video, list]): + """Sets the suggested frames.""" + self.suggestions = suggestions - def __delitem__(self, key): - self.labeled_frames.remove(self.labeled_frames[key]) + def delete_suggestions(self, video): + """Deletes suggestions for specified video.""" + if video in self.suggestions: + del self.suggestions[video] - def remove(self, value: LabeledFrame): - self.labeled_frames.remove(value) - self._lf_by_video[new_label.video].remove(value) - del self._frame_idx_map[new_label.video][value.frame_idx] + # Methods for videos def add_video(self, video: Video): """ Add a video to the labels if it is not already in it. @@ -546,8 +615,7 @@ def remove_video(self, video: Video): self.labeled_frames.remove(label) # Delete data that's indexed by video - if video in self.suggestions: - del self.suggestions[video] + self.delete_suggestions(video) if video in self.negative_anchors: del self.negative_anchors[video] @@ -560,6 +628,8 @@ def remove_video(self, video: Video): if video in self._frame_idx_map: del self._frame_idx_map[video] + # Methods for negative anchors + def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): """Adds a location for a negative training sample. @@ -572,57 +642,7 @@ def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): self.negative_anchors[video] = [] self.negative_anchors[video].append((frame_idx, *where)) - def get_video_suggestions(self, video:Video) -> list: - """ - Returns the list of suggested frames for the specified video - or suggestions for all videos (if no video specified). - """ - return self.suggestions.get(video, list()) - - def get_suggestions(self) -> list: - """Return all suggestions as a list of (video, frame) tuples.""" - suggestion_list = [(video, frame_idx) - for video in self.videos - for frame_idx in self.get_video_suggestions(video) - ] - return suggestion_list - - def get_next_suggestion(self, video, frame_idx, seek_direction=1) -> list: - """Returns a (video, frame_idx) tuple.""" - # make sure we have valid seek_direction - if seek_direction not in (-1, 1): return (None, None) - # make sure the video belongs to this Labels object - if video not in self.videos: return (None, None) - - all_suggestions = self.get_suggestions() - - # If we're currently on a suggestion, then follow order of list - if (video, frame_idx) in all_suggestions: - suggestion_idx = all_suggestions.index((video, frame_idx)) - new_idx = (suggestion_idx+seek_direction)%len(all_suggestions) - video, frame_suggestion = all_suggestions[new_idx] - - # Otherwise, find the prev/next suggestion sorted by frame order - else: - # look for next (or previous) suggestion in current video - if seek_direction == 1: - frame_suggestion = min((i for i in self.get_video_suggestions(video) if i > frame_idx), default=None) - else: - frame_suggestion = max((i for i in self.get_video_suggestions(video) if i < frame_idx), default=None) - if frame_suggestion is not None: return (video, frame_suggestion) - # if we didn't find suggestion in current video, - # then we want earliest frame in next video with suggestions - next_video_idx = (self.videos.index(video) + seek_direction) % len(self.videos) - video = self.videos[next_video_idx] - if seek_direction == 1: - frame_suggestion = min((i for i in self.get_video_suggestions(video)), default=None) - else: - frame_suggestion = max((i for i in self.get_video_suggestions(video)), default=None) - return (video, frame_suggestion) - - def set_suggestions(self, suggestions:Dict[Video, list]): - """Sets the suggested frames.""" - self.suggestions = suggestions + # Methods for saving/loading def extend_from(self, new_frames): """Merge data from another Labels object or list of LabeledFrames into self. diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 351d0fafd..67c1e13c1 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -6,6 +6,7 @@ from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance from sleap.io.video import Video, MediaVideo from sleap.io.dataset import Labels, load_labels_json_old +from sleap.gui.suggestions import VideoFrameSuggestions TEST_H5_DATASET = 'tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5' @@ -284,6 +285,23 @@ def test_instance_access(): assert len(list(labels.instances(video=dummy_video))) == 20 assert len(list(labels.instances(video=dummy_video2))) == 30 +def test_suggestions(small_robot_mp4_vid): + dummy_video = small_robot_mp4_vid + dummy_skeleton = Skeleton() + dummy_instance = Instance(dummy_skeleton) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance,]) + + labels = Labels() + labels.append(dummy_frame) + + suggestions = dict() + suggestions[dummy_video] = VideoFrameSuggestions.suggest( + dummy_video, + params=dict(method="random", per_video=13)) + labels.set_suggestions(suggestions) + + assert len(labels.get_video_suggestions(dummy_video)) == 13 + def test_load_labels_mat(mat_labels): assert len(mat_labels.nodes) == 6 From 759328261ffecbf3d821dd92612c2da1d0b2f8bc Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 9 Sep 2019 13:24:48 -0400 Subject: [PATCH 007/176] bug fix when creating Instance w/ zero points --- sleap/instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/instance.py b/sleap/instance.py index d03f8cfd0..4d38f0f4f 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -351,7 +351,7 @@ def __attrs_post_init__(self): # If the user did not pass a points list initialize a point array for future # points. - if self._points is None: + if self._points is None or len(self._points) == 0: # Initialize an empty point array that is the size of the skeleton. self._points = self._point_array_type.make_default(len(self.skeleton.nodes)) From e8151b151880d3326597e9e4e17c1609d870c90f Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 9 Sep 2019 16:16:03 -0400 Subject: [PATCH 008/176] predict on range of frames --- sleap/gui/overlays/base.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 0aac7faf8..3281edb35 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -21,6 +21,7 @@ class ModelData: model: 'keras.Model' video: Video do_rescale: bool=False + adjust_vals: bool=True def __getitem__(self, i): """Data data for frame i from predictor.""" @@ -42,17 +43,19 @@ def __getitem__(self, i): frame_result = inference_transform.invert_scale(frame_result) # We just want the single image results - frame_result = frame_result[0] - - # If max value is below 1, amplify values so max is 1. - # This allows us to visualize model with small ptp value - # even though this model may not give us adequate predictions. - max_val = np.max(frame_result) - if max_val < 1: - frame_result = frame_result/np.max(frame_result) - - # Clip values to ensure that they're within [0, 1] - frame_result = np.clip(frame_result, 0, 1) + if type(i) != slice: + frame_result = frame_result[0] + + if self.adjust_vals: + # If max value is below 1, amplify values so max is 1. + # This allows us to visualize model with small ptp value + # even though this model may not give us adequate predictions. + max_val = np.max(frame_result) + if max_val < 1: + frame_result = frame_result/np.max(frame_result) + + # Clip values to ensure that they're within [0, 1] + frame_result = np.clip(frame_result, 0, 1) return frame_result From 41eb461379b9d50726e5b7c42658f0302f06f989 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 08:39:29 -0400 Subject: [PATCH 009/176] opencv-python 3.4.2.17 => 3.4.1.15 This seems to get around issue with missing DLLs. --- .conda/bld.bat | 2 +- environment.yml | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.conda/bld.bat b/.conda/bld.bat index 3c9668685..aa46b2eb7 100644 --- a/.conda/bld.bat +++ b/.conda/bld.bat @@ -13,7 +13,7 @@ rem # this out myself, ughhh. set PIP_NO_INDEX=False set PIP_NO_DEPENDENCIES=False set PIP_IGNORE_INSTALLED=False -pip install cattrs==1.0.0rc opencv-python==3.4.2.17 PySide2==5.12.0 imgaug qimage2ndarray==1.8 imgstore +pip install cattrs==1.0.0rc opencv-python==3.4.1.15 PySide2==5.12.0 imgaug qimage2ndarray==1.8 imgstore rem # Use and update environment.yml call to install pip dependencies. This is slick. rem # While environment.yml contains the non pip dependencies, the only thing left diff --git a/environment.yml b/environment.yml index 7d929639a..d9d3f0684 100644 --- a/environment.yml +++ b/environment.yml @@ -17,7 +17,7 @@ dependencies: - python-rapidjson - pip - pip: - - opencv-python==3.4.2.17 + - opencv-python==3.4.1.15 - PySide2==5.12.0 - imgaug - cattrs==1.0.0rc0 diff --git a/requirements.txt b/requirements.txt index dd0c26892..2d2c82a6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ tensorflow keras h5py python-rapidjson -opencv-python==3.4.2.17 +opencv-python==3.4.1.15 pandas psutil PySide2 From 798235d35aecb1b74de67363f1586dd7ae6db899 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 09:11:58 -0400 Subject: [PATCH 010/176] expose gaussian size and sigma from peak_tf_inference --- sleap/nn/inference.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 6782b1a74..5d8f5fac5 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -67,9 +67,11 @@ class Predictor: nms_min_thresh: A threshold of non-max suppression peak finding in confidence maps. All values below this minimum threshold will be set to zero before peak finding algorithm is run. - nms_sigma: Gaussian blur is applied to confidence maps before - non-max supression peak finding occurs. This is the - standard deviation of the kernel applied to the image. + nms_kernal_size: Gaussian blur is applied to confidence maps before + non-max supression peak finding occurs. This is size of the + kernel applied to the image. + nms_sigma: For Gassian blur applied to confidence maps, this + is the standard deviation of the kernel. min_score_to_node_ratio: FIXME min_score_midpts: FIXME min_score_integral: FIXME @@ -90,6 +92,7 @@ class Predictor: read_chunk_size: int = 256 save_frequency: int = 100 # chunks nms_min_thresh = 0.3 + nms_kernal_size = 9 nms_sigma = 3 min_score_to_node_ratio: float = 0.2 min_score_midpts: float = 0.05 @@ -590,6 +593,8 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] model = conf_model["model"], data = imgs.astype("float32")/255, min_thresh=self.nms_min_thresh, + gaussian_size=self.nms_kernel_size, + gaussian_sigma=self.nms_sigma, downsample_factor=int(1/paf_model["multiscale"]), upsample_factor=int(1/conf_model["multiscale"]), return_confmaps=self.save_confmaps_pafs From 90910e11d5d4fe77943ca198ec3307252e9c44a1 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 09:40:51 -0400 Subject: [PATCH 011/176] bug fixes to last commit --- sleap/nn/inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 5d8f5fac5..6e6a4bc17 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -67,7 +67,7 @@ class Predictor: nms_min_thresh: A threshold of non-max suppression peak finding in confidence maps. All values below this minimum threshold will be set to zero before peak finding algorithm is run. - nms_kernal_size: Gaussian blur is applied to confidence maps before + nms_kernel_size: Gaussian blur is applied to confidence maps before non-max supression peak finding occurs. This is size of the kernel applied to the image. nms_sigma: For Gassian blur applied to confidence maps, this @@ -92,8 +92,8 @@ class Predictor: read_chunk_size: int = 256 save_frequency: int = 100 # chunks nms_min_thresh = 0.3 - nms_kernal_size = 9 - nms_sigma = 3 + nms_kernel_size: int = 9 + nms_sigma: float = 3. min_score_to_node_ratio: float = 0.2 min_score_midpts: float = 0.05 min_score_integral: float = 0.6 From 7724818c231f2e1990b6514259fa8485edeb30e6 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 09:41:39 -0400 Subject: [PATCH 012/176] determine channels from centroid model --- sleap/nn/inference.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 6e6a4bc17..abd870a2f 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -143,6 +143,18 @@ def predict(self, logger.info(f"Predict is async: {is_async}") + # Find out how many channels the model was trained on + + model_channels = 3 # default + + if ModelOutputType.CENTROIDS in self.sleap_models: + centroid_model = self.fetch_model( + input_size = None, + output_types = [ModelOutputType.CENTROIDS]) + model_channels = centroid_model["model"].input_shape[-1] + + grayscale = (model_channels == 1) + # Open the video if we need it. try: @@ -152,7 +164,7 @@ def predict(self, if isinstance(input_video, dict): vid = Video.cattr().structure(input_video, Video) elif isinstance(input_video, str): - vid = Video.from_filename(input_video, grayscale=False) + vid = Video.from_filename(input_video, grayscale=grayscale) else: raise AttributeError(f"Unable to load input video: {input_video}") From 87eaf1dfc760808b071d0102d3088c2e3f7a7eb4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 09:45:41 -0400 Subject: [PATCH 013/176] remove unused crop_iou_threshold --- sleap/nn/inference.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index abd870a2f..7f811716f 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -79,7 +79,6 @@ class Predictor: with_tracking: whether to run tracking after inference flow_window: The number of frames that tracking should look back when trying to identify instances. - crop_iou_threshold: FIXME single_per_crop: FIXME output_path: the output path to save the results save_confmaps_pafs: whether to save confmaps/pafs @@ -100,7 +99,6 @@ class Predictor: add_last_edge: bool = True with_tracking: bool = False flow_window: int = 15 - crop_iou_threshold: float = .9 single_per_crop: bool = False crop_padding: int = 40 crop_growth: int = 64 @@ -229,8 +227,7 @@ def predict(self, # Use centroid predictions to get subchunks of crops subchunks_to_process = self.centroid_crop_inference( - mov_full, frames_idx, - iou_threshold=self.crop_iou_threshold) + mov_full, frames_idx) else: # Scale without centroid cropping @@ -407,8 +404,7 @@ def predict_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: def centroid_crop_inference(self, imgs: np.ndarray, - frames_idx: List[int], - iou_threshold: float=.9) \ + frames_idx: List[int]) \ -> List[Tuple[np.ndarray, DataTransform]]: """ Takes stack of images and runs centroid inference to get crops. From edb9b42ebe2b780eca7bff2bf1040d57b1484bd6 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 10:14:55 -0400 Subject: [PATCH 014/176] args to control merge/size for centroid boxes --- sleap/nn/inference.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 7f811716f..d50d497ec 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -404,7 +404,9 @@ def predict_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: def centroid_crop_inference(self, imgs: np.ndarray, - frames_idx: List[int]) \ + frames_idx: List[int], + box_size: int=None, + do_merge: bool=True) \ -> List[Tuple[np.ndarray, DataTransform]]: """ Takes stack of images and runs centroid inference to get crops. @@ -452,13 +454,16 @@ def centroid_crop_inference(self, min_thresh=self.nms_min_thresh, sigma=self.nms_sigma) - # Get training bounding box size to determine (min) centroid crop size - crop_model_package = self.fetch_model( - input_size = None, - output_types = [ModelOutputType.CONFIDENCE_MAP]) - crop_size = crop_model_package["bounding_box_size"] - bb_half = (crop_size + self.crop_padding)//2 + if box_size is None: + # Get training bounding box size to determine (min) centroid crop size + crop_model_package = self.fetch_model( + input_size = None, + output_types = [ModelOutputType.CONFIDENCE_MAP]) + crop_size = crop_model_package["bounding_box_size"] + bb_half = (crop_size + self.crop_padding)//2 + else: + bb_half = box_size//2 logger.info(f" Centroid crop box size: {bb_half*2}") @@ -483,11 +488,17 @@ def centroid_crop_inference(self, boxes.append((peak_x-bb_half, peak_y-bb_half, peak_x+bb_half, peak_y+bb_half)) - # Merge overlapping boxes and pad to multiple of crop size - merged_boxes = merge_boxes_with_overlap_and_padding( - boxes=boxes, - pad_factor_box=(self.crop_growth, self.crop_growth), - within=crop_within) + if do_merge: + # Merge overlapping boxes and pad to multiple of crop size + merged_boxes = merge_boxes_with_overlap_and_padding( + boxes=boxes, + pad_factor_box=(self.crop_growth, self.crop_growth), + within=crop_within) + else: + # Just return the boxes centered around each centroid. + # Note that these aren't guaranteed to be within the + # image bounds, so take care if using these to crop. + merged_boxes = boxes # Keep track of all boxes, grouped by size and frame idx for box in merged_boxes: From 9c8ea9fd9ea6f7bc26e4cff48262ee4e8e20a387 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 11:13:17 -0400 Subject: [PATCH 015/176] use / for imgstore paths --- sleap/io/dataset.py | 4 +++- sleap/io/video.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index fb52a262e..2ed4877c7 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1312,7 +1312,9 @@ def save_frame_data_imgstore(self, output_dir: str = './', format: str = 'png', if v == lf.video and (all_labels or lf.has_user_instances)] - frames_filename = os.path.join(output_dir, f'frame_data_vid{v_idx}') + # Join with "/" instead of os.path.join() since we want + # path to work on Windows and Posix systems + frames_filename = output_dir + f'/frame_data_vid{v_idx}' vid = v.to_imgstore(path=frames_filename, frame_numbers=frame_nums, format=format) # Close the video for now diff --git a/sleap/io/video.py b/sleap/io/video.py index 5caad1294..a868269b8 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -349,7 +349,8 @@ def __attrs_post_init__(self): # If the filename does not contain metadata.yaml, append it to the filename # assuming that this is a directory that contains the imgstore. if 'metadata.yaml' not in self.filename: - self.filename = os.path.join(self.filename, 'metadata.yaml') + # Use "/" since this works on Windows and posix + self.filename = self.filename '/metadata.yaml' # Make relative path into absolute, ImgStores don't work properly it seems # without full paths if we change working directories. Video.fixup_path will From 0d9b067ab42127c3e3eccc0b1e7bc9197ed47f2e Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 10 Sep 2019 11:15:47 -0400 Subject: [PATCH 016/176] use / for imgstore paths (bug fix) --- sleap/io/video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index a868269b8..5a34396c9 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -350,7 +350,7 @@ def __attrs_post_init__(self): # assuming that this is a directory that contains the imgstore. if 'metadata.yaml' not in self.filename: # Use "/" since this works on Windows and posix - self.filename = self.filename '/metadata.yaml' + self.filename = self.filename + '/metadata.yaml' # Make relative path into absolute, ImgStores don't work properly it seems # without full paths if we change working directories. Video.fixup_path will From e066b391f7553cbef404f68db80b7383d5648498 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 10:13:21 -0400 Subject: [PATCH 017/176] show number of user labeled frames in status bar --- sleap/gui/app.py | 7 +++++++ sleap/io/dataset.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 19e3e725d..191d67f62 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1576,6 +1576,13 @@ def updateStatusMessage(self, message = None): if self.player.seekbar.hasSelection(): start, end = self.player.seekbar.getSelection() message += f" (selection: {start}-{end})" + message += f" Labeled Frames: " + if self.video is not None: + message += f"{len(self.labels.get_video_user_labeled_frames(self.video))}" + if len(self.labels.videos) > 1: + message += " in video, " + if len(self.labels.videos) > 1: + message += f"{len(self.labels.user_labeled_frames)} in project" self.statusBar().showMessage(message) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 2ed4877c7..ed6da889a 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -366,6 +366,9 @@ def find_last(self, video: Video, frame_idx: int = None) -> LabeledFrame: def user_labeled_frames(self): return [lf for lf in self.labeled_frames if lf.has_user_instances] + def get_video_user_labeled_frames(self, video: Video) -> List[LabeledFrame]: + return [lf for lf in self.labeled_frames if lf.has_user_instances and lf.video == video] + # Methods for instances def instance_count(self, video: Video, frame_idx: int) -> int: From dc912b06602c129f83f3795d6aa7f21337f96531 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 10:22:52 -0400 Subject: [PATCH 018/176] allow 0-1000 negative samples --- sleap/config/active.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sleap/config/active.yaml b/sleap/config/active.yaml index 1c061da42..01afe9684 100644 --- a/sleap/config/active.yaml +++ b/sleap/config/active.yaml @@ -66,6 +66,7 @@ expert: label: Negative samples (if cropping) type: int default: 0 + range: 0,1000 - name: batch_size label: Batch Size @@ -110,6 +111,7 @@ learning: label: Negative samples type: int default: 20 + range: 0,1000 - name: batch_size label: Batch Size From b941894c571081959b689522684f95c8ec4117a4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 10:30:06 -0400 Subject: [PATCH 019/176] show full video filename (don't truncate) --- sleap/gui/dataviews.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 8cbcdf52c..f22cf77d7 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -60,11 +60,7 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): video = self.videos[idx] if prop == "filename": - # show parent dir + name - parent_dir = os.path.split(os.path.dirname(video.filename))[-1] - file_name = os.path.basename(video.filename) - trunc_name = os.path.join(parent_dir, file_name) - return trunc_name + return video.filename elif prop == "frames": return video.frames elif prop == "height": From e7f37fe2352202c9499c21bce9517eca0517cb6b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 11:28:57 -0400 Subject: [PATCH 020/176] bug fix, use tolist() to get list of Python ints --- sleap/gui/suggestions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 6287c3a9c..0666fc0ce 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -131,7 +131,7 @@ def proofreading( low_instances = np.nansum(scores < score_limit, axis=1) # Find all the frames with at least low scoring instances - result = list(idxs[low_instances >= instance_limit]) + result = idxs[low_instances >= instance_limit].tolist() return result From 8fe792020d25e2c679e47215b401ab79531d6a91 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 12:22:23 -0400 Subject: [PATCH 021/176] bug fix when confmaps/paf multiscale different --- sleap/nn/inference.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index d50d497ec..557ca5ad4 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -607,6 +607,8 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] # Find peaks t0 = time() + multiscale_diff = paf_model["multiscale"] / conf_model["multiscale"] + peaks, peak_vals, confmaps = \ peak_tf_inference( model = conf_model["model"], @@ -614,12 +616,12 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] min_thresh=self.nms_min_thresh, gaussian_size=self.nms_kernel_size, gaussian_sigma=self.nms_sigma, - downsample_factor=int(1/paf_model["multiscale"]), + downsample_factor=int(1/multiscale_diff), upsample_factor=int(1/conf_model["multiscale"]), return_confmaps=self.save_confmaps_pafs ) - transform.scale = transform.scale * paf_model["multiscale"] + transform.scale = transform.scale * multiscale_diff logger.info(" Inferred confmaps and found-peaks (gpu) [%.1fs]" % (time() - t0)) logger.info(f" peaks: {len(peaks)}") From 508b2e466a413c3855ac6cf8a6900fb017194edf Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 12:23:07 -0400 Subject: [PATCH 022/176] disable blur when upsampling (it's not working) --- sleap/nn/peakfinding_tf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/peakfinding_tf.py b/sleap/nn/peakfinding_tf.py index 7d7aa6d14..f19ee605c 100644 --- a/sleap/nn/peakfinding_tf.py +++ b/sleap/nn/peakfinding_tf.py @@ -124,7 +124,7 @@ def peak_tf_inference(model, data, n, h, w, c = confmaps.get_shape().as_list() - if gaussian_size: + if gaussian_size and upsample_factor == 1: # Make Gaussian Kernel with desired specs. gauss_kernel = gaussian_kernel(size=gaussian_size, mean=0.0, std=gaussian_sigma) From 58da178f08ff16a1f39055487b4987249f5eb600 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 13:04:45 -0400 Subject: [PATCH 023/176] support for header (graph), use mark objects also fix track colors so they match instance colors --- sleap/gui/slider.py | 315 +++++++++++++++++++++++++++++++------------- 1 file changed, 226 insertions(+), 89 deletions(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 24333215d..dae8177bf 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -5,13 +5,55 @@ from PySide2.QtWidgets import QApplication, QWidget, QLayout, QAbstractSlider from PySide2.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem from PySide2.QtWidgets import QSizePolicy, QLabel, QGraphicsRectItem -from PySide2.QtGui import QPainter, QPen, QBrush, QColor, QKeyEvent -from PySide2.QtCore import Qt, Signal, QRect, QRectF +from PySide2.QtGui import QPainter, QPen, QBrush, QColor, QKeyEvent, QPolygonF, QPainterPath +from PySide2.QtCore import Qt, Signal, QRect, QRectF, QPointF from sleap.gui.overlays.tracks import TrackColorManager -from operator import itemgetter -from itertools import groupby +import attr +import itertools +import numpy as np +from typing import Union + +@attr.s(auto_attribs=True, cmp=False) +class SliderMark: + type: str + val: float + end_val: float=None + row: int=None + track: 'Track'=None + _color: Union[tuple,str]="black" + + @property + def color(self): + colors = dict(simple="black", + filled="blue", + open="blue", + predicted="red") + + if self.type in colors: + return colors[self.type] + else: + return self._color + + @color.setter + def color(self, val): + self._color = val + + @property + def QColor(self): + c = self.color + if type(c) == str: + return QColor(c) + else: + return QColor(*c) + + @property + def filled(self): + if self.type == "open": + return False + else: + return True class VideoSlider(QGraphicsView): """Drop-in replacement for QSlider with additional features. @@ -50,24 +92,28 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # ScrollBarAsNeeded - self._color_manager = color_manager or TrackColorManager() + self._color_manager = color_manager + self._track_rows = 0 self._track_height = 3 + self._header_height = 0 + self._min_height = 19 + self._header_height - height = 19 - slider_rect = QRect(0, 0, 200, height-3) - handle_width = 6 - handle_rect = QRect(0, 1, handle_width, slider_rect.height()-2) - self.setMinimumHeight(height) - self.setMaximumHeight(height) - + # Add border rect + slider_rect = QRect(0, 0, 200, self._min_height-3) self.slider = self.scene.addRect(slider_rect) self.slider.setPen(QPen(QColor("black"))) + # Add drag handle rect + handle_width = 6 + handle_rect = QRect(0, self._handleTop(), handle_width, self._handleHeight()) + self.setMinimumHeight(self._min_height) + self.setMaximumHeight(self._min_height) self.handle = self.scene.addRect(handle_rect) self.handle.setPen(QPen(QColor(80, 80, 80))) self.handle.setBrush(QColor(128, 128, 128, 128)) + # Add (hidden) rect to highlight selection self.select_box = self.scene.addRect(QRect(0, 1, 0, slider_rect.height()-2)) self.select_box.setPen(QPen(QColor(80, 80, 255))) self.select_box.setBrush(QColor(80, 80, 255, 128)) @@ -82,6 +128,17 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self.setValue(val) self.setMarks(marks) + pen = QPen(QColor(80, 80, 255), .5) + pen.setCosmetic(True) + self.poly = self.scene.addPath(QPainterPath(), pen, self.select_box.brush()) + self.headerSeries = dict() + self.drawHeader() + + def _pointsToPath(self, points): + path = QPainterPath() + path.addPolygon(QPolygonF(points)) + return path + def setTracksFromLabels(self, labels, video): """Set slider marks using track information from `Labels` object. @@ -91,66 +148,112 @@ def setTracksFromLabels(self, labels, video): labels: the `labels` with tracks and labeled_frames video: the video for which to show marks """ + + if self._color_manager is None: + self._color_manager = TrackColorManager(labels=labels) + lfs = labels.find(video) slider_marks = [] - track_idx = 0 + track_row = 0 # Add marks with track track_occupancy = labels.get_track_occupany(video) for track in labels.tracks: -# track_idx = labels.tracks.index(track) if track in track_occupancy and not track_occupancy[track].is_empty: for occupancy_range in track_occupancy[track].list: - slider_marks.append((track_idx, *occupancy_range)) - track_idx += 1 + slider_marks.append(SliderMark("track", val=occupancy_range[0], end_val=occupancy_range[1], row=track_row, color=self._color_manager.get_color(track))) + track_row += 1 # Add marks without track if None in track_occupancy: for occupancy_range in track_occupancy[None].list: - slider_marks.extend(range(*occupancy_range)) + for val in range(*occupancy_range): + slider_marks.append(SliderMark("simple", val=val)) # list of frame_idx for simple markers for labeled frames labeled_marks = [lf.frame_idx for lf in lfs] user_labeled = [lf.frame_idx for lf in lfs if len(lf.user_instances)] - # "f" for suggestions with instances and "o" for those without - # "f" means "filled", "o" means "open" - # "p" for suggestions with only predicted instances - def mark_type(frame): - if frame in user_labeled: - return "f" - elif frame in labeled_marks: - return "p" + + for frame_idx in labels.get_video_suggestions(video): + if frame_idx in user_labeled: + mark_type = "filled" + elif frame_idx in labeled_marks: + mark_type = "predicted" else: - return "o" - # list of (type, frame) tuples for suggestions - suggestion_marks = [(mark_type(frame_idx), frame_idx) - for frame_idx in labels.get_video_suggestions(video)] - # combine marks for labeled frame and marks for suggested frames - slider_marks.extend(suggestion_marks) - - self.setTracks(track_idx) + mark_type = "open" + slider_marks.append(SliderMark(mark_type, val=frame_idx)) + + self.setTracks(track_row) # total number of tracks to show self.setMarks(slider_marks) + # self.setHeaderSeries(lfs) + self.updatedTracks.emit() - def setTracks(self, tracks): + def setHeaderSeries(self, lfs): + # calculate total point distance for instances from last labeled frame + def inst_velocity(lf, last_lf): + val = 0 + for inst in lf: + if last_lf is not None: + last_inst = last_lf.find(track=inst.track) + if last_inst: + points_a = inst.points_array(invisible_as_nan=True) + points_b = last_inst[0].points_array(invisible_as_nan=True) + point_dist = np.linalg.norm(points_a - points_b, axis=1) + inst_dist = np.sum(point_dist) # np.nanmean(point_dist) + val += inst_dist if not np.isnan(inst_dist) else 0 + return val + + series = dict() + + last_lf = None + for lf in lfs: + val = inst_velocity(lf, last_lf) + last_lf = lf + if not np.isnan(val): + series[lf.frame_idx] = val #len(lf.instances) + + self.headerSeries = series + self.drawHeader() + + def setTracks(self, track_rows): """Set the number of tracks to show in slider. Args: - tracks: the number of tracks to show + track_rows: the number of tracks to show """ + self._track_rows = track_rows + self.updateHeight() + + def updateHeight(self): + """Update the height of the slider.""" + + tracks = self._track_rows if tracks == 0: - min_height = max_height = 19 + min_height = self._min_height + max_height = self._min_height else: - min_height = max(19, 8 + (self._track_height * min(tracks, 20))) - max_height = max(19, 8 + (self._track_height * tracks)) + # Start with padding height + extra_height = 8 + self._header_height + min_height = extra_height + max_height = extra_height + + # Add height for tracks + min_height += self._track_height * min(tracks, 20) + max_height += self._track_height * tracks + + # Make sure min/max height is at least 19, even if few tracks + min_height = max(self._min_height, min_height) + max_height = max(self._min_height, max_height) self.setMaximumHeight(max_height) self.setMinimumHeight(min_height) self.resizeEvent() def _toPos(self, val, center=False): + """Convert value to x position on slider.""" x = val x -= self._val_min x /= max(1, self._val_max-self._val_min) @@ -160,6 +263,7 @@ def _toPos(self, val, center=False): return x def _toVal(self, x, center=False): + """Convert x position to value.""" val = x val /= self._sliderWidth() val *= max(1, self._val_max-self._val_min) @@ -304,6 +408,9 @@ def setMarks(self, marks): self.clearMarks() if marks is not None: for mark in marks: + if not isinstance(mark, SliderMark): + mark = SliderMark("simple", mark) + print(mark) self.addMark(mark, update=False) self.updatePos() @@ -321,46 +428,29 @@ def addMark(self, new_mark, update=True): new_mark: value to mark """ # check if mark is within slider range - if self._mark_val(new_mark) > self._val_max: return - if self._mark_val(new_mark) < self._val_min: return + if new_mark.val > self._val_max: return + if new_mark.val < self._val_min: return self._marks.add(new_mark) + v_top_pad = 3 + self._header_height + v_bottom_pad = 3 + width = 0 - filled = True - if type(new_mark) == tuple: - if type(new_mark[0]) == int: - # colored track if mark has format: (track_number, start_frame_idx, end_frame_idx) - track = new_mark[0] - v_offset = 3 + (self._track_height * track) - height = 1 - color = QColor(*self._color_manager.get_color(track)) - else: - # rect (open/filled) if format: ("o", frame_idx) or ("f", frame_idx) - # ("p", frame_idx) when only predicted instances on frame - mark_type = new_mark[0] - v_offset = 3 - height = self.slider.rect().height()-6 - if mark_type == "o": - width = 2 - filled = False - color = QColor("blue") - elif mark_type == "f": - width = 2 - color = QColor("blue") - elif mark_type == "p": - width = 0 - color = QColor("red") + if new_mark.type == "track": + v_offset = v_top_pad + (self._track_height * new_mark.row) + height = 1 else: - # line if mark has format: frame_idx - v_offset = 3 - height = self.slider.rect().height()-6 - color = QColor("black") + v_offset = v_top_pad + height = self.slider.rect().height()-(v_offset+v_bottom_pad) + + width = 2 if new_mark.type in ("open", "filled") else 0 + color = new_mark.QColor pen = QPen(color, .5) pen.setCosmetic(True) - brush = QBrush(color) if filled else QBrush() + brush = QBrush(color) if new_mark.filled else QBrush() line = self.scene.addRect(-width//2, v_offset, width, height, pen, brush) @@ -368,26 +458,24 @@ def addMark(self, new_mark, update=True): if update: self.updatePos() def _mark_val(self, mark): - return mark[1] if type(mark) == tuple else mark + return mark.val def updatePos(self): - """Update the visual position of handle and slider annotations.""" + """Update the visual x position of handle and slider annotations.""" x = self._toPos(self.value()) self.handle.setPos(x, 0) + for mark in self._mark_items.keys(): + width = 0 - if type(mark) == tuple: - in_track = True - v = mark[1] - if type(mark[0]) == int: - width_in_frames = mark[2] - mark[1] - width = max(2, self._toPos(width_in_frames)) - elif mark[0] == "o": - width = 2 - else: - in_track = False - v = mark - x = self._toPos(v, center=True) + if mark.type == "track": + width_in_frames = mark.end_val - mark.val + width = max(2, self._toPos(width_in_frames)) + + elif mark.type in ("open", "filled"): + width = 2 + + x = self._toPos(mark.val, center=True) self._mark_items[mark].setPos(x, 0) rect = self._mark_items[mark].rect() @@ -395,6 +483,45 @@ def updatePos(self): self._mark_items[mark].setRect(rect) + def drawHeader(self): + if len(self.headerSeries) == 0 or self._header_height == 0: + self.poly.setPath(QPainterPath()) + return + + step = max(self.headerSeries.keys())//int(self._sliderWidth()) + step = max(step, 1) + count = max(self.headerSeries.keys())//step*step + + sampled = np.full((count), 0.0) + for key, val in self.headerSeries.items(): + if key < count: + sampled[key] = val + sampled = np.max(sampled.reshape(count//step,step), axis=1) + series = {i*step:sampled[i] for i in range(count//step)} + +# series = {key:self.headerSeries[key] for key in sorted(self.headerSeries.keys())} + + series_min = np.min(sampled) - 1 + series_max = np.max(sampled) + series_scale = (self._header_height-5)/(series_max - series_min) + + def toYPos(val): + return self._header_height-((val-series_min)*series_scale) + + step_chart = False # use steps rather than smooth line + + points = [] + points.append((self._toPos(0, center=True), toYPos(series_min))) + for idx, val in series.items(): + points.append((self._toPos(idx, center=True), toYPos(val))) + if step_chart: + points.append((self._toPos(idx+step, center=True), toYPos(val))) + points.append((self._toPos(max(series.keys()) + 1, center=True), toYPos(series_min))) + + # Convert to list of QPointF objects + points = list(itertools.starmap(QPointF,points)) + self.poly.setPath(self._pointsToPath(points)) + def moveHandle(self, x, y): """Move handle in response to mouse position. @@ -441,16 +568,28 @@ def resizeEvent(self, event=None): slider_rect.setHeight(height-3) if event is not None: slider_rect.setWidth(event.size().width()-1) - handle_rect.setHeight(slider_rect.height()-2) - select_box_rect.setHeight(slider_rect.height()-2) + handle_rect.setHeight(self._handleHeight()) + select_box_rect.setHeight(self._handleHeight()) self.slider.setRect(slider_rect) self.handle.setRect(handle_rect) self.select_box.setRect(select_box_rect) self.updatePos() + self.drawHeader() super(VideoSlider, self).resizeEvent(event) + def _handleTop(self): + return 1 + self._header_height + + def _handleHeight(self, slider_rect=None): + if slider_rect is None: + slider_rect = self.slider.rect() + + handle_bottom_offset = 1 + handle_height = slider_rect.height() - (self._handleTop()+handle_bottom_offset) + return handle_height + def mousePressEvent(self, event): """Override method to move handle for mouse press/drag. @@ -521,6 +660,7 @@ def paint(self, *args, **kwargs): """Method required by Qt.""" super(VideoSlider, self).paint(*args, **kwargs) + if __name__ == "__main__": app = QApplication([]) @@ -528,10 +668,7 @@ def paint(self, *args, **kwargs): min=0, max=20, val=15, marks=(10,15)#((0,10),(0,15),(1,10),(1,11),(2,12)), tracks=3 ) - window.setTracks(5) -# mark_positions = ((0,10),(0,15),(1,10),(1,11),(2,12),(3,12),(3,13),(3,14),(4,15),(4,16),(4,21)) - mark_positions = [("o",i) for i in range(3,15,4)] + [("f",18)] - window.setMarks(mark_positions) + window.valueChanged.connect(lambda x: print(x)) window.show() From 1d211b05c6345a1a4a3478c8e9aea78701e67c80 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 13:18:39 -0400 Subject: [PATCH 024/176] note change when skeleton added --- sleap/gui/app.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 191d67f62..4965471be 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -657,6 +657,8 @@ def openSkeleton(self): if len(sk_list): self.skeleton = sk_list[0] + self.changestack_push("new skeleton") + # Update data model self.update_data_views() @@ -762,6 +764,7 @@ def generateSuggestions(self, params): params=params) self.labels.set_suggestions(new_suggestions) + self.update_data_views() self.updateSeekbarMarks() From e503c28d9fafa52077a372b838d8c7505daed78d Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 13:43:05 -0400 Subject: [PATCH 025/176] create output dir if doesn't exist --- sleap/nn/inference.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 557ca5ad4..1d3361f08 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -185,6 +185,11 @@ def predict(self, # Initialize tracking tracker = FlowShiftTracker(window=self.flow_window, verbosity=0) + # Create output directory if it doesn't exist + try: + os.mkdir(os.path.dirname(self.output_path)) + except FileExistsError: + pass # Delete the output file if it exists already if os.path.exists(self.output_path): os.unlink(self.output_path) From 3a091fbe49e76b9cb567342f874f8a729604f133 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 11 Sep 2019 19:00:47 -0400 Subject: [PATCH 026/176] check labels is not None when updating gui --- sleap/gui/app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 4965471be..66dd6c014 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -423,9 +423,9 @@ def update_gui_state(self): has_selected_instance = (self.player.view.getSelection() is not None) has_unsaved_changes = self.changestack_has_changes() has_multiple_videos = (self.labels is not None and len(self.labels.videos) > 1) - has_labeled_frames = any((lf.video == self.video for lf in self.labels)) - has_suggestions = (len(self.labels.suggestions) > 0) - has_tracks = (len(self.labels.tracks) > 0) + has_labeled_frames = self.labels is not None and any((lf.video == self.video for lf in self.labels)) + has_suggestions = self.labels is not None and (len(self.labels.suggestions) > 0) + has_tracks = self.labels is not None and (len(self.labels.tracks) > 0) has_multiple_instances = (self.labeled_frame is not None and len(self.labeled_frame.instances) > 1) # todo: exclude predicted instances from count has_nodes_selected = (self.skeletonEdgesSrc.currentIndex() > -1 and From a4cb24b75faef00baf8a3f4e5ed66361fb0b8c57 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 12 Sep 2019 09:07:26 -0400 Subject: [PATCH 027/176] add method to find instance in frame --- sleap/instance.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sleap/instance.py b/sleap/instance.py index 4d38f0f4f..597cf0d1d 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -791,6 +791,14 @@ def __setitem__(self, index, value: Instance): # Modify the instance to have a reference back to this frame value.frame = self + def find(self, track=-1, user=False): + instances = self.instances + if user: + instances = list(filter(lambda inst: type(inst) == Instance, instances)) + if track != -1: # use -1 since we want to accept None as possible value + instances = list(filter(lambda inst: inst.track == track, instances)) + return instances + @property def instances(self): """ From 660b753066696fc21f525673c5d5a14c623a021c Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 12 Sep 2019 10:17:59 -0400 Subject: [PATCH 028/176] Attempt to fix bugs from active learning results The instances returned from active learning had nodes or skeletons that weren't matching those already in the project, and this was messing up lots of things. Now we merge video by video, and then check nodes by name when creating instances from predicted instances in the gui. This seems to fix issues 148 and 152. --- sleap/gui/active.py | 52 +++++++++++++++++---------------------------- sleap/gui/app.py | 6 +++--- sleap/instance.py | 8 +++---- 3 files changed, 26 insertions(+), 40 deletions(-) diff --git a/sleap/gui/active.py b/sleap/gui/active.py index d47611f52..913cb7a2f 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -319,36 +319,17 @@ def run(self): save_confmaps_pafs = form_data.get("_save_confmaps_pafs", False) # Run active learning pipeline using the TrainingJobs - new_lfs = run_active_learning_pipeline( - labels_filename = self.labels_filename, - labels = self.labels, - training_jobs = training_jobs, - frames_to_predict = frames_to_predict, - with_tracking = with_tracking, - save_confmaps_pafs = save_confmaps_pafs) - - # remove labeledframes without any predicted instances - new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) - # Update labels with results of active learning - - new_tracks = {inst.track for lf in new_lfs for inst in lf.instances if inst.track is not None} - if len(new_tracks) < 50: - self.labels.tracks = list(set(self.labels.tracks).union(new_tracks)) - # if there are more than 50 predicted tracks, assume this is wrong (FIXME?) - elif len(new_tracks): - for lf in new_lfs: - for inst in lf.instances: - inst.track = None - - # Update Labels with new data - # add new labeled frames - self.labels.extend_from(new_lfs) - # combine instances from labeledframes with same video/frame_idx - self.labels.merge_matching_frames() + new_counts = run_active_learning_pipeline( + labels_filename = self.labels_filename, + labels = self.labels, + training_jobs = training_jobs, + frames_to_predict = frames_to_predict, + with_tracking = with_tracking, + save_confmaps_pafs = save_confmaps_pafs) self.learningFinished.emit() - QtWidgets.QMessageBox(text=f"Active learning has finished. Instances were predicted on {len(new_lfs)} frames.").exec_() + QtWidgets.QMessageBox(text=f"Active learning has finished. Instances were predicted on {new_counts} frames.").exec_() def view_datagen(self): from sleap.nn.datagen import generate_training_data, \ @@ -671,7 +652,7 @@ def run_active_learning_pipeline( # Run the Predictor for suggested frames # We want to predict for suggested frames that don't already have user instances - new_labeled_frames = [] + new_labeled_frame_count = 0 user_labeled_frames = labels.user_labeled_frames # show message while running inference @@ -698,21 +679,26 @@ def run_active_learning_pipeline( new_labels_json = result.get() new_labels = Labels.from_json(new_labels_json, match_to=labels) - video_lfs = new_labels.labeled_frames + # Add new frames to labels + # (we're doing this for each video as we go since there was a problem + # when we tried to add frames for all videos together.) + new_lfs = new_labels.labeled_frames + new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) + labels.extend_from(new_lfs) + labels.merge_matching_frames() + + new_labeled_frame_count += len(new_lfs) else: QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() result.get() else: import time time.sleep(1) - video_lfs = [] - - new_labeled_frames.extend(video_lfs) # close message window win.close() - return new_labeled_frames + return new_labeled_frame_count if __name__ == "__main__": import sys diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 66dd6c014..d72b83fa1 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1056,7 +1056,7 @@ def doubleClickInstance(self, instance): in_view_rect = self.player.view.mapToScene(self.player.view.rect()).boundingRect() for node in self.skeleton.nodes: - if node not in instance.nodes or instance[node].isnan(): + if node.name not in instance.node_names or instance[node].isnan(): # pick random points within currently zoomed view x = in_view_rect.x() + (in_view_rect.width() * 0.1) \ + (np.random.rand() * in_view_rect.width() * 0.8) @@ -1114,9 +1114,9 @@ def newInstance(self, copy_instance=None): in_view_rect = self.player.view.mapToScene(self.player.view.rect()).boundingRect() # go through each node in skeleton - for node in self.skeleton.nodes: + for node in self.skeleton.node_names: # if we're copying from a skeleton that has this node - if copy_instance is not None and node in copy_instance.nodes and not copy_instance[node].isnan(): + if copy_instance is not None and node in copy_instance and not copy_instance[node].isnan(): # just copy x, y, and visible # we don't want to copy a PredictedPoint or score attribute new_instance[node] = Point( diff --git a/sleap/instance.py b/sleap/instance.py index 4d38f0f4f..f9bc68bdd 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -565,8 +565,8 @@ def fix_array(self): # Add points into new array for i, node in enumerate(self._nodes): - if node in self.skeleton.nodes: - new_array[self.skeleton.nodes.index(node)] = self._points[i] + if node.name in self.skeleton.node_names: + new_array[self.skeleton.node_names.index(node.name)] = self._points[i] # Update points and nodes for this instance self._points = new_array @@ -691,10 +691,10 @@ def make_instance_cattr(): converter.register_unstructure_hook(PredictedPointArray, lambda x: None) def unstructure_instance(x: Instance): - # Unstructure everything but the points array and frame attribute + # Unstructure everything but the points array, nodes, and frame attribute d = {field.name: converter.unstructure(x.__getattribute__(field.name)) for field in attr.fields(x.__class__) - if field.name not in ['_points', 'frame']} + if field.name not in ['_points', '_nodes', 'frame']} # Replace the point array with a dict d['_points'] = converter.unstructure({k: v for k, v in x.nodes_points}) From e1005654be8dcf8b85939038ef00a2b207d6aac0 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 12 Sep 2019 15:37:22 -0400 Subject: [PATCH 029/176] don't test for strict skeleton identity --- tests/nn/test_training.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index e583538c2..52f661507 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -30,10 +30,13 @@ def test_training_job_json(tmpdir, multi_skel_vid_labels, backbone): # Load the JSON back in loaded_run = TrainingJob.load_json(json_path) - assert loaded_run == train_run - - # Make sure the skeletons match too, not sure what the difference - # between __eq__ and matches on skeleton is at this point. + # Make sure the skeletons match (even though not eq) for sk1, sk2 in zip(loaded_run.model.skeletons, train_run.model.skeletons): assert sk1.matches(sk2) + # Now remove the skeletons since we want to check eq on everything else + loaded_run.model.skeletons = [] + train_run.model.skeletons = [] + + assert loaded_run == train_run + From 47d0d2ba2ae413265e063ca04bb974da195c7a60 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 12 Sep 2019 15:39:09 -0400 Subject: [PATCH 030/176] Hash skeleton by id, remove __eq__ override We were hashing skeletons by name, and __eq__ was checking for same name and match(). This meant that serializing/deserializing a skeleton would result in something == to the original so we'd only save a single skeleton in the labels project (even if both skeletons were used by instances). Of course we'd usually want to unify the instances so they all use the same skeleton (assuming the skeletons do match), but the skeleton object shouldn't make things break if you don't unify. --- sleap/skeleton.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index d3450f460..6b4af6a18 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -887,19 +887,8 @@ def load_mat(cls, filename: str): def __str__(self): return "%s(name=%r)" % (self.__class__.__name__, self.name) - def __eq__(self, other: 'Skeleton'): - - # First check names, duh! - if other.name != self.name: - return False - - # Then check if the graphs match - return self.matches(other) - def __hash__(self): """ - Construct a hash from skeleton name, which we force to be immutable so hashes - will not change. + Construct a hash from skeleton id. """ - return hash(self.name) - + return id(self) From 1410dc9d1b67d697523b91b4aa59967cf10932e6 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 12 Sep 2019 15:44:12 -0400 Subject: [PATCH 031/176] Add unify arg to Labels.extend_from(), new tests. We use unify if we want objects in the new frames to be replaced with corresponding objects already in the labels. Note that if we're adding two frames that could be unified with each other, they won't be unified unless they can both be unified with data already in the labels to which they're being added. --- sleap/io/dataset.py | 15 +++++++-- tests/io/test_dataset.py | 69 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index ed6da889a..25a39e0b5 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -647,11 +647,14 @@ def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): # Methods for saving/loading - def extend_from(self, new_frames): - """Merge data from another Labels object or list of LabeledFrames into self. + def extend_from(self, new_frames: Union['Labels',List[LabeledFrame]], unify:bool=False): + """ + Merge data from another Labels object or list of LabeledFrames into self. Arg: new_frames: the object from which to copy data + unify: whether to replace objects in new frames with + corresponding objects from current `Labels` data Returns: bool, True if we added frames, False otherwise """ @@ -662,6 +665,14 @@ def extend_from(self, new_frames): if not isinstance(new_frames, list) or len(new_frames) == 0: return False if not isinstance(new_frames[0], LabeledFrame): return False + # If unify, we want to replace objects in the frames with + # corresponding objects from the current labels. + # We do this by deserializing/serializing with match_to. + if unify: + new_json = Labels(labeled_frames=new_frames).to_dict() + new_labels = Labels.from_json(new_json, match_to=self) + new_frames = new_labels.labeled_frames + # copy the labeled frames self.labeled_frames.extend(new_frames) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 67c1e13c1..7a0aeb74a 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -268,6 +268,75 @@ def test_label_mutability(): assert len(labels) == 1 assert len(labels.labeled_frames[0].instances) == 10 +def skeleton_ids_from_label_instances(labels): + return list(map(id, (lf.instances[0].skeleton for lf in labels.labeled_frames))) + +def test_duplicate_skeletons_serializing(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + new_labels = Labels(labeled_frames=[lf_a, lf_b]) + new_labels_json = new_labels.to_dict() + +def test_distinct_skeletons_serializing(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b.add_node("foo") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + new_labels = Labels(labeled_frames=[lf_a, lf_b]) + + # Make sure we can serialize this + new_labels_json = new_labels.to_dict() + +def test_unify_skeletons(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + labels = Labels() + labels.extend_from([lf_a], unify=True) + labels.extend_from([lf_b], unify=True) + ids = skeleton_ids_from_label_instances(labels) + + # Make sure that skeleton_b got replaced with skeleton_a when we + # added the frame with "unify" set + assert len(set(ids)) == 1 + + # Make sure we can serialize this + labels.to_dict() + +def test_dont_unify_skeletons(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + labels = Labels(labeled_frames=[lf_a]) + labels.extend_from([lf_b], unify=False) + ids = skeleton_ids_from_label_instances(labels) + + # Make sure we still have two distinct skeleton objects + assert len(set(ids)) == 2 + + # Make sure we can serialize this + labels.to_dict() def test_instance_access(): labels = Labels() From 0cf087e8c3932f435c62096ea4e957431fb2fe20 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 12 Sep 2019 17:22:40 -0400 Subject: [PATCH 032/176] update empty gui skeleton from labels skeleton --- sleap/gui/app.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index d72b83fa1..bea561035 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -462,6 +462,9 @@ def update_gui_state(self): self._buttons["delete instance"].setEnabled(self.instancesTable.currentIndex().isValid()) def update_data_views(self): + if len(self.skeleton.nodes) == 0 and len(self.labels.skeletons): + self.skeleton = self.labels.skeletons[0] + self.videosTable.model().videos = self.labels.videos self.skeletonNodesTable.model().skeleton = self.skeleton From c84128dcdeecdfa5baa5b07f79ba5101e21a5185 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 10:38:06 -0400 Subject: [PATCH 033/176] disable save_confmaps_pafs since not currently working --- sleap/nn/inference.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 1d3361f08..4701c859b 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -589,9 +589,14 @@ def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame # Save confmaps if self.output_path is not None and self.save_confmaps_pafs: - save_visual_outputs( - output_path = self.output_path, - data = dict(confmaps=confmaps, box=imgs)) + logger.warning("Not saving confmaps because feature currently not working.") + # Disable save_confmaps_pafs since not currently working. + # The problem is that we can't put data for different crop sizes + # all into a single h5 datasource. It's now possible to view live + # predicted confmap and paf in the gui, so this isn't high priority. + # save_visual_outputs( + # output_path = self.output_path, + # data = dict(confmaps=confmaps, box=imgs)) return predicted_frames_chunk @@ -669,10 +674,15 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] # Save confmaps and pafs if self.output_path is not None and self.save_confmaps_pafs: - save_visual_outputs( - output_path = self.output_path, - data = dict(confmaps=confmaps, pafs=pafs, - frame_idxs=transform.frame_idxs, bounds=transform.bounding_boxes)) + logger.warning("Not saving confmaps/pafs because feature currently not working.") + # Disable save_confmaps_pafs since not currently working. + # The problem is that we can't put data for different crop sizes + # all into a single h5 datasource. It's now possible to view live + # predicted confmap and paf in the gui, so this isn't high priority. + # save_visual_outputs( + # output_path = self.output_path, + # data = dict(confmaps=confmaps, pafs=pafs, + # frame_idxs=transform.frame_idxs, bounds=transform.bounding_boxes)) return predicted_frames_chunk From 006ad93d6d8b8566c2319e4f59544e7740085e8b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 10:45:36 -0400 Subject: [PATCH 034/176] refactor active learning code into more functions --- sleap/gui/active.py | 265 ++++++++++++++++++++++++++++---------------- 1 file changed, 168 insertions(+), 97 deletions(-) diff --git a/sleap/gui/active.py b/sleap/gui/active.py index 913cb7a2f..5d79f0b53 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -316,7 +316,13 @@ def run(self): with_tracking = True else: frames_to_predict = dict() - save_confmaps_pafs = form_data.get("_save_confmaps_pafs", False) + save_confmaps_pafs = False + # Disable save_confmaps_pafs since not currently working. + # The problem is that we can't put data for different crop sizes + # all into a single h5 datasource. It's now possible to view live + # predicted confmap and paf in the gui, so this isn't high priority. + # If you want to enable, uncomment this: + # save_confmaps_pafs = form_data.get("_save_confmaps_pafs", False) # Run active learning pipeline using the TrainingJobs new_counts = run_active_learning_pipeline( @@ -324,8 +330,7 @@ def run(self): labels = self.labels, training_jobs = training_jobs, frames_to_predict = frames_to_predict, - with_tracking = with_tracking, - save_confmaps_pafs = save_confmaps_pafs) + with_tracking = with_tracking) self.learningFinished.emit() @@ -461,8 +466,8 @@ def select_job(self, model_type, idx): def make_default_training_jobs(): - from sleap.nn.model import Model, ModelOutputType - from sleap.nn.training import Trainer, TrainingJob + from sleap.nn.model import Model + from sleap.nn.training import Trainer from sleap.nn.architectures import unet, leap # Build Models (wrapper for Keras model with some metadata) @@ -530,7 +535,6 @@ def find_saved_jobs(job_dir, jobs=None): Returns: dict of {ModelOutputType: list of (filename, TrainingJob) tuples} """ - from sleap.nn.training import TrainingJob files = os.listdir(job_dir) @@ -559,23 +563,39 @@ def find_saved_jobs(job_dir, jobs=None): return jobs +def add_frames_from_json(labels: Labels, new_labels_json: str): + # Deserialize the new frames, matching to the existing videos/skeletons if possible + new_lfs = Labels.from_json(new_labels_json, match_to=labels).labeled_frames + + # Remove any frames without instances + new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) + + # Now add them to labels and merge labeled frames with same video/frame_idx + labels.extend_from(new_lfs) + labels.merge_matching_frames() + + return len(new_lfs) + def run_active_learning_pipeline( labels_filename: str, - labels: Labels=None, - training_jobs: Dict=None, - frames_to_predict: Dict=None, - with_tracking: bool=False, - save_confmaps_pafs: bool=False, - skip_learning: bool=False): - # Imports here so we don't load TensorFlow before necessary - from sleap.nn.monitor import LossViewer - from sleap.nn.training import TrainingJob - from sleap.nn.model import ModelOutputType - from sleap.nn.inference import Predictor + labels: Labels, + training_jobs: Dict['ModelOutputType', 'TrainingJob']=None, + frames_to_predict: Dict[Video, List[int]]=None, + with_tracking: bool=False) -> int: + """Run training (as needed) and inference. - from PySide2 import QtWidgets + Args: + labels_filename: Path to already saved current labels object. + labels: The current labels object; results will be added to this. + training_jobs: The TrainingJobs with params/hyperparams for training. + frames_to_predict: Dict that gives list of frame indices for each video. + with_tracking: Whether to run tracking code after we predict instances. + This should be used only when predicting on continuous set of frames. - labels = labels or Labels.load_json(labels_filename) + Returns: + Number of new frames added to labels. + + """ # Prepare our TrainingJobs @@ -586,119 +606,170 @@ def run_active_learning_pipeline( # Set the parameters specific to this run for job in training_jobs.values(): job.labels_filename = labels_filename -# job.trainer.scale = scale - # Run the TrainingJobs + save_dir = os.path.join(os.path.dirname(labels_filename), "models") - save_dir = os.path.join(os.path.dirname(labels_filename), "models") + # Train the TrainingJobs + trained_jobs = run_active_training(labels, training_jobs, save_dir) - # open training monitor window - win = LossViewer() - win.resize(600, 400) - win.show() + # Check that all the models were trained + if None in trained_jobs.values(): + return 0 + + # Run the Predictor for suggested frames + new_labeled_frame_count = \ + run_active_inference(labels, trained_jobs, save_dir, frames_to_predict, with_tracking) + + return new_labeled_frame_count + +def run_active_training( + labels: Labels, + training_jobs: Dict['ModelOutputType', 'TrainingJob'], + save_dir:str, + gui:bool = True) -> Dict['ModelOutputType', 'TrainingJob']: + """ + Run training for each training job. + + Args: + labels: Labels object from which we'll get training data. + training_jobs: Dict of the jobs to train. + save_dir: Path to the directory where we'll save inference results. + gui: Whether to show gui windows and process gui events. + + Returns: + Dict of trained jobs corresponding with input training jobs. + """ + + trained_jobs = dict() + + if gui: + from sleap.nn.monitor import LossViewer + + # open training monitor window + win = LossViewer() + win.resize(600, 400) + win.show() for model_type, job in training_jobs.items(): if getattr(job, "use_trained_model", False): # set path to TrainingJob already trained from previous run json_name = f"{job.run_name}.json" - training_jobs[model_type] = os.path.join(job.save_dir, json_name) - print(f"Using already trained model: {training_jobs[model_type]}") + trained_jobs[model_type] = os.path.join(job.save_dir, json_name) + print(f"Using already trained model: {trained_jobs[model_type]}") else: - print("Resetting monitor window.") - win.reset(what=str(model_type)) - win.setWindowTitle(f"Training Model - {str(model_type)}") + if gui: + print("Resetting monitor window.") + win.reset(what=str(model_type)) + win.setWindowTitle(f"Training Model - {str(model_type)}") + print(f"Start training {str(model_type)}...") - if not skip_learning: - # run training - pool, result = job.trainer.train_async(model=job.model, labels=labels, - save_dir=save_dir) + # Start training in separate process + # This makes it easier to ensure that tensorflow released memory when done + pool, result = job.trainer.train_async(model=job.model, labels=labels, + save_dir=save_dir) - while not result.ready(): + # Wait for training results + while not result.ready(): + if gui: QtWidgets.QApplication.instance().processEvents() - # win.check_messages() - result.wait(.01) + result.wait(.01) - if result.successful(): - # get the path to the resulting TrainingJob file - training_jobs[model_type] = result.get() - print(f"Finished training {str(model_type)}.") - else: - training_jobs[model_type] = None + if result.successful(): + # get the path to the resulting TrainingJob file + trained_jobs[model_type] = result.get() + print(f"Finished training {str(model_type)}.") + else: + if gui: win.close() QtWidgets.QMessageBox(text=f"An error occured while training {str(model_type)}. Your command line terminal may have more information about the error.").exec_() - result.get() - - - if not skip_learning: - for model_type, job in training_jobs.items(): - # load job from json - training_jobs[model_type] = TrainingJob.load_json(training_jobs[model_type]) - - # close training monitor window - win.close() + trained_jobs[model_type] = None + result.get() + + # Load the jobs we just trained + for model_type, job in trained_jobs.items(): + # Replace path to saved TrainingJob with the deseralized object + if trained_jobs[model_type] is not None: + trained_jobs[model_type] = TrainingJob.load_json(trained_jobs[model_type]) + + if gui: + # close training monitor window + win.close() + + return trained_jobs + +def run_active_inference( + labels: Labels, + training_jobs: Dict['ModelOutputType', 'TrainingJob'], + save_dir:str, + frames_to_predict: Dict[Video, List[int]], + with_tracking: bool, + gui: bool=True) -> int: + """Run inference on specified frames using models from training_jobs. - if not skip_learning: - timestamp = datetime.now().strftime("%y%m%d_%H%M%S") - inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5") + Args: + labels: The current labels object; results will be added to this. + training_jobs: The TrainingJobs with trained models to use. + save_dir: Path to the directory where we'll save inference results. + frames_to_predict: Dict that gives list of frame indices for each video. + with_tracking: Whether to run tracking code after we predict instances. + This should be used only when predicting on continuous set of frames. + gui: Whether to show gui windows and process gui events. - # Create Predictor from the results of training - predictor = Predictor(sleap_models=training_jobs, - with_tracking=with_tracking, - output_path=inference_output_path, - save_confmaps_pafs=save_confmaps_pafs) + Returns: + Number of new frames added to labels. + """ + from sleap.nn.inference import Predictor - # Run the Predictor for suggested frames - # We want to predict for suggested frames that don't already have user instances + total_new_lf_count = 0 + timestamp = datetime.now().strftime("%y%m%d_%H%M%S") + inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5") - new_labeled_frame_count = 0 - user_labeled_frames = labels.user_labeled_frames + # Create Predictor from the results of training + predictor = Predictor(sleap_models=training_jobs, + with_tracking=with_tracking, + output_path=inference_output_path) - # show message while running inference - win = QtWidgets.QProgressDialog() - win.setLabelText(" Running inference on selected frames... ") - win.show() - QtWidgets.QApplication.instance().processEvents() + if gui: + # show message while running inference + win = QtWidgets.QProgressDialog() + win.setLabelText(" Running inference on selected frames... ") + win.show() + QtWidgets.QApplication.instance().processEvents() for video, frames in frames_to_predict.items(): if len(frames): - if not skip_learning: - # run predictions for desired frames in this video - # video_lfs = predictor.predict(input_video=video, frames=frames, output_path=inference_output_path) - pool, result = predictor.predict_async( - input_video=video, - frames=frames) + # Run inference for desired frames in this video + pool, result = predictor.predict_async( + input_video=video, + frames=frames) - while not result.ready(): + while not result.ready(): + if gui: QtWidgets.QApplication.instance().processEvents() - result.wait(.01) + result.wait(.01) - if result.successful(): - new_labels_json = result.get() - new_labels = Labels.from_json(new_labels_json, match_to=labels) + if result.successful(): + new_labels_json = result.get() - # Add new frames to labels - # (we're doing this for each video as we go since there was a problem - # when we tried to add frames for all videos together.) - new_lfs = new_labels.labeled_frames - new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) - labels.extend_from(new_lfs) - labels.merge_matching_frames() + # Add new frames to labels + # (we're doing this for each video as we go since there was a problem + # when we tried to add frames for all videos together.) + new_lf_count = add_frames_from_json(labels, new_labels_json) - new_labeled_frame_count += len(new_lfs) - else: - QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() - result.get() + total_new_lf_count += new_lf_count else: - import time - time.sleep(1) + if gui: + QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() + result.get() # close message window - win.close() + if gui: + win.close() - return new_labeled_frame_count + return total_new_lf_count if __name__ == "__main__": import sys From 97fc2a93d4fc1048fcc5a30bbcd4b089a8fbbd5f Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 12:20:18 -0400 Subject: [PATCH 035/176] Partial test coverage for active learning. Doesn't test gui, training, or inference; just the helper functions that load/create training jobs and merge results of inference into labels. --- .../set_a/default_centroids.json | 1 + .../set_a/default_confmaps.json | 1 + .../training_profiles/set_a/default_pafs.json | 1 + .../set_b/test_confmaps.json | 1 + tests/gui/test_active.py | 116 ++++++++++++++++++ 5 files changed, 120 insertions(+) create mode 100644 tests/data/training_profiles/set_a/default_centroids.json create mode 100644 tests/data/training_profiles/set_a/default_confmaps.json create mode 100644 tests/data/training_profiles/set_a/default_pafs.json create mode 100644 tests/data/training_profiles/set_b/test_confmaps.json create mode 100644 tests/gui/test_active.py diff --git a/tests/data/training_profiles/set_a/default_centroids.json b/tests/data/training_profiles/set_a/default_centroids.json new file mode 100644 index 000000000..2e0f59b6d --- /dev/null +++ b/tests/data/training_profiles/set_a/default_centroids.json @@ -0,0 +1 @@ +{"model": {"output_type": 2, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 16, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 4, "num_epochs": 100, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 0.25, "sigma": 5.0, "instance_crop": false}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file diff --git a/tests/data/training_profiles/set_a/default_confmaps.json b/tests/data/training_profiles/set_a/default_confmaps.json new file mode 100644 index 000000000..4503d7e8b --- /dev/null +++ b/tests/data/training_profiles/set_a/default_confmaps.json @@ -0,0 +1 @@ +{"model": {"output_type": 0, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 32, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file diff --git a/tests/data/training_profiles/set_a/default_pafs.json b/tests/data/training_profiles/set_a/default_pafs.json new file mode 100644 index 000000000..5c04a2acc --- /dev/null +++ b/tests/data/training_profiles/set_a/default_pafs.json @@ -0,0 +1 @@ +{"model": {"output_type": 1, "backbone": {"down_blocks": 3, "up_blocks": 3, "upsampling_layers": true, "num_filters": 32, "interp": "bilinear"}, "skeletons": null, "backbone_name": "LeapCNN"}, "trainer": {"val_size": 0.15, "optimizer": "adam", "learning_rate": 5e-5, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 100, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-6, "reduce_lr_factor": 0.5, "reduce_lr_patience": 8, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file diff --git a/tests/data/training_profiles/set_b/test_confmaps.json b/tests/data/training_profiles/set_b/test_confmaps.json new file mode 100644 index 000000000..2245a173c --- /dev/null +++ b/tests/data/training_profiles/set_b/test_confmaps.json @@ -0,0 +1 @@ +{"model": {"output_type": 0, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 32, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"num_epochs": 17}} \ No newline at end of file diff --git a/tests/gui/test_active.py b/tests/gui/test_active.py new file mode 100644 index 000000000..5d340cd1d --- /dev/null +++ b/tests/gui/test_active.py @@ -0,0 +1,116 @@ +import os + +from sleap.skeleton import Skeleton +from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance +from sleap.io.video import Video +from sleap.io.dataset import Labels +from sleap.nn.model import ModelOutputType +from sleap.gui.active import make_default_training_jobs, find_saved_jobs, add_frames_from_json + +def test_make_default_training_jobs(): + jobs = make_default_training_jobs() + + assert ModelOutputType.CONFIDENCE_MAP in jobs + assert ModelOutputType.PART_AFFINITY_FIELD in jobs + + for output_type in jobs: + assert jobs[output_type].model.output_type == output_type + assert jobs[output_type].best_model_filename is None + +def test_find_saved_jobs(): + jobs_a = find_saved_jobs("tests/data/training_profiles/set_a") + assert len(jobs_a) == 3 + assert len(jobs_a[ModelOutputType.CONFIDENCE_MAP]) == 1 + + jobs_b = find_saved_jobs("tests/data/training_profiles/set_b") + assert len(jobs_b) == 1 + + path, job = jobs_b[ModelOutputType.CONFIDENCE_MAP][0] + assert os.path.basename(path) == "test_confmaps.json" + assert job.trainer.num_epochs == 17 + + # Add jobs from set_a to already loaded jobs from set_b + jobs_c = find_saved_jobs("tests/data/training_profiles/set_a", jobs_b) + assert len(jobs_c) == 3 + + # Make sure we now have two confmap jobs + assert len(jobs_c[ModelOutputType.CONFIDENCE_MAP]) == 2 + + # Make sure set_a was added after items from set_b + paths = [name for (name, job) in jobs_c[ModelOutputType.CONFIDENCE_MAP]] + assert os.path.basename(paths[0]) == "test_confmaps.json" + assert os.path.basename(paths[1]) == "default_confmaps.json" + +def test_add_frames_from_json(): + vid_a = Video.from_filename("foo.mp4") + vid_b = Video.from_filename("bar.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid_a, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid_b, frame_idx=3, instances=[Instance(skeleton_b)]) + + empty_labels = Labels() + labels_with_video = Labels(videos=[vid_a]) + labels_with_skeleton = Labels(skeletons=[skeleton_a]) + + new_labels_a = Labels(labeled_frames=[lf_a]) + new_labels_b = Labels(labeled_frames=[lf_b]) + + json_a = new_labels_a.to_dict() + json_b = new_labels_b.to_dict() + + # Test with empty labels + + assert len(empty_labels.labeled_frames) == 0 + assert len(empty_labels.skeletons) == 0 + assert len(empty_labels.skeletons) == 0 + + add_frames_from_json(empty_labels, json_a) + assert len(empty_labels.labeled_frames) == 1 + assert len(empty_labels.videos) == 1 + assert len(empty_labels.skeletons) == 1 + + add_frames_from_json(empty_labels, json_b) + assert len(empty_labels.labeled_frames) == 2 + assert len(empty_labels.videos) == 2 + assert len(empty_labels.skeletons) == 1 + + empty_labels.to_dict() + + # Test with labels that have video + + assert len(labels_with_video.labeled_frames) == 0 + assert len(labels_with_video.skeletons) == 0 + assert len(labels_with_video.videos) == 1 + + add_frames_from_json(labels_with_video, json_a) + assert len(labels_with_video.labeled_frames) == 1 + assert len(labels_with_video.videos) == 1 + assert len(labels_with_video.skeletons) == 1 + + add_frames_from_json(labels_with_video, json_b) + assert len(labels_with_video.labeled_frames) == 2 + assert len(labels_with_video.videos) == 2 + assert len(labels_with_video.skeletons) == 1 + + labels_with_video.to_dict() + + # Test with labels that have skeleton + + assert len(labels_with_skeleton.labeled_frames) == 0 + assert len(labels_with_skeleton.skeletons) == 1 + assert len(labels_with_skeleton.videos) == 0 + + add_frames_from_json(labels_with_skeleton, json_a) + assert len(labels_with_skeleton.labeled_frames) == 1 + assert len(labels_with_skeleton.videos) == 1 + assert len(labels_with_skeleton.skeletons) == 1 + + add_frames_from_json(labels_with_skeleton, json_b) + assert len(labels_with_skeleton.labeled_frames) == 2 + assert len(labels_with_skeleton.videos) == 2 + assert len(labels_with_skeleton.skeletons) == 1 + + labels_with_skeleton.to_dict() \ No newline at end of file From 3a53b76015c328b24c3d5096c4d2bf4095889c30 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 12:35:49 -0400 Subject: [PATCH 036/176] allow saving project with just skeleton also, don't break when loading project with no video --- sleap/gui/app.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index bea561035..9fce06df8 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -552,7 +552,8 @@ def importData(self, filename=None, do_load=True): self.update_data_views() # Load first video - self.loadVideo(self.labels.videos[0], 0) + if len(self.labels.videos): + self.loadVideo(self.labels.videos[0], 0) # Update track menu options self.updateTrackMenu() @@ -660,7 +661,9 @@ def openSkeleton(self): if len(sk_list): self.skeleton = sk_list[0] - self.changestack_push("new skeleton") + if self.skeleton not in self.labels: + self.labels.skeletons.append(self.skeleton) + self.changestack_push("new skeleton") # Update data model self.update_data_views() From 055a8603c07997360de63865beec11007874007a Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 12:53:38 -0400 Subject: [PATCH 037/176] test modifying instance skeleton --- tests/test_instance.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_instance.py b/tests/test_instance.py index dbe402cad..5fe50e295 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -154,6 +154,24 @@ def test_points_array(skeleton): pts = instance1.points_array(invisible_as_nan=True) assert np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() +def test_modifying_skeleton(skeleton): + node_names = ["left-wing", "head", "right-wing"] + points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} + + instance1 = Instance(skeleton=skeleton, points=points) + + assert len(instance1.points()) == 3 + + skeleton.add_node('new test node') + + instance1.fix_array() # update with changes from skeleton + instance1['new test node'] = Point(7,8) + + assert len(instance1.points()) == 4 + + skeleton.delete_node('head') + assert len(instance1.points()) == 3 + def test_instance_labeled_frame_ref(skeleton, centered_pair_vid): """ Test whether links between labeled frames and instances are kept From 4a121d938ebc628213c85991b4f9994ec6524e61 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 13:06:25 -0400 Subject: [PATCH 038/176] match nodes by id (not name) when updating skel. (partially reverts change from 660b753) --- sleap/instance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index 7eb0c88f1..6f737f6ab 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -565,8 +565,8 @@ def fix_array(self): # Add points into new array for i, node in enumerate(self._nodes): - if node.name in self.skeleton.node_names: - new_array[self.skeleton.node_names.index(node.name)] = self._points[i] + if node in self.skeleton.nodes: + new_array[self.skeleton.nodes.index(node)] = self._points[i] # Update points and nodes for this instance self._points = new_array From 0d9a96f0129775af27a91ee6285fdba32e23c1e7 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 13:26:28 -0400 Subject: [PATCH 039/176] bug fix to previousLabeledFrameIndex() It wasn't returning anything. This was causing us to not find instance in previous frame to copy from when adding a new instance. --- sleap/gui/app.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 9fce06df8..f3a7724b7 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1453,6 +1453,8 @@ def previousLabeledFrameIndex(self): except: return + return next_idx + def previousLabeledFrame(self): prev_idx = self.previousLabeledFrameIndex() if prev_idx is not None: From 35df69ef1754aa2fb29416064bb5e0697c51b4b7 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 13:30:43 -0400 Subject: [PATCH 040/176] minor refactoring in code to find copy instance --- sleap/gui/app.py | 53 +++++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index f3a7724b7..00bfc89f7 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1080,43 +1080,50 @@ def newInstance(self, copy_instance=None): # FIXME: filter by skeleton type from_predicted = copy_instance - unused_predictions = self.labeled_frame.unused_predictions - from_prev_frame = False + if copy_instance is None: selected_idx = self.player.view.getSelection() if selected_idx is not None: # If the user has selected an instance, copy that one. copy_instance = self.labeled_frame.instances[selected_idx] from_predicted = copy_instance - elif len(unused_predictions): + + if copy_instance is None: + unused_predictions = self.labeled_frame.unused_predictions + if len(unused_predictions): # If there are predicted instances that don't correspond to an instance # in this frame, use the first predicted instance without matching instance. copy_instance = unused_predictions[0] from_predicted = copy_instance - else: - # Otherwise, if there are instances in previous frames, - # copy the points from one of those instances. - prev_idx = self.previousLabeledFrameIndex() - if prev_idx is not None: - prev_instances = self.labels.find(self.video, prev_idx, return_new=True)[0].instances - if len(prev_instances) > len(self.labeled_frame.instances): - # If more instances in previous frame than current, then use the - # first unmatched instance. - copy_instance = prev_instances[len(self.labeled_frame.instances)] - from_prev_frame = True - elif len(self.labeled_frame.instances): - # Otherwise, if there are already instances in current frame, - # copy the points from the last instance added to frame. - copy_instance = self.labeled_frame.instances[-1] - elif len(prev_instances): - # Otherwise use the last instance added to previous frame. - copy_instance = prev_instances[-1] - from_prev_frame = True + + if copy_instance is None: + # Otherwise, if there are instances in previous frames, + # copy the points from one of those instances. + prev_idx = self.previousLabeledFrameIndex() + + if prev_idx is not None: + prev_instances = self.labels.find(self.video, prev_idx, return_new=True)[0].instances + if len(prev_instances) > len(self.labeled_frame.instances): + # If more instances in previous frame than current, then use the + # first unmatched instance. + copy_instance = prev_instances[len(self.labeled_frame.instances)] + from_prev_frame = True + elif len(self.labeled_frame.instances): + # Otherwise, if there are already instances in current frame, + # copy the points from the last instance added to frame. + copy_instance = self.labeled_frame.instances[-1] + elif len(prev_instances): + # Otherwise use the last instance added to previous frame. + copy_instance = prev_instances[-1] + from_prev_frame = True + from_predicted = from_predicted if hasattr(from_predicted, "score") else None + + # Now create the new instance new_instance = Instance(skeleton=self.skeleton, from_predicted=from_predicted) - # the rect that's currently visibile in the window view + # Get the rect that's currently visibile in the window view in_view_rect = self.player.view.mapToScene(self.player.view.rect()).boundingRect() # go through each node in skeleton From e67fa0972399ca28e7c236f7d0fdcc2ff07735e4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 15:52:53 -0400 Subject: [PATCH 041/176] only update relevant dataviews (for speed) --- sleap/gui/app.py | 52 ++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 9fce06df8..066d65f6f 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -461,32 +461,40 @@ def update_gui_state(self): self._buttons["remove video"].setEnabled(self.videosTable.currentIndex().isValid()) self._buttons["delete instance"].setEnabled(self.instancesTable.currentIndex().isValid()) - def update_data_views(self): + def update_data_views(self, *update): + update = update or ("video", "skeleton", "labels", "frame", "suggestions") + if len(self.skeleton.nodes) == 0 and len(self.labels.skeletons): self.skeleton = self.labels.skeletons[0] - self.videosTable.model().videos = self.labels.videos + if "video" in update: + self.videosTable.model().videos = self.labels.videos - self.skeletonNodesTable.model().skeleton = self.skeleton - self.skeletonEdgesTable.model().skeleton = self.skeleton - self.skeletonEdgesSrc.model().skeleton = self.skeleton - self.skeletonEdgesDst.model().skeleton = self.skeleton + if "skeleton" in update: + self.skeletonNodesTable.model().skeleton = self.skeleton + self.skeletonEdgesTable.model().skeleton = self.skeleton + self.skeletonEdgesSrc.model().skeleton = self.skeleton + self.skeletonEdgesDst.model().skeleton = self.skeleton + + if "labels" in update: + self.instancesTable.model().labels = self.labels + self.instancesTable.model().color_manager = self._color_manager - self.instancesTable.model().labels = self.labels - self.instancesTable.model().labeled_frame = self.labeled_frame - self.instancesTable.model().color_manager = self._color_manager + if "frame" in update: + self.instancesTable.model().labeled_frame = self.labeled_frame - self.suggestionsTable.model().labels = self.labels + if "suggestions" in update: + self.suggestionsTable.model().labels = self.labels - # update count of suggested frames w/ labeled instances - suggestion_status_text = "" - suggestion_list = self.labels.get_suggestions() - if len(suggestion_list): - suggestion_label_counts = [self.labels.instance_count(video, frame_idx) - for (video, frame_idx) in suggestion_list] - labeled_count = len(suggestion_list) - suggestion_label_counts.count(0) - suggestion_status_text = f"{labeled_count}/{len(suggestion_list)} labeled" - self.suggested_count_label.setText(suggestion_status_text) + # update count of suggested frames w/ labeled instances + suggestion_status_text = "" + suggestion_list = self.labels.get_suggestions() + if len(suggestion_list): + suggestion_label_counts = [self.labels.instance_count(video, frame_idx) + for (video, frame_idx) in suggestion_list] + labeled_count = len(suggestion_list) - suggestion_label_counts.count(0) + suggestion_status_text = f"{labeled_count}/{len(suggestion_list)} labeled" + self.suggested_count_label.setText(suggestion_status_text) def keyPressEvent(self, event: QKeyEvent): if event.key() == Qt.Key_Q: @@ -596,7 +604,7 @@ def addVideo(self, filename=None): self.loadVideo(video, len(self.labels.videos)-1) # Update data model/view - self.update_data_views() + self.update_data_views("video") def removeVideo(self): # Get selected video @@ -771,7 +779,7 @@ def generateSuggestions(self, params): self.labels.set_suggestions(new_suggestions) - self.update_data_views() + self.update_data_views("suggestions") self.updateSeekbarMarks() def _frames_for_prediction(self): @@ -1574,7 +1582,7 @@ def newFrame(self, player, frame_idx, selected_idx): # Update related displays self.updateStatusMessage() - self.update_data_views() + self.update_data_views("frame") # Trigger event after the overlays have been added player.view.updatedViewer.emit() From 5d35763244d2b0b979d30182fe9c93718741d628 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 16:22:07 -0400 Subject: [PATCH 042/176] cache video data shown in table --- sleap/gui/dataviews.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index f22cf77d7..abd5c61c6 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -39,16 +39,24 @@ class VideosTableModel(QtCore.QAbstractTableModel): def __init__(self, videos: list): super(VideosTableModel, self).__init__() - self._videos = videos + self.videos = videos @property def videos(self): - return self._videos + return self._cache @videos.setter def videos(self, val): self.beginResetModel() - self._videos = val + self._cache = [] + for video in val: + row_data = dict( + filename=video.filename, + frames=video.frames, + height=video.height, + width=video.width, + channels=video.channels) + self._cache.append(row_data) self.endResetModel() def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): @@ -59,16 +67,8 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): if len(self.videos) > (idx - 1): video = self.videos[idx] - if prop == "filename": - return video.filename - elif prop == "frames": - return video.frames - elif prop == "height": - return video.height - elif prop == "width": - return video.width - elif prop == "channels": - return video.channels + if prop in video: + return video[prop] return None From 3f607349915ceaaddd2cd8367e4711a7d0cb97ff Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 16:37:15 -0400 Subject: [PATCH 043/176] use widget in status bar to show permanent message --- sleap/gui/app.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index d01da5cae..510d78d0e 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -136,6 +136,8 @@ def initialize_gui(self): ####### Status bar ####### self.statusBar() # Initialize status bar + self.statusWidget = QLabel("") + self.statusBar().addWidget(self.statusWidget) self.load_overlays() @@ -1597,20 +1599,28 @@ def newFrame(self, player, frame_idx, selected_idx): player.view.updatedViewer.emit() def updateStatusMessage(self, message = None): - if message is None: - message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" - if self.player.seekbar.hasSelection(): - start, end = self.player.seekbar.getSelection() - message += f" (selection: {start}-{end})" - message += f" Labeled Frames: " - if self.video is not None: - message += f"{len(self.labels.get_video_user_labeled_frames(self.video))}" - if len(self.labels.videos) > 1: - message += " in video, " + # show temporary message in status bar + self.statusBar().showMessage(message) + + # show permanent message in status widget + message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" + if self.player.seekbar.hasSelection(): + start, end = self.player.seekbar.getSelection() + message += f" (selection: {start}-{end})" + + if len(self.labels.videos) > 1: + message += f" of video {self.labels.videos.index(self.video)}" + + message += f" Labeled Frames: " + if self.video is not None: + message += f"{len(self.labels.get_video_user_labeled_frames(self.video))}" if len(self.labels.videos) > 1: - message += f"{len(self.labels.user_labeled_frames)} in project" + message += " in video, " + if len(self.labels.videos) > 1: + message += f"{len(self.labels.user_labeled_frames)} in project" + + self.statusWidget.setText(message) - self.statusBar().showMessage(message) def main(*args, **kwargs): app = QApplication([]) From c5144b61f6569a98f83decb634826845b62cc904 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 17:09:24 -0400 Subject: [PATCH 044/176] Revert "use widget in status bar" This reverts commit 3f607349915ceaaddd2cd8367e4711a7d0cb97ff. Change prevented window from maximizing on Windows. --- sleap/gui/app.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 510d78d0e..d01da5cae 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -136,8 +136,6 @@ def initialize_gui(self): ####### Status bar ####### self.statusBar() # Initialize status bar - self.statusWidget = QLabel("") - self.statusBar().addWidget(self.statusWidget) self.load_overlays() @@ -1599,28 +1597,20 @@ def newFrame(self, player, frame_idx, selected_idx): player.view.updatedViewer.emit() def updateStatusMessage(self, message = None): - # show temporary message in status bar - self.statusBar().showMessage(message) - - # show permanent message in status widget - message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" - if self.player.seekbar.hasSelection(): - start, end = self.player.seekbar.getSelection() - message += f" (selection: {start}-{end})" - - if len(self.labels.videos) > 1: - message += f" of video {self.labels.videos.index(self.video)}" - - message += f" Labeled Frames: " - if self.video is not None: - message += f"{len(self.labels.get_video_user_labeled_frames(self.video))}" + if message is None: + message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" + if self.player.seekbar.hasSelection(): + start, end = self.player.seekbar.getSelection() + message += f" (selection: {start}-{end})" + message += f" Labeled Frames: " + if self.video is not None: + message += f"{len(self.labels.get_video_user_labeled_frames(self.video))}" + if len(self.labels.videos) > 1: + message += " in video, " if len(self.labels.videos) > 1: - message += " in video, " - if len(self.labels.videos) > 1: - message += f"{len(self.labels.user_labeled_frames)} in project" - - self.statusWidget.setText(message) + message += f"{len(self.labels.user_labeled_frames)} in project" + self.statusBar().showMessage(message) def main(*args, **kwargs): app = QApplication([]) From 4e08c92f7dc892223e26f136d1c1bb2c2da7b661 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 17:20:41 -0400 Subject: [PATCH 045/176] ignore event to set empty tooltip this happens when you hover over the menubar on Windows, and it was causing the statusbar message to go away. --- sleap/gui/app.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index d01da5cae..9bfdb65c8 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1,7 +1,7 @@ from PySide2 import QtCore, QtWidgets -from PySide2.QtCore import Qt +from PySide2.QtCore import Qt, QEvent -from PySide2.QtGui import QKeyEvent, QKeySequence +from PySide2.QtGui import QKeyEvent, QKeySequence, QStatusTipEvent from PySide2.QtWidgets import QApplication, QMainWindow, QWidget, QDockWidget from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout @@ -80,6 +80,12 @@ def __init__(self, data_path=None, video=None, import_data=None, *args, **kwargs if video is not None: self.addVideo(video) + def event(self, e): + if e.type() == QEvent.StatusTip: + if e.tip() == '': + return True + return super().event(e) + def changestack_push(self, change=None): """Add to stack of changes made by user.""" # Currently the change doesn't store any data, and we're only using this From 7bbd6d1305b9c621ff35ae20f518c44ded2862b7 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 13 Sep 2019 17:26:42 -0400 Subject: [PATCH 046/176] show current video index in statusbar --- sleap/gui/app.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 9bfdb65c8..3a9b68087 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1608,6 +1608,10 @@ def updateStatusMessage(self, message = None): if self.player.seekbar.hasSelection(): start, end = self.player.seekbar.getSelection() message += f" (selection: {start}-{end})" + + if len(self.labels.videos) > 1: + message += f" of video {self.labels.videos.index(self.video)}" + message += f" Labeled Frames: " if self.video is not None: message += f"{len(self.labels.get_video_user_labeled_frames(self.video))}" From 455414323c9f324360b2ac82bf1ddef8f9912c0a Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Sat, 14 Sep 2019 07:38:11 -0400 Subject: [PATCH 047/176] add active learning option to not use pafs --- sleap/config/active.yaml | 5 +++++ sleap/gui/active.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sleap/config/active.yaml b/sleap/config/active.yaml index 01afe9684..fa23b0b48 100644 --- a/sleap/config/active.yaml +++ b/sleap/config/active.yaml @@ -26,6 +26,11 @@ expert: type: bool default: False +- name: _dont_use_pafs + label: Single-instance mode (without pafs) + type: bool + default: False + - name: _view_paf label: View Edge Profile... type: button diff --git a/sleap/gui/active.py b/sleap/gui/active.py index 5d79f0b53..0aef11c8b 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -256,8 +256,12 @@ def _get_model_types_to_use(self): form_data = self.form_widget.get_form_data() types_to_use = [] + # always include confidence maps types_to_use.append(ModelOutputType.CONFIDENCE_MAP) - types_to_use.append(ModelOutputType.PART_AFFINITY_FIELD) + + # by default we want to use part affinity fields + if not form_data.get("_dont_use_pafs", False): + types_to_use.append(ModelOutputType.PART_AFFINITY_FIELD) # by default we want to use centroids if form_data.get("_use_centroids", True): From 542ffdeeda4a86e197ce59a851c3a5e7e2c37f45 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 08:43:08 -0400 Subject: [PATCH 048/176] test _dont_use_pafs gui option --- tests/gui/test_active.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/gui/test_active.py b/tests/gui/test_active.py index 5d340cd1d..04e1e80ae 100644 --- a/tests/gui/test_active.py +++ b/tests/gui/test_active.py @@ -5,7 +5,25 @@ from sleap.io.video import Video from sleap.io.dataset import Labels from sleap.nn.model import ModelOutputType -from sleap.gui.active import make_default_training_jobs, find_saved_jobs, add_frames_from_json +from sleap.gui.active import ActiveLearningDialog, make_default_training_jobs, find_saved_jobs, add_frames_from_json + +def test_active_gui(qtbot, centered_pair_labels): + win = ActiveLearningDialog( + labels_filename="foo.json", + labels=centered_pair_labels, + mode="expert") + win.show() + qtbot.addWidget(win) + + # Make sure we include pafs by default + jobs = win._get_current_training_jobs() + assert ModelOutputType.PART_AFFINITY_FIELD in jobs + + # Test option to not include pafs + assert "_dont_use_pafs" in win.form_widget.fields + win.form_widget.set_form_data(dict(_dont_use_pafs=True)) + jobs = win._get_current_training_jobs() + assert ModelOutputType.PART_AFFINITY_FIELD not in jobs def test_make_default_training_jobs(): jobs = make_default_training_jobs() From 349f3d83bf811d310e60df911b05939d07ffd056 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 10:03:59 -0400 Subject: [PATCH 049/176] resize output from multiscale confmaps/centroids --- sleap/gui/overlays/base.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 3281edb35..7bdeab221 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -21,6 +21,7 @@ class ModelData: model: 'keras.Model' video: Video do_rescale: bool=False + output_scale: float=1.0 adjust_vals: bool=True def __getitem__(self, i): @@ -30,16 +31,17 @@ def __getitem__(self, i): # Trim to size that works for model frame_img = frame_img[:, :self.video.height//8*8, :self.video.width//8*8, :] + inference_transform = DataTransform() if self.do_rescale: # Scale input image if model trained on scaled images - inference_transform = DataTransform() frame_img = inference_transform.scale_to( imgs=frame_img, target_size=self.model.input_shape[1:3]) # Get predictions frame_result = self.model.predict(frame_img.astype("float32") / 255) - if self.do_rescale: + if self.do_rescale or self.output_scale != 1.0: + inference_transform.scale *= self.output_scale frame_result = inference_transform.invert_scale(frame_result) # We just want the single image results @@ -137,9 +139,18 @@ def from_model(cls, filename, video, **kwargs): do_rescale = model_data["scale"] < 1 + # Determine how the output from the model should be scaled + img_output_scale = 1.0 # image rescaling + obj_output_scale = 1.0 # scale to pass to overlay object + + if model_output_type == ModelOutputType.PART_AFFINITY_FIELD: + obj_output_scale = model_data["multiscale"] + else: + img_output_scale = model_data["multiscale"] + # Construct the ModelData object that runs inference - data_object = ModelData(model, video, do_rescale=do_rescale) + data_object = ModelData(model, video, do_rescale=do_rescale, output_scale=img_output_scale) # Determine whether to use confmap or paf overlay @@ -157,7 +168,7 @@ def from_model(cls, filename, video, **kwargs): # will be passed to the overlay object to do its own upscaling # (at least for pafs). - transform = DataTransform(scale=model_data["multiscale"]) + transform = DataTransform(scale=obj_output_scale) return cls( data=data_object, From 5babe18cebf85c9e01b1e711fa3cdfc18abcf637 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 11:48:08 -0400 Subject: [PATCH 050/176] iqr*1.5 for outliers, ensure > 0 for log scale --- sleap/nn/monitor.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index 92dcf0f1b..e7e842bcf 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -143,7 +143,7 @@ def add_datapoint(self, x, y, which="batch"): # Redraw batch ever 40 points (faster than plotting each) if x % 40 == 0: xs, ys = self.X, self.Y - points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys)] + points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] self.series["batch"].replace(points) # Set X scale to show all points @@ -151,10 +151,17 @@ def add_datapoint(self, x, y, which="batch"): self.chart.axisX().setRange(min(self.X) - dx, max(self.X) + dx) # Set Y scale to exclude outliers - dy = np.ptp(self.Y) * 0.04 - low, high = np.quantile(self.Y, (.02, .98)) + q1, q3 = np.quantile(self.Y, (.25, .75)) + iqr = q3-q1 # interquartile range + low = q1 - iqr * 1.5 + high = q3 + iqr * 1.5 - self.chart.axisY().setRange(low - dy, high + dy) + low = max(low, min(self.Y) - .2) # keep within range of data + low = max(low, 1e-5) # for log scale, low cannot be 0 + + high = min(high, max(self.Y) + .2) + + self.chart.axisY().setRange(low, high) else: self.series[which].append(x, y) @@ -211,7 +218,7 @@ def check_messages(self, timeout=10): def test_point(x=[0]): x[0] += 1 i = x[0]+1 - win.add_datapoint(i, i%30+1) + win.add_datapoint(i, i%30) t = QtCore.QTimer() t.timeout.connect(test_point) From bec6931d5178aca7c2ceaa3f23ba18b41fa60361 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 13:34:22 -0400 Subject: [PATCH 051/176] use Instance for selection in viewer --- sleap/gui/app.py | 22 +++++++++++++-------- sleap/gui/dataviews.py | 16 +++++++++++---- sleap/gui/video.py | 36 ++++++++++++++-------------------- tests/gui/test_video_player.py | 19 +++++++++++------- 4 files changed, 53 insertions(+), 40 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 3a9b68087..5aa0099e6 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -373,12 +373,18 @@ def _make_dock(name, widgets=[], tab_with=None): instances_layout.addWidget(hbw) def update_instance_table_selection(): - cur_video_instance = self.player.view.getSelection() - if cur_video_instance is None: cur_video_instance = -1 - table_index = self.instancesTable.model().createIndex(cur_video_instance, 0) - self.instancesTable.setCurrentIndex(table_index) + inst_selected = self.player.view.getSelectionInstance() - self.instancesTable.selectionChangedSignal.connect(lambda row: self.player.view.selectInstance(row, from_all=True, signal=False)) + if not inst_selected: return + + idx = -1 + if inst_selected in self.labeled_frame.instances_to_show: + idx = self.labeled_frame.instances_to_show.index(inst_selected) + + table_row_idx = self.instancesTable.model().createIndex(idx, 0) + self.instancesTable.setCurrentIndex(table_row_idx) + + self.instancesTable.selectionChangedSignal.connect(lambda inst: self.player.view.selectInstance(inst, signal=False)) self.player.view.updatedSelection.connect(update_instance_table_selection) # update track UI when change to track name @@ -1581,7 +1587,7 @@ def openKeyRef(self): def openAbout(self): pass - def newFrame(self, player, frame_idx, selected_idx): + def newFrame(self, player, frame_idx, selected_inst): """Called each time a new frame is drawn.""" # Store the current LabeledFrame (or make new, empty object) @@ -1592,8 +1598,8 @@ def newFrame(self, player, frame_idx, selected_idx): overlay.add_to_scene(self.video, frame_idx) # Select instance if there was already selection - if selected_idx > -1: - player.view.selectInstance(selected_idx) + if selected_inst is not None: + player.view.selectInstance(selected_inst) # Update related displays self.updateStatusMessage() diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index abd5c61c6..cfa07d8ad 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -21,7 +21,7 @@ from sleap.gui.overlays.tracks import TrackColorManager from sleap.io.video import Video from sleap.io.dataset import Labels -from sleap.instance import LabeledFrame +from sleap.instance import LabeledFrame, Instance from sleap.skeleton import Skeleton, Node @@ -241,7 +241,7 @@ class LabeledFrameTable(QTableView): """Table view widget backed by a custom data model for displaying lists of Video instances. """ - selectionChangedSignal = QtCore.Signal(int) + selectionChangedSignal = QtCore.Signal(Instance) def __init__(self, labeled_frame: LabeledFrame = None, labels: Labels = None): super(LabeledFrameTable, self).__init__() @@ -250,11 +250,19 @@ def __init__(self, labeled_frame: LabeledFrame = None, labels: Labels = None): self.setSelectionMode(QAbstractItemView.SingleSelection) def selectionChanged(self, new, old): + """Return `Instance` selected in table.""" super(LabeledFrameTable, self).selectionChanged(new, old) - row_idx = -1 + + instance = None if len(new.indexes()): row_idx = new.indexes()[0].row() - self.selectionChangedSignal.emit(row_idx) + try: + instance = self.model().labeled_frame.instances_to_show[row_idx] + except: + # Usually means that there's no labeled_frame + pass + + self.selectionChangedSignal.emit(instance) class LabeledFrameTableModel(QtCore.QAbstractTableModel): diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 5bd3a472a..58ff0daa5 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -28,7 +28,7 @@ import math import numpy as np -from typing import Callable +from typing import Callable, Union from PySide2.QtWidgets import QGraphicsItem, QGraphicsObject # The PySide2.QtWidgets.QGraphicsObject class provides a base class for all graphics items that require signals, slots and properties. @@ -54,7 +54,7 @@ class QtVideoPlayer(QWidget): changedData: Emitted whenever data is changed by user """ - changedPlot = Signal(QWidget, int, int) + changedPlot = Signal(QWidget, int, Instance) changedData = Signal(Instance) def __init__(self, video: Video = None, color_manager=None, *args, **kwargs): @@ -175,28 +175,20 @@ def plot(self, idx=None): self.frame_idx = idx self.seekbar.setValue(self.frame_idx) - # Save index of selected instance - selected_idx = self.view.getSelection() - selected_idx = -1 if selected_idx is None else selected_idx # use -1 for no selection + # Store which Instance is selected + selected_inst = self.view.getSelectionInstance() # Clear existing objects self.view.clear() # Convert ndarray to QImage - # TODO: handle RGB and other formats - # https://stackoverflow.com/questions/34232632/convert-python-opencv-image-numpy-array-to-pyqt-qpixmap-image - # https://stackoverflow.com/questions/55063499/pyqt5-convert-cv2-image-to-qimage - # image = QImage(frame.copy().data, frame.shape[1], frame.shape[0], frame.shape[1], QImage.Format_Grayscale8) - # image = QImage(frame.copy().data, frame.shape[1], frame.shape[0], QImage.Format_Grayscale8) - - # Magic bullet: image = qimage2ndarray.array2qimage(frame) # Display image self.view.setImage(image) - # Emit signal (it's better to use the signal than a callback) - self.changedPlot.emit(self, idx, selected_idx) + # Emit signal + self.changedPlot.emit(self, idx, selected_inst) def nextFrame(self, dt=1): """ Go to next frame. @@ -531,18 +523,20 @@ def nextSelection(self): # signal that the selection has changed (so we can update visual display) self.updatedSelection.emit() - def selectInstance(self, select_idx, from_all=False, signal=True): + def selectInstance(self, select: Union[Instance, int], signal=True): """ - Select a particular skeleton instance. + Select a particular instance in view. Args: - select_idx: index of skeleton to select + select: either `Instance` or index of instance in view + Returns: + None """ - instances = self.selectable_instances if not from_all else self.all_instances self.clearSelection(signal=False) - if select_idx < len(instances): - for idx, instance in enumerate(instances): - instance.selected = (select_idx == idx) + + for idx, instance in enumerate(self.all_instances): + instance.selected = (select == idx or select == instance.instance) + # signal that the selection has changed (so we can update visual display) if signal: self.updatedSelection.emit() diff --git a/tests/gui/test_video_player.py b/tests/gui/test_video_player.py index 7a227fce9..e83bae07f 100644 --- a/tests/gui/test_video_player.py +++ b/tests/gui/test_video_player.py @@ -17,11 +17,11 @@ def test_gui_video_instances(qtbot, small_robot_mp4_vid, centered_pair_labels): vp = QtVideoPlayer(small_robot_mp4_vid) qtbot.addWidget(vp) - test_frame_idx = 0 - labeled_frames = [_ for _ in centered_pair_labels if _.frame_idx == test_frame_idx] + test_frame_idx = 63 + labeled_frames = centered_pair_labels.labeled_frames def plot_instances(vp, idx): - for instance in labeled_frames[idx].instances: + for instance in labeled_frames[test_frame_idx].instances: vp.addInstance(instance=instance, color=(0,0,128)) vp.changedPlot.connect(plot_instances) @@ -31,15 +31,15 @@ def plot_instances(vp, idx): vp.plot() # Check that all instances are included in viewer - assert len(vp.instances) == len(labeled_frames[0].instances) + assert len(vp.instances) == len(labeled_frames[test_frame_idx].instances) vp.zoomToFit() # Check that we zoomed correctly - assert(vp.view.zoomFactor > 2) - + assert(vp.view.zoomFactor > 1) + vp.instances[0].updatePoints(complete=True) - + # Check that node is marked as complete assert vp.instances[0].childItems()[3].point.complete @@ -50,6 +50,11 @@ def plot_instances(vp, idx): qtbot.keyClick(vp, QtCore.Qt.Key_QuoteLeft) assert vp.view.getSelection() == 1 + # Check that selection by Instance works + for inst in labeled_frames[test_frame_idx].instances: + vp.view.selectInstance(inst) + assert vp.view.getSelectionInstance() == inst + # Check that sequence selection works with qtbot.waitCallback() as cb: vp.view.clearSelection() From 04558dd17f5fd71d75fbfa3f96211b89bac783e0 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 16:58:50 -0400 Subject: [PATCH 052/176] gui controls for log scale/outliers/batches to show --- sleap/nn/monitor.py | 145 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 117 insertions(+), 28 deletions(-) diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index e7e842bcf..f26d2c755 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -15,6 +15,11 @@ def __init__(self, zmq_context=None, show_controller=True, parent=None): self.show_controller = show_controller self.stop_button = None + self.redraw_batch_interval = 40 + self.batches_to_show = 200 # -1 to show all + self.ignore_outliers = False + self.log_scale = True + self.reset() self.setup_zmq(zmq_context) @@ -58,18 +63,29 @@ def reset(self, what=""): self.chart.addSeries(self.series["epoch_loss"]) self.chart.addSeries(self.series["val_loss"]) - # self.chart.createDefaultAxes() axisX = QtCharts.QtCharts.QValueAxis() axisX.setLabelFormat("%d") axisX.setTitleText("Batches") self.chart.addAxis(axisX, QtCore.Qt.AlignBottom) - axisY = QtCharts.QtCharts.QLogValueAxis() - axisY.setLabelFormat("%f") - axisY.setLabelsVisible(True) - axisY.setMinorTickCount(1) - axisY.setTitleText("Loss") - axisY.setBase(10) + # create the different Y axes that can be used + self.axisY = dict() + + self.axisY["log"] = QtCharts.QtCharts.QLogValueAxis() + self.axisY["log"].setBase(10) + + self.axisY["linear"] = QtCharts.QtCharts.QValueAxis() + + # settings that apply to all Y axes + for axisY in self.axisY.values(): + axisY.setLabelFormat("%f") + axisY.setLabelsVisible(True) + axisY.setMinorTickCount(1) + axisY.setTitleText("Loss") + + # use the default Y axis + axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"] + self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) for series in self.chart.series(): @@ -86,17 +102,44 @@ def reset(self, what=""): layout.addWidget(self.chartView) if self.show_controller: + control_layout = QtWidgets.QHBoxLayout() + + field = QtWidgets.QCheckBox("Log Scale") + field.setChecked(self.log_scale) + field.stateChanged.connect(lambda x: self.toggle("log_scale")) + control_layout.addWidget(field) + + field = QtWidgets.QCheckBox("Ignore Outliers") + field.setChecked(self.ignore_outliers) + field.stateChanged.connect(lambda x: self.toggle("ignore_outliers")) + control_layout.addWidget(field) + + control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) + + field = QtWidgets.QComboBox() + self.batch_options = "200,1000,5000,All".split(",") + for opt in self.batch_options: + field.addItem(opt) + field.currentIndexChanged.connect(lambda x: self.set_batches_to_show(self.batch_options[x])) + control_layout.addWidget(field) + + control_layout.addStretch(1) + self.stop_button = QtWidgets.QPushButton("Stop Training") self.stop_button.clicked.connect(self.stop) - layout.addWidget(self.stop_button) + control_layout.addWidget(self.stop_button) + + + widget = QtWidgets.QWidget() + widget.setLayout(control_layout) + layout.addWidget(widget) wid = QtWidgets.QWidget() wid.setLayout(layout) self.setCentralWidget(wid) - # Only show that last 2000 batch values - self.X = deque(maxlen=2000) - self.Y = deque(maxlen=2000) + self.X = [] + self.Y = [] self.t0 = None self.current_job_output_type = what @@ -106,6 +149,40 @@ def reset(self, what=""): self.last_batch_number = 0 self.is_running = False + def toggle(self, what): + if what == "log_scale": + self.log_scale = not self.log_scale + self.update_y_axis() + elif what == "ignore_outliers": + self.ignore_outliers = not self.ignore_outliers + elif what == "entire_history": + if self.batches_to_show > 0: + self.batches_to_show = -1 + else: + self.batches_to_show = 200 + + def set_batches_to_show(self, val): + if val.isdigit(): + self.batches_to_show = int(val) + else: + self.batches_to_show = -1 + + def update_y_axis(self): + to = "log" if self.log_scale else "linear" + # remove other axes + for name, axisY in self.axisY.items(): + if name != to: + if axisY in self.chart.axes(): + self.chart.removeAxis(axisY) + for series in self.chart.series(): + if axisY in series.attachedAxes(): + series.detachAxis(axisY) + # add axis + axisY = self.axisY[to] + self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) + for series in self.chart.series(): + series.attachAxis(axisY) + def setup_zmq(self, zmq_context): # Progress monitoring self.ctx_given = (zmq_context is not None) @@ -132,7 +209,6 @@ def stop(self): self.stop_button.setText("Stopping...") self.stop_button.setEnabled(False) - def add_datapoint(self, x, y, which="batch"): # Keep track of all batch points @@ -140,28 +216,41 @@ def add_datapoint(self, x, y, which="batch"): self.X.append(x) self.Y.append(y) - # Redraw batch ever 40 points (faster than plotting each) - if x % 40 == 0: - xs, ys = self.X, self.Y + # Redraw batch at intervals (faster than plotting each) + if x % self.redraw_batch_interval == 0: + + if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: + xs, ys = self.X, self.Y + else: + xs, ys = self.X[-self.batches_to_show:], self.Y[-self.batches_to_show:] + points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] self.series["batch"].replace(points) # Set X scale to show all points dx = 0.5 - self.chart.axisX().setRange(min(self.X) - dx, max(self.X) + dx) - - # Set Y scale to exclude outliers - q1, q3 = np.quantile(self.Y, (.25, .75)) - iqr = q3-q1 # interquartile range - low = q1 - iqr * 1.5 - high = q3 + iqr * 1.5 - - low = max(low, min(self.Y) - .2) # keep within range of data - low = max(low, 1e-5) # for log scale, low cannot be 0 - - high = min(high, max(self.Y) + .2) + self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx) + + if self.ignore_outliers: + # Set Y scale to exclude outliers + q1, q3 = np.quantile(ys, (.25, .75)) + iqr = q3-q1 # interquartile range + low = q1 - iqr * 1.5 + high = q3 + iqr * 1.5 + + low = max(low, min(ys) - .2) # keep within range of data + high = min(high, max(ys) + .2) + else: + # Set Y scale to show all points + dy = np.ptp(ys) * 0.02 + low = min(ys) - dy + high = max(ys) + dy + + if self.log_scale: + low = max(low, 1e-5) # for log scale, low cannot be 0 self.chart.axisY().setRange(low, high) + else: self.series[which].append(x, y) @@ -218,7 +307,7 @@ def check_messages(self, timeout=10): def test_point(x=[0]): x[0] += 1 i = x[0]+1 - win.add_datapoint(i, i%30) + win.add_datapoint(i, i%30+1) t = QtCore.QTimer() t.timeout.connect(test_point) From 012bd027a695354c3f6b9aa38df291acfd75bfce Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 17:16:03 -0400 Subject: [PATCH 053/176] tweak bounds for when ignoring outliers --- sleap/nn/monitor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index f26d2c755..360b2f0ee 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -232,14 +232,15 @@ def add_datapoint(self, x, y, which="batch"): self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx) if self.ignore_outliers: + dy = np.ptp(ys) * 0.02 # Set Y scale to exclude outliers q1, q3 = np.quantile(ys, (.25, .75)) iqr = q3-q1 # interquartile range low = q1 - iqr * 1.5 high = q3 + iqr * 1.5 - low = max(low, min(ys) - .2) # keep within range of data - high = min(high, max(ys) + .2) + low = max(low, min(ys) - dy) # keep within range of data + high = min(high, max(ys) + dy) else: # Set Y scale to show all points dy = np.ptp(ys) * 0.02 From 0c08939bed69c835e64fc358866931e343526116 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 18:48:19 -0400 Subject: [PATCH 054/176] pip install . --- appveyor.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 986f2c7b2..be4a80ef4 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -56,8 +56,9 @@ install: # We need to install this separately, what a mess. # - pip install PySide2 opencv-python imgaug cattrs -# Install dev requirements too. + # Install dev requirements too. - pip install -r dev_requirements.txt + - pip install . build: off test_script: From 8d86b4c8f50de26d193d7232ce6cb61f6fa41a56 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 16 Sep 2019 19:00:50 -0400 Subject: [PATCH 055/176] Install sleap package in appveyor. This allows Requirements.parse("sleap") to work from tests. --- appveyor.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/appveyor.yml b/appveyor.yml index be4a80ef4..d1133066a 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -58,6 +58,8 @@ install: # Install dev requirements too. - pip install -r dev_requirements.txt + + # Install sleap package - pip install . build: off From b664832cd677f3b979af5124cce344bcd97f3fbd Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 06:12:20 -0400 Subject: [PATCH 056/176] Make sure video matches video for ModelData object --- sleap/gui/overlays/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 7bdeab221..b001d00c5 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -72,6 +72,9 @@ class DataOverlay: def add_to_scene(self, video, frame_idx): if self.data is None: return + # Make sure video matches video for ModelData object + if hasattr(self.data, "video") and self.data.video != video: return + if self.transform is None: self._add(self.player.view.scene, self.overlay_class(self.data[frame_idx])) From 2abefd53bc65c687332b4a3f8163b486246022ec Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 06:20:47 -0400 Subject: [PATCH 057/176] Update video for ModelData if shapes match --- sleap/gui/overlays/base.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index b001d00c5..d05c227ca 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -72,8 +72,17 @@ class DataOverlay: def add_to_scene(self, video, frame_idx): if self.data is None: return - # Make sure video matches video for ModelData object - if hasattr(self.data, "video") and self.data.video != video: return + # Check if video matches video for ModelData object + if hasattr(self.data, "video") and self.data.video != video: + video_shape = (video.height, video.width, video.channels) + prior_shape = (self.data.video.height, self.data.video.width, self.data.video.channels) + # Check if the videos are both compatible with the loaded model + if video_shape == prior_shape: + # Shapes match so we can apply model to this video + self.data.video = video + else: + # Shapes don't match so don't do anything with this video + return if self.transform is None: self._add(self.player.view.scene, self.overlay_class(self.data[frame_idx])) From d1e5e2fc9b68fd2128bc283aba6b397fa4165a9e Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 06:35:49 -0400 Subject: [PATCH 058/176] Add visible_points_array property and rename p_a() points_array() renamed to get_points_array() --- sleap/gui/app.py | 2 +- sleap/gui/slider.py | 4 ++-- sleap/info/metrics.py | 10 +++++----- sleap/info/write_tracking_h5.py | 2 +- sleap/instance.py | 8 ++++++-- sleap/io/dataset.py | 2 +- sleap/nn/datagen.py | 2 +- sleap/nn/tracking.py | 8 ++++---- tests/nn/test_inference.py | 4 ++-- tests/test_instance.py | 8 ++++---- 10 files changed, 27 insertions(+), 23 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 5aa0099e6..968837512 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -952,7 +952,7 @@ def delete_area_callback(x0, y0, x1, y1): max_corner = (x1, y1) def is_bounded(inst): - points_array = inst.points_array(invisible_as_nan=True) + points_array = inst.visible_points_array valid_points = points_array[~np.isnan(points_array).any(axis=1)] is_gt_min = np.all(valid_points >= min_corner) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index dae8177bf..4fdb642cf 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -199,8 +199,8 @@ def inst_velocity(lf, last_lf): if last_lf is not None: last_inst = last_lf.find(track=inst.track) if last_inst: - points_a = inst.points_array(invisible_as_nan=True) - points_b = last_inst[0].points_array(invisible_as_nan=True) + points_a = inst.visible_points_array + points_b = last_inst[0].visible_points_array point_dist = np.linalg.norm(points_a - points_b, axis=1) inst_dist = np.sum(point_dist) # np.nanmean(point_dist) val += inst_dist if not np.isnan(inst_dist) else 0 diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index b3acddb8c..ee4e958c7 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -158,8 +158,8 @@ def point_dist( inst_b: Union[Instance, PredictedInstance]) -> np.ndarray: """Given two instances, returns array of distances for corresponding nodes.""" - points_a = inst_a.points_array(invisible_as_nan=True) - points_b = inst_b.points_array(invisible_as_nan=True) + points_a = inst_a.visible_points_array + points_b = inst_b.visible_points_array point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist @@ -171,8 +171,8 @@ def nodeless_point_dist(inst_a: Union[Instance, PredictedInstance], matrix_size = (len(inst_a.skeleton.nodes), len(inst_b.skeleton.nodes)) pairwise_distance_matrix = np.full(matrix_size, 0) - points_a = inst_a.points_array(invisible_as_nan=True) - points_b = inst_b.points_array(invisible_as_nan=True) + points_a = inst_a.visible_points_array + points_b = inst_b.visible_points_array # Calculate the distance between any pair of inst A and inst B points for idx_a in range(points_a.shape[0]): @@ -205,7 +205,7 @@ def compare_instance_lists( def list_points_array(instances: List[Union[Instance, PredictedInstance]]) -> np.ndarray: """Given list of Instances, returns (instances * nodes * 2) matrix.""" - points_arrays = list(map(lambda inst: inst.points_array(invisible_as_nan=True), instances)) + points_arrays = list(map(lambda inst: inst.visible_points_array, instances)) return np.stack(points_arrays) def point_match_count(dist_array: np.ndarray, thresh: float=5) -> int: diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index fe312a3dd..143d87083 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -65,7 +65,7 @@ def video_callback(video_list, new_paths=[os.path.dirname(args.data_path)]): occupancy_matrix[track_i, frame_i] = 1 - inst_points = inst.points_array(invisible_as_nan=True) + inst_points = inst.visible_points_array prediction_matrix[frame_i, ..., track_i] = inst_points print(f"track_occupancy: {occupancy_matrix.shape}") diff --git a/sleap/instance.py b/sleap/instance.py index 6f737f6ab..5ced949c1 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -572,7 +572,7 @@ def fix_array(self): self._points = new_array self._nodes = self.skeleton.nodes - def points_array(self, copy: bool = True, + def get_points_array(self, copy: bool = True, invisible_as_nan: bool = False, full: bool = False) -> np.ndarray: """ @@ -607,10 +607,14 @@ def points_array(self, copy: bool = True, return parray + @property + def visible_points_array(self) -> np.ndarray: + return self.get_points_array(invisible_as_nan=True) + @property def centroid(self) -> np.ndarray: """Returns instance centroid as (x,y) numpy row vector.""" - points = self.points_array(invisible_as_nan=True) + points = self.visible_points_array centroid = np.nanmedian(points, axis=0) return centroid diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 25a39e0b5..1e1797679 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1169,7 +1169,7 @@ def append_unique(old, new): frames[frame_id] = (frame_id+frame_id_offset, video_to_idx[label.video], label.frame_idx, instance_id+instance_id_offset, instance_id+instance_id_offset+len(label.instances)) for instance in label.instances: - parray = instance.points_array(copy=False, full=True) + parray = instance.get_points_array(copy=False, full=True) instance_type = type(instance) # Check whether we are working with a PredictedInstance or an Instance. diff --git a/sleap/nn/datagen.py b/sleap/nn/datagen.py index d3860b678..78a550938 100644 --- a/sleap/nn/datagen.py +++ b/sleap/nn/datagen.py @@ -142,7 +142,7 @@ def generate_points_from_list(labels:Labels, frame_list: List[Tuple], scale: flo def lf_points_from_singleton(lf_singleton): if len(lf_singleton) == 0: return [] lf = lf_singleton[0] - points = [inst.points_array(invisible_as_nan=True)*scale + points = [inst.visible_points_array*scale for inst in lf.user_instances] return points diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 2facf5bb3..aac286de2 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -67,7 +67,7 @@ def frame_idx(self) -> int: """ return self.frame.frame_idx - def points_array(self, *args, **kwargs): + def get_points_array(self, *args, **kwargs): """ Return the ShiftedInstance as a numpy array. ShiftedInstance stores its points as an array always, unlike the Instance class. This method provides @@ -222,7 +222,7 @@ def process(self, self.last_frame_index = t t = frame.frame_idx - instances_pts = [i.points_array() for i in frame.instances] + instances_pts = [i.get_points_array() for i in frame.instances] # If we do not have any active tracks, we will spawn one for each # matched instance and continue to the next frame. @@ -240,7 +240,7 @@ def process(self, # Get all points in reference frame instances_ref = self.tracks.get_frame_instances(self.last_frame_index, max_shift=self.window - 1) - pts_ref = [instance.points_array() for instance in instances_ref] + pts_ref = [instance.get_points_array() for instance in instances_ref] tmp = min([instance.frame_idx for instance in instances_ref] + [instance.source.frame_idx for instance in instances_ref @@ -305,7 +305,7 @@ def process(self, cost_matrix = np.full((len(unassigned_pts), len(shifted_tracks)), np.nan) for i, track in enumerate(shifted_tracks): # Get shifted points for current track - track_pts = np.stack([instance.points_array() + track_pts = np.stack([instance.get_points_array() for instance in shifted_instances if instance.track == track], axis=0) # track_instances x nodes x 2 diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 21f5a2f76..1f26e6866 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -92,6 +92,6 @@ def test_peaks_with_scaling(): # make sure that each true instance has points matching one of the new instances for inst_a, inst_b in zip(true_labels.labeled_frames[i].instances, new_labels.labeled_frames[i].instances): - assert inst_a.points_array().shape == inst_b.points_array().shape + assert inst_a.get_points_array().shape == inst_b.get_points_array().shape # FIXME: new instances have nans, so for now just check first 5 points - assert np.allclose(inst_a.points_array()[0:5], inst_b.points_array()[0:5], atol=1/scale) + assert np.allclose(inst_a.get_points_array()[0:5], inst_b.get_points_array()[0:5], atol=1/scale) diff --git a/tests/test_instance.py b/tests/test_instance.py index 5fe50e295..61f98a5eb 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -130,7 +130,7 @@ def test_points_array(skeleton): instance1 = Instance(skeleton=skeleton, points=points) - pts = instance1.points_array() + pts = instance1.get_points_array() assert pts.shape == (len(skeleton.nodes), 2) assert np.allclose(pts[skeleton.node_to_index('left-wing'), :], [2, 5]) @@ -141,17 +141,17 @@ def test_points_array(skeleton): # Now change a point, make sure it is reflected instance1['head'].x = 0 instance1['thorax'] = Point(1, 2) - pts = instance1.points_array() + pts = instance1.get_points_array() assert np.allclose(pts[skeleton.node_to_index('head'), :], [0, 4]) assert np.allclose(pts[skeleton.node_to_index('thorax'), :], [1, 2]) # Make sure that invisible points are nan iff invisible_as_nan=True instance1['thorax'] = Point(1, 2, visible=False) - pts = instance1.points_array() + pts = instance1.get_points_array() assert not np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() - pts = instance1.points_array(invisible_as_nan=True) + pts = instance1.visible_points_array assert np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() def test_modifying_skeleton(skeleton): From 2d3f26e43173ab71936ee6611f9d4080727f1300 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 06:40:55 -0400 Subject: [PATCH 059/176] Change Instance.points() method to property. --- sleap/instance.py | 5 +++-- sleap/io/dataset.py | 2 +- tests/nn/test_inference.py | 2 +- tests/test_instance.py | 10 +++++----- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index 5ced949c1..f5faf7e2a 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -501,7 +501,7 @@ def matches(self, other): if type(self) is not type(other): return False - if list(self.points()) != list(other.points()): + if list(self.points) != list(other.points): return False if not self.skeleton.matches(other.skeleton): @@ -541,9 +541,10 @@ def nodes_points(self): Returns: The instance's (node, point) tuple pairs for all labelled point. """ - names_to_points = dict(zip(self.nodes, self.points())) + names_to_points = dict(zip(self.nodes, self.points)) return names_to_points.items() + @property def points(self) -> Tuple[Point]: """ Return the list of labelled points, in order they were labelled. diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 1e1797679..f5efeadff 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1396,7 +1396,7 @@ def load_mat(cls, filename): x = points_[node_idx][0][i] y = points_[node_idx][1][i] new_inst[node] = Point(x, y) - if len(new_inst.points()): + if len(new_inst.points): new_frame = LabeledFrame(video=vid, frame_idx=i) new_frame.instances = new_inst, labeled_frames.append(new_frame) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 1f26e6866..b9c1b4088 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -19,7 +19,7 @@ def check_labels(labels): for i in labels.all_instances: assert type(i) == PredictedInstance - assert type(i.points()[0]) == PredictedPoint + assert type(i.points[0]) == PredictedPoint # Make sure frames are in order for i, frame in enumerate(labels): diff --git a/tests/test_instance.py b/tests/test_instance.py index 61f98a5eb..82e93abf2 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -68,8 +68,8 @@ def test_instance_point_iter(skeleton): instance = Instance(skeleton=skeleton, points=points) assert [node.name for node in instance.nodes] == ['head', 'left-wing', 'right-wing'] - assert np.allclose([p.x for p in instance.points()], [1, 2, 3]) - assert np.allclose([p.y for p in instance.points()], [4, 5, 6]) + assert np.allclose([p.x for p in instance.points], [1, 2, 3]) + assert np.allclose([p.y for p in instance.points], [4, 5, 6]) # Make sure we can iterate over tuples for (node, point) in instance.nodes_points: @@ -160,17 +160,17 @@ def test_modifying_skeleton(skeleton): instance1 = Instance(skeleton=skeleton, points=points) - assert len(instance1.points()) == 3 + assert len(instance1.points) == 3 skeleton.add_node('new test node') instance1.fix_array() # update with changes from skeleton instance1['new test node'] = Point(7,8) - assert len(instance1.points()) == 4 + assert len(instance1.points) == 4 skeleton.delete_node('head') - assert len(instance1.points()) == 3 + assert len(instance1.points) == 3 def test_instance_labeled_frame_ref(skeleton, centered_pair_vid): """ From e68299aac5b6996d13719d6e0a16b0f81e1051d1 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 08:40:02 -0400 Subject: [PATCH 060/176] re-ordered colors used for overlays --- sleap/gui/overlays/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index d05c227ca..694ba796f 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -190,9 +190,9 @@ def from_model(cls, filename, video, **kwargs): h5_colors = [ [204, 81, 81], - [127, 51, 51], [81, 204, 204], [51, 127, 127], + [127, 51, 51], [142, 204, 81], [89, 127, 51], [142, 81, 204], From a3b9b0a5f08c3394309c40b9cfa3c8ecbd33c897 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 08:40:45 -0400 Subject: [PATCH 061/176] prevent training/inference if model overlay loaded --- sleap/gui/app.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 968837512..bab43072c 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -820,6 +820,10 @@ def remove_user_labeled(video, frames, user_labeled_frames=self.labels.user_labe def _show_learning_window(self, mode): from sleap.gui.active import ActiveLearningDialog + if "inference" in self.overlays: + QMessageBox(text=f"In order to use this function you must first quit and re-open sLEAP to release resources used by visualizing model outputs.").exec_() + return + if self._child_windows.get(mode, None) is None: self._child_windows[mode] = ActiveLearningDialog(self.filename, self.labels, mode) self._child_windows[mode].learningFinished.connect(self.learningFinished) From aa43f616f9a81e6ad6a3b6588f279a3c633c5b30 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 15:03:27 -0400 Subject: [PATCH 062/176] Gui and class for viewing/changing menu shortcuts. --- sleap/config/shortcuts.yaml | 64 ++++++++--------- sleap/gui/app.py | 33 ++++----- sleap/gui/shortcuts.py | 138 ++++++++++++++++++++++++++++++++++++ tests/gui/test_shortcuts.py | 12 ++++ 4 files changed, 196 insertions(+), 51 deletions(-) create mode 100644 sleap/gui/shortcuts.py create mode 100644 tests/gui/test_shortcuts.py diff --git a/sleap/config/shortcuts.yaml b/sleap/config/shortcuts.yaml index cda364e6c..4e64205ed 100644 --- a/sleap/config/shortcuts.yaml +++ b/sleap/config/shortcuts.yaml @@ -1,32 +1,32 @@ -"new": QKeySequence.New -"open": QKeySequence.Open -"save": QKeySequence.Save -"save as": QKeySequence.SaveAs -"close": QKeySequence.Close -"add videos": Qt.CTRL + Qt.Key_A -"next video": QKeySequence.Forward -"prev video": QKeySequence.Back -"goto frame": Qt.CTRL + Qt.Key_J -"mark frame": Qt.CTRL + Qt.Key_M -"goto marked": Qt.CTRL + Qt.SHIFT + Qt.Key_M -"add instance": Qt.CTRL + Qt.Key_I -"delete instance": Qt.CTRL + Qt.Key_Backspace -"delete track": Qt.CTRL + Qt.SHIFT + Qt.Key_Backspace -"transpose": Qt.CTRL + Qt.Key_T -"select next": QKeySequence(Qt.Key.Key_QuoteLeft) -"clear selection": QKeySequence(Qt.Key.Key_Escape) -"goto next": Qt.CTRL + Qt.Key_Period -"goto prev": -"goto next user": Qt.CTRL + Qt.Key_Greater -"goto next suggestion": QKeySequence.FindNext -"goto prev suggestion": QKeySequence.FindPrevious -"goto next track": Qt.CTRL + Qt.Key_E -"show labels": Qt.CTRL + Qt.Key_Tab -"show edges": Qt.CTRL + Qt.SHIFT + Qt.Key_Tab -"show trails": -"color predicted": -"fit": Qt.CTRL + Qt.Key_Equal -"learning": Qt.CTRL + Qt.Key_L -"export clip": -"delete clip": -"delete area": Qt.CTRL + Qt.Key_K +add instance: Ctrl+I +add videos: Ctrl+A +clear selection: Esc +close: QKeySequence.Close +color predicted: +delete area: Ctrl+K +delete clip: +delete instance: Ctrl+Backspace +delete track: Ctrl+Shift+Backspace +export clip: +fit: Ctrl+= +goto frame: Ctrl+J +goto marked: Ctrl+Shift+M +goto next suggestion: QKeySequence.FindNext +goto next track spawn: Ctrl+E +goto next user: Ctrl+> +goto next labeled: Ctrl+. +goto prev suggestion: QKeySequence.FindPrevious +goto prev labeled: +learning: Ctrl+L +mark frame: Ctrl+M +new: Ctrl+N +next video: QKeySequence.Forward +open: Ctrl+O +prev video: QKeySequence.Back +save as: QKeySequence.SaveAs +save: Ctrl+S +select next: '`' +show edges: Ctrl+Shift+Tab +show labels: Ctrl+Tab +show trails: +transpose: Ctrl+T diff --git a/sleap/gui/app.py b/sleap/gui/app.py index bab43072c..e7ceda1b4 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -31,6 +31,7 @@ LabeledFrameTable, SkeletonNodeModel, SuggestionsTable from sleap.gui.importvideos import ImportVideos from sleap.gui.formbuilder import YamlFormWidget +from sleap.gui.shortcuts import Shortcuts, ShortcutDialog from sleap.gui.suggestions import VideoFrameSuggestions from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay @@ -122,15 +123,7 @@ def filename(self, x): def initialize_gui(self): - shortcut_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/shortcuts.yaml") - with open(shortcut_yaml, 'r') as f: - shortcuts = yaml.load(f, Loader=yaml.SafeLoader) - - for action in shortcuts: - key_string = shortcuts.get(action, None) - key_string = "" if key_string is None else key_string - if "." in key_string: - shortcuts[action] = eval(key_string) + shortcuts = Shortcuts() ####### Video player ####### self.player = QtVideoPlayer(color_manager=self._color_manager) @@ -150,12 +143,12 @@ def initialize_gui(self): ### File Menu ### fileMenu = self.menuBar().addMenu("File") - self._menu_actions["new"] = fileMenu.addAction("&New Project", self.newProject, shortcuts["new"]) - self._menu_actions["open"] = fileMenu.addAction("&Open Project...", self.openProject, shortcuts["open"]) + self._menu_actions["new"] = fileMenu.addAction("New Project", self.newProject, shortcuts["new"]) + self._menu_actions["open"] = fileMenu.addAction("Open Project...", self.openProject, shortcuts["open"]) fileMenu.addSeparator() self._menu_actions["add videos"] = fileMenu.addAction("Add Videos...", self.addVideo, shortcuts["add videos"]) fileMenu.addSeparator() - self._menu_actions["save"] = fileMenu.addAction("&Save", self.saveProject, shortcuts["save"]) + self._menu_actions["save"] = fileMenu.addAction("Save", self.saveProject, shortcuts["save"]) self._menu_actions["save as"] = fileMenu.addAction("Save As...", self.saveProjectAs, shortcuts["save as"]) fileMenu.addSeparator() self._menu_actions["close"] = fileMenu.addAction("Quit", self.close, shortcuts["close"]) @@ -164,15 +157,15 @@ def initialize_gui(self): goMenu = self.menuBar().addMenu("Go") - self._menu_actions["goto next"] = goMenu.addAction("Next Labeled Frame", self.nextLabeledFrame, shortcuts["goto next"]) - self._menu_actions["goto prev"] = goMenu.addAction("Previous Labeled Frame", self.previousLabeledFrame, shortcuts["goto prev"]) + self._menu_actions["goto next labeled"] = goMenu.addAction("Next Labeled Frame", self.nextLabeledFrame, shortcuts["goto next labeled"]) + self._menu_actions["goto prev labeled"] = goMenu.addAction("Previous Labeled Frame", self.previousLabeledFrame, shortcuts["goto prev labeled"]) self._menu_actions["goto next user"] = goMenu.addAction("Next User Labeled Frame", self.nextUserLabeledFrame, shortcuts["goto next user"]) self._menu_actions["goto next suggestion"] = goMenu.addAction("Next Suggestion", self.nextSuggestedFrame, shortcuts["goto next suggestion"]) self._menu_actions["goto prev suggestion"] = goMenu.addAction("Previous Suggestion", lambda:self.nextSuggestedFrame(-1), shortcuts["goto prev suggestion"]) - self._menu_actions["goto next track"] = goMenu.addAction("Next Track Spawn Frame", self.nextTrackFrame, shortcuts["goto next track"]) + self._menu_actions["goto next track spawn"] = goMenu.addAction("Next Track Spawn Frame", self.nextTrackFrame, shortcuts["goto next track spawn"]) goMenu.addSeparator() @@ -457,13 +450,13 @@ def update_gui_state(self): self._menu_actions["next video"].setEnabled(has_multiple_videos) self._menu_actions["prev video"].setEnabled(has_multiple_videos) - self._menu_actions["goto next"].setEnabled(has_labeled_frames) - self._menu_actions["goto prev"].setEnabled(has_labeled_frames) + self._menu_actions["goto next labeled"].setEnabled(has_labeled_frames) + self._menu_actions["goto prev labeled"].setEnabled(has_labeled_frames) self._menu_actions["goto next suggestion"].setEnabled(has_suggestions) self._menu_actions["goto prev suggestion"].setEnabled(has_suggestions) - self._menu_actions["goto next track"].setEnabled(has_tracks) + self._menu_actions["goto next track spawn"].setEnabled(has_tracks) # Update buttons self._buttons["add edge"].setEnabled(has_nodes_selected) @@ -1586,8 +1579,10 @@ def toggleAutoZoom(self): def openDocumentation(self): pass + def openKeyRef(self): - pass + ShortcutDialog().exec_() + def openAbout(self): pass diff --git a/sleap/gui/shortcuts.py b/sleap/gui/shortcuts.py new file mode 100644 index 000000000..0bf7b797d --- /dev/null +++ b/sleap/gui/shortcuts.py @@ -0,0 +1,138 @@ +from PySide2 import QtWidgets, QtCore +from PySide2.QtCore import Qt +from PySide2.QtGui import QKeySequence + +import sys +import yaml + +from pkg_resources import Requirement, resource_filename + +class ShortcutDialog(QtWidgets.QDialog): + + _column_len = 13 + + def __init__(self, *args, **kwargs): + super(ShortcutDialog, self).__init__(*args, **kwargs) + + self.setWindowTitle("Keyboard Shortcuts") + self.load_shortcuts() + self.make_form() + + def accept(self): + for action, widget in self.key_widgets.items(): + self.shortcuts[action] = widget.keySequence().toString() + self.shortcuts.save() + + super(ShortcutDialog, self).accept() + + def load_shortcuts(self): + self.shortcuts = Shortcuts() + + def make_form(self): + self.key_widgets = dict() # dict to store QKeySequenceEdit widgets + + layout = QtWidgets.QVBoxLayout() + layout.addWidget(self.make_shortcuts_widget()) + layout.addWidget(QtWidgets.QLabel("Any changes to keyboard shortcuts will not take effect until you quit and re-open the application.")) + layout.addWidget(self.make_buttons_widget()) + self.setLayout(layout) + + def make_buttons_widget(self): + buttons = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + return buttons + + def make_shortcuts_widget(self): + shortcuts = self.shortcuts + + widget = QtWidgets.QWidget() + layout = QtWidgets.QHBoxLayout() + + # show shortcuts in columns + for a in range(0, len(shortcuts), self._column_len): + b = min(len(shortcuts), a + self._column_len) + column_widget = self.make_column_widget(shortcuts[a:b]) + layout.addWidget(column_widget) + widget.setLayout(layout) + return widget + + def make_column_widget(self, shortcuts): + column_widget = QtWidgets.QWidget() + column_layout = QtWidgets.QFormLayout() + for action in shortcuts: + item = QtWidgets.QKeySequenceEdit(shortcuts[action]) + column_layout.addRow(action.title(), item) + self.key_widgets[action] = item + column_widget.setLayout(column_layout) + return column_widget + + +def dict_cut(d, a, b): + return dict(list(d.items())[a:b]) + +class Shortcuts: + + _shortcuts = None + _names = ("new", "open", "save", "save as", "close", + "add videos", "next video", "prev video", + "goto frame", "mark frame", "goto marked", + "add instance", "delete instance", "delete track", + "transpose", "select next", "clear selection", + "goto next labeled", "goto prev labeled", "goto next user", + "goto next suggestion", "goto prev suggestion", + "goto next track spawn", + "show labels", "show edges", "show trails", + "color predicted", "fit", "learning", + "export clip", "delete clip", "delete area") + + def __init__(self): + shortcut_yaml = resource_filename(Requirement.parse("sleap"), "sleap/config/shortcuts.yaml") + with open(shortcut_yaml, 'r') as f: + shortcuts = yaml.load(f, Loader=yaml.SafeLoader) + + for action in shortcuts: + key_string = shortcuts.get(action, None) + key_string = "" if key_string is None else key_string + + try: + shortcuts[action] = eval(key_string) + except: + shortcuts[action] = QKeySequence.fromString(key_string) + + self._shortcuts = shortcuts + + def save(self): + shortcut_yaml = resource_filename(Requirement.parse("sleap"), "sleap/config/shortcuts.yaml") + with open(shortcut_yaml, 'w') as f: + yaml.dump(self._shortcuts, f) + + def __getitem__(self, idx): + if isinstance(idx, slice): + # dict with names and values + return {self._names[i]:self[i] for i in range(*idx.indices(len(self)))} + elif isinstance(idx, int): + # value + idx = self._names[idx] + return self[idx] + else: + # value + if idx in self._shortcuts: + return self._shortcuts[idx] + return "" + + def __setitem__(self, idx, val): + if type(idx) == int: + idx = self._names[idx] + self[idx] = val + else: + self._shortcuts[idx] = val + + def __len__(self): + return len(self._names) + +if __name__ == "__main__": + app = QtWidgets.QApplication() + win = ShortcutDialog() + win.show() + app.exec_() \ No newline at end of file diff --git a/tests/gui/test_shortcuts.py b/tests/gui/test_shortcuts.py new file mode 100644 index 000000000..d6524dcb9 --- /dev/null +++ b/tests/gui/test_shortcuts.py @@ -0,0 +1,12 @@ +from PySide2.QtGui import QKeySequence + +from sleap.gui.shortcuts import Shortcuts + +def test_shortcuts(): + shortcuts = Shortcuts() + + assert shortcuts["new"] == shortcuts[0] + assert shortcuts["new"] == QKeySequence.fromString("Ctrl+N") + shortcuts["new"] = QKeySequence.fromString("Ctrl+Shift+N") + assert shortcuts["new"] == QKeySequence.fromString("Ctrl+Shift+N") + assert list(shortcuts[0:2].keys()) == ["new", "open"] \ No newline at end of file From a70ba763a178bbf395de44248a8bb111eee4ba46 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 15:54:31 -0400 Subject: [PATCH 063/176] Overlay to show negative anchors, command to clear --- sleap/gui/app.py | 12 ++++++++++++ sleap/io/dataset.py | 20 ++++++++++++++++++-- tests/io/test_dataset.py | 12 ++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index e7ceda1b4..3b65479b7 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -36,6 +36,7 @@ from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay from sleap.gui.overlays.instance import InstanceOverlay +from sleap.gui.overlays.anchors import NegativeAnchorOverlay OPEN_IN_NEW = True @@ -245,6 +246,7 @@ def initialize_gui(self): self._menu_actions["learning expert"] = predictionMenu.addAction("Expert Controls...", self.runLearningExpert) predictionMenu.addSeparator() self._menu_actions["negative sample"] = predictionMenu.addAction("Mark Negative Training Sample...", self.markNegativeAnchor) + self._menu_actions["clear negative samples"] = predictionMenu.addAction("Clear Current Frame Negative Samples", self.clearFrameNegativeAnchors) predictionMenu.addSeparator() self._menu_actions["visualize models"] = predictionMenu.addAction("Visualize Model Outputs...", self.visualizeOutputs) self._menu_actions["import predictions"] = predictionMenu.addAction("Import Predictions...", self.importPredictions) @@ -414,6 +416,10 @@ def update_instance_table_selection(): self.update_gui_timer.start(0.1) def load_overlays(self): + self.overlays["negative"] = NegativeAnchorOverlay( + labels = self.labels, + scene = self.player.view.scene) + self.overlays["trails"] = TrackTrailOverlay( labels = self.labels, scene = self.player.view.scene, @@ -1029,11 +1035,17 @@ def click_callback(x, y): self.updateStatusMessage() self.labels.add_negative_anchor(self.video, self.player.frame_idx, (x, y)) self.changestack_push("add negative anchors") + self.plotFrame() # Prompt the user to select area self.updateStatusMessage(f"Please click where you want a negative sample...") self.player.onPointSelection(click_callback) + def clearFrameNegativeAnchors(self): + self.labels.remove_negative_anchors(self.video, self.player.frame_idx) + self.changestack_push("remove negative anchors") + self.plotFrame() + def importPredictions(self): filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"] filenames, selected_filter = QFileDialog.getOpenFileNames(self, dir=None, caption="Import labeled data...", filter=";;".join(filters)) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index f5efeadff..0a73134b6 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -378,7 +378,7 @@ def instance_count(self, video: Video, frame_idx: int) -> int: count = len([inst for inst in labeled_frame.instances if type(inst)==Instance]) return count - + @property def all_instances(self): return list(self.instances()) @@ -529,7 +529,7 @@ def find_track_instances(self, *args, **kwargs) -> List[Instance]: return [inst for lf, inst in self.find_track_occupancy(*args, **kwargs)] # Methods for suggestions - + def get_video_suggestions(self, video:Video) -> list: """ Returns the list of suggested frames for the specified video @@ -645,6 +645,22 @@ def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): self.negative_anchors[video] = [] self.negative_anchors[video].append((frame_idx, *where)) + def remove_negative_anchors(self, video:Video, frame_idx: int): + """Removes negative training samples for given video and frame. + + Args: + video: the `Video` for which we're removing negative samples + frame_idx: frame index + Returns: + None + """ + if video not in self.negative_anchors: return + + anchors = [(idx, x, y) + for idx, x, y in self.negative_anchors[video] + if idx != frame_idx] + self.negative_anchors[video] = anchors + # Methods for saving/loading def extend_from(self, new_frames: Union['Labels',List[LabeledFrame]], unify:bool=False): diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 7a0aeb74a..2d60f28b9 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -371,6 +371,18 @@ def test_suggestions(small_robot_mp4_vid): assert len(labels.get_video_suggestions(dummy_video)) == 13 +def test_negative_anchors(): + video = Video.from_filename("foo.mp4") + labels = Labels() + + labels.add_negative_anchor(video, 1, (3, 4)) + labels.add_negative_anchor(video, 1, (7, 8)) + labels.add_negative_anchor(video, 2, (5, 9)) + + assert len(labels.negative_anchors[video]) == 3 + + labels.remove_negative_anchors(video, 1) + assert len(labels.negative_anchors[video]) == 1 def test_load_labels_mat(mat_labels): assert len(mat_labels.nodes) == 6 From 3737d84d4fc00c45572e7b4f2f5925468c8a2c1b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 15:54:31 -0400 Subject: [PATCH 064/176] Overlay to show negative anchors, command to clear --- sleap/gui/app.py | 12 ++++++++++++ sleap/gui/overlays/anchors.py | 27 +++++++++++++++++++++++++++ sleap/io/dataset.py | 20 ++++++++++++++++++-- tests/io/test_dataset.py | 12 ++++++++++++ 4 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 sleap/gui/overlays/anchors.py diff --git a/sleap/gui/app.py b/sleap/gui/app.py index e7ceda1b4..3b65479b7 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -36,6 +36,7 @@ from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay from sleap.gui.overlays.instance import InstanceOverlay +from sleap.gui.overlays.anchors import NegativeAnchorOverlay OPEN_IN_NEW = True @@ -245,6 +246,7 @@ def initialize_gui(self): self._menu_actions["learning expert"] = predictionMenu.addAction("Expert Controls...", self.runLearningExpert) predictionMenu.addSeparator() self._menu_actions["negative sample"] = predictionMenu.addAction("Mark Negative Training Sample...", self.markNegativeAnchor) + self._menu_actions["clear negative samples"] = predictionMenu.addAction("Clear Current Frame Negative Samples", self.clearFrameNegativeAnchors) predictionMenu.addSeparator() self._menu_actions["visualize models"] = predictionMenu.addAction("Visualize Model Outputs...", self.visualizeOutputs) self._menu_actions["import predictions"] = predictionMenu.addAction("Import Predictions...", self.importPredictions) @@ -414,6 +416,10 @@ def update_instance_table_selection(): self.update_gui_timer.start(0.1) def load_overlays(self): + self.overlays["negative"] = NegativeAnchorOverlay( + labels = self.labels, + scene = self.player.view.scene) + self.overlays["trails"] = TrackTrailOverlay( labels = self.labels, scene = self.player.view.scene, @@ -1029,11 +1035,17 @@ def click_callback(x, y): self.updateStatusMessage() self.labels.add_negative_anchor(self.video, self.player.frame_idx, (x, y)) self.changestack_push("add negative anchors") + self.plotFrame() # Prompt the user to select area self.updateStatusMessage(f"Please click where you want a negative sample...") self.player.onPointSelection(click_callback) + def clearFrameNegativeAnchors(self): + self.labels.remove_negative_anchors(self.video, self.player.frame_idx) + self.changestack_push("remove negative anchors") + self.plotFrame() + def importPredictions(self): filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"] filenames, selected_filter = QFileDialog.getOpenFileNames(self, dir=None, caption="Import labeled data...", filter=";;".join(filters)) diff --git a/sleap/gui/overlays/anchors.py b/sleap/gui/overlays/anchors.py new file mode 100644 index 000000000..04a2306f3 --- /dev/null +++ b/sleap/gui/overlays/anchors.py @@ -0,0 +1,27 @@ +import attr + +from PySide2 import QtWidgets, QtGui + +from sleap.gui.video import QtVideoPlayer +from sleap.io.dataset import Labels + +@attr.s(auto_attribs=True) +class NegativeAnchorOverlay: + + labels: Labels=None + scene: QtWidgets.QGraphicsScene=None + pen = QtGui.QPen(QtGui.QColor("red")) + line_len: int=3 + + def add_to_scene(self, video, frame_idx): + if self.labels is None: return + if video not in self.labels.negative_anchors: return + + anchors = self.labels.negative_anchors[video] + for idx, x, y in anchors: + if frame_idx == idx: + self._add(x,y) + + def _add(self, x, y): + self.scene.addLine(x-self.line_len, y-self.line_len, x+self.line_len, y+self.line_len, self.pen) + self.scene.addLine(x+self.line_len, y-self.line_len, x-self.line_len, y+self.line_len, self.pen) \ No newline at end of file diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index f5efeadff..0a73134b6 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -378,7 +378,7 @@ def instance_count(self, video: Video, frame_idx: int) -> int: count = len([inst for inst in labeled_frame.instances if type(inst)==Instance]) return count - + @property def all_instances(self): return list(self.instances()) @@ -529,7 +529,7 @@ def find_track_instances(self, *args, **kwargs) -> List[Instance]: return [inst for lf, inst in self.find_track_occupancy(*args, **kwargs)] # Methods for suggestions - + def get_video_suggestions(self, video:Video) -> list: """ Returns the list of suggested frames for the specified video @@ -645,6 +645,22 @@ def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): self.negative_anchors[video] = [] self.negative_anchors[video].append((frame_idx, *where)) + def remove_negative_anchors(self, video:Video, frame_idx: int): + """Removes negative training samples for given video and frame. + + Args: + video: the `Video` for which we're removing negative samples + frame_idx: frame index + Returns: + None + """ + if video not in self.negative_anchors: return + + anchors = [(idx, x, y) + for idx, x, y in self.negative_anchors[video] + if idx != frame_idx] + self.negative_anchors[video] = anchors + # Methods for saving/loading def extend_from(self, new_frames: Union['Labels',List[LabeledFrame]], unify:bool=False): diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 7a0aeb74a..2d60f28b9 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -371,6 +371,18 @@ def test_suggestions(small_robot_mp4_vid): assert len(labels.get_video_suggestions(dummy_video)) == 13 +def test_negative_anchors(): + video = Video.from_filename("foo.mp4") + labels = Labels() + + labels.add_negative_anchor(video, 1, (3, 4)) + labels.add_negative_anchor(video, 1, (7, 8)) + labels.add_negative_anchor(video, 2, (5, 9)) + + assert len(labels.negative_anchors[video]) == 3 + + labels.remove_negative_anchors(video, 1) + assert len(labels.negative_anchors[video]) == 1 def test_load_labels_mat(mat_labels): assert len(mat_labels.nodes) == 6 From 5bcba24526eb0f05716449dbbcf53e34dfa7e288 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 17 Sep 2019 21:03:22 -0400 Subject: [PATCH 065/176] Overlay to show track labels (when appropriate) Overlay shows when user has instance selected and is holding down control (so they could be setting track from keyboard shortcut). --- sleap/gui/app.py | 17 +++++----- sleap/gui/overlays/tracks.py | 60 ++++++++++++++++++++++++++++++++++-- sleap/gui/video.py | 2 +- 3 files changed, 68 insertions(+), 11 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 3b65479b7..700f86b4c 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -34,7 +34,7 @@ from sleap.gui.shortcuts import Shortcuts, ShortcutDialog from sleap.gui.suggestions import VideoFrameSuggestions -from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay +from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay, TrackListOverlay from sleap.gui.overlays.instance import InstanceOverlay from sleap.gui.overlays.anchors import NegativeAnchorOverlay @@ -416,6 +416,11 @@ def update_instance_table_selection(): self.update_gui_timer.start(0.1) def load_overlays(self): + self.overlays["track_labels"] = TrackListOverlay( + labels = self.labels, + view = self.player.view, + color_manager = self._color_manager) + self.overlays["negative"] = NegativeAnchorOverlay( labels = self.labels, scene = self.player.view.scene) @@ -441,6 +446,7 @@ def update_gui_state(self): # todo: exclude predicted instances from count has_nodes_selected = (self.skeletonEdgesSrc.currentIndex() > -1 and self.skeletonEdgesDst.currentIndex() > -1) + control_key_down = QApplication.queryKeyboardModifiers() == Qt.ControlModifier # Update menus @@ -472,6 +478,9 @@ def update_gui_state(self): self._buttons["remove video"].setEnabled(self.videosTable.currentIndex().isValid()) self._buttons["delete instance"].setEnabled(self.instancesTable.currentIndex().isValid()) + # Update overlays + self.overlays["track_labels"].visible = control_key_down and has_selected_instance + def update_data_views(self, *update): update = update or ("video", "skeleton", "labels", "frame", "suggestions") @@ -507,12 +516,6 @@ def update_data_views(self, *update): suggestion_status_text = f"{labeled_count}/{len(suggestion_list)} labeled" self.suggested_count_label.setText(suggestion_status_text) - def keyPressEvent(self, event: QKeyEvent): - if event.key() == Qt.Key_Q: - self.close() - else: - event.ignore() # Kicks the event up to parent - def plotFrame(self, *args, **kwargs): """Wrap call to player.plot so we can redraw/update things.""" if self.video is None: return diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index e968d3783..b698e3d0d 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -163,7 +163,7 @@ class TrackTrailOverlay: color_manager: TrackColorManager=TrackColorManager(labels) trail_length: int=4 show: bool=False - + def get_track_trails(self, frame_selection, track: Track): """Get data needed to draw track trail. @@ -210,7 +210,7 @@ def get_frame_selection(self, video: Video, frame_idx: int): def get_tracks_in_frame(self, video: Video, frame_idx: int): """Return list of tracks that have instance in specified frame.""" - + tracks_in_frame = [inst.track for lf in self.labels.find(video, frame_idx) for inst in lf] return tracks_in_frame @@ -251,4 +251,58 @@ def add_to_scene(self, video: Video, frame_idx: int): @staticmethod def map_to_qt_polygon(point_list): """Converts a list of (x, y)-tuples to a `QPolygonF`.""" - return QtGui.QPolygonF(list(itertools.starmap(QtCore.QPointF, point_list))) \ No newline at end of file + return QtGui.QPolygonF(list(itertools.starmap(QtCore.QPointF, point_list))) + + +@attr.s(auto_attribs=True) +class TrackListOverlay: + """Class to show track number and names in overlay. + """ + + labels: Labels=None + view: QtWidgets.QGraphicsView=None + color_manager: TrackColorManager=TrackColorManager(labels) + text_box = None + + def add_to_scene(self, video: Video, frame_idx: int): + from sleap.gui.video import QtTextWithBackground + + html = "" + num_to_show = min(9, len(self.labels.tracks)) + + for i, track in enumerate(self.labels.tracks[:num_to_show]): + idx = i+1 + + if html: html += "
" + color = self.color_manager.get_color(track) + html_color = f"#{color[0]:02X}{color[1]:02X}{color[2]:02X}" + track_text = f"{track.name}" + if str(idx) != track.name: + track_text += f" ({idx})" + html += f"{track_text}" + + text_box = QtTextWithBackground() + text_box.setDefaultTextColor(QtGui.QColor("white")) + text_box.setHtml(html) + text_box.setOpacity(.7) + + self.text_box = text_box + self.visible = False + + self.view.scene.addItem(self.text_box) + + @property + def visible(self): + if self.text_box is None: return False + return self.text_box.isVisible() + + @visible.setter + def visible(self, val): + if self.text_box is None: return + if val: + pos = self.view.mapToScene(10, 10) + if pos.x() > 0: + self.text_box.setPos(pos) + else: + self.text_box.setPos(10, 10) + self.text_box.setVisible(val) \ No newline at end of file diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 58ff0daa5..abd430bda 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -1175,7 +1175,6 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, self.track_label = QtTextWithBackground(parent=self) self.track_label.setDefaultTextColor(QColor(*self.color)) - self.track_label.setFlag(QGraphicsItem.ItemIgnoresTransformations) instance_label_text = "" if self.instance.track is not None: @@ -1343,6 +1342,7 @@ class QtTextWithBackground(QGraphicsTextItem): def __init__(self, *args, **kwargs): super(QtTextWithBackground, self).__init__(*args, **kwargs) + self.setFlag(QGraphicsItem.ItemIgnoresTransformations) def boundingRect(self): """ Method required by Qt. From e2db6baf9e0d65e9b13197ee09c6c3d49d8f260f Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 18 Sep 2019 09:24:06 -0400 Subject: [PATCH 066/176] Autodetect initial video settings in import gui. --- sleap/gui/importvideos.py | 70 ++++++++++++++++++++++++++++----------- tests/gui/test_import.py | 9 ++++- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/sleap/gui/importvideos.py b/sleap/gui/importvideos.py index 3ba51fa99..dce9cfa11 100644 --- a/sleap/gui/importvideos.py +++ b/sleap/gui/importvideos.py @@ -85,7 +85,8 @@ def __init__(self, file_names:list, *args, **kwargs): { "name": "dataset", "type": "function_menu", - "options": "_get_h5_dataset_options" + "options": "_get_h5_dataset_options", + "required": True }, { "name": "input_format", @@ -225,7 +226,7 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs): self.setFrameStyle(QFrame.Panel) self.options_widget.changed.connect(self.update_video) - self.update_video() + self.update_video(initial=True) def is_enabled(self): """Am I enabled? @@ -252,27 +253,34 @@ def get_data(self) -> dict: } return video_data - def update_video(self): + def update_video(self, initial=False): """Update preview video using current param values. - + + Args: + initial: if True, then get video settings that are used by + the `Video` object when they aren't specified as params Returns: None. """ - - video_params = self.options_widget.get_values() + + video_params = self.options_widget.get_values(only_required=initial) + try: if self.import_type["video_class"] is not None: self.video = self.import_type["video_class"](**video_params) else: self.video = None - + self.preview_widget.load_video(self.video) except Exception as e: - print(e) + print(f"Unable to load video with these parameters. Error: {e}") # if we got an error showing video with those settings, clear the video preview self.video = None self.preview_widget.clear_video() + if initial and self.video is not None: + self.options_widget.set_values_from_video(self.video) + def boundingRect(self) -> QRectF: """Method required by Qt.""" return QRectF() @@ -346,11 +354,12 @@ def make_layout(self) -> QLayout: self.widget_elements = widget_elements return widget_layout - def get_values(self): + def get_values(self, only_required=False): """Method to get current user-selected values for import parameters. Args: - None. + only_required: Only return the parameters that are required + for instantiating `Video` object Returns: Dict of param keys/values. @@ -366,16 +375,39 @@ def get_values(self): for param_item in param_list: name = param_item["name"] type = param_item["type"] - value = None - if type == "radio": - value = self.widget_elements[name].checkedButton().text() - elif type == "check": - value = self.widget_elements[name].isChecked() - elif type == "function_menu": - value = self.widget_elements[name].currentText() - param_values[name] = value + is_required = param_item.get("required", False) + + if not only_required or is_required: + value = None + if type == "radio": + value = self.widget_elements[name].checkedButton().text() + elif type == "check": + value = self.widget_elements[name].isChecked() + elif type == "function_menu": + value = self.widget_elements[name].currentText() + param_values[name] = value return param_values - + + def set_values_from_video(self, video): + """Set the form fields using attributes on video.""" + param_list = self.import_type["params"] + for param in param_list: + name = param["name"] + type = param["type"] + print(name,type) + if hasattr(video, name): + val = getattr(video, name) + print(name,val) + widget = self.widget_elements[name] + if hasattr(widget, "isChecked"): + widget.setChecked(val) + elif hasattr(widget, "value"): + widget.setValue(val) + elif hasattr(widget, "currentText"): + widget.setCurrentText(str(val)) + elif hasattr(widget, "text"): + widget.setText(str(val)) + def _get_h5_dataset_options(self) -> list: """Method to get a list of all datasets in hdf5 file. diff --git a/tests/gui/test_import.py b/tests/gui/test_import.py index 1ed473338..8fd84f1e8 100644 --- a/tests/gui/test_import.py +++ b/tests/gui/test_import.py @@ -31,4 +31,11 @@ def test_gui_import(qtbot): qtbot.mouseClick(btn, QtCore.Qt.LeftButton) assert import_item.is_enabled() - assert len(importer.get_data()) == 2 \ No newline at end of file + assert len(importer.get_data()) == 2 + +def test_video_import_detect_params(): + importer = ImportParamDialog(["tests/data/videos/centered_pair_small.mp4", "tests/data/videos/small_robot.mp4"]) + data = importer.get_data() + + assert data[0]["params"]["grayscale"] == True + assert data[1]["params"]["grayscale"] == False From 3ba5e902c9482c37ce5e877db1a1bab784874171 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 09:45:40 -0400 Subject: [PATCH 067/176] add save_file method, detects format from name --- sleap/io/dataset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 0a73134b6..62dd3e95c 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1307,6 +1307,7 @@ def load_hdf5(cls, filename: str, @classmethod def load_file(cls, filename: str, *args, **kwargs): + """Load file, detecting format from filename.""" if filename.endswith((".h5", ".hdf5")): return cls.load_hdf5(filename, *args, **kwargs) elif filename.endswith((".json", ".json.zip")): @@ -1319,6 +1320,18 @@ def load_file(cls, filename: str, *args, **kwargs): else: raise ValueError(f"Cannot detect filetype for {filename}") + @classmethod + def save_file(cls, labels: 'Labels', filename: str, *args, **kwargs): + """Save file, detecting format from filename.""" + if filename.endswith((".json", ".zip")): + compress = filename.endswith(".zip") + cls.save_json(labels = labels, filename = filename, + compress = compress) + elif filename.endswith(".h5"): + cls.save_hdf5(labels = labels, filename = filename) + else: + raise ValueError(f"Cannot detect filetype for {filename}") + def save_frame_data_imgstore(self, output_dir: str = './', format: str = 'png', all_labels: bool = False): """ Write all labeled frames from all videos to a collection of imgstore datasets. From 016e7bbf032fcd659d132845df5bfdbd2f6c9164 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 09:46:16 -0400 Subject: [PATCH 068/176] catch errors during save and notify user --- sleap/gui/app.py | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index aa84f71dc..764f6f058 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1357,17 +1357,7 @@ def saveProject(self): if self.filename is not None: filename = self.filename - if filename.endswith((".json", ".json.zip")): - compress = filename.endswith(".zip") - Labels.save_json(labels = self.labels, filename = filename, - compress = compress) - elif filename.endswith(".h5"): - Labels.save_hdf5(labels = self.labels, filename = filename) - - # Mark savepoint in change stack - self.changestack_savepoint() - # Redraw. Not sure why, but sometimes we need to do this. - self.plotFrame() + self._trySave(self.filename) else: # No filename (must be new project), so treat as "Save as" self.saveProjectAs() @@ -1385,23 +1375,27 @@ def saveProjectAs(self): if len(filename) == 0: return - if filename.endswith((".json", ".zip")): - compress = filename.endswith(".zip") - Labels.save_json(labels = self.labels, filename = filename, compress = compress) - self.filename = filename - # Mark savepoint in change stack - self.changestack_savepoint() - # Redraw. Not sure why, but sometimes we need to do this. - self.plotFrame() - elif filename.endswith(".h5"): - Labels.save_hdf5(labels = self.labels, filename = filename) + if self._trySave(filename): + # If save was successful self.filename = filename + + def _trySave(self, filename): + success = False + try: + Labels.save_file(labels = self.labels, filename = filename) + success = True # Mark savepoint in change stack self.changestack_savepoint() - # Redraw. Not sure why, but sometimes we need to do this. - self.plotFrame() - else: - QMessageBox(text=f"File not saved. Try saving as json.").exec_() + + except Exception as e: + message = f"An error occured when attempting to save:\n {e}\n\n" + message += "Try saving your project with a different filename or in a different format." + QtWidgets.QMessageBox(text=message).exec_() + + # Redraw. Not sure why, but sometimes we need to do this. + self.plotFrame() + + return success def closeEvent(self, event): if not self.changestack_has_changes(): From 4a05b45a8e077246d52ffcd3a6682c9ddcdfe66f Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 11:04:57 -0400 Subject: [PATCH 069/176] ignore empty tracks for track occupancy h5 --- sleap/info/write_tracking_h5.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 143d87083..ebf487c5c 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -67,7 +67,17 @@ def video_callback(video_list, new_paths=[os.path.dirname(args.data_path)]): inst_points = inst.visible_points_array prediction_matrix[frame_i, ..., track_i] = inst_points - + + occupied_track_mask = np.sum(occupancy_matrix, axis=1) > 0 +# print(track_names[occupied_track_mask]) + + # Ignore unoccupied tracks + if(np.sum(~occupied_track_mask)): + print(f"ignoring {np.sum(~occupied_track_mask)} empty tracks") + occupancy_matrix = occupancy_matrix[occupied_track_mask] + prediction_matrix = prediction_matrix[...,occupied_track_mask] + track_names = [track_names[i] for i in range(len(track_names)) if occupied_track_mask[i]] + print(f"track_occupancy: {occupancy_matrix.shape}") print(f"tracks: {prediction_matrix.shape}") From d57d1da720ae44a632883cb033d4d2271d8f89c4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 12:33:21 -0400 Subject: [PATCH 070/176] keep w/in array bounds (when user adding new profile) --- sleap/gui/active.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/sleap/gui/active.py b/sleap/gui/active.py index 0aef11c8b..21a51924c 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -217,16 +217,18 @@ def update_gui(self): conf_job, _ = self._get_current_job(ModelOutputType.CONFIDENCE_MAP) paf_job, _ = self._get_current_job(ModelOutputType.PART_AFFINITY_FIELD) - if conf_job.trainer.scale != paf_job.trainer.scale: - can_run = False - error_messages.append(f"training image scale for confmaps ({conf_job.trainer.scale}) does not match pafs ({paf_job.trainer.scale})") - if conf_job.trainer.instance_crop != paf_job.trainer.instance_crop: - can_run = False - crop_model_name = "confmaps" if conf_job.trainer.instance_crop else "pafs" - error_messages.append(f"exactly one model ({crop_model_name}) was trained on crops") - if use_centroids and not conf_job.trainer.instance_crop: - can_run = False - error_messages.append(f"models used with centroids must be trained on cropped images") + # only check compatible if we have both profiles + if conf_job is not None and paf_job is not None: + if conf_job.trainer.scale != paf_job.trainer.scale: + can_run = False + error_messages.append(f"training image scale for confmaps ({conf_job.trainer.scale}) does not match pafs ({paf_job.trainer.scale})") + if conf_job.trainer.instance_crop != paf_job.trainer.instance_crop: + can_run = False + crop_model_name = "confmaps" if conf_job.trainer.instance_crop else "pafs" + error_messages.append(f"exactly one model ({crop_model_name}) was trained on crops") + if use_centroids and not conf_job.trainer.instance_crop: + can_run = False + error_messages.append(f"models used with centroids must be trained on cropped images") message = "" if not can_run: @@ -243,6 +245,11 @@ def _get_current_job(self, model_type): field = self.training_profile_widgets[model_type] idx = field.currentIndex() + # Check that selection corresponds to something we're loaded + # (it won't when user is adding a new profile) + if idx >= len(self.job_options[model_type]): + return None, None + job_filename, job = self.job_options[model_type][idx] if model_type == ModelOutputType.CENTROIDS: From 59ea43a4cef6e7b21900b51c61e7fb40fd9e79d3 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 14:04:28 -0400 Subject: [PATCH 071/176] add start property to RangeList --- sleap/rangelist.py | 6 ++++++ tests/test_rangelist.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/sleap/rangelist.py b/sleap/rangelist.py index 4c5e0701d..a9a29bb70 100644 --- a/sleap/rangelist.py +++ b/sleap/rangelist.py @@ -28,6 +28,12 @@ def is_empty(self): """Returns True if the list is empty.""" return len(self.list) == 0 + @property + def start(self): + """Returns the start value of range (or None if empty).""" + if self.is_empty: return None + return self.list[0][0] + def add(self, val, tolerance=0): """Adds a single value, merges to last range if contiguous.""" if len(self.list) and self.list[-1][1] + tolerance >= val: diff --git a/tests/test_rangelist.py b/tests/test_rangelist.py index 3f813362c..518049509 100644 --- a/tests/test_rangelist.py +++ b/tests/test_rangelist.py @@ -11,6 +11,10 @@ def test_rangelist(): a.remove((5,8)) assert a.list == [(1, 2), (3, 5), (8, 20), (50, 100)] + + assert a.start == 1 + a.remove((1,3)) + assert a.start == 3 b = RangeList() b.add(1) From 89ea959bee4a3b9ce6694384ed0d93354d08dc3c Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 14:05:21 -0400 Subject: [PATCH 072/176] use track_range.start to next find spawn frame --- sleap/gui/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 764f6f058..a30c35d40 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1530,8 +1530,8 @@ def nextSuggestedFrame(self, seek_direction=1): def nextTrackFrame(self): cur_idx = self.player.frame_idx - video_tracks = {inst.track for lf in self.labels.find(self.video) for inst in lf if inst.track is not None} - next_idx = min([track.spawned_on for track in video_tracks if track.spawned_on > cur_idx], default=-1) + track_ranges = self.labels.get_track_occupany(self.video) + next_idx = min([track_range.start for track_range in track_ranges.values() if track_range.start > cur_idx], default=-1) if next_idx > -1: self.plotFrame(next_idx) From 416392c50d7712caab8ee29ff752c4132d339d2c Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 15:24:27 -0400 Subject: [PATCH 073/176] show labeled frame and user frame counts --- sleap/info/labels.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sleap/info/labels.py b/sleap/info/labels.py index 775725a75..20c169a6f 100644 --- a/sleap/info/labels.py +++ b/sleap/info/labels.py @@ -17,6 +17,8 @@ print(f"Video files:") + total_user_frames = 0 + for vid in labels.videos: lfs = labels.find(vid) @@ -25,9 +27,15 @@ tracks = {inst.track for lf in lfs for inst in lf} concurrent_count = max((len(lf.instances) for lf in lfs)) + user_frames = len(labels.get_video_user_labeled_frames(vid)) + + total_user_frames += user_frames print(f" {vid.filename}") - print(f" labeled from {first_idx} to {last_idx}") + print(f" labeled frames from {first_idx} to {last_idx}") + print(f" labeled frames: {len(lfs)}") + print(f" user labeled frames: {user_frames}") print(f" tracks: {len(tracks)}") print(f" max instances in frame: {concurrent_count}") + print(f"Total user labeled frames: {total_user_frames}") \ No newline at end of file From 5a89781cc9be6b36a26bf09e43ae9a951be17696 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 15:27:02 -0400 Subject: [PATCH 074/176] remove duplicate instances when merging frames --- sleap/instance.py | 19 +++++++++++++++++-- tests/io/test_dataset.py | 10 +++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index f5faf7e2a..eb7392781 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -880,7 +880,8 @@ def instances_to_show(self): return inst_to_show @staticmethod - def merge_frames(labeled_frames, video): + def merge_frames(labeled_frames, video, remove_redundant=True): + redundant_count = 0 frames_found = dict() # move instances into first frame with matching frame_idx for idx, lf in enumerate(labeled_frames): @@ -888,7 +889,19 @@ def merge_frames(labeled_frames, video): if lf.frame_idx in frames_found.keys(): # move instances dst_idx = frames_found[lf.frame_idx] - labeled_frames[dst_idx].instances.extend(lf.instances) + if remove_redundant: + for new_inst in lf.instances: + redundant = False + for old_inst in labeled_frames[dst_idx].instances: + if new_inst.matches(old_inst): + redundant = True + if not hasattr(new_inst, "score"): + redundant_count += 1 + break + if not redundant: + labeled_frames[dst_idx].instances.append(new_inst) + else: + labeled_frames[dst_idx].instances.extend(lf.instances) lf.instances = [] else: # note first lf with this frame_idx @@ -896,5 +909,7 @@ def merge_frames(labeled_frames, video): # remove labeled frames with no instances labeled_frames = list(filter(lambda lf: len(lf.instances), labeled_frames)) + if redundant_count: + print(f"skipped {redundant_count} redundant instances") return labeled_frames diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 2d60f28b9..2daaf661c 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -260,7 +260,15 @@ def test_label_mutability(): labels.remove_video(dummy_video) assert len(labels.find(dummy_video)) == 0 - dummy_frames3 = [LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance,]) for _ in range(10)] + dummy_frames3 = [] + dummy_skeleton.add_node("node") + + # Add 10 instances with different points (so they aren't "redundant") + for i in range(10): + instance = Instance(skeleton=dummy_skeleton, points=dict(node=Point(i,i))) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[instance,]) + dummy_frames3.append(dummy_frame) + labels.labeled_frames.extend(dummy_frames3) assert len(labels) == 10 assert len(labels.labeled_frames[0].instances) == 1 From 6d3959a2e22c338b99bfc3433a2110c085b3402c Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 19 Sep 2019 16:03:48 -0400 Subject: [PATCH 075/176] rename and move Import... menu command Moved from Predict -> Import Predictions to File -> Import Labels since it doesn't just import predicted data. --- sleap/gui/app.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index a30c35d40..edab2ea0b 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -146,6 +146,7 @@ def initialize_gui(self): fileMenu = self.menuBar().addMenu("File") self._menu_actions["new"] = fileMenu.addAction("New Project", self.newProject, shortcuts["new"]) self._menu_actions["open"] = fileMenu.addAction("Open Project...", self.openProject, shortcuts["open"]) + self._menu_actions["import predictions"] = fileMenu.addAction("Import Labels...", self.importPredictions) fileMenu.addSeparator() self._menu_actions["add videos"] = fileMenu.addAction("Add Videos...", self.addVideo, shortcuts["add videos"]) fileMenu.addSeparator() @@ -249,7 +250,6 @@ def initialize_gui(self): self._menu_actions["clear negative samples"] = predictionMenu.addAction("Clear Current Frame Negative Samples", self.clearFrameNegativeAnchors) predictionMenu.addSeparator() self._menu_actions["visualize models"] = predictionMenu.addAction("Visualize Model Outputs...", self.visualizeOutputs) - self._menu_actions["import predictions"] = predictionMenu.addAction("Import Predictions...", self.importPredictions) predictionMenu.addSeparator() self._menu_actions["remove predictions"] = predictionMenu.addAction("Delete All Predictions...", self.deletePredictions) self._menu_actions["remove clip predictions"] = predictionMenu.addAction("Delete Predictions from Clip...", self.deleteClipPredictions, shortcuts["delete clip"]) @@ -1060,14 +1060,7 @@ def importPredictions(self): gui_video_callback = Labels.make_gui_video_callback( search_paths=[os.path.dirname(filename)]) - if filename.endswith((".h5", ".hdf5")): - new_labels = Labels.load_hdf5( - filename, - match_to=self.labels, - video_callback=gui_video_callback) - - elif filename.endswith((".json", ".json.zip")): - new_labels = Labels.load_json( + new_labels = Labels.load_file( filename, match_to=self.labels, video_callback=gui_video_callback) From 73d56fe90c6975d76191712a66b2962ed954454f Mon Sep 17 00:00:00 2001 From: Talmo Date: Thu, 19 Sep 2019 20:29:38 -0400 Subject: [PATCH 076/176] Small inference fixes in prediction --- sleap/nn/inference.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 4701c859b..5c3f95ede 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -156,7 +156,7 @@ def predict(self, # Open the video if we need it. try: - input_video.get_frame(0) + input_video.get_frame(frames[0]) vid = input_video except AttributeError: if isinstance(input_video, dict): @@ -183,16 +183,17 @@ def predict(self, raise ValueError("Predictor has no model.") # Initialize tracking - tracker = FlowShiftTracker(window=self.flow_window, verbosity=0) + if self.with_tracking: + tracker = FlowShiftTracker(window=self.flow_window, verbosity=0) - # Create output directory if it doesn't exist - try: - os.mkdir(os.path.dirname(self.output_path)) - except FileExistsError: - pass - # Delete the output file if it exists already - if os.path.exists(self.output_path): - os.unlink(self.output_path) + if self.output_path: + # Delete the output file if it exists already + if os.path.exists(self.output_path): + os.unlink(self.output_path) + + # Create output directory if it doesn't exist + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) # Process chunk-by-chunk! t0_start = time() From 8431428ea6a84f2a5287d838499e74cca7b0b943 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 20 Sep 2019 09:08:07 -0400 Subject: [PATCH 077/176] Defer loading imgstore until accessing video data. --- sleap/io/video.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index 5a34396c9..fbd697cbd 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -343,6 +343,8 @@ class ImgStoreVideo: filename: str = attr.ib(default=None) index_by_original: bool = attr.ib(default=True) + _store_ = None + _img_ = None def __attrs_post_init__(self): @@ -358,7 +360,6 @@ def __attrs_post_init__(self): self.filename = os.path.abspath(self.filename) self.__store = None - self.open() # The properties and methods below complete our contract with the # higher level Video interface. @@ -375,6 +376,22 @@ def matches(self, other): """ return self.filename == other.filename and self.index_by_original == other.index_by_original + @property + def __store(self): + if self._store_ is None: + self.open() + return self._store_ + + @__store.setter + def __store(self, val): + self._store_ = val + + @property + def __img(self): + if self._img_ is None: + self.open() + return self._img_ + @property def frames(self): return self.__store.frame_count @@ -445,12 +462,12 @@ def open(self): Returns: None """ - if not self.imgstore: + if not self._store_: # Open the imgstore - self.__store = imgstore.new_for_filename(self.filename) + self._store_ = imgstore.new_for_filename(self.filename) # Read a frame so we can compute shape an such - self.__img, (frame_number, frame_timestamp) = self.__store.get_next_image() + self._img_, (frame_number, frame_timestamp) = self._store_.get_next_image() def close(self): """ From b5acacfb350afa062999821877a07e4b470f2b87 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 20 Sep 2019 09:45:16 -0400 Subject: [PATCH 078/176] Use weak_filename_match() to match video paths. This checks the last three parts of path, after changing \ to / and removing pid from tmp_pid_... directory names. --- sleap/io/dataset.py | 5 ++--- sleap/util.py | 15 ++++++++++++++- tests/test_util.py | 15 +++++++++++++-- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 62dd3e95c..7c2b96667 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -40,7 +40,7 @@ make_instance_cattr, PointArray, PredictedPointArray from sleap.rangelist import RangeList from sleap.io.video import Video -from sleap.util import uniquify +from sleap.util import uniquify, weak_filename_match def json_loads(json_str: str): @@ -907,8 +907,7 @@ def from_json(cls, data: Union[str, dict], match_to: Optional['Labels'] = None) for idx, vid in enumerate(videos): for old_vid in match_to.videos: # compare last three parts of path - weak_match = vid.filename.split("/")[-3:] == old_vid.filename.split("/")[-3:] - if vid.filename == old_vid.filename or weak_match: + if vid.filename == old_vid.filename or weak_filename_match(vid.filename, old_vid.filename): # use video from match videos[idx] = old_vid break diff --git a/sleap/util.py b/sleap/util.py index 289c33ff2..d333d91ff 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -3,6 +3,7 @@ unless they really have no other place. """ import os +import re import h5py as h5 import numpy as np @@ -115,4 +116,16 @@ def uniquify(seq): # Raymond Hettinger # https://twitter.com/raymondh/status/944125570534621185 - return list(dict.fromkeys(seq)) \ No newline at end of file + return list(dict.fromkeys(seq)) + +def weak_filename_match(filename_a, filename_b): + """Check if paths probably point to same file.""" + filename_a = filename_a.replace("\\","/") + filename_b = filename_b.replace("\\","/") + + # remove unique pid so we can match tmp directories for same zip + filename_a = re.sub("/tmp_\d+_", "tmp_", filename_a) + filename_b = re.sub("/tmp_\d+_", "tmp_", filename_b) + + # check if last three parts of path match + return filename_a.split("/")[-3:] == filename_b.split("/")[-3:] \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py index cbc6e72b7..f17429f0f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,7 @@ from typing import List, Dict -from sleap.util import attr_to_dtype, frame_list +from sleap.util import attr_to_dtype, frame_list, weak_filename_match def test_attr_to_dtype(): """ @@ -43,4 +43,15 @@ class TestAttr3: def test_frame_list(): assert frame_list("3-5") == [3,4,5] - assert frame_list("7,10") == [7,10] \ No newline at end of file + assert frame_list("7,10") == [7,10] + +def test_weak_match(): + assert weak_filename_match("one/two", "one/two") + assert weak_filename_match( + "M:\\code\\sandbox\\sleap_nas\\pilot_6pts\\tmp_11576_FoxP1_6pts.training.n=468.json.zip\\frame_data_vid0\\metadata.yaml", + "D:\\projects\\code\\sandbox\\sleap_nas\\pilot_6pts\\tmp_99713_FoxP1_6pts.training.n=468.json.zip\\frame_data_vid0\\metadata.yaml") + assert weak_filename_match("zero/one/two/three.mp4","other\\one\\two\\three.mp4") + + assert not weak_filename_match("one/two/three", "two/three") + assert not weak_filename_match("one/two/three.mp4","one/two/three.avi") + assert not weak_filename_match("foo.mp4","bar.mp4") From bca244d80d06e4a6dac8883f0e25a913bcc94310 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 20 Sep 2019 13:31:38 -0400 Subject: [PATCH 079/176] Include empty frames in track occupancy matrix The matrix will now include a column for each frame index between the first and last labeled frames (so empty frames at beginning and end of video will still not be included). --- sleap/info/write_tracking_h5.py | 34 +++------------------------------ 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index ebf487c5c..0043911b6 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -12,44 +12,16 @@ parser.add_argument("data_path", help="Path to labels json file") args = parser.parse_args() - def video_callback(video_list, new_paths=[os.path.dirname(args.data_path)]): - # Check each video - for video_item in video_list: - if "backend" in video_item and "filename" in video_item["backend"]: - current_filename = video_item["backend"]["filename"] - # check if we can find video - if not os.path.exists(current_filename): - is_found = False - - current_basename = os.path.basename(current_filename) - # handle unix, windows, or mixed paths - if current_basename.find("/") > -1: - current_basename = current_basename.split("/")[-1] - if current_basename.find("\\") > -1: - current_basename = current_basename.split("\\")[-1] - - # First see if we can find the file in another directory, - # and if not, prompt the user to find the file. - - # We'll check in the current working directory, and if the user has - # already found any missing videos, check in the directory of those. - for path_dir in new_paths: - check_path = os.path.join(path_dir, current_basename) - if os.path.exists(check_path): - # we found the file in a different directory - video_item["backend"]["filename"] = check_path - is_found = True - break - + video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) labels = Labels.load_file(args.data_path, video_callback=video_callback) - frame_count = len(labels) track_count = len(labels.tracks) track_names = [np.string_(track.name) for track in labels.tracks] node_count = len(labels.skeletons[0].nodes) frame_idxs = [lf.frame_idx for lf in labels] frame_idxs.sort() + frame_count = frame_idxs[-1] - frame_idxs[0] + 1 # count should include unlabeled frames # Desired MATLAB format: # "track_occupancy" tracks * frames @@ -60,7 +32,7 @@ def video_callback(video_list, new_paths=[os.path.dirname(args.data_path)]): prediction_matrix = np.full((frame_count, node_count, 2, track_count), np.nan, dtype=float) for lf, inst in [(lf, inst) for lf in labels for inst in lf.instances]: - frame_i = frame_idxs.index(lf.frame_idx) + frame_i = lf.frame_idx - frame_idxs[0] track_i = labels.tracks.index(inst.track) occupancy_matrix[track_i, frame_i] = 1 From 0295c51fe8c717c9fd6957cf8129f06a9a017994 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 20 Sep 2019 15:24:35 -0400 Subject: [PATCH 080/176] Keep list order when merging in new videos --- sleap/io/dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 7c2b96667..d3894d303 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -127,7 +127,11 @@ def _update_from_labels(self, merge=False): # Add any videos that are present in the labels but # missing from the video list if merge or len(self.videos) == 0: - self.videos = list(set(self.videos).union({label.video for label in self.labels})) + # find videos in labeled frames that aren't yet in top level videos + new_videos = {label.video for label in self.labels} - set(self.videos) + # just add the new videos so we don't re-order current list + if len(new_videos): + self.videos.extend(list(new_videos)) # Ditto for skeletons if merge or len(self.skeletons) == 0: From be660ef0d4e06949acaa5702d2d659d417e87b6b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 20 Sep 2019 16:45:20 -0400 Subject: [PATCH 081/176] Check track_range not None before comparing (bug fix) --- sleap/gui/app.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index edab2ea0b..64d7d1a36 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1524,7 +1524,11 @@ def nextSuggestedFrame(self, seek_direction=1): def nextTrackFrame(self): cur_idx = self.player.frame_idx track_ranges = self.labels.get_track_occupany(self.video) - next_idx = min([track_range.start for track_range in track_ranges.values() if track_range.start > cur_idx], default=-1) + next_idx = min([track_range.start + for track_range in track_ranges.values() + if track_range.start is not None + and track_range.start > cur_idx], + default=-1) if next_idx > -1: self.plotFrame(next_idx) From 8d9f18e4c2a1392edc0e07e4208436cd0dc624ca Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 20 Sep 2019 17:11:32 -0400 Subject: [PATCH 082/176] Only load defered imgstore once (bug fix) --- sleap/io/video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index fbd697cbd..5b8dde6c2 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -429,7 +429,7 @@ def get_frame(self, frame_number) -> np.ndarray: """ # Check if we need to open the imgstore and do it if needed - if not self.imgstore: + if not self._store_: self.open() if self.index_by_original: From b10fc502d1656548b2207425a1649d8002f9c0f3 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 20 Sep 2019 17:27:02 -0400 Subject: [PATCH 083/176] Add dummy frame to otherwise empty imgstore. --- sleap/io/video.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sleap/io/video.py b/sleap/io/video.py index 5b8dde6c2..f3f5df94e 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -761,6 +761,11 @@ def to_imgstore(self, path, for frame_num in frame_numbers: store.add_image(self.get_frame(frame_num), frame_num, time.time()) + # If there are no frames to save for this video, add a dummy frame + # since we can't save an empty imgstore. + if len(frame_numbers) == 0: + store.add_image(np.zeros((self.shape[1], self.shape[2], self.shape[3])), 0, time.time()) + store.close() # Return an ImgStoreVideo object referencing this new imgstore. From 45585a9c1d97a9a47db22696ccd95690f011ce45 Mon Sep 17 00:00:00 2001 From: Talmo Date: Sun, 22 Sep 2019 23:27:25 -0400 Subject: [PATCH 084/176] Fix stacked hourglass implementation - Added output scale calculation to Model class - Fix cattrs structuring when deserializing models - Add SHG params to training editor - Miscellaneous typos --- sleap/config/training_editor.yaml | 36 ++++++++++++++++----- sleap/io/video.py | 6 ++-- sleap/nn/architectures/__init__.py | 3 +- sleap/nn/architectures/hourglass.py | 36 ++++++++++++--------- sleap/nn/loadmodel.py | 12 ++----- sleap/nn/model.py | 49 +++++++++++++++++++++++++++-- sleap/nn/training.py | 49 ++++++++++++++++------------- sleap/skeleton.py | 3 +- 8 files changed, 135 insertions(+), 59 deletions(-) diff --git a/sleap/config/training_editor.yaml b/sleap/config/training_editor.yaml index 5ac01e6ad..c11f18a17 100644 --- a/sleap/config/training_editor.yaml +++ b/sleap/config/training_editor.yaml @@ -6,6 +6,13 @@ model: options: confmaps,pafs,centroids default: confmaps +- name: arch # backbone_name + label: Architecture + type: list + default: + options: LeapCNN,UNet,StackedHourglass + # options: LeapCNN,UNet,StackedHourglass,StackedUNet + - name: down_blocks label: Down Blocks type: int @@ -35,20 +42,35 @@ model: label: Upsampling Layers type: bool default: False - + - name: interp label: Interpolation type: list default: bilinear options: bilinear -# skeletons? +# stacked model: +- name: num_stacks + label: Stacks + type: int + default: 3 -- name: arch # backbone_name - label: Architecture - type: list - default: - options: LeapCNN,StackedHourglass,UNet,StackedUNet +# - name: batch_norm +# label: Batch norm +# type: bool +# default: True + +# - name: intermediate_inputs +# label: Intermediate inputs +# type: bool +# default: True + +# - name: initial_stride +# label: Initial stride +# type: int +# default: 1 + +# skeletons? datagen: diff --git a/sleap/io/video.py b/sleap/io/video.py index f3f5df94e..ae3df6523 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -9,9 +9,11 @@ import numpy as np import attr import cattr +import logging from typing import Iterable, Union, List +logger = logging.getLogger(__name__) @attr.s(auto_attribs=True, cmp=False) class HDF5Video: @@ -774,7 +776,7 @@ def to_imgstore(self, path, @staticmethod def cattr(): """ - Return a cattr converter for serialiazing\deseriializing Video objects. + Return a cattr converter for serialiazing/deserializing Video objects. Returns: A cattr converter. @@ -842,6 +844,6 @@ def fixup_path(path, raise_error=False) -> str: if raise_error: raise FileNotFoundError(f"Cannot find a video file: {path}") else: - print(f"Cannot find a video file: {path}") + logger.warning(f"Cannot find a video file: {path}") return path diff --git a/sleap/nn/architectures/__init__.py b/sleap/nn/architectures/__init__.py index ba2d97b94..01e9ebacd 100644 --- a/sleap/nn/architectures/__init__.py +++ b/sleap/nn/architectures/__init__.py @@ -4,5 +4,6 @@ # TODO: We can set this up to find all classes under sleap.nn.architectures available_archs = [LeapCNN, UNet, StackedUNet, StackedHourglass] +available_arch_names = [arch.__name__ for arch in available_archs] -__all__ = ['available_archs'] + [arch.__name__ for arch in available_archs] +__all__ = ["available_archs", "available_arch_names"] + [arch.__name__ for arch in available_archs] diff --git a/sleap/nn/architectures/hourglass.py b/sleap/nn/architectures/hourglass.py index 2f33e343a..70881fac8 100644 --- a/sleap/nn/architectures/hourglass.py +++ b/sleap/nn/architectures/hourglass.py @@ -33,15 +33,17 @@ class StackedHourglass: by concatenating with intermediate outputs upsampling_layers: Use upsampling instead of transposed convolutions. interp: Method to use for interpolation when upsampling smaller features. + initial_stride: Stride of first convolution to use for reducing input resolution. """ - num_hourglass_blocks: int = 3 + num_stacks: int = 3 num_filters: int = 32 depth: int = 3 batch_norm: bool = True intermediate_inputs: bool = True upsampling_layers: bool = True interp: str = "bilinear" + initial_stride: int = 1 def output(self, x_in, num_output_channels): @@ -143,8 +145,8 @@ def hourglass_block(x_in, num_output_channels, num_filters, depth=3, batch_norm= return x, x_out -def stacked_hourglass(x_in, num_output_channels, num_hourglass_blocks=3, num_filters=32, depth=3, batch_norm=True, - intermediate_inputs=True, upsampling_layers=True, interp="bilinear"): +def stacked_hourglass(x_in, num_output_channels, num_stacks=3, num_filters=32, depth=3, batch_norm=True, + intermediate_inputs=True, upsampling_layers=True, interp="bilinear", initial_stride=1): """Stacked hourglass block. This function builds and connects multiple hourglass blocks. See `hourglass` for @@ -172,39 +174,43 @@ def stacked_hourglass(x_in, num_output_channels, num_hourglass_blocks=3, num_fil by concatenating with intermediate outputs upsampling_layers: Use upsampling instead of transposed convolutions. interp: Method to use for interpolation when upsampling smaller features. + initial_stride: Stride of first convolution to use for reducing input resolution. Returns: x_outs: List of tf.Tensors of the output of the block of the same width and height as the input with `num_output_channels` channels. """ + + # Expand block-specific parameters if scalars provided + num_filters = expand_to_n(num_filters, num_stacks) + depth = expand_to_n(depth, num_stacks) + batch_norm = expand_to_n(batch_norm, num_stacks) + upsampling_layers = expand_to_n(upsampling_layers, num_stacks) + interp = expand_to_n(interp, num_stacks) + # Initial downsampling - x = conv(num_filters, kernel_size=(7, 7))(x_in) + x = conv(num_filters[0], kernel_size=(7, 7), strides=initial_stride)(x_in) # Batchnorm after the intial down sampling - if batch_norm: + if batch_norm[0]: x = BatchNormalization()(x) - - # Expand block-specific parameters if scalars provided - num_filters = expand_to_n(num_filters, num_hourglass_blocks) - depth = expand_to_n(depth, num_hourglass_blocks) - batch_norm = expand_to_n(batch_norm, num_hourglass_blocks) - upsampling_layers = expand_to_n(upsampling_layers, num_hourglass_blocks) - interp = expand_to_n(interp, num_hourglass_blocks) # Make sure first block gets the right number of channels - x = x_in + # x = x_in if x.shape[-1] != num_filters[0]: x = residual_block(x, num_filters[0], batch_norm[0]) # Create individual hourglasses and collect intermediate outputs + x_in = x x_outs = [] - for i in range(num_hourglass_blocks): + for i in range(num_stacks): if i > 0 and intermediate_inputs: x = Concatenate()([x, x_in]) x = residual_block(x, num_filters[i], batch_norm[i]) - x, x_out = hourglass_block(x, num_output_channels, num_filters[i], depth=depth[i], batch_norm=batch_norm[i], upsampling_layers=upsampling_layers[i], interp=interp[i]) + x, x_out = hourglass_block(x, num_output_channels, num_filters[i], + depth=depth[i], batch_norm=batch_norm[i], upsampling_layers=upsampling_layers[i], interp=interp[i]) x_outs.append(x_out) return x_outs diff --git a/sleap/nn/loadmodel.py b/sleap/nn/loadmodel.py index d65e6efc9..e715de99c 100644 --- a/sleap/nn/loadmodel.py +++ b/sleap/nn/loadmodel.py @@ -87,16 +87,10 @@ def get_model_data( job = sleap_models[model_type] # Model input is scaled by to get output - try: - asym = job.model.backbone.down_blocks - job.model.backbone.up_blocks - multiscale = 1/(2**asym) - except: - multiscale = 1 - model_properties = dict( - skeleton = job.model.skeletons[0], - scale = job.trainer.scale, - multiscale = multiscale) + skeleton=job.model.skeletons[0], + scale=job.trainer.scale, + multiscale=job.model.output_scale) return model_properties diff --git a/sleap/nn/model.py b/sleap/nn/model.py index 2afa979e9..9b2782e89 100644 --- a/sleap/nn/model.py +++ b/sleap/nn/model.py @@ -84,7 +84,7 @@ def __attrs_post_init__(self): if self.backbone_name is None: self.backbone_name = self.backbone.__class__.__name__ - def output(self, input_tesnor, num_output_channels=None): + def output(self, input_tensor, num_output_channels=None): """ Invoke the backbone function with current backbone_args and backbone_kwargs to produce the model backbone block. This is a convenience property for @@ -112,7 +112,7 @@ def output(self, input_tesnor, num_output_channels=None): "Cannot infer num output channels.") - return self.backbone.output(input_tesnor, num_output_channels) + return self.backbone.output(input_tensor, num_output_channels) @property def name(self): @@ -125,3 +125,48 @@ def name(self): """ return self.backbone_name + @property + def output_scale(self): + """Calculates output scale relative to input.""" + + output_scale = 1 + + # TODO: Determine scale within model implementation + if hasattr(self.backbone, "down_blocks") and hasattr(self.backbone, "up_blocks"): + asym = self.backbone.down_blocks - self.backbone.up_blocks + output_scale = 1 / (2 ** asym) + + elif hasattr(self.backbone, "initial_stride"): + output_scale = 1 / self.backbone.initial_stride + + return output_scale + + + @staticmethod + def _structure_model(model_dict, cls): + """Structuring hook for instantiating Model via cattrs. + + This function should be used directly with cattrs as a + structuring hook. It serves the purpose of instantiating + the appropriate backbone class from the string name. + + This is required when backbone classes do not have a + unique attribute name from which to infer the appropriate + class to use. + + Args: + model_dict: Dictionaries containing deserialized Model. + cls: Class to return (not used). + + Returns: + An instantiated Model class with the correct backbone. + + Example: + >> cattr.register_structure_hook(Model, Model.structure_model) + """ + + arch_idx = available_arch_names.index(model_dict["backbone_name"]) + backbone_cls = available_archs[arch_idx] + + return Model(backbone=backbone_cls(**model_dict["backbone"]), + output_type=ModelOutputType(model_dict["output_type"])) diff --git a/sleap/nn/training.py b/sleap/nn/training.py index c400c80ff..648cfcb27 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -223,15 +223,11 @@ def train(self, num_outputs_channels = 1 # Determine input and output sizes - # If there are more downsampling layers than upsampling layers, # then the output (confidence maps or part affinity fields) will # be at a different scale than the input (images). - up_down_diff = model.backbone.down_blocks - model.backbone.up_blocks - output_scale = 1/(2**up_down_diff) - input_img_size = (imgs_train.shape[1], imgs_train.shape[2]) - output_img_size = (input_img_size[0]*output_scale, input_img_size[1]*output_scale) + output_img_size = (int(input_img_size[0] * model.output_scale), int(input_img_size[1] * model.output_scale)) logger.info(f"Training set: {imgs_train.shape} -> {output_img_size}, {num_outputs_channels} channels") logger.info(f"Validation set: {imgs_val.shape} -> {output_img_size}, {num_outputs_channels} channels") @@ -240,8 +236,8 @@ def train(self, img_input = Input((img_height, img_width, img_channels)) # Rectify image sizes not divisible by pooling factor - depth = getattr(model.backbone, 'depth', 0) - depth = depth or getattr(model.backbone, 'down_blocks', 0) + depth = getattr(model.backbone, "depth", 0) + depth = depth or getattr(model.backbone, "down_blocks", 0) if depth: pool_factor = 2 ** depth @@ -258,7 +254,7 @@ def train(self, # Solution: https://www.tensorflow.org/api_docs/python/tf/pad + Lambda layer + corresponding crop at the end? # Instantiate the backbone, this builds the Tensorflow graph - x_outs = model.output(input_tesnor=img_input, num_output_channels=num_outputs_channels) + x_outs = model.output(input_tensor=img_input, num_output_channels=num_outputs_channels) # Create training model by combining the input layer and backbone graph. keras_model = keras.Model(inputs=img_input, outputs=x_outs) @@ -269,7 +265,7 @@ def train(self, elif self.optimizer.lower() == "rmsprop": _optimizer = keras.optimizers.RMSprop(lr=self.learning_rate) else: - raise ValueError(f"Unknown optimizer, value = {optimizer}!") + raise ValueError(f"Unknown optimizer, value = {self.optimizer}!") # Compile the Keras model keras_model.compile( @@ -290,11 +286,11 @@ def train(self, if model.output_type == ModelOutputType.CONFIDENCE_MAP: def datagen_function(points): return generate_confmaps_from_points(points, skeleton, input_img_size, - sigma=self.sigma, scale=output_scale) + sigma=self.sigma, scale=model.output_scale) elif model.output_type == ModelOutputType.PART_AFFINITY_FIELD: def datagen_function(points): return generate_pafs_from_points(points, skeleton, input_img_size, - sigma=self.sigma, scale=output_scale) + sigma=self.sigma, scale=model.output_scale) elif model.output_type == ModelOutputType.CENTROIDS: def datagen_function(points): return generate_confmaps_from_points(points, None, input_img_size, @@ -339,10 +335,16 @@ def datagen_function(points): os.makedirs(save_path, exist_ok=True) # Setup a list of necessary callbacks to invoke while training. + monitor_metric_name = "val_loss" + if len(keras_model.output_names) > 1: + monitor_metric_name = "val_" + keras_model.output_names[-1] + "_loss" callbacks = self._setup_callbacks( train_run, save_path, train_datagen, tensorboard_dir, control_zmq_port, - progress_report_zmq_port, output_type=str(model.output_type)) + progress_report_zmq_port, + output_type=str(model.output_type), + monitor_metric_name=monitor_metric_name, + ) # Train! history = keras_model.fit_generator( @@ -402,7 +404,9 @@ def train_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: def _setup_callbacks(self, train_run: 'TrainingJob', save_path, train_datagen, tensorboard_dir, control_zmq_port, - progress_report_zmq_port, output_type): + progress_report_zmq_port, + output_type, + monitor_metric_name="val_loss"): """ Setup callbacks for the call to Keras fit_generator. @@ -420,14 +424,14 @@ def _setup_callbacks(self, train_run: 'TrainingJob', train_run.newest_model_filename = os.path.relpath(full_path, train_run.save_dir) callbacks.append( ModelCheckpoint(filepath=full_path, - monitor="val_loss", save_best_only=False, + monitor=monitor_metric_name, save_best_only=False, save_weights_only=False, period=1)) if self.save_best_val: full_path = os.path.join(save_path, "best_model.h5") train_run.best_model_filename = os.path.relpath(full_path, train_run.save_dir) callbacks.append( ModelCheckpoint(filepath=full_path, - monitor="val_loss", save_best_only=True, + monitor=monitor_metric_name, save_best_only=True, save_weights_only=False, period=1)) TrainingJob.save_json(train_run, f"{save_path}.json") @@ -443,12 +447,12 @@ def _setup_callbacks(self, train_run: 'TrainingJob', patience=self.reduce_lr_patience, cooldown=self.reduce_lr_cooldown, min_lr=self.reduce_lr_min_lr, - monitor="val_loss", mode="auto", verbose=1, ) + monitor=monitor_metric_name, mode="auto", verbose=1, ) ) # Callbacks: Early stopping callbacks.append( - EarlyStopping(monitor="val_loss", + EarlyStopping(monitor=monitor_metric_name, min_delta=self.early_stopping_min_delta, patience=self.early_stopping_patience, verbose=1)) @@ -542,17 +546,18 @@ def load_json(cls, filename: str): """ # Open and parse the JSON in filename - with open(filename, 'r') as file: + with open(filename, "r") as file: json_str = file.read() dicts = json.loads(json_str) # We have some skeletons to deal with, make sure to setup a Skeleton cattr. my_cattr = Skeleton.make_cattr() - try: - run = my_cattr.structure(dicts, cls) - except: - raise ValueError(f"Failure deserializing {filename} to TrainingJob.") + # Setup structuring hook for unambiguous backbone class resolution. + my_cattr.register_structure_hook(Model, Model._structure_model) + + # Build classes. + run = my_cattr.structure(dicts, cls) # if we can't find save_dir for job, set it to path of json we're loading if run.save_dir is not None: diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 6b4af6a18..131a3e8d4 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -507,7 +507,8 @@ def delete_symmetry(self, node1:str, node2: str): Returns: None """ - node1_node, node1_node = self.find_node(node1), self.find_node(node2) + node1_node = self.find_node(node1) + node2_node = self.find_node(node2) if self.get_symmetry(node1) != node2 or self.get_symmetry(node2) != node1: raise ValueError(f"Nodes {node1}, {node2} are not symmetric.") From 74ede4e9dde5231456bbe2f15a6b027e976953bd Mon Sep 17 00:00:00 2001 From: Talmo Date: Sun, 22 Sep 2019 23:30:08 -0400 Subject: [PATCH 085/176] Minor typo --- sleap/io/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index d3894d303..7741882c5 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -280,8 +280,8 @@ def __delitem__(self, key): def remove(self, value: LabeledFrame): self.labeled_frames.remove(value) - self._lf_by_video[new_label.video].remove(value) - del self._frame_idx_map[new_label.video][value.frame_idx] + self._lf_by_video[value.video].remove(value) + del self._frame_idx_map[value.video][value.frame_idx] def find(self, video: Video, frame_idx: Union[int, range] = None, return_new: bool=False) -> List[LabeledFrame]: """ Search for labeled frames given video and/or frame index. From 690de300c52ec58f736a3d34aa11af1fd75fcc22 Mon Sep 17 00:00:00 2001 From: Talmo Date: Mon, 23 Sep 2019 09:12:58 -0400 Subject: [PATCH 086/176] Added ResNet50 architecture - Generalized output scale calculation - Generalized backbone typing --- sleap/nn/architectures/__init__.py | 7 ++- sleap/nn/architectures/resnet.py | 92 ++++++++++++++++++++++++++++++ sleap/nn/model.py | 17 +++--- 3 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 sleap/nn/architectures/resnet.py diff --git a/sleap/nn/architectures/__init__.py b/sleap/nn/architectures/__init__.py index 01e9ebacd..86c13db10 100644 --- a/sleap/nn/architectures/__init__.py +++ b/sleap/nn/architectures/__init__.py @@ -1,9 +1,12 @@ from sleap.nn.architectures.leap import LeapCNN from sleap.nn.architectures.unet import UNet, StackedUNet from sleap.nn.architectures.hourglass import StackedHourglass +from sleap.nn.architectures.resnet import ResNet50 +from typing import TypeVar # TODO: We can set this up to find all classes under sleap.nn.architectures -available_archs = [LeapCNN, UNet, StackedUNet, StackedHourglass] +available_archs = [LeapCNN, UNet, StackedUNet, StackedHourglass, ResNet50] available_arch_names = [arch.__name__ for arch in available_archs] +BackboneType = TypeVar("BackboneType", *available_archs) -__all__ = ["available_archs", "available_arch_names"] + [arch.__name__ for arch in available_archs] +__all__ = ["available_archs", "available_arch_names", "BackboneType"] + [arch.__name__ for arch in available_archs] diff --git a/sleap/nn/architectures/resnet.py b/sleap/nn/architectures/resnet.py new file mode 100644 index 000000000..096a0b43e --- /dev/null +++ b/sleap/nn/architectures/resnet.py @@ -0,0 +1,92 @@ +import tensorflow as tf +import keras +from keras import applications + +import attr + +@attr.s(auto_attribs=True) +class ResNet50: + """ResNet50 pretrained backbone. + + Args: + x_in: Input 4-D tf.Tensor or instantiated layer. + num_output_channels: The number of output channels of the block. + upsampling_layers: Use upsampling instead of transposed convolutions. + interp: Method to use for interpolation when upsampling smaller features. + up_blocks: Number of upsampling steps to perform. The backbone reduces + the output scale by 1/32. If set to 5, outputs will be upsampled to the + input resolution. + refine_conv_up: If true, applies a 1x1 conv after each upsampling step. + pretrained: Load pretrained ImageNet weights for transfer learning. If + False, random weights are used for initialization. + """ + + upsampling_layers: bool = True + interp: str = "bilinear" + up_blocks: int = 5 + refine_conv_up: bool = False + pretrained: bool = True + + def output(self, x_in, num_output_channels): + """ + Generate a tensorflow graph for the backbone and return the output tensor. + + Args: + x_in: Input 4-D tf.Tensor or instantiated layer. Must have height and width + that are divisible by `2^down_blocks. + num_output_channels: The number of output channels of the block. These + are the final output tensors on which intermediate supervision may be + applied. + + Returns: + x_out: tf.Tensor of the output of the block of with `num_output_channels` channels. + """ + return resnet50(x_in, num_output_channels, **attr.asdict(self)) + + @property + def output_scale(self): + """Returns relative scaling factor of this backbone.""" + + down_blocks = 5 + return (1 / (2 ** (down_blocks - self.up_blocks))) + + +def preprocess_input(X): + """Rescale input to [-1, 1] and tile if not RGB.""" + X = (X * 2) - 1 + + if tf.shape(X)[-1] != 3: + X = tf.tile(X, [1, 1, 1, 3]) + + return X + + +def resnet50(x_in, num_output_channels, up_blocks=5, upsampling_layers=True, + interp="bilinear", refine_conv_up=False, pretrained=True): + """Build ResNet50 backbone.""" + + # Input should be rescaled from [0, 1] to [-1, 1] and needs to be 3 channels (RGB) + x = keras.layers.Lambda(preprocess_input)(x_in) + + # Automatically downloads weights + resnet_model = applications.ResNet50( + include_top=False, + input_shape=(int(x_in.shape[-3]), int(x_in.shape[-2]), 3), + weights="imagenet" if pretrained else None, + ) + + # Output size is reduced by factor of 32 (2 ** 5) + x = resnet_model(x) + + for i in range(up_blocks): + if upsampling_layers: + x = keras.layers.UpSampling2D(size=(2, 2), interpolation=interp)(x) + else: + x = keras.layers.Conv2DTranspose(2 ** (8 - i), kernel_size=3, strides=2, padding="same", kernel_initializer="glorot_normal")(x) + + if refine_conv_up: + x = keras.layers.Conv2D(2 ** (8 - i), kernel_size=1, padding="same")(x) + + x = keras.layers.Conv2D(num_output_channels, (3, 3), padding="same")(x) + + return x diff --git a/sleap/nn/model.py b/sleap/nn/model.py index 9b2782e89..3babcb818 100644 --- a/sleap/nn/model.py +++ b/sleap/nn/model.py @@ -67,7 +67,7 @@ class Model: """ output_type: ModelOutputType - backbone: Union[LeapCNN, UNet, StackedUNet, StackedHourglass] + backbone: BackboneType skeletons: Union[None, List[Skeleton]] = None backbone_name: str = None @@ -128,18 +128,19 @@ def name(self): @property def output_scale(self): """Calculates output scale relative to input.""" - - output_scale = 1 - # TODO: Determine scale within model implementation - if hasattr(self.backbone, "down_blocks") and hasattr(self.backbone, "up_blocks"): + if hasattr(self.backbone, "output_scale"): + return self.backbone.output_scale + + elif hasattr(self.backbone, "down_blocks") and hasattr(self.backbone, "up_blocks"): asym = self.backbone.down_blocks - self.backbone.up_blocks - output_scale = 1 / (2 ** asym) + return (1 / (2 ** asym)) elif hasattr(self.backbone, "initial_stride"): - output_scale = 1 / self.backbone.initial_stride + return (1 / self.backbone.initial_stride) - return output_scale + else: + return 1 @staticmethod From 164664c5f31f0ab719209c9bdd87c01ed0942a6c Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 23 Sep 2019 11:15:28 -0400 Subject: [PATCH 087/176] Fixes for showing/clearing header. Other marks are now re-drawn at correct heights after header is added or removed from slider. --- sleap/gui/slider.py | 103 +++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 55 deletions(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 4fdb642cf..46765c50e 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -13,7 +13,7 @@ import attr import itertools import numpy as np -from typing import Union +from typing import Dict, Optional, Union @attr.s(auto_attribs=True, cmp=False) class SliderMark: @@ -100,9 +100,9 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self._min_height = 19 + self._header_height # Add border rect - slider_rect = QRect(0, 0, 200, self._min_height-3) - self.slider = self.scene.addRect(slider_rect) - self.slider.setPen(QPen(QColor("black"))) + outline_rect = QRect(0, 0, 200, self._min_height-3) + self.outlineBox = self.scene.addRect(outline_rect) + self.outlineBox.setPen(QPen(QColor("black"))) # Add drag handle rect handle_width = 6 @@ -114,7 +114,7 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self.handle.setBrush(QColor(128, 128, 128, 128)) # Add (hidden) rect to highlight selection - self.select_box = self.scene.addRect(QRect(0, 1, 0, slider_rect.height()-2)) + self.select_box = self.scene.addRect(QRect(0, 1, 0, outline_rect.height()-2)) self.select_box.setPen(QPen(QColor(80, 80, 255))) self.select_box.setBrush(QColor(80, 80, 255, 128)) self.select_box.hide() @@ -187,36 +187,24 @@ def setTracksFromLabels(self, labels, video): self.setTracks(track_row) # total number of tracks to show self.setMarks(slider_marks) - # self.setHeaderSeries(lfs) + def setHeaderSeries(self, series:Optional[Dict[int,float]] = None): + """Show header graph with specified series. - self.updatedTracks.emit() - - def setHeaderSeries(self, lfs): - # calculate total point distance for instances from last labeled frame - def inst_velocity(lf, last_lf): - val = 0 - for inst in lf: - if last_lf is not None: - last_inst = last_lf.find(track=inst.track) - if last_inst: - points_a = inst.visible_points_array - points_b = last_inst[0].visible_points_array - point_dist = np.linalg.norm(points_a - points_b, axis=1) - inst_dist = np.sum(point_dist) # np.nanmean(point_dist) - val += inst_dist if not np.isnan(inst_dist) else 0 - return val - - series = dict() - - last_lf = None - for lf in lfs: - val = inst_velocity(lf, last_lf) - last_lf = lf - if not np.isnan(val): - series[lf.frame_idx] = val #len(lf.instances) - - self.headerSeries = series + Args: + series: {frame number: series value} dict. + Returns: + None. + """ + self.headerSeries = [] if series is None else series + self._header_height = 30 self.drawHeader() + self.updateHeight() + + def clearHeader(self): + """Remove header graph from slider.""" + self.headerSeries = [] + self._header_height = 0 + self.updateHeight() def setTracks(self, track_rows): """Set the number of tracks to show in slider. @@ -250,7 +238,13 @@ def updateHeight(self): self.setMaximumHeight(max_height) self.setMinimumHeight(min_height) + + # Redraw all marks with new height and y position + marks = self.getMarks() + self.setMarks(marks) + self.resizeEvent() + self.updatedTracks.emit() def _toPos(self, val, center=False): """Convert value to x position on slider.""" @@ -272,7 +266,7 @@ def _toVal(self, x, center=False): return val def _sliderWidth(self): - return self.slider.rect().width()-self.handle.rect().width() + return self.outlineBox.rect().width()-self.handle.rect().width() def value(self): """Get value of slider.""" @@ -358,7 +352,7 @@ def drawSelection(self, a, b): start_pos = self._toPos(start, center=True) end_pos = self._toPos(end, center=True) selection_rect = QRect(start_pos, 1, - end_pos-start_pos, self.slider.rect().height()-2) + end_pos-start_pos, self.outlineBox.rect().height()-2) self.select_box.setRect(selection_rect) self.select_box.show() @@ -371,7 +365,7 @@ def moveSelectionAnchor(self, x, y): y: y position of mouse """ x = max(x, 0) - x = min(x, self.slider.rect().width()) + x = min(x, self.outlineBox.rect().width()) anchor_val = self._toVal(x, center=True) if len(self._selection)%2 == 0: @@ -387,7 +381,7 @@ def releaseSelectionAnchor(self, x, y): y: y position of mouse """ x = max(x, 0) - x = min(x, self.slider.rect().width()) + x = min(x, self.outlineBox.rect().width()) anchor_val = self._toVal(x) self.endSelection(anchor_val) @@ -410,7 +404,6 @@ def setMarks(self, marks): for mark in marks: if not isinstance(mark, SliderMark): mark = SliderMark("simple", mark) - print(mark) self.addMark(mark, update=False) self.updatePos() @@ -443,7 +436,7 @@ def addMark(self, new_mark, update=True): height = 1 else: v_offset = v_top_pad - height = self.slider.rect().height()-(v_offset+v_bottom_pad) + height = self.outlineBox.rect().height()-(v_offset+v_bottom_pad) width = 2 if new_mark.type in ("open", "filled") else 0 @@ -484,6 +477,7 @@ def updatePos(self): self._mark_items[mark].setRect(rect) def drawHeader(self): + """Draw the header graph.""" if len(self.headerSeries) == 0 or self._header_height == 0: self.poly.setPath(QPainterPath()) return @@ -499,8 +493,6 @@ def drawHeader(self): sampled = np.max(sampled.reshape(count//step,step), axis=1) series = {i*step:sampled[i] for i in range(count//step)} -# series = {key:self.headerSeries[key] for key in sorted(self.headerSeries.keys())} - series_min = np.min(sampled) - 1 series_max = np.max(sampled) series_scale = (self._header_height-5)/(series_max - series_min) @@ -533,7 +525,7 @@ def moveHandle(self, x, y): """ x -= self.handle.rect().width()/2. x = max(x, 0) - x = min(x, self.slider.rect().width()-self.handle.rect().width()) + x = min(x, self.outlineBox.rect().width()-self.handle.rect().width()) val = self._toVal(x) @@ -559,20 +551,21 @@ def resizeEvent(self, event=None): Args: event """ - height = self.size().height() - slider_rect = self.slider.rect() + outline_rect = self.outlineBox.rect() handle_rect = self.handle.rect() select_box_rect = self.select_box.rect() - slider_rect.setHeight(height-3) - if event is not None: slider_rect.setWidth(event.size().width()-1) - handle_rect.setHeight(self._handleHeight()) - select_box_rect.setHeight(self._handleHeight()) + outline_rect.setHeight(height-3) + if event is not None: outline_rect.setWidth(event.size().width()-1) + self.outlineBox.setRect(outline_rect) - self.slider.setRect(slider_rect) + handle_rect.setTop(self._handleTop()) + handle_rect.setHeight(self._handleHeight()) self.handle.setRect(handle_rect) + + select_box_rect.setHeight(self._handleHeight()) self.select_box.setRect(select_box_rect) self.updatePos() @@ -582,12 +575,12 @@ def resizeEvent(self, event=None): def _handleTop(self): return 1 + self._header_height - def _handleHeight(self, slider_rect=None): - if slider_rect is None: - slider_rect = self.slider.rect() + def _handleHeight(self, outline_rect=None): + if outline_rect is None: + outline_rect = self.outlineBox.rect() handle_bottom_offset = 1 - handle_height = slider_rect.height() - (self._handleTop()+handle_bottom_offset) + handle_height = outline_rect.height() - (self._handleTop()+handle_bottom_offset) return handle_height def mousePressEvent(self, event): @@ -601,7 +594,7 @@ def mousePressEvent(self, event): # Do nothing if not enabled if not self.enabled(): return # Do nothing if click outside slider area - if not self.slider.rect().contains(scenePos): return + if not self.outlineBox.rect().contains(scenePos): return move_function = None release_function = None @@ -654,7 +647,7 @@ def keyReleaseEvent(self, event): def boundingRect(self) -> QRectF: """Method required by Qt.""" - return self.slider.rect() + return self.outlineBox.rect() def paint(self, *args, **kwargs): """Method required by Qt.""" From 524c120195b38ca57706c414fb7b3941d55923da Mon Sep 17 00:00:00 2001 From: Talmo Date: Mon, 23 Sep 2019 11:40:13 -0400 Subject: [PATCH 088/176] Clean up imports --- sleap/nn/loadmodel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sleap/nn/loadmodel.py b/sleap/nn/loadmodel.py index e715de99c..446cbcfc3 100644 --- a/sleap/nn/loadmodel.py +++ b/sleap/nn/loadmodel.py @@ -2,14 +2,12 @@ logger = logging.getLogger(__name__) import numpy as np -import keras from time import time, clock from typing import Dict, List, Union, Optional, Tuple import tensorflow as tf import keras -# keras = tf.keras from sleap.skeleton import Skeleton from sleap.nn.model import ModelOutputType From b4dd573b7739b422bdc61135a95468b05d1fc3df Mon Sep 17 00:00:00 2001 From: Talmo Date: Mon, 23 Sep 2019 12:47:29 -0400 Subject: [PATCH 089/176] Proper handling of nested skeletons when structuring TrainingJob - Also fix HDF5Video edge case --- sleap/io/video.py | 4 ++-- sleap/nn/model.py | 4 +++- sleap/nn/training.py | 50 +++++++++++++++++++++++++------------------- sleap/skeleton.py | 2 +- 4 files changed, 35 insertions(+), 25 deletions(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index ae3df6523..de56c034f 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -47,7 +47,7 @@ def __attrs_post_init__(self): self.filename = self.__file_h5.filename elif type(self.filename) is str: try: - self.__file_h5 = h5.File(self.filename, 'r') + self.__file_h5 = h5.File(self.filename, "r") except OSError as ex: raise FileNotFoundError(f"Could not find HDF5 file {self.filename}") from ex else: @@ -58,7 +58,7 @@ def __attrs_post_init__(self): self.__dataset_h5 = self.dataset self.__file_h5 = self.__dataset_h5.file self.dataset = self.__dataset_h5.name - elif self.dataset is not None and type(self.dataset) is str: + elif (self.dataset is not None) and isinstance(self.dataset, str) and (self.__file_h5 is not None): self.__dataset_h5 = self.__file_h5[self.dataset] else: self.__dataset_h5 = None diff --git a/sleap/nn/model.py b/sleap/nn/model.py index 3babcb818..4c1680031 100644 --- a/sleap/nn/model.py +++ b/sleap/nn/model.py @@ -170,4 +170,6 @@ class to use. backbone_cls = available_archs[arch_idx] return Model(backbone=backbone_cls(**model_dict["backbone"]), - output_type=ModelOutputType(model_dict["output_type"])) + output_type=ModelOutputType(model_dict["output_type"]), + skeletons=model_dict["skeletons"] + ) diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 648cfcb27..8c59b1163 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -546,27 +546,35 @@ def load_json(cls, filename: str): """ # Open and parse the JSON in filename - with open(filename, "r") as file: - json_str = file.read() - dicts = json.loads(json_str) - - # We have some skeletons to deal with, make sure to setup a Skeleton cattr. - my_cattr = Skeleton.make_cattr() - - # Setup structuring hook for unambiguous backbone class resolution. - my_cattr.register_structure_hook(Model, Model._structure_model) - - # Build classes. - run = my_cattr.structure(dicts, cls) - - # if we can't find save_dir for job, set it to path of json we're loading - if run.save_dir is not None: - if not os.path.exists(run.save_dir): - run.save_dir = os.path.dirname(filename) - - run.final_model_filename = cls._fix_path(run.final_model_filename) - run.best_model_filename = cls._fix_path(run.best_model_filename) - run.newest_model_filename = cls._fix_path(run.newest_model_filename) + with open(filename, "r") as f: + dicts = json.load(f) + + # We have some skeletons to deal with, make sure to setup a Skeleton cattr. + converter = Skeleton.make_cattr() + + # Structure the nested skeletons if we have any. + if ("model" in dicts) and ("skeletons" in dicts["model"]): + if dicts["model"]["skeletons"]: + dicts["model"]["skeletons"] = converter.structure( + dicts["model"]["skeletons"], List[Skeleton]) + + else: + dicts["model"]["skeletons"] = [] + + # Setup structuring hook for unambiguous backbone class resolution. + converter.register_structure_hook(Model, Model._structure_model) + + # Build classes. + run = converter.structure(dicts, cls) + + # if we can't find save_dir for job, set it to path of json we're loading + if run.save_dir is not None: + if not os.path.exists(run.save_dir): + run.save_dir = os.path.dirname(filename) + + run.final_model_filename = cls._fix_path(run.final_model_filename) + run.best_model_filename = cls._fix_path(run.best_model_filename) + run.newest_model_filename = cls._fix_path(run.newest_model_filename) return run diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 131a3e8d4..0adde62a9 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -175,7 +175,7 @@ def make_cattr(idx_to_node: Dict[int, Node] = None): _cattr = cattr.Converter() _cattr.register_unstructure_hook(Skeleton, lambda x: Skeleton.to_dict(x, node_to_idx)) - _cattr.register_structure_hook(Skeleton, lambda x,type: Skeleton.from_dict(x, idx_to_node)) + _cattr.register_structure_hook(Skeleton, lambda x, cls: Skeleton.from_dict(x, idx_to_node)) return _cattr @property From c41d6587b9001644073580981dc931fef1cc4ebf Mon Sep 17 00:00:00 2001 From: Talmo Date: Mon, 23 Sep 2019 13:11:38 -0400 Subject: [PATCH 090/176] Explicit input_video type handling in inference --- sleap/nn/inference.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 5c3f95ede..10d17250a 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -154,17 +154,14 @@ def predict(self, grayscale = (model_channels == 1) # Open the video if we need it. - - try: - input_video.get_frame(frames[0]) + if isinstance(input_video, Video): vid = input_video - except AttributeError: - if isinstance(input_video, dict): - vid = Video.cattr().structure(input_video, Video) - elif isinstance(input_video, str): - vid = Video.from_filename(input_video, grayscale=grayscale) - else: - raise AttributeError(f"Unable to load input video: {input_video}") + elif isinstance(input_video, dict): + vid = Video.cattr().structure(input_video, Video) + elif isinstance(input_video, str): + vid = Video.from_filename(input_video, grayscale=grayscale) + else: + raise AttributeError(f"Unable to load input video: {input_video}") # List of frames to process (or entire video if not specified) frames = frames or list(range(vid.num_frames)) From 5d06ce85c1dcf3879495978ea34b8ac4af5ac646 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 23 Sep 2019 13:35:33 -0400 Subject: [PATCH 091/176] --all-frames arg includes initial empty frames This ensures that frame index and column index will be aligned; empty frames from end of video are still not included. --- sleap/info/write_tracking_h5.py | 48 +++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 0043911b6..469ee01a7 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -1,3 +1,24 @@ +""" +Generate an HDF5 file with track occupancy and point location data. + +Ignores tracks that are entirely empty. By default will also ignore +empty frames from the beginning and end of video, although +`--all-frames` argument will make it include empty frames from beginning +of video. + +Call from command line as: +> python -m sleap.io.write_tracking_h5 + +Will write file to `.tracking.h5`. + +The HDF5 file has these datasets: + "track_occupancy" shape: tracks * frames + "tracks" shape: frames * nodes * 2 * tracks + "track_names" shape: tracks + +Note: the datasets are stored column-major as expected by MATLAB. +""" + import os import re import h5py as h5 @@ -10,6 +31,9 @@ parser = argparse.ArgumentParser() parser.add_argument("data_path", help="Path to labels json file") + parser.add_argument('--all-frames', dest='all_frames', action='store_const', + const=True, default=False, + help='include all frames without predictions') args = parser.parse_args() video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) @@ -21,7 +45,10 @@ frame_idxs = [lf.frame_idx for lf in labels] frame_idxs.sort() - frame_count = frame_idxs[-1] - frame_idxs[0] + 1 # count should include unlabeled frames + + first_frame_idx = 0 if args.all_frames else frame_idxs[0] + + frame_count = frame_idxs[-1] - first_frame_idx + 1 # count should include unlabeled frames # Desired MATLAB format: # "track_occupancy" tracks * frames @@ -30,9 +57,9 @@ occupancy_matrix = np.zeros((track_count, frame_count), dtype=np.uint8) prediction_matrix = np.full((frame_count, node_count, 2, track_count), np.nan, dtype=float) - + for lf, inst in [(lf, inst) for lf in labels for inst in lf.instances]: - frame_i = lf.frame_idx - frame_idxs[0] + frame_i = lf.frame_idx - first_frame_idx track_i = labels.tracks.index(inst.track) occupancy_matrix[track_i, frame_i] = 1 @@ -41,13 +68,12 @@ prediction_matrix[frame_i, ..., track_i] = inst_points occupied_track_mask = np.sum(occupancy_matrix, axis=1) > 0 -# print(track_names[occupied_track_mask]) # Ignore unoccupied tracks - if(np.sum(~occupied_track_mask)): + if np.sum(~occupied_track_mask): print(f"ignoring {np.sum(~occupied_track_mask)} empty tracks") occupancy_matrix = occupancy_matrix[occupied_track_mask] - prediction_matrix = prediction_matrix[...,occupied_track_mask] + prediction_matrix = prediction_matrix[..., occupied_track_mask] track_names = [track_names[i] for i in range(len(track_names)) if occupied_track_mask[i]] print(f"track_occupancy: {occupancy_matrix.shape}") @@ -60,10 +86,10 @@ # We have to transpose the arrays since MATLAB expects column-major ds = f.create_dataset("track_names", data=track_names) ds = f.create_dataset( - "track_occupancy", data=np.transpose(occupancy_matrix), - compression="gzip", compression_opts=9) + "track_occupancy", data=np.transpose(occupancy_matrix), + compression="gzip", compression_opts=9) ds = f.create_dataset( - "tracks", data=np.transpose(prediction_matrix), - compression="gzip", compression_opts=9) + "tracks", data=np.transpose(prediction_matrix), + compression="gzip", compression_opts=9) - print(f"Saved as {output_filename}") \ No newline at end of file + print(f"Saved as {output_filename}") From 233bb9d0a25533d2c352e07d9984ccaf5264362b Mon Sep 17 00:00:00 2001 From: Talmo Date: Mon, 23 Sep 2019 13:53:56 -0400 Subject: [PATCH 092/176] Allow for importing Lambda layers with TensorFlow functions --- sleap/nn/loadmodel.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sleap/nn/loadmodel.py b/sleap/nn/loadmodel.py index 446cbcfc3..7159470bf 100644 --- a/sleap/nn/loadmodel.py +++ b/sleap/nn/loadmodel.py @@ -105,12 +105,9 @@ def get_model_skeleton(sleap_models, output_types) -> Skeleton: def load_model_from_job(job: TrainingJob) -> keras.Model: """Load keras Model from a specific TrainingJob.""" - # init = tf.global_variables_initializer() - # keras.backend.get_session().run(init) - # logger.info("Initialized TF global variables.") - # Load model from TrainingJob data - keras_model = tf.keras.models.load_model(job_model_path(job)) + keras_model = keras.models.load_model(job_model_path(job), + custom_objects={"tf": tf}) # Rename to prevent layer naming conflict name_prefix = f"{job.model.output_type}_" From 3c1d203a3b1aa0c776f3835ee946ec75b0cd3bf3 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 23 Sep 2019 14:26:04 -0400 Subject: [PATCH 093/176] Maintain order of tracks when updating top level. We were always re-ordering the tracks in the top level list, and this caused problems when loading an hdf5 file (since the track order was getting changed when the Labels object was created from json in the hdf5). --- sleap/io/dataset.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index d3894d303..71e9b2640 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -146,27 +146,29 @@ def _update_from_labels(self, merge=False): # Ditto for tracks, a pattern is emerging here if merge or len(self.tracks) == 0: - tracks = set(self.tracks) - - # Add tracks from any Instances or PredictedInstances - tracks = tracks.union({instance.track - for frame in self.labels - for instance in frame.instances - if instance.track}) + # Get tracks from any Instances or PredictedInstances + other_tracks = {instance.track + for frame in self.labels + for instance in frame.instances + if instance.track} # Add tracks from any PredictedInstance referenced by instance # This fixes things when there's a referenced PredictionInstance # which is no longer in the frame. - tracks = tracks.union({instance.from_predicted.track - for frame in self.labels - for instance in frame.instances - if instance.from_predicted - and instance.from_predicted.track}) + other_tracks = other_tracks.union( + {instance.from_predicted.track + for frame in self.labels + for instance in frame.instances + if instance.from_predicted and instance.from_predicted.track}) + + # Get list of other tracks not already in track list + new_tracks = list(other_tracks - set(self.tracks)) + + # Sort the new tracks by spawned on and then name + new_tracks.sort(key=lambda t:(t.spawned_on, t.name)) - self.tracks = list(tracks) + self.tracks.extend(new_tracks) - # Sort the tracks by spawned on and then name - self.tracks.sort(key=lambda t:(t.spawned_on, t.name)) def _update_lookup_cache(self): # Data structures for caching From 1b8acd9699f216957298a595465b0fae01d9fbcf Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 23 Sep 2019 15:12:46 -0400 Subject: [PATCH 094/176] Add seekbar headers to gui. --- sleap/gui/app.py | 48 +++++++++++++++++++++++++++++++++ sleap/info/summary.py | 62 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 sleap/info/summary.py diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 64d7d1a36..167584130 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -11,6 +11,7 @@ from PySide2.QtWidgets import QFileDialog, QMessageBox import copy +import re import operator import os import sys @@ -26,6 +27,7 @@ from sleap.instance import Instance, PredictedInstance, Point, LabeledFrame, Track from sleap.io.video import Video from sleap.io.dataset import Labels +from sleap.info.summary import Summary from sleap.gui.video import QtVideoPlayer from sleap.gui.dataviews import VideosTable, SkeletonNodesTable, SkeletonEdgesTable, \ LabeledFrameTable, SkeletonNodeModel, SuggestionsTable @@ -197,6 +199,25 @@ def initialize_gui(self): viewMenu.addSeparator() + self.seekbarHeaderMenu = viewMenu.addMenu("Seekbar Header") + headers = ( + "None", + "Point Displacement (sum)", + "Point Displacement (max)", + "Instance Score (sum)", + "Instance Score (min)", + "Point Score (sum)", + "Point Score (min)", + "Number of predicted points" + ) + for header in headers: + menu_item = self.seekbarHeaderMenu.addAction(header, + lambda x=header: self.setSeekbarHeader(x)) + menu_item.setCheckable(True) + self.setSeekbarHeader("None") + + viewMenu.addSeparator() + self._menu_actions["show labels"] = viewMenu.addAction("Show Node Names", self.toggleLabels, shortcuts["show labels"]) self._menu_actions["show edges"] = viewMenu.addAction("Show Edges", self.toggleEdges, shortcuts["show edges"]) self._menu_actions["show trails"] = viewMenu.addAction("Show Trails", self.toggleTrails, shortcuts["show trails"]) @@ -784,6 +805,33 @@ def deleteEdge(self): def updateSeekbarMarks(self): self.player.seekbar.setTracksFromLabels(self.labels, self.video) + def setSeekbarHeader(self, graph_name): + data_obj = Summary(self.labels) + header_functions = { + "Point Displacement (sum)": data_obj.get_point_displacement_series, + "Point Displacement (max)": data_obj.get_point_displacement_series, + "Instance Score (sum)": data_obj.get_instance_score_series, + "Instance Score (min)": data_obj.get_instance_score_series, + "Point Score (sum)": data_obj.get_point_score_series, + "Point Score (min)": data_obj.get_point_score_series, + "Number of predicted points": data_obj.get_point_count_series, + } + + self._menu_check_single(self.seekbarHeaderMenu, graph_name) + + if graph_name == "None": + self.player.seekbar.clearHeader() + else: + if graph_name in header_functions: + kwargs = dict(video=self.video) + reduction_name = re.search("\((sum|max|min)\)", graph_name) + if reduction_name is not None: + kwargs["reduction"] = reduction_name.group(1) + series = header_functions[graph_name](**kwargs) + self.player.seekbar.setHeaderSeries(series) + else: + print(f"Could not find function for {header_functions}") + def generateSuggestions(self, params): new_suggestions = dict() for video in self.labels.videos: diff --git a/sleap/info/summary.py b/sleap/info/summary.py new file mode 100644 index 000000000..7e3106441 --- /dev/null +++ b/sleap/info/summary.py @@ -0,0 +1,62 @@ +import attr +import numpy as np + + +@attr.s(auto_attribs=True) +class Summary: + labels: 'Labels' + + def get_point_count_series(self, video): + series = dict() + + for lf in self.labels.find(video): + val = sum(len(inst.points) for inst in lf if hasattr(inst, "score")) + series[lf.frame_idx] = val + return series + + def get_point_score_series(self, video, reduction="sum"): + reduce_funct = dict(sum=sum, min=lambda x: min(x, default=0))[reduction] + + series = dict() + + for lf in self.labels.find(video): + val = reduce_funct(point.score for inst in lf for point in inst.points if hasattr(inst, "score")) + series[lf.frame_idx] = val + return series + + def get_instance_score_series(self, video, reduction="sum"): + reduce_funct = dict(sum=sum, min=lambda x: min(x, default=0))[reduction] + + series = dict() + + for lf in self.labels.find(video): + val = reduce_funct(inst.score for inst in lf if hasattr(inst, "score")) + series[lf.frame_idx] = val + return series + + def get_point_displacement_series(self, video, reduction="sum"): + reduce_funct = dict(sum=np.sum, mean=np.nanmean, max=np.max)[reduction] + + series = dict() + + last_lf = None + for lf in self.labels.find(video): + val = self._calculate_frame_velocity(lf, last_lf, reduce_funct) + last_lf = lf + if not np.isnan(val): + series[lf.frame_idx] = val #len(lf.instances) + return series + + @staticmethod + def _calculate_frame_velocity(lf, last_lf, reduce_function): + val = 0 + for inst in lf: + if last_lf is not None: + last_inst = last_lf.find(track=inst.track) + if last_inst: + points_a = inst.visible_points_array + points_b = last_inst[0].visible_points_array + point_dist = np.linalg.norm(points_a - points_b, axis=1) + inst_dist = reduce_function(point_dist) + val += inst_dist if not np.isnan(inst_dist) else 0 + return val From 33f21da0ab5728ea487f9111a05e96df572acb8d Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 23 Sep 2019 16:11:01 -0400 Subject: [PATCH 095/176] Add score column to suggestions, columns sortable. --- sleap/gui/dataviews.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index cfa07d8ad..ebd40fc14 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -439,9 +439,10 @@ def __init__(self, labels): self.setModel(SuggestionsTableModel(labels)) self.setSelectionBehavior(QAbstractItemView.SelectRows) self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSortingEnabled(True) class SuggestionsTableModel(QtCore.QAbstractTableModel): - _props = ["video", "frame", "labeled",] + _props = ["video", "frame", "labeled", "mean score",] def __init__(self, labels): super(SuggestionsTableModel, self).__init__() @@ -470,7 +471,7 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): frame_idx = self._suggestions_list[idx][1] if prop == "video": - return os.path.basename(video.filename) # just show the name, not full path + return f"{self.labels.videos.index(video)}: {os.path.basename(video.filename)}" elif prop == "frame": return int(frame_idx) + 1 # start at frame 1 rather than 0 elif prop == "labeled": @@ -478,9 +479,30 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): val = self._labels.instance_count(video, frame_idx) val = str(val) if val > 0 else "" return val + elif prop == "mean score": + return self._getScore(video, frame_idx) return None + def _getScore(self, video, frame_idx): + scores = [inst.score for lf in self.labels.find(video, frame_idx) for inst in lf if hasattr(inst, "score")] + return sum(scores) / len(scores) + + def sort(self, column_idx: int, order: Qt.SortOrder): + prop = self._props[column_idx] + if prop in ("video", "frame"): + sort_function = lambda s: s + elif prop == "labeled": + sort_function = lambda s: self._labels.instance_count(*s) + elif prop == "mean score": + sort_function = lambda s: self._getScore(*s) + + reverse = (order == Qt.SortOrder.DescendingOrder) + + self.beginResetModel() + self._suggestions_list.sort(key=sort_function, reverse=reverse) + self.endResetModel() + def rowCount(self, *args): return len(self._suggestions_list) From 522f4c31b25f67ceae8f4710b515a888006747c8 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 08:09:21 -0400 Subject: [PATCH 096/176] Docstring improvements. --- sleap/util.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/sleap/util.py b/sleap/util.py index d333d91ff..4ea5dc180 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -10,10 +10,14 @@ import attr import psutil -from typing import Callable - - def attr_to_dtype(cls): + """Convert classes with basic types to numpy composite dtypes. + + Arguments: + cls: class to convert + Returns: + numpy dtype. + """ dtype_list = [] for field in attr.fields(cls): if field.type == str: @@ -29,14 +33,12 @@ def attr_to_dtype(cls): return np.dtype(dtype_list) - def usable_cpu_count() -> int: """Get number of CPUs usable by the current process. Takes into consideration cpusets restrictions. - Returns - ------- + Returns: The number of usable cpus """ try: @@ -89,6 +91,13 @@ def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): raise ValueError('Cannot save %s type'%type(item)) def frame_list(frame_str: str): + """Convert 'n-m' string to list of ints. + + Args: + frame_str: string representing range + Returns: + List of ints, or None if string does not represent valid range. + """ # Handle ranges of frames. Must be of the form "1-200" if '-' in frame_str: @@ -99,7 +108,6 @@ def frame_list(frame_str: str): return [int(x) for x in frame_str.split(",")] if len(frame_str) else None - def uniquify(seq): """ Given a list, return unique elements but preserve order. @@ -120,12 +128,13 @@ def uniquify(seq): def weak_filename_match(filename_a, filename_b): """Check if paths probably point to same file.""" - filename_a = filename_a.replace("\\","/") - filename_b = filename_b.replace("\\","/") + # convert all path separators to / + filename_a = filename_a.replace("\\", "/") + filename_b = filename_b.replace("\\", "/") # remove unique pid so we can match tmp directories for same zip filename_a = re.sub("/tmp_\d+_", "tmp_", filename_a) filename_b = re.sub("/tmp_\d+_", "tmp_", filename_b) # check if last three parts of path match - return filename_a.split("/")[-3:] == filename_b.split("/")[-3:] \ No newline at end of file + return filename_a.split("/")[-3:] == filename_b.split("/")[-3:] From 8b06a86453bf10531bf6ba02e726554a79107c53 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 08:33:06 -0400 Subject: [PATCH 097/176] Better typing and docstrings. --- sleap/rangelist.py | 76 ++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/sleap/rangelist.py b/sleap/rangelist.py index a9a29bb70..6db9d7e6e 100644 --- a/sleap/rangelist.py +++ b/sleap/rangelist.py @@ -1,10 +1,19 @@ -class RangeList(): +""" +Module with RangeList class for manipulating a list of range intervals. + +This is used to cache the track occupancy so we can keep cache updating +when user manipulates tracks for a range of instances. +""" + +from typing import List, Tuple + +class RangeList: """ Class for manipulating a list of range intervals. Each range interval in the list is a [start, end)-tuple. """ - def __init__(self, range_list: list=None): + def __init__(self, range_list: List[Tuple[int]] = None): self.list = range_list if range_list is not None else [] def __repr__(self): @@ -19,9 +28,6 @@ def list(self): def list(self, val): """Sets the list of ranges.""" self._list = val -# for i, r in enumerate(self._list): -# if type(r) == tuple: -# self._list[i] = range(r[0], r[1]) @property def is_empty(self): @@ -30,43 +36,43 @@ def is_empty(self): @property def start(self): - """Returns the start value of range (or None if empty).""" + """Return the start value of range (or None if empty).""" if self.is_empty: return None return self.list[0][0] def add(self, val, tolerance=0): - """Adds a single value, merges to last range if contiguous.""" - if len(self.list) and self.list[-1][1] + tolerance >= val: + """Add a single value, merges to last range if contiguous.""" + if self.list and self.list[-1][1] + tolerance >= val: self.list[-1] = (self.list[-1][0], val+1) else: self.list.append((val, val+1)) def insert(self, new_range: tuple): - """Adds a new range, merging to adjacent/overlapping ranges as appropriate.""" + """Add a new range, merging to adjacent/overlapping ranges as appropriate.""" new_range = self._as_tuple(new_range) - pre, within, post = self.cut_range(new_range) + pre, _, post = self.cut_range(new_range) self.list = self.join_([pre, [new_range], post]) return self.list - def insert_list(self, range_list: list): - """Adds each range from a list of ranges.""" + def insert_list(self, range_list: List[Tuple[int]]): + """Add each range from a list of ranges.""" for range_ in range_list: self.insert(range_) return self.list def remove(self, remove: tuple): - """Removes everything that overlaps with given range.""" - pre, within, post = self.cut_range(remove) + """Remove everything that overlaps with given range.""" + pre, _, post = self.cut_range(remove) self.list = pre + post def cut(self, cut: int): - """Returns a pair of lists with everything before/after cut.""" + """Return a pair of lists with everything before/after cut.""" return self.cut_(self.list, cut) def cut_range(self, cut: tuple): - """Returns three lists, everthing before/within/after cut range.""" - if len(self.list) == 0: return [], [], [] + """Return three lists, everthing before/within/after cut range.""" + if not self.list: return [], [], [] cut = self._as_tuple(cut) a, r = self.cut_(self.list, cut[0]) @@ -76,11 +82,19 @@ def cut_range(self, cut: tuple): @staticmethod def _as_tuple(x): - if type(x) == range: return x.start, x.stop + """Return tuple (converting from range if necessary).""" + if isinstance(x, range): return x.start, x.stop return x @staticmethod - def cut_(range_list: list, cut: int): + def cut_(range_list: List[Tuple[int]], cut: int): + """Return a pair of lists with everything before/after cut. + Args: + range_list: the list to cut + cut: the value at which to cut list + Returns: + (pre-cut list, post-cut list)-tuple + """ pre = [] post = [] @@ -89,7 +103,7 @@ def cut_(range_list: list, cut: int): pre.append(range_) elif range_[0] >= cut: post.append(range_) - elif range_[0] < cut and range_[1] > cut: + elif range_[0] < cut < range_[1]: # two new ranges, split at cut a = (range_[0], cut) b = (cut, range_[1]) @@ -98,18 +112,26 @@ def cut_(range_list: list, cut: int): return pre, post @classmethod - def join_(cls, list_list: list): + def join_(cls, list_list: List[List[Tuple[int]]]): + """Return a single list that includes all lists in input list. + + Args: + list_list: a list of range lists + Returns: + range list that joins all of the lists in list_list + """ if len(list_list) == 1: return list_list[0] if len(list_list) == 2: return cls.join_pair_(list_list[0], list_list[1]) - else: return cls.join_pair_(list_list[0], cls.join_(list_list[1:])) + return cls.join_pair_(list_list[0], cls.join_(list_list[1:])) @staticmethod - def join_pair_(list_a: list, list_b: list): - if len(list_a) == 0 or len(list_b) == 0: return list_a + list_b - + def join_pair_(list_a: List[Tuple[int]], list_b: List[Tuple[int]]): + """Return a single pair of lists that joins two input lists.""" + if not list_a or not list_b: return list_a + list_b + last_a = list_a[-1] first_b = list_b[0] if last_a[1] >= first_b[0]: return list_a[:-1] + [(last_a[0], first_b[1])] + list_b[1:] - else: - return list_a + list_b + + return list_a + list_b From 8ac1a027a5430b86adc4f8257d3a4d1d8ca10b94 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 08:47:04 -0400 Subject: [PATCH 098/176] Fix delete_symmetry and add test. --- sleap/skeleton.py | 18 +++++++++--------- tests/test_skeleton.py | 7 +++++++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 6b4af6a18..4c915f356 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -11,7 +11,6 @@ import numpy as np import jsonpickle import json -import networkx as nx import h5py as h5 import copy @@ -19,8 +18,9 @@ from itertools import count from typing import Iterable, Union, List, Dict +import networkx as nx from networkx.readwrite import json_graph -from scipy.io import loadmat, savemat +from scipy.io import loadmat class EdgeType(Enum): @@ -79,8 +79,8 @@ class Skeleton: """ """ - A index variable used to give skeletons a default name that attemtpts to be - unique across all skeletons. + A index variable used to give skeletons a default name that attempts + to be unique across all skeletons. """ _skeleton_idx = count(0) @@ -137,7 +137,7 @@ def dict_match(dict1, dict2): def graph(self): edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.BODY] # TODO: properly induce subgraph for MultiDiGraph - # Currently, NetworkX will just return the nodes in the subgraph. + # Currently, NetworkX will just return the nodes in the subgraph. # See: https://stackoverflow.com/questions/16150557/networkxcreating-a-subgraph-induced-from-edges return self._graph.edge_subgraph(edges) @@ -497,7 +497,7 @@ def add_symmetry(self, node1:str, node2: str): self._graph.add_edge(node1_node, node2_node, type=EdgeType.SYMMETRY) self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY) - def delete_symmetry(self, node1:str, node2: str): + def delete_symmetry(self, node1: str, node2: str): """Deletes a previously established symmetry relationship between two nodes. Args: @@ -507,9 +507,9 @@ def delete_symmetry(self, node1:str, node2: str): Returns: None """ - node1_node, node1_node = self.find_node(node1), self.find_node(node2) + node1_node, node2_node = self.find_node(node1), self.find_node(node2) - if self.get_symmetry(node1) != node2 or self.get_symmetry(node2) != node1: + if self.get_symmetry(node1) != node2_node or self.get_symmetry(node2) != node1_node: raise ValueError(f"Nodes {node1}, {node2} are not symmetric.") edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges([node1_node, node2_node], keys=True, data="type") if edge_type == EdgeType.SYMMETRY] @@ -652,7 +652,7 @@ def has_edge(self, source_name: str, dest_name: str) -> bool: True is yes, False if no. """ - source_node, destination_node = self.find_node(source_name), self.find_node(dest_name) + source_node, destination_node = self.find_node(source_name), self.find_node(dest_name) return self._graph.has_edge(source_node, destination_node) @staticmethod diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 1f2080836..e9f39f453 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -153,6 +153,13 @@ def test_symmetry(): with pytest.raises(ValueError): s1.add_symmetry('6', '1') + s1.delete_symmetry('1', '5') + assert s1.get_symmetry("1") is None + + with pytest.raises(ValueError): + s1.delete_symmetry('1', '5') + + def test_json(skeleton, tmpdir): """ Test saving and loading a Skeleton object in JSON. From b164d24543f448e4902bd0d46d6132162878aa63 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 08:59:44 -0400 Subject: [PATCH 099/176] Formatting changes (from pylint). --- sleap/skeleton.py | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 4c915f356..05fb6c5fe 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -48,6 +48,7 @@ class Node: @staticmethod def from_names(name_list: str): + """Convert list of node names to list of nodes objects.""" nodes = [] for name in name_list: nodes.append(Node(name)) @@ -55,6 +56,7 @@ def from_names(name_list: str): @classmethod def as_node(cls, node): + """Convert given `node` to `Node` object (if not already).""" return node if isinstance(node, cls) else cls(node) def matches(self, other): @@ -94,7 +96,7 @@ def __init__(self, name: str = None): """ # If no skeleton was create, try to create a unique name for this Skeleton. - if name is None or type(name) is not str or len(name) == 0: + if name is None or not isinstance(name, str) or not name: name = "Skeleton-" + str(next(self._skeleton_idx)) @@ -340,7 +342,7 @@ def add_node(self, name: str): Returns: None """ - if type(name) is not str: + if not isinstance(name, str): raise TypeError("Cannot add nodes to the skeleton that are not str") if name in self.node_names: @@ -389,13 +391,16 @@ def find_node(self, name: str): """ if isinstance(name, Node): name = name.name + nodes = [node for node in self.nodes if node.name == name] + if len(nodes) == 1: return nodes[0] - elif len(nodes) > 1: + + if len(nodes) > 1: raise ValueError("Found multiple nodes named ({}).".format(name)) - elif len(nodes) == 0: - return None + + return None def add_edge(self, source: str, destination: str): """Add an edge between two nodes. @@ -466,11 +471,11 @@ def delete_edge(self, source: str, destination: str): self._graph.remove_edge(source_node, destination_node) - def add_symmetry(self, node1:str, node2: str): + def add_symmetry(self, node1: str, node2: str): """Specify that two parts (nodes) in the skeleton are symmetrical. - Certain parts of an animal body can be related as symmetrical parts in a pair. For example, - the left and right hands of a person. + Certain parts of an animal body can be related as symmetrical + parts in a pair. For example, the left and right hands of a person. Args: node1: The name of the first part in the symmetric pair @@ -515,7 +520,7 @@ def delete_symmetry(self, node1: str, node2: str): edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges([node1_node, node2_node], keys=True, data="type") if edge_type == EdgeType.SYMMETRY] self._graph.remove_edges_from(edges) - def get_symmetry(self, node:str): + def get_symmetry(self, node: str): """ Returns the node symmetric with the specified node. Args: @@ -535,8 +540,8 @@ def get_symmetry(self, node:str): else: raise ValueError(f"{node} has more than one symmetry.") - def get_symmetry_name(self, node:str): - """ Returns the name of the node symmetric with the specified node. + def get_symmetry_name(self, node: str): + """Returns the name of the node symmetric with the specified node. Args: node: The name of the node to query. @@ -589,7 +594,7 @@ def relabel_node(self, old_name: str, new_name: str): """ self.relabel_nodes({old_name: new_name}) - def relabel_nodes(self, mapping:dict): + def relabel_nodes(self, mapping: Dict[str, str]): """ Relabel the nodes of the skeleton. @@ -600,12 +605,12 @@ def relabel_nodes(self, mapping:dict): None """ existing_nodes = self.nodes - for k, v in mapping.items(): - if self.has_node(v): + for old_name, new_name in mapping.items(): + if self.has_node(new_name): raise ValueError("Cannot relabel a node to an existing name.") - node = self.find_node(k) + node = self.find_node(old_name) if node is not None: - node.name = v + node.name = new_name # self._graph = nx.relabel_nodes(G=self._graph, mapping=mapping) @@ -760,7 +765,7 @@ def load_hdf5(cls, file: Union[str, h5.File], name: str): Returns: The skeleton instance stored in the HDF5 file. """ - if type(file) is str: + if isinstance(file, str): with h5.File(file) as _file: skeletons = Skeleton._load_hdf5(_file) # Load all skeletons else: @@ -783,7 +788,7 @@ def load_all_hdf5(cls, file: Union[str, h5.File], Returns: The skeleton instances stored in the HDF5 file. Either in List or Dict form. """ - if type(file) is str: + if isinstance(file, str): with h5.File(file) as _file: skeletons = Skeleton._load_hdf5(_file) # Load all skeletons else: @@ -791,8 +796,8 @@ def load_all_hdf5(cls, file: Union[str, h5.File], if return_dict: return skeletons - else: - return list(skeletons.values()) + + return list(skeletons.values()) @classmethod def _load_hdf5(cls, file: h5.File): @@ -804,7 +809,7 @@ def _load_hdf5(cls, file: h5.File): return skeletons def save_hdf5(self, file: Union[str, h5.File]): - if type(file) is str: + if isinstance(file, str): with h5.File(file) as _file: self._save_hdf5(_file) else: From bb5704a1f1affa85618bb5004ac18905475a76f4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 09:10:59 -0400 Subject: [PATCH 100/176] Bug fixes to demo functions. --- sleap/gui/video.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sleap/gui/video.py b/sleap/gui/video.py index abd430bda..4e01d1d84 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -1374,26 +1374,28 @@ def video_demo(labels, standalone=False): if standalone: app.exec_() def plot_instances(scene, frame_idx, labels, video=None, fixed=True): + from sleap.gui.overlays.tracks import TrackColorManager + video = labels.videos[0] - color_manager = TrackColorManager(labels) - lfs = [label for label in labels.labels if label.video == video and label.frame_idx == frame_idx] + color_manager = TrackColorManager(labels=labels) + lfs = labels.find(video, frame_idx) - if len(lfs) == 0: return + if not lfs: return labeled_frame = lfs[0] count_no_track = 0 for i, instance in enumerate(labeled_frame.instances_to_show): - if instance.track in self.labels.tracks: + if instance.track in labels.tracks: pseudo_track = instance.track else: # Instance without track - pseudo_track = len(self.labels.tracks) + count_no_track + pseudo_track = len(labels.tracks) + count_no_track count_no_track += 1 # Plot instance inst = QtInstance(instance=instance, - color=color_manager(pseudo_track), + color=color_manager.get_color(pseudo_track), predicted=fixed, color_predicted=True, show_non_visible=False) From 2fb086947a3913a77f64fa21356f0ba57cc0c617 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 14:11:31 -0400 Subject: [PATCH 101/176] Methods for complex merging of labels. Merge frames that can be merged cleanly, return conflicts as well as data about merge, plus method for finishes merge (after you resolve conflicts). --- sleap/instance.py | 85 ++++++++++++++++++++++++++++++++ sleap/io/dataset.py | 104 +++++++++++++++++++++++++++++++++++++++ tests/io/test_dataset.py | 69 ++++++++++++++++++++++++-- 3 files changed, 255 insertions(+), 3 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index eb7392781..0666203a5 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -9,6 +9,7 @@ import pandas as pd import cattr +from copy import copy from typing import Dict, List, Optional, Union, Tuple from numpy.lib.recfunctions import structured_to_unstructured @@ -913,3 +914,87 @@ def merge_frames(labeled_frames, video, remove_redundant=True): print(f"skipped {redundant_count} redundant instances") return labeled_frames + @classmethod + def complex_merge_between(cls, base_labels: 'Labels', new_frames: List['LabeledFrame']): + """Merge new_frames into base_labels cleanly when possible, + return conflicts if any. + + Args: + base_labels + new_frames + Returns: + tuple of three items: + * dict with {video: list (per frame) of list of merged instances + * list of conflicting instances in base + * list of conflicting instances in new_frames + """ + merged = dict() + extra_base = [] + extra_new = [] + + for new_frame in new_frames: + base_lfs = base_labels.find(new_frame.video, new_frame.frame_idx) + merged_instances = None + + if not base_lfs: + base_labels.labeled_frames.append(new_frame) + merged_instances = new_frame.instances + else: + merged_instances, extra_base_frame, extra_new_frame = \ + cls.complex_frame_merge(base_lfs[0], new_frame) + if extra_base_frame: + extra_base.append(extra_base_frame) + if extra_new_frame: + extra_new.append(extra_new_frame) + + if merged_instances: + if new_frame.video not in merged: + merged[new_frame.video] = [] + merged[new_frame.video].append(merged_instances) + return merged, extra_base, extra_new + + @classmethod + def complex_frame_merge(cls, base_frame, new_frame): + """Merge two frames, return conflicts if any.""" + merged_instances = [] + redundant_instances = [] + extra_base_instances = copy(base_frame.instances) + extra_new_instances = [] + + for new_inst in new_frame: + redundant = False + for base_inst in base_frame.instances: + if new_inst.matches(base_inst): + base_inst.frame = None + extra_base_instances.remove(base_inst) + redundant_instances.append(base_inst) + redundant = True + continue + if not redundant: + new_inst.frame = None + extra_new_instances.append(new_inst) + + if extra_base_instances and extra_new_instances: + # Conflict, so update base to just include non-conflicting + # instances (perfect matches) + base_frame.instances.clear() + base_frame.instances.extend(redundant_instances) + else: + # No conflict, so include all instances in base + base_frame.instances.extend(extra_new_instances) + merged_instances = copy(extra_new_instances) + extra_base_instances = [] + extra_new_instances = [] + + # Construct frames to hold any conflicting instances + extra_base = cls( + video=base_frame.video, + frame_idx=base_frame.frame_idx, + instances=extra_base_instances) if extra_base_instances else None + + extra_new = cls( + video=new_frame.video, + frame_idx=new_frame.frame_idx, + instances=extra_new_instances) if extra_new_instances else None + + return merged_instances, extra_base, extra_new \ No newline at end of file diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 71e9b2640..c3e5c6c3b 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -707,6 +707,110 @@ def extend_from(self, new_frames: Union['Labels',List[LabeledFrame]], unify:bool return True + @classmethod + def complex_merge_between(cls, base_labels: 'Labels', new_labels: 'Labels', unify:bool = True) -> tuple: + """ + Merge frames and other data that can be merged cleanly, + and return frames that conflict. + + Anything that can be merged cleanly is merged into base_labels. + + Frames conflict just in case each labels object has a matching + frame (same video and frame idx) which instances not in the other. + + Frames can be merged cleanly if + - the frame is in only one of the labels, or + - the frame is in both labels, but all instances perfectly match + (which means they are redundant), or + - the frame is in both labels, maybe there are some redundant + instances, but only one version of the frame has additional + instances not in the other. + + Args: + base_labels: the `Labels` that we're merging into + new_labels: the `Labels` that we're merging from + unify: whether to replace objects (e.g., `Video`s) in + new_labels with *matching* objects from base + + Returns: + tuple of two lists of `LabeledFrame`s + * data from base that conflicts + * data from new that conflicts + """ + # If unify, we want to replace objects in the frames with + # corresponding objects from the current labels. + # We do this by deserializing/serializing with match_to. + if unify: + new_json = new_labels.to_dict() + new_labels = cls.from_json(new_json, match_to=base_labels) + + # Merge anything that can be merged cleanly and get conflicts + merged, extra_base, extra_new = \ + LabeledFrame.complex_merge_between( + base_labels=base_labels, + new_frames=new_labels.labeled_frames) + + # For clean merge, finish merge now by cleaning up base object + if not extra_base and not extra_new: + # Add any new videos (etc) into top level lists in base + base_labels._update_from_labels(merge=True) + # Update caches + base_labels._update_lookup_cache() + + # Merge suggestions and negative anchors + cls.merge_container_dicts(base_labels.suggestions, new_labels.suggestions) + cls.merge_container_dicts(base_labels.negative_anchors, new_labels.negative_anchors) + + return merged, extra_base, extra_new + +# @classmethod +# def merge_predictions_by_score(cls, extra_base: List[LabeledFrame], extra_new: List[LabeledFrame]): +# """ +# Remove all predictions from input lists, return list with only +# the merged predictions. +# +# Args: +# extra_base: list of `LabeledFrame`s +# extra_new: list of `LabeledFrame`s +# Conflicting frames should have same index in both lists. +# Returns: +# list of `LabeledFrame`s with merged predictions +# """ +# pass + + @staticmethod + def finish_complex_merge(base_labels: 'Labels', resolved_frames: List[LabeledFrame]): + """ + Finish conflicted merge from complex_merge_between. + + Args: + base_labels: the `Labels` that we're merging into + resolved_frames: the list of frames to add into base_labels + Returns: + None. + """ + # Add all the resolved frames to base + base_labels.labeled_frames.extend(resolved_frames) + + # Combine instances when there are two LabeledFrames for same + # video and frame index + base_labels.merge_matching_frames() + + # Add any new videos (etc) into top level lists in base + base_labels._update_from_labels(merge=True) + # Update caches + base_labels._update_lookup_cache() + + @staticmethod + def merge_container_dicts(dict_a, dict_b): + """Merge data from dict_b into dict_a.""" + for key in dict_b.keys(): + if key in dict_a: + dict_a[key].extend(dict_b[key]) + uniquify(dict_a[key]) + else: + dict_a[key] = dict_b[key] + def merge_matching_frames(self, video=None): """ Combine all instances from LabeledFrames that have same frame_idx. diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 2daaf661c..dd0d83c4e 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -260,22 +260,85 @@ def test_label_mutability(): labels.remove_video(dummy_video) assert len(labels.find(dummy_video)) == 0 - dummy_frames3 = [] +def test_labels_merge(): + dummy_video = Video(backend=MediaVideo) + dummy_skeleton = Skeleton() dummy_skeleton.add_node("node") + labels = Labels() + dummy_frames = [] + # Add 10 instances with different points (so they aren't "redundant") for i in range(10): instance = Instance(skeleton=dummy_skeleton, points=dict(node=Point(i,i))) dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[instance,]) - dummy_frames3.append(dummy_frame) + dummy_frames.append(dummy_frame) - labels.labeled_frames.extend(dummy_frames3) + labels.labeled_frames.extend(dummy_frames) assert len(labels) == 10 assert len(labels.labeled_frames[0].instances) == 1 + labels.merge_matching_frames() assert len(labels) == 1 assert len(labels.labeled_frames[0].instances) == 10 +def test_complex_merge(): + dummy_video_a = Video.from_filename("foo.mp4") + dummy_video_b = Video.from_filename("foo.mp4") + + dummy_skeleton_a = Skeleton() + dummy_skeleton_a.add_node("node") + + dummy_skeleton_b = Skeleton() + dummy_skeleton_b.add_node("node") + + dummy_instances_a = [] + dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1,1)))) + dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2,2)))) + + labels_a = Labels() + labels_a.append(LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a)) + + dummy_instances_b = [] + dummy_instances_b.append(Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1,1)))) + dummy_instances_b.append(Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(3,3)))) + + labels_b = Labels() + labels_b.append(LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b)) # conflict + labels_b.append(LabeledFrame(dummy_video_b, frame_idx=1, instances=dummy_instances_b)) # clean + + merged, extra_a, extra_b = Labels.complex_merge_between(labels_a, labels_b) + + # Check that we have the cleanly merged frame + assert dummy_video_a in merged + assert len(merged[dummy_video_a]) == 1 # one merged frame + assert len(merged[dummy_video_a][0]) == 2 # with two instances + + # Check that labels_a includes redundant and clean + assert len(labels_a.labeled_frames) == 2 + assert len(labels_a.labeled_frames[0].instances) == 1 + assert labels_a.labeled_frames[0].instances[0].points[0].x == 1 + assert len(labels_a.labeled_frames[1].instances) == 2 + assert labels_a.labeled_frames[1].instances[0].points[0].x == 1 + assert labels_a.labeled_frames[1].instances[1].points[0].x == 3 + + # Check that extra_a/b includes the appropriate conflicting instance + assert len(extra_a) == 1 + assert len(extra_b) == 1 + assert len(extra_a[0].instances) == 1 + assert len(extra_b[0].instances) == 1 + assert extra_a[0].instances[0].points[0].x == 2 + assert extra_b[0].instances[0].points[0].x == 3 + + # Check that objects were unified + assert extra_a[0].video == extra_b[0].video + + # Check resolving the conflict using new + Labels.finish_complex_merge(labels_a, extra_b) + assert len(labels_a.labeled_frames) == 2 + assert len(labels_a.labeled_frames[0].instances) == 2 + assert labels_a.labeled_frames[0].instances[1].points[0].x == 3 + def skeleton_ids_from_label_instances(labels): return list(map(id, (lf.instances[0].skeleton for lf in labels.labeled_frames))) From e4633203393a1b6f2c4265a9d9633b6fe68ed698 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 20:40:24 -0400 Subject: [PATCH 102/176] Merge conflict only if instances are same type. If there are redundant instances, predicted instances in one file and non-predicted instances in the other, then we can merge without conflict. --- sleap/instance.py | 17 +++++++++++++++++ tests/io/test_dataset.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/sleap/instance.py b/sleap/instance.py index 0666203a5..cd18b3e1b 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -974,7 +974,24 @@ def complex_frame_merge(cls, base_frame, new_frame): new_inst.frame = None extra_new_instances.append(new_inst) + conflict = False if extra_base_instances and extra_new_instances: + base_predictions = list(filter(lambda inst: hasattr(inst, "score"), extra_base_instances)) + new_predictions = list(filter(lambda inst: hasattr(inst, "score"), extra_new_instances)) + + base_has_nonpred = len(extra_base_instances) - len(base_predictions) + new_has_nonpred = len(extra_new_instances) - len(new_predictions) + + # If they both have some predictions or they both have some + # non-predictions, then there is a conflict. + # (Otherwise it's not a conflict since we can cleanly merge + # all the predicted instances with all the non-predicted.) + if base_predictions and new_predictions: + conflict = True + elif base_has_nonpred and new_has_nonpred: + conflict = True + + if conflict: # Conflict, so update base to just include non-conflicting # instances (perfect matches) base_frame.instances.clear() diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index dd0d83c4e..210a432fe 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -339,6 +339,38 @@ def test_complex_merge(): assert len(labels_a.labeled_frames[0].instances) == 2 assert labels_a.labeled_frames[0].instances[1].points[0].x == 3 +def test_merge_predictions(): + dummy_video_a = Video.from_filename("foo.mp4") + dummy_video_b = Video.from_filename("foo.mp4") + + dummy_skeleton_a = Skeleton() + dummy_skeleton_a.add_node("node") + + dummy_skeleton_b = Skeleton() + dummy_skeleton_b.add_node("node") + + dummy_instances_a = [] + dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1,1)))) + dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2,2)))) + + labels_a = Labels() + labels_a.append(LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a)) + + dummy_instances_b = [] + dummy_instances_b.append(Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1,1)))) + dummy_instances_b.append(PredictedInstance(skeleton=dummy_skeleton_b, points=dict(node=Point(3,3)), score=1)) + + labels_b = Labels() + labels_b.append(LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b)) + + # Frames have one redundant instance (perfect match) and all the + # non-matching instances are different types (one predicted, one not). + merged, extra_a, extra_b = Labels.complex_merge_between(labels_a, labels_b) + assert len(merged[dummy_video_a]) == 1 + assert len(merged[dummy_video_a][0]) == 1 # the predicted instance was merged + assert not extra_a + assert not extra_b + def skeleton_ids_from_label_instances(labels): return list(map(id, (lf.instances[0].skeleton for lf in labels.labeled_frames))) From df5f21565a306800350341544a9390bcb01d4f76 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 24 Sep 2019 20:45:19 -0400 Subject: [PATCH 103/176] Gui for merging w/ conflict resolution. --- sleap/gui/app.py | 9 +-- sleap/gui/merge.py | 147 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 6 deletions(-) create mode 100644 sleap/gui/merge.py diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 167584130..db33a9bc9 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -33,6 +33,7 @@ LabeledFrameTable, SkeletonNodeModel, SuggestionsTable from sleap.gui.importvideos import ImportVideos from sleap.gui.formbuilder import YamlFormWidget +from sleap.gui.merge import MergeDialog from sleap.gui.shortcuts import Shortcuts, ShortcutDialog from sleap.gui.suggestions import VideoFrameSuggestions @@ -1110,14 +1111,10 @@ def importPredictions(self): new_labels = Labels.load_file( filename, - match_to=self.labels, video_callback=gui_video_callback) - self.labels.extend_from(new_labels) - - for vid in new_labels.videos: - print(f"Labels imported for {vid.filename}") - print(f" frames labeled: {len(new_labels.find(vid))}") + # Merging data is handled by MergeDialog + MergeDialog(base_labels = self.labels, new_labels = new_labels).exec_() # update display/ui self.plotFrame() diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py new file mode 100644 index 000000000..7d0cffd40 --- /dev/null +++ b/sleap/gui/merge.py @@ -0,0 +1,147 @@ +""" +Gui for merging two labels files with options to resolve conflicts. +""" + +import attr + +from typing import List + +from sleap.instance import LabeledFrame +from sleap.io.dataset import Labels + +from PySide2 import QtWidgets, QtCore + +class MergeDialog(QtWidgets.QDialog): + + def __init__(self, + base_labels: Labels, + new_labels: Labels, + *args, **kwargs): + + super(MergeDialog, self).__init__(*args, **kwargs) + + self.base_labels = base_labels + self.new_labels = new_labels + + merged, self.extra_base, self.extra_new = \ + Labels.complex_merge_between(self.base_labels, self.new_labels) + + merge_total = 0 + merge_frames = 0 + for vid_frame_list in merged.values(): + # number of frames for this video + merge_frames += len(vid_frame_list) + # number of instances across frames for this video + merge_total += sum((map(len, vid_frame_list))) + + buttons = self._make_buttons(conflict=self.extra_base) + + merged_label = QtWidgets.QLabel(f"Cleanly merged {merge_total} instances across {merge_frames} frames.") + + conflict_text = "There are no conflicts." if not self.extra_base else "Merge conflicts:" + conflict_label = QtWidgets.QLabel(conflict_text) + + layout = QtWidgets.QVBoxLayout() + layout.addWidget(merged_label) + + layout.addWidget(conflict_label) + if self.extra_base: + conflict_table = ConflictTable(self.base_labels, self.extra_base, self.extra_new) + layout.addWidget(conflict_table) + + layout.addWidget(buttons) + + self.setLayout(layout) + + def _make_buttons(self, conflict: bool): + self.use_base_button = None + self.use_new_button = None + self.okay_button = None + + buttons = QtWidgets.QDialogButtonBox() + if conflict: + self.use_base_button = buttons.addButton("Use Base", QtWidgets.QDialogButtonBox.YesRole) + self.use_new_button = buttons.addButton("Use New", QtWidgets.QDialogButtonBox.NoRole) + else: + self.okay_button = buttons.addButton(QtWidgets.QDialogButtonBox.Ok) + + buttons.clicked.connect(self.finishMerge) + + return buttons + + def finishMerge(self, button): + if button == self.use_base_button: + Labels.finish_complex_merge(self.base_labels, self.extra_base) + elif button == self.use_new_button: + Labels.finish_complex_merge(self.base_labels, self.extra_new) + elif button == self.okay_button: + Labels.finish_complex_merge(self.base_labels, []) + + self.accept() + +class ConflictTable(QtWidgets.QTableView): + def __init__(self, *args, **kwargs): + super(ConflictTable, self).__init__() + self.setModel(ConflictTableModel(*args, **kwargs)) + +class ConflictTableModel(QtCore.QAbstractTableModel): + _props = ["video", "frame", "base", "new"] + + def __init__(self, + base_labels: Labels, + extra_base: List[LabeledFrame], + extra_new: List[LabeledFrame]): + super(ConflictTableModel, self).__init__() + self.base_labels = base_labels + self.extra_base = extra_base + self.extra_new = extra_new + + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + if role == QtCore.Qt.DisplayRole and index.isValid(): + idx = index.row() + prop = self._props[index.column()] + + if idx < self.rowCount(): + if prop == "video": + return self.extra_base[idx].video.filename + if prop == "frame": + return self.extra_base[idx].frame_idx + if prop == "base": + return self._showInstanceCount(self.extra_base[idx]) + if prop == "new": + return self._showInstanceCount(self.extra_new[idx]) + + return None + + @staticmethod + def _showInstanceCount(instance_list): + prediction_count = len(list(filter(lambda inst: hasattr(inst, "score"), instance_list))) + user_count = len(instance_list) - prediction_count + return f"{prediction_count}/{user_count}" + + def rowCount(self, *args): + return len(self.extra_base) + + def columnCount(self, *args): + return len(self._props) + + def headerData(self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole): + if role == QtCore.Qt.DisplayRole: + if orientation == QtCore.Qt.Horizontal: + return self._props[section] + elif orientation == QtCore.Qt.Vertical: + return section + return None + +if __name__ == "__main__": + + file_a = "tests/data/json_format_v1/centered_pair.json" + file_b = "tests/data/json_format_v2/centered_pair_predictions.json" + + base_labels = Labels.load_file(file_a) + new_labels = Labels.load_file(file_b) + + app = QtWidgets.QApplication() + win = MergeDialog(base_labels, new_labels) + win.show() + app.exec_() \ No newline at end of file From 7e9441459f18af44851a55d8128f313c2b151c8d Mon Sep 17 00:00:00 2001 From: Talmo Date: Wed, 25 Sep 2019 02:41:15 -0400 Subject: [PATCH 104/176] Major refactoring of inference module - Created a InferenceModel class that holds all of the metadata required to run models from a TrainingJob. This includes input/output scale, skeletons, Keras model loading and caching, variable input shape, normalization and other convenience properties. - Added some generic attributes to the sleap.nn.Model class to allow for storing or inferring output scaling (depends on architecture details, so it can be overloaded in backbones). - sleap.nn.loadmodel module removed. - Predictor refactored to take advantage of InferenceModel caching and reuse multiprocessing pool (fixes: #155). - Previously, models were being reloaded from disk and copies made every time a new crop size was encountered. - Previously, a new pool was being opened on every predict() call. - Refactored DataTransform usage to account for discrepancies between input scale vs output scale vs relative output scale vs scale across models. - Fix: DataTransform-based scaling would always convert images to float64 dtype. - Refactored TensorFlow-based peak finding for GPU execution. - Account for unspecified confmaps tensor at graph creation time. - Fix subpixel refinement algorithm to enable greater resolution than original data scale. - General rework of offset adjustment. - Exposed supersampling parameters (window size, window batching) - Fix tensors being overwritten or dropping out of context - Fixes: #149 - Exposed peak finding options and other inference functionality as Predictor attributes. - Renamed: Instance.visible_points_array -> Instance.points_array - Sensible cropping defaults in profiles and function kwargs. - Indented default profiles for readability. - General unfucking of syntax errors, anti-patterns, typos, etc. --- sleap/gui/active.py | 6 +- sleap/gui/app.py | 2 +- sleap/gui/overlays/base.py | 1 + sleap/gui/slider.py | 4 +- sleap/info/metrics.py | 10 +- sleap/info/write_tracking_h5.py | 2 +- sleap/instance.py | 10 +- sleap/nn/architectures/resnet.py | 11 +- sleap/nn/datagen.py | 8 +- sleap/nn/inference.py | 710 ++++++++++++------ sleap/nn/loadmodel.py | 147 ---- sleap/nn/model.py | 16 + sleap/nn/peakfinding_tf.py | 163 ++-- sleap/nn/peakmatching.py | 25 +- sleap/nn/training.py | 4 +- sleap/nn/transform.py | 2 +- .../training_profiles/default_centroids.json | 49 +- sleap/training_profiles/default_confmaps.json | 51 +- sleap/training_profiles/default_pafs.json | 49 +- tests/test_instance.py | 2 +- 20 files changed, 771 insertions(+), 501 deletions(-) delete mode 100644 sleap/nn/loadmodel.py diff --git a/sleap/gui/active.py b/sleap/gui/active.py index 21a51924c..ba31d8682 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -738,9 +738,9 @@ def run_active_inference( inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5") # Create Predictor from the results of training - predictor = Predictor(sleap_models=training_jobs, - with_tracking=with_tracking, - output_path=inference_output_path) + predictor = Predictor(training_jobs=training_jobs, + with_tracking=with_tracking, + output_path=inference_output_path) if gui: # show message while running inference diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 64d7d1a36..727f6897c 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -959,7 +959,7 @@ def delete_area_callback(x0, y0, x1, y1): max_corner = (x1, y1) def is_bounded(inst): - points_array = inst.visible_points_array + points_array = inst.points_array valid_points = points_array[~np.isnan(points_array).any(axis=1)] is_gt_min = np.all(valid_points >= min_corner) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 694ba796f..78a1672b8 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -18,6 +18,7 @@ def __getitem__(self, i): @attr.s(auto_attribs=True) class ModelData: + # TODO: Unify this class with inference.Predictor or InferenceModel model: 'keras.Model' video: Video do_rescale: bool=False diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 4fdb642cf..970221ff5 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -199,8 +199,8 @@ def inst_velocity(lf, last_lf): if last_lf is not None: last_inst = last_lf.find(track=inst.track) if last_inst: - points_a = inst.visible_points_array - points_b = last_inst[0].visible_points_array + points_a = inst.points_array + points_b = last_inst[0].points_array point_dist = np.linalg.norm(points_a - points_b, axis=1) inst_dist = np.sum(point_dist) # np.nanmean(point_dist) val += inst_dist if not np.isnan(inst_dist) else 0 diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index ee4e958c7..ca155076c 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -158,8 +158,8 @@ def point_dist( inst_b: Union[Instance, PredictedInstance]) -> np.ndarray: """Given two instances, returns array of distances for corresponding nodes.""" - points_a = inst_a.visible_points_array - points_b = inst_b.visible_points_array + points_a = inst_a.points_array + points_b = inst_b.points_array point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist @@ -171,8 +171,8 @@ def nodeless_point_dist(inst_a: Union[Instance, PredictedInstance], matrix_size = (len(inst_a.skeleton.nodes), len(inst_b.skeleton.nodes)) pairwise_distance_matrix = np.full(matrix_size, 0) - points_a = inst_a.visible_points_array - points_b = inst_b.visible_points_array + points_a = inst_a.points_array + points_b = inst_b.points_array # Calculate the distance between any pair of inst A and inst B points for idx_a in range(points_a.shape[0]): @@ -205,7 +205,7 @@ def compare_instance_lists( def list_points_array(instances: List[Union[Instance, PredictedInstance]]) -> np.ndarray: """Given list of Instances, returns (instances * nodes * 2) matrix.""" - points_arrays = list(map(lambda inst: inst.visible_points_array, instances)) + points_arrays = list(map(lambda inst: inst.points_array, instances)) return np.stack(points_arrays) def point_match_count(dist_array: np.ndarray, thresh: float=5) -> int: diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 0043911b6..ea03b6dd3 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -37,7 +37,7 @@ occupancy_matrix[track_i, frame_i] = 1 - inst_points = inst.visible_points_array + inst_points = inst.points_array prediction_matrix[frame_i, ..., track_i] = inst_points occupied_track_mask = np.sum(occupancy_matrix, axis=1) > 0 diff --git a/sleap/instance.py b/sleap/instance.py index eb7392781..448c31184 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -609,13 +609,13 @@ def get_points_array(self, copy: bool = True, return parray @property - def visible_points_array(self) -> np.ndarray: + def points_array(self) -> np.ndarray: return self.get_points_array(invisible_as_nan=True) @property def centroid(self) -> np.ndarray: """Returns instance centroid as (x,y) numpy row vector.""" - points = self.visible_points_array + points = self.points_array centroid = np.nanmedian(points, axis=0) return centroid @@ -836,7 +836,11 @@ def instances(self, instances: List[Instance]): @property def user_instances(self): - return [inst for inst in self._instances if type(inst) == Instance] + return [inst for inst in self._instances if not isinstance(inst, PredictedInstance)] + + @property + def predicted_instances(self): + return [inst for inst in self._instances if isinstance(inst, PredictedInstance)] @property def has_user_instances(self): diff --git a/sleap/nn/architectures/resnet.py b/sleap/nn/architectures/resnet.py index 096a0b43e..cef678071 100644 --- a/sleap/nn/architectures/resnet.py +++ b/sleap/nn/architectures/resnet.py @@ -43,12 +43,19 @@ def output(self, x_in, num_output_channels): """ return resnet50(x_in, num_output_channels, **attr.asdict(self)) + @property + def down_blocks(self): + """Returns the number of downsampling steps in the model.""" + + # This is a fixed constant for ResNet50. + return 5 + + @property def output_scale(self): """Returns relative scaling factor of this backbone.""" - down_blocks = 5 - return (1 / (2 ** (down_blocks - self.up_blocks))) + return (1 / (2 ** (self.down_blocks - self.up_blocks))) def preprocess_input(X): diff --git a/sleap/nn/datagen.py b/sleap/nn/datagen.py index 78a550938..cdb38202c 100644 --- a/sleap/nn/datagen.py +++ b/sleap/nn/datagen.py @@ -142,7 +142,7 @@ def generate_points_from_list(labels:Labels, frame_list: List[Tuple], scale: flo def lf_points_from_singleton(lf_singleton): if len(lf_singleton) == 0: return [] lf = lf_singleton[0] - points = [inst.visible_points_array*scale + points = [inst.points_array*scale for inst in lf.user_instances] return points @@ -491,6 +491,9 @@ def _bb_pad_shape(bbs, min_crop_size, img_shape): Returns: (size, size) tuple """ + + # TODO: Holy hardcoded fuck Batman! This really needs to get cleaned up + # Find a nicely sized box that's large enough to bound all instances max_height = max((y1 - y0 for (x0, y0, x1, y1) in bbs)) max_width = max((x1 - x0 for (x0, y0, x1, y1) in bbs)) @@ -596,7 +599,8 @@ def pad_box_to_multiple(box, pad_factor_box, within): pad_h, pad_w = pad_factor_box # Find multiple of pad_factor_box that's large enough to hold box - multiple_h, multiple_w = ceil(box_h / pad_h), ceil(box_w / pad_w) + multiple_h = ceil(box_h / pad_h) + multiple_w = ceil(box_w / pad_w) # Maintain aspect ratio multiple = max(multiple_h, multiple_w) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 10d17250a..bc666774a 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -18,9 +18,7 @@ from multiprocessing.pool import AsyncResult, ThreadPool from time import time, clock -from typing import Dict, List, Union, Optional, Tuple - -from keras.utils import multi_gpu_model +from typing import Any, Dict, List, Union, Optional, Text, Tuple from sleap.instance import LabeledFrame from sleap.io.dataset import Labels @@ -34,13 +32,230 @@ from sleap.nn.transform import DataTransform from sleap.nn.datagen import merge_boxes_with_overlap_and_padding -from sleap.nn.loadmodel import load_model, get_model_data, get_model_skeleton from sleap.nn.peakfinding import find_all_peaks, find_all_single_peaks from sleap.nn.peakfinding_tf import peak_tf_inference from sleap.nn.peakmatching import match_single_peaks_all, match_peaks_paf, match_peaks_paf_par, instances_nms from sleap.nn.util import batch, batch_count, save_visual_outputs -OVERLAPPING_INSTANCES_NMS = True + +@attr.s(auto_attribs=True) +class InferenceModel: + """This class provides convenience metadata and methods for running inference from a TrainingJob.""" + + job: TrainingJob + _keras_model: keras.Model = None + _model_path: Text = None + _trained_input_shape: Tuple[int] = None + _output_channels: int = None + + @property + def skeleton(self) -> Skeleton: + """Returns the skeleton associated with this model.""" + + return self.job.model.skeletons[0] + + + @property + def output_type(self) -> ModelOutputType: + """Returns the output type of this model.""" + + return self.job.model.output_type + + @property + def input_scale(self) -> float: + """Returns the scale of the images that the model was trained on.""" + + return self.job.trainer.scale + + @property + def output_scale(self) -> float: + """Returns the scale of the outputs of the model relative to the original data. + + For a model trained on inputs with scale = 0.5 that outputs predictions that + are half of the size of the inputs, the output scale is 0.25. + """ + return self.input_scale * self.job.model.output_scale + + @property + def output_relative_scale(self) -> float: + """Returns the scale of the outputs relative to the scaled inputs. + + This differs from output_scale in that it is the scaling factor after + applying the input scaling. + """ + + return self.job.model.output_scale + + def compute_output_shape(self, input_shape: Tuple[int], relative=True) -> Tuple[int]: + """Returns the output tensor shape for a given input shape. + + Args: + input_shape: Shape of input images in the form (height, width). + relative: If True, input_shape specifies the shape after input scaling. + + Returns: + A tuple of (height, width, channels) of the output of the model. + """ + + # TODO: Support multi-input/multi-output models. + + scaling_factor = self.output_scale + if relative: + scaling_factor = self.output_relative_scale + + output_shape = ( + int(input_shape[0] * scaling_factor), + int(input_shape[1] * scaling_factor), + self.output_channels) + + return output_shape + + + def load_model(self, model_path: Text = None) -> keras.Model: + """Loads a saved model from disk and caches it. + + Args: + model_path: If not provided, uses the model + paths in the training job. + + Returns: + The loaded Keras model. This model can accept any size + of inputs that are valid. + """ + + if not model_path: + # Try the best model first. + model_path = os.path.join(self.job.save_dir, + self.job.best_model_filename) + + # Try the final model if that didn't exist. + if not os.path.exists(model_path): + model_path = os.path.join(self.job.save_dir, + self.job.final_model_filename) + + # Load from disk. + keras_model = keras.models.load_model(model_path, + custom_objects={"tf": tf}) + logger.info("Loaded model: " + model_path) + + # Store the loaded model path for reference. + self._model_path = model_path + + # TODO: Multi-input/output support + # Find the original data shape from the input shape of the first input node. + self._trained_input_shape = keras_model.get_input_shape_at(0) + + # Save output channels since that should be static. + self._output_channels = keras_model.get_output_shape_at(0)[-1] + + # Create input node with undetermined height/width. + input_tensor = keras.layers.Input((None, None, self.input_channels)) + keras_model = keras.Model( + inputs=input_tensor, + outputs=keras_model(input_tensor)) + + + # Save the modified and loaded model. + self._keras_model = keras_model + + return self.keras_model + + + @property + def keras_model(self) -> keras.Model: + """Returns the underlying Keras model, loading it if necessary.""" + + if self._keras_model is None: + self.load_model() + + return self._keras_model + + + @property + def model_path(self) -> Text: + """Returns the path to the loaded model.""" + + if not self._model_path: + raise AttributeError("No model loaded. Call inference_model.load_model() first.") + + return self._model_path + + + @property + def trained_input_shape(self) -> Tuple[int]: + """Returns the shape of the model when it was loaded.""" + + if not self._trained_input_shape: + raise AttributeError("No model loaded. Call inference_model.load_model() first.") + + return self._trained_input_shape + + @property + def output_channels(self) -> int: + """Returns the number of output channels of the model.""" + if not self._trained_input_shape: + raise AttributeError("No model loaded. Call inference_model.load_model() first.") + + return self._output_channels + + + @property + def input_channels(self) -> int: + """Returns the number of channels expected for the input data.""" + + # TODO: Multi-output support + return self.trained_input_shape[-1] + + + @property + def is_grayscale(self) -> bool: + """Returns True if the model expects grayscale images.""" + + return self.input_channels == 1 + + + @property + def down_blocks(self): + """Returns the number of pooling steps applied during the model. + + Data needs to be of a shape divisible by the number of pooling steps. + """ + + # TODO: Replace this with an explicit calculation that takes stride sizes into account. + return self.job.model.down_blocks + + + def predict(self, X: Union[np.ndarray, List[np.ndarray]], + batch_size: int = 32, + normalize: bool = True + ) -> Union[np.ndarray, List[np.ndarray]]: + """Runs inference on the input data. + + This is a simple wrapper around the keras model predict function. + + Args: + X: The inputs to provide to the model. Can be different height/width as + the data it was trained on. + batch_size: Batch size to perform inference on at a time. + normalize: Applies normalization to the input data if needed + (e.g., if casting or range normalization is required). + + Returns: + The outputs of the model. + """ + + if normalize: + # TODO: Store normalization scheme in the model metadata. + if isinstance(X, np.ndarray): + if X.dtype == np.dtype("uint8"): + X = X.astype("float32") / 255. + elif isinstance(X, list): + for i in range(len(X)): + if X[i].dtype == np.dtype("uint8"): + X[i] = X[i].astype("float32") / 255. + + return self.keras_model.predict(X, batch_size=batch_size) + @attr.s(auto_attribs=True) class Predictor: @@ -85,7 +300,9 @@ class Predictor: resize_hack: whether to resize images to power of 2 """ - sleap_models: Dict[ModelOutputType, TrainingJob] = None + training_jobs: Dict[ModelOutputType, TrainingJob] = None + inference_models: Dict[ModelOutputType, InferenceModel] = attr.ib(default=attr.Factory(dict)) + skeleton: Skeleton = None inference_batch_size: int = 2 read_chunk_size: int = 256 @@ -106,15 +323,26 @@ class Predictor: output_path: Optional[str] = None save_confmaps_pafs: bool = False resize_hack: bool = True + pool: multiprocessing.Pool = None + + gpu_peak_finding: bool = True + supersample_window_size: int = 7 # must be odd + supersample_factor: float = 2 # factor to upsample cropped windows by + overlapping_instances_nms: bool = False # suppress overlapping instances + + def __attrs_post_init__(self): + + # Create inference models from the TrainingJob metadata. + for model_output_type, training_job in self.training_jobs.items(): + self.inference_models[model_output_type] = InferenceModel(job=training_job) + self.inference_models[model_output_type].load_model() - _models: Dict = attr.ib(default=attr.Factory(dict)) def predict(self, input_video: Union[dict, Video], frames: Optional[List[int]] = None, is_async: bool = False) -> List[LabeledFrame]: - """ - Run the entire inference pipeline on an input video. + """Run the entire inference pipeline on an input video. Args: input_video: Either a `Video` object or dict that can be @@ -126,13 +354,19 @@ def predict(self, children. Returns: - list of LabeledFrame objects + A list of LabeledFrames with predicted instances. """ + # Check if we have models. + if len(self.inference_models) == 0: + logger.warning("Predictor has no model.") + raise ValueError("Predictor has no model.") + self.is_async = is_async - # Initialize parallel pool - self.pool = None if self.is_async else multiprocessing.Pool(processes=usable_cpu_count()) + # Initialize parallel pool if needed. + if not is_async and self.pool is None: + self.pool = multiprocessing.Pool(processes=usable_cpu_count()) # Fix the number of threads for OpenCV, not that we are using # anything in OpenCV that is actually multi-threaded but maybe @@ -141,19 +375,11 @@ def predict(self, logger.info(f"Predict is async: {is_async}") - # Find out how many channels the model was trained on - - model_channels = 3 # default - - if ModelOutputType.CENTROIDS in self.sleap_models: - centroid_model = self.fetch_model( - input_size = None, - output_types = [ModelOutputType.CENTROIDS]) - model_channels = centroid_model["model"].input_shape[-1] + # Find out if the images should be grayscale from the first model. + # TODO: Unify this with input data normalization. + grayscale = list(self.inference_models.values())[0].is_grayscale - grayscale = (model_channels == 1) - - # Open the video if we need it. + # Open the video object if needed. if isinstance(input_video, Video): vid = input_video elif isinstance(input_video, dict): @@ -165,19 +391,11 @@ def predict(self, # List of frames to process (or entire video if not specified) frames = frames or list(range(vid.num_frames)) - - vid_h = vid.shape[1] - vid_w = vid.shape[2] - logger.info("Opened video:") logger.info(" Source: " + str(vid.backend)) logger.info(" Frames: %d" % len(frames)) - logger.info(" Frame shape: %d x %d" % (vid_h, vid_w)) + logger.info(" Frame shape (H x W): %d x %d" % (vid.height, vid.width)) - # Check training models - if len(self.sleap_models) == 0: - logger.warning("Predictor has no model.") - raise ValueError("Predictor has no model.") # Initialize tracking if self.with_tracking: @@ -194,8 +412,7 @@ def predict(self, # Process chunk-by-chunk! t0_start = time() - predicted_frames: List[LabeledFrame] = [] - + predicted_frames = [] num_chunks = batch_count(frames, self.read_chunk_size) logger.info("Number of chunks for process: %d" % (num_chunks)) @@ -220,43 +437,41 @@ def predict(self, # Read the next batch of images t0 = time() - mov_full = vid[frames_idx] - logger.info(" Read %d frames [%.1fs]" % (len(mov_full), time() - t0)) + imgs_full = vid[frames_idx] + logger.info(" Read %d frames [%.1fs]" % (len(imgs_full), time() - t0)) # Transform images (crop or scale) t0 = time() - if ModelOutputType.CENTROIDS in self.sleap_models: - # Use centroid predictions to get subchunks of crops + if ModelOutputType.CENTROIDS in self.inference_models: + # Use centroid predictions to get subchunks of crops. subchunks_to_process = self.centroid_crop_inference( - mov_full, frames_idx) + imgs_full, frames_idx) else: # Scale without centroid cropping - - # Get the scale that was used when training models - model_data = get_model_data(self.sleap_models, [ModelOutputType.CONFIDENCE_MAP]) - scale = model_data["scale"] + # TODO: Move this into the processing below to allow for different input scales by model. # Determine scaled image size - scale_to = (int(vid.height//(1/scale)), int(vid.width//(1/scale))) + # cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] + # input_scale = cm_model.input_scale + # scale_to = (int(vid.height // (1 / input_scale)), int(vid.width // (1 / input_scale))) - # FIXME: Adjust to appropriate power of 2 - # It would be better to pad image to a usable size, since - # the resize could affect aspect ratio. - if self.resize_hack: - scale_to = (scale_to[0]//8*8, scale_to[1]//8*8) + # # if self.resize_hack: + # # TODO: Replace this when model-specific divisibility calculation implemented. + # divisor = 2 ** cm_model.down_blocks + # crop_to = ( + # (scale_to[0] // divisor) * divisor, + # (scale_to[1] // divisor) * divisor) # Create transform object - transform = DataTransform( - frame_idxs = frames_idx, - scale = model_data["multiscale"]) - + transform = DataTransform(frame_idxs=frames_idx) + # Scale if target doesn't match current size - mov = transform.scale_to(mov_full, target_size=scale_to) + # imgs_full = transform.scale_to(mov_full, target_size=scale_to) - subchunks_to_process = [(mov, transform)] + subchunks_to_process = [(imgs_full, transform)] logger.info(" Transformed images [%.1fs]" % (time() - t0)) @@ -274,20 +489,20 @@ def predict(self, subchunk_results = [] - for subchunk_mov, subchunk_transform in subchunks_to_process: + for subchunk_imgs_full, subchunk_transform in subchunks_to_process: logger.info(f" Running inference for subchunk:") - logger.info(f" Shape: {subchunk_mov.shape}") - logger.info(f" Prediction Scale: {subchunk_transform.scale}") + logger.info(f" Shape: {subchunk_imgs_full.shape}") + logger.info(f" Scale: {subchunk_transform.scale}") - if ModelOutputType.PART_AFFINITY_FIELD not in self.sleap_models: + if ModelOutputType.PART_AFFINITY_FIELD not in self.inference_models: # Pipeline for predicting a single animal in a frame # This uses only confidence maps logger.warning("No PAF model! Running in SINGLE INSTANCE mode.") subchunk_lfs = self.single_instance_inference( - subchunk_mov, + subchunk_imgs_full, subchunk_transform, vid) @@ -295,7 +510,7 @@ def predict(self, # Pipeline for predicting multiple animals in a frame # This uses confidence maps and part affinity fields subchunk_lfs = self.multi_instance_inference( - subchunk_mov, + subchunk_imgs_full, subchunk_transform, vid) @@ -323,7 +538,7 @@ def predict(self, predicted_frames_chunk.extend(subchunk_frames) predicted_frames_chunk = LabeledFrame.merge_frames(predicted_frames_chunk, video=vid) - logger.info(f" Instances found on {len(predicted_frames_chunk)} out of {len(mov_full)} frames.") + logger.info(f" Instances found on {len(predicted_frames_chunk)} out of {len(imgs_full)} frames.") if len(predicted_frames_chunk): @@ -333,7 +548,7 @@ def predict(self, # Track if self.with_tracking and len(predicted_frames_chunk): t0 = time() - tracker.process(mov_full, predicted_frames_chunk) + tracker.process(imgs_full, predicted_frames_chunk) logger.info(" Tracked IDs via flow shift [%.1fs]" % (time() - t0)) # Save @@ -342,7 +557,7 @@ def predict(self, if chunk % self.save_frequency == 0 or chunk == (num_chunks - 1): t0 = time() - # FIXME: We are re-writing the whole output each time, this is dumb. + # TODO: We are re-writing the whole output each time, this is dumb. # We should save in chunks then combine at the end. labels = Labels(labeled_frames=predicted_frames) if self.output_path is not None: @@ -403,7 +618,6 @@ def predict_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: return pool, result - # Methods for running inferring on components of pipeline def centroid_crop_inference(self, imgs: np.ndarray, @@ -424,55 +638,53 @@ def centroid_crop_inference(self, which allows us to merge overlapping crops into larger crops. """ - crop_within = (imgs.shape[1]//8*8, imgs.shape[2]//8*8) + # Get inference models with metadata. + centroid_model = self.inference_models[ModelOutputType.CENTROIDS] + cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] - # Fetch centroid model (uses cache if already loaded) + logger.info(" Performing centroid cropping.") - model_package = self.fetch_model( - input_size = None, - output_types = [ModelOutputType.CENTROIDS]) + # TODO: Replace this calculation when model-specific divisibility calculation implemented. + divisor = 2 ** centroid_model.down_blocks + crop_within = ((imgs.shape[1] // divisor) * divisor, (imgs.shape[2] // divisor) * divisor) + logger.info(f" crop_within: {crop_within}") # Create transform - # This lets us scale the images before we predict centroids, # and will also let us map the points on the scaled image to # points on the original images so we can crop original images. - centroid_transform = DataTransform() + target_shape = (int(imgs.shape[1] * centroid_model.input_scale), int(imgs.shape[2] * centroid_model.input_scale)) - # Scale to match input size of trained centroid model - # Usually this will be 1/4-scale of original images - - centroid_imgs_scaled = \ - centroid_transform.scale_to( - imgs=imgs, - target_size=model_package["model"].input_shape[1:3]) + # Scale to match input size of trained centroid model. + centroid_imgs_scaled = centroid_transform.scale_to( + imgs=imgs, target_size=target_shape) - # Predict centroid confidence maps, then find peaks - - centroid_confmaps = model_package["model"].predict(centroid_imgs_scaled.astype("float32") / 255, - batch_size=self.inference_batch_size) + # Predict centroid confidence maps, then find peaks. + t0 = time() + centroid_confmaps = centroid_model.predict(centroid_imgs_scaled, + batch_size=self.inference_batch_size) peaks, peak_vals = find_all_peaks(centroid_confmaps, - min_thresh=self.nms_min_thresh, - sigma=self.nms_sigma) + min_thresh=self.nms_min_thresh, sigma=self.nms_sigma) + elapsed = time() - t0 + total_peaks = sum([len(frame_peaks[0]) for frame_peaks in peaks]) + logger.info(f" Found {total_peaks} centroid peaks ({total_peaks / len(peaks):.2f} centroids/frame) [{elapsed:.2f}s].") if box_size is None: - # Get training bounding box size to determine (min) centroid crop size - crop_model_package = self.fetch_model( - input_size = None, - output_types = [ModelOutputType.CONFIDENCE_MAP]) - crop_size = crop_model_package["bounding_box_size"] - bb_half = (crop_size + self.crop_padding)//2 + # Get training bounding box size to determine (min) centroid crop size. + # TODO: fix this to use a stored value or move this logic elsewhere + crop_size = int(max(cm_model.trained_input_shape[1:3]) // cm_model.input_scale) + bb_half = crop_size // 2 + # bb_half = (crop_size + self.crop_padding) // 2 else: - bb_half = box_size//2 + bb_half = box_size // 2 - logger.info(f" Centroid crop box size: {bb_half*2}") - - all_boxes = dict() + logger.info(f" Crop box size: {bb_half * 2}") # Iterate over each frame to filter bounding boxes + all_boxes = dict() for frame_i, (frame_peaks, frame_peak_vals) in enumerate(zip(peaks, peak_vals)): # If we found centroids on this frame... @@ -484,12 +696,13 @@ def centroid_crop_inference(self, boxes = [] for peak_i in range(frame_peaks[0].shape[0]): + # Rescale peak back onto full-sized image - peak_x = int(frame_peaks[0][peak_i][0] / centroid_transform.scale) - peak_y = int(frame_peaks[0][peak_i][1] / centroid_transform.scale) + peak_x = int(frame_peaks[0][peak_i][0] / centroid_model.output_scale) + peak_y = int(frame_peaks[0][peak_i][1] / centroid_model.output_scale) - boxes.append((peak_x-bb_half, peak_y-bb_half, - peak_x+bb_half, peak_y+bb_half)) + boxes.append((peak_x - bb_half, peak_y - bb_half, + peak_x + bb_half, peak_y + bb_half)) if do_merge: # Merge overlapping boxes and pad to multiple of crop size @@ -497,6 +710,7 @@ def centroid_crop_inference(self, boxes=boxes, pad_factor_box=(self.crop_growth, self.crop_growth), within=crop_within) + else: # Just return the boxes centered around each centroid. # Note that these aren't guaranteed to be within the @@ -506,29 +720,33 @@ def centroid_crop_inference(self, # Keep track of all boxes, grouped by size and frame idx for box in merged_boxes: - box_size = (box[2]-box[0], box[3]-box[1]) + merged_box_size = (box[2] - box[0], box[3] - box[1]) - if box_size not in all_boxes: - all_boxes[box_size] = dict() - if frame_i not in all_boxes[box_size]: - all_boxes[box_size][frame_i] = [] + if merged_box_size not in all_boxes: + all_boxes[merged_box_size] = dict() + logger.info(f" Found box size: {merged_box_size}") - all_boxes[box_size][frame_i].append(box) + if frame_i not in all_boxes[merged_box_size]: + all_boxes[merged_box_size][frame_i] = [] + + all_boxes[merged_box_size][frame_i].append(box) + + logger.info(f" Found {len(all_boxes)} box sizes after merging.") subchunks = [] # Check if we found any boxes for this chunk of frames if len(all_boxes): - model_data = get_model_data(self.sleap_models, [ModelOutputType.CONFIDENCE_MAP]) # We'll make a "subchunk" for each crop size for crop_size in all_boxes: - if crop_size[0] >= 1024: - logger.info(f" Skipping subchunk for size {crop_size}, would have {len(all_boxes[crop_size])} crops.") - for debug_frame_idx in all_boxes[crop_size].keys(): - print(f" frame {frames_idx[debug_frame_idx]}: {all_boxes[crop_size][debug_frame_idx]}") - continue + # TODO: Look into this edge case? + # if crop_size[0] >= 1024: + # logger.info(f" Skipping subchunk for size {crop_size}, would have {len(all_boxes[crop_size])} crops.") + # for debug_frame_idx in all_boxes[crop_size].keys(): + # print(f" frame {frames_idx[debug_frame_idx]}: {all_boxes[crop_size][debug_frame_idx]}") + # continue # Make list of all boxes and corresponding img index. subchunk_idxs = [] @@ -536,12 +754,12 @@ def centroid_crop_inference(self, for frame_i, frame_boxes in all_boxes[crop_size].items(): subchunk_boxes.extend(frame_boxes) - subchunk_idxs.extend( [frame_i] * len(frame_boxes) ) + subchunk_idxs.extend([frame_i] * len(frame_boxes)) + # TODO: This should probably be in the main loop # Create transform object - transform = DataTransform( - frame_idxs = frames_idx, - scale = model_data["multiscale"]) + # transform = DataTransform(frame_idxs=frames_idx, scale=cm_model.output_relative_scale) + transform = DataTransform(frame_idxs=frames_idx) # Do the cropping imgs_cropped = transform.crop(imgs, subchunk_boxes, subchunk_idxs) @@ -556,38 +774,60 @@ def centroid_crop_inference(self, return subchunks + def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame]: - """Run the single instance pipeline for a stack of images.""" + """Run the single instance pipeline for a stack of images. + + Args: + imgs: Subchunk of images to process. + transform: DataTransform object tracking input transformations. + video: Video object for building LabeledFrames with correct reference to source. - # Get confmap model for this image size - model_package = self.fetch_model( - input_size = imgs.shape[1:], - output_types = [ModelOutputType.CONFIDENCE_MAP]) + Returns: + A list of LabeledFrames with predicted points. + """ - # Run inference - t0 = time() + # Get confmap inference model. + cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] + + # Scale to match input size of trained model. + # Images are expected to be at full resolution, but may be cropped. + assert(transform.scale == 1.0) + target_shape = (int(imgs.shape[1] * cm_model.input_scale), int(imgs.shape[2] * cm_model.input_scale)) + imgs_scaled = transform.scale_to(imgs=imgs, target_size=target_shape) - confmaps = model_package["model"].predict(imgs.astype("float32") / 255, batch_size=self.inference_batch_size) + # TODO: Adjust for divisibility + # divisor = 2 ** cm_model.down_blocks + # crop_within = ((imgs.shape[1] // divisor) * divisor, (imgs.shape[2] // divisor) * divisor) + + # Run inference. + t0 = time() + confmaps = cm_model.predict(imgs_scaled, batch_size=self.inference_batch_size) logger.info( " Inferred confmaps [%.1fs]" % (time() - t0)) logger.info(f" confmaps: shape={confmaps.shape}, ptp={np.ptp(confmaps)}") t0 = time() + # TODO: Move this to GPU and add subpixel refinement. # Use single highest peak in channel corresponding node points_arrays = find_all_single_peaks(confmaps, min_thresh=self.nms_min_thresh) + # Adjust for multi-scale such that the points are at the scale of the transform. + points_arrays = [pts / cm_model.output_relative_scale for pts in points_arrays] + + # Create labeled frames and predicted instances from the points. predicted_frames_chunk = match_single_peaks_all( - points_arrays = points_arrays, - skeleton = model_package["skeleton"], - transform = transform, - video = video) + points_arrays=points_arrays, + skeleton=cm_model.skeleton, + transform=transform, + video=video) logger.info(" Used highest peaks to create instances [%.1fs]" % (time() - t0)) # Save confmaps if self.output_path is not None and self.save_confmaps_pafs: - logger.warning("Not saving confmaps because feature currently not working.") + raise NotImplementedError("Not saving confmaps/pafs because feature currently not working.") # Disable save_confmaps_pafs since not currently working. # The problem is that we can't put data for different crop sizes # all into a single h5 datasource. It's now possible to view live @@ -598,70 +838,114 @@ def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame return predicted_frames_chunk + def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame]: - """ - Run the multi-instance inference pipeline for a stack of images. + """Run the multi-instance inference pipeline for a stack of images. + + Args: + imgs: Subchunk of images to process. + transform: DataTransform object tracking input transformations. + video: Video object for building LabeledFrames with correct reference to source. + + Returns: + A list of LabeledFrames with predicted points. """ # Load appropriate models as needed - conf_model = self.fetch_model( - input_size = imgs.shape[1:], - output_types = [ModelOutputType.CONFIDENCE_MAP]) - - paf_model = self.fetch_model( - input_size = imgs.shape[1:], - output_types = [ModelOutputType.PART_AFFINITY_FIELD]) + cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] + paf_model = self.inference_models[ModelOutputType.PART_AFFINITY_FIELD] # Find peaks t0 = time() - multiscale_diff = paf_model["multiscale"] / conf_model["multiscale"] - - peaks, peak_vals, confmaps = \ - peak_tf_inference( - model = conf_model["model"], - data = imgs.astype("float32")/255, - min_thresh=self.nms_min_thresh, - gaussian_size=self.nms_kernel_size, - gaussian_sigma=self.nms_sigma, - downsample_factor=int(1/multiscale_diff), - upsample_factor=int(1/conf_model["multiscale"]), - return_confmaps=self.save_confmaps_pafs - ) + # Scale to match input resolution of model. + # Images are expected to be at full resolution, but may be cropped. + assert(transform.scale == 1.0) + cm_target_shape = (int(imgs.shape[1] * cm_model.input_scale), int(imgs.shape[2] * cm_model.input_scale)) + imgs_scaled = transform.scale_to(imgs=imgs, target_size=cm_target_shape) + if imgs_scaled.dtype == np.dtype("uint8"): # TODO: Unify normalization. + imgs_scaled = imgs_scaled.astype("float32") / 255. + + # TODO: Unfuck this whole workflow + if self.gpu_peak_finding: + confmaps_shape = cm_model.compute_output_shape((imgs_scaled.shape[1], imgs_scaled.shape[2])) + peaks, peak_vals, confmaps = peak_tf_inference( + model=cm_model.keras_model, + confmaps_shape=confmaps_shape, + data=imgs_scaled, + min_thresh=self.nms_min_thresh, + gaussian_size=self.nms_kernel_size, + gaussian_sigma=self.nms_sigma, + upsample_factor=int(self.supersample_factor / cm_model.output_scale), + win_size=self.supersample_window_size, + return_confmaps=self.save_confmaps_pafs, + batch_size=self.inference_batch_size + ) - transform.scale = transform.scale * multiscale_diff + else: + confmaps = cm_model.predict(imgs_scaled, batch_size=self.inference_batch_size) + peaks, peak_vals = find_all_peaks(confmaps, min_thresh=self.nms_min_thresh, sigma=self.nms_sigma) + + # # Undo just the scaling so we're back to full resolution, but possibly cropped. + for t in range(len(peaks)): # frames + for c in range(len(peaks[t])): # channels + peaks[t][c] /= cm_model.output_scale + + # Peaks should be at (refined) full resolution now. + # Keep track of scale adjustment. + transform.scale = 1.0 + + elapsed = time() - t0 + total_peaks = sum([len(channel_peaks) for frame_peaks in peaks for channel_peaks in frame_peaks]) + logger.info(f" Found {total_peaks} peaks ({total_peaks / len(imgs):.2f} peaks/frame) [{elapsed:.2f}s].") + # logger.info(f" peaks: {peaks}") + + # Scale to match input resolution of model. + # Images are expected to be at full resolution, but may be cropped. + paf_target_shape = (int(imgs.shape[1] * paf_model.input_scale), int(imgs.shape[2] * paf_model.input_scale)) + if (imgs_scaled.shape[1] == paf_target_shape[0]) and (imgs_scaled.shape[2] == paf_target_shape[1]): + # No need to scale again if we're already there, so just adjust the stored scale + transform.scale = paf_model.input_scale - logger.info(" Inferred confmaps and found-peaks (gpu) [%.1fs]" % (time() - t0)) - logger.info(f" peaks: {len(peaks)}") + else: + # Adjust scale from full resolution images (avoiding possible resizing up from confmaps input scale) + imgs_scaled = transform.scale_to(imgs=imgs, target_size=paf_target_shape) # Infer pafs t0 = time() - pafs = paf_model["model"].predict(imgs.astype("float32") / 255, batch_size=self.inference_batch_size) - + pafs = paf_model.predict(imgs_scaled, batch_size=self.inference_batch_size) logger.info( " Inferred PAFs [%.1fs]" % (time() - t0)) logger.info(f" pafs: shape={pafs.shape}, ptp={np.ptp(pafs)}") + # Adjust points to the paf output scale so we can invert later (should not incur loss of precision) + # TODO: Check precision + for t in range(len(peaks)): # frames + for c in range(len(peaks[t])): # channels + peaks[t][c] *= paf_model.output_scale + transform.scale = paf_model.output_scale + # Determine whether to use serial or parallel version of peak-finding # Use the serial version is we're already running in a thread pool match_peaks_function = match_peaks_paf_par if not self.is_async else match_peaks_paf # Match peaks via PAFs t0 = time() - predicted_frames_chunk = match_peaks_function( - peaks, peak_vals, pafs, conf_model["skeleton"], - transform=transform, video=video, - min_score_to_node_ratio=self.min_score_to_node_ratio, - min_score_midpts=self.min_score_midpts, - min_score_integral=self.min_score_integral, - add_last_edge=self.add_last_edge, - single_per_crop=self.single_per_crop, - pool=self.pool) - + peaks, peak_vals, pafs, paf_model.skeleton, + transform=transform, video=video, + min_score_to_node_ratio=self.min_score_to_node_ratio, + min_score_midpts=self.min_score_midpts, + min_score_integral=self.min_score_integral, + add_last_edge=self.add_last_edge, + single_per_crop=self.single_per_crop, + pool=self.pool) + + total_instances = sum([len(labeled_frame) for labeled_frame in predicted_frames_chunk]) logger.info(" Matched peaks via PAFs [%.1fs]" % (time() - t0)) + logger.info(f" Found {total_instances} instances ({total_instances / len(imgs):.2f} instances/frame)") # Remove overlapping predicted instances - if OVERLAPPING_INSTANCES_NMS: + if self.overlapping_instances_nms: t0 = clock() for lf in predicted_frames_chunk: n = len(lf.instances) @@ -672,7 +956,7 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] # Save confmaps and pafs if self.output_path is not None and self.save_confmaps_pafs: - logger.warning("Not saving confmaps/pafs because feature currently not working.") + raise NotImplementedError("Not saving confmaps/pafs because feature currently not working.") # Disable save_confmaps_pafs since not currently working. # The problem is that we can't put data for different crop sizes # all into a single h5 datasource. It's now possible to view live @@ -684,60 +968,14 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] return predicted_frames_chunk - def fetch_model(self, - input_size: tuple, - output_types: List[ModelOutputType]) -> dict: - """Loads and returns keras Model with caching.""" - - key = (input_size, tuple(output_types)) - - if key not in self._models: - - # Load model - - keras_model = load_model(self.sleap_models, input_size, output_types) - first_sleap_model = self.sleap_models[output_types[0]] - model_data = get_model_data(self.sleap_models, output_types) - skeleton = get_model_skeleton(self.sleap_models, output_types) - - # logger.info(f"Model multiscale: {model_data['multiscale']}") - - # If no input size was specified, then use the input size - # from original trained model. - - if input_size is None: - input_size = keras_model.input_shape[1:] - - # Get the size of the bounding box from training data - # (or the size of crop that model was trained on if the - # bounding box size wasn't set). - - if first_sleap_model.trainer.instance_crop: - bounding_box_size = \ - first_sleap_model.trainer.bounding_box_size or keras_model.input_shape[1] - else: - bounding_box_size = None - - # Cache the model so we don't have to load it next time - - self._models[key] = dict( - model=keras_model, - skeleton=model_data["skeleton"], - multiscale=model_data["multiscale"], - bounding_box_size=bounding_box_size - ) - - # Return the keras Model - return self._models[key] - def main(): def frame_list(frame_str: str): # Handle ranges of frames. Must be of the form "1-200" - if '-' in frame_str: - min_max = frame_str.split('-') + if "-" in frame_str: + min_max = frame_str.split("-") min_frame = int(min_max[0]) max_frame = int(min_max[1]) return list(range(min_frame, max_frame+1)) @@ -751,28 +989,28 @@ def frame_list(frame_str: str): "Multiple models can be specified, each preceded by " "--model. Confmap and PAF models are required.", required=True) - parser.add_argument('--resize-input', dest='resize_input', action='store_const', + parser.add_argument("--resize-input", dest="resize_input", action="store_const", const=True, default=False, - help='resize the input layer to image size (default False)') - parser.add_argument('--with-tracking', dest='with_tracking', action='store_const', + help="resize the input layer to image size (default False)") + parser.add_argument("--with-tracking", dest="with_tracking", action="store_const", const=True, default=False, - help='just visualize predicted confmaps/pafs (default False)') - parser.add_argument('--frames', type=frame_list, default="", - help='list of frames to predict. Either comma separated list (e.g. 1,2,3) or ' - 'a range separated by hyphen (e.g. 1-3). (default is entire video)') - parser.add_argument('-o', '--output', type=str, default=None, - help='The output filename to use for the predicted data.') - parser.add_argument('--out_format', choices=['hdf5', 'json'], help='The format to use for' - ' the output file. Either hdf5 or json. hdf5 is the default.', - default='hdf5') - parser.add_argument('--save-confmaps-pafs', dest='save_confmaps_pafs', action='store_const', + help="just visualize predicted confmaps/pafs (default False)") + parser.add_argument("--frames", type=frame_list, default="", + help="list of frames to predict. Either comma separated list (e.g. 1,2,3) or " + "a range separated by hyphen (e.g. 1-3). (default is entire video)") + parser.add_argument("-o", "--output", type=str, default=None, + help="The output filename to use for the predicted data.") + parser.add_argument("--out_format", choices=["hdf5", "json"], help="The format to use for" + " the output file. Either hdf5 or json. hdf5 is the default.", + default="hdf5") + parser.add_argument("--save-confmaps-pafs", dest="save_confmaps_pafs", action="store_const", const=True, default=False, - help='Whether to save the confidence maps or pafs') - parser.add_argument('-v', '--verbose', help='Increase logging output verbosity.', action="store_true") + help="Whether to save the confidence maps or pafs") + parser.add_argument("-v", "--verbose", help="Increase logging output verbosity.", action="store_true") args = parser.parse_args() - if args.out_format == 'json': + if args.out_format == "json": output_suffix = ".predictions.json" else: output_suffix = ".predictions.h5" @@ -805,10 +1043,10 @@ def frame_list(frame_str: str): img_shape = None # Create a predictor to do the work. - predictor = Predictor(sleap_models=sleap_models, - output_path=save_path, - save_confmaps_pafs=args.save_confmaps_pafs, - with_tracking=args.with_tracking) + predictor = Predictor(training_jobs=sleap_models, + output_path=save_path, + save_confmaps_pafs=args.save_confmaps_pafs, + with_tracking=args.with_tracking) # Run the inference pipeline return predictor.predict(input_video=data_path, frames=frames) diff --git a/sleap/nn/loadmodel.py b/sleap/nn/loadmodel.py deleted file mode 100644 index 7159470bf..000000000 --- a/sleap/nn/loadmodel.py +++ /dev/null @@ -1,147 +0,0 @@ -import logging -logger = logging.getLogger(__name__) - -import numpy as np - -from time import time, clock -from typing import Dict, List, Union, Optional, Tuple - -import tensorflow as tf -import keras - -from sleap.skeleton import Skeleton -from sleap.nn.model import ModelOutputType -from sleap.nn.training import TrainingJob - -def load_model( - sleap_models: List[TrainingJob], - input_size: Optional[tuple], - output_types: List[ModelOutputType]) -> keras.Model: - """ - Load keras Model for specified input size and output types. - - Supports centroids, confmaps, and pafs. If output type includes - confmaps and pafs then we'll combine these into a single model. - - Arguments: - sleap_models: dict of the TrainingJobs where we can find models. - input_size: (h, w, c) tuple; if None, don't resize input layer - output_types: list of ModelOutputTypes - Returns: - keras Model - """ - - if ModelOutputType.CENTROIDS in output_types: - # Load centroid model - keras_model = load_model_from_job(sleap_models[ModelOutputType.CENTROIDS]) - - logger.info(f"Loaded centroid model trained on shape {keras_model.input_shape}") - - else: - # Load model for confmaps or pafs or both - - models = [] - - new_input_layer = tf.keras.layers.Input(input_size) if input_size is not None else None - - for output_type in output_types: - - # Load the model - job = sleap_models[output_type] - model = load_model_from_job(job) - - logger.info(f"Loaded {output_type} model trained on shape {model.input_shape}") - - # Get input layer if we didn't create one for a specified size - if new_input_layer is None: - new_input_layer = model.input - - # Resize input layer - model.layers.pop(0) - model = model(new_input_layer) - - logger.info(f" Resized input layer to {input_size}") - - # Add to list of models we've just loaded - models.append(model) - - if len(models) == 1: - keras_model = tf.keras.Model(new_input_layer, models[0]) - else: - # Merge multiple models into single model - keras_model = tf.keras.Model(new_input_layer, models) - - logger.info(f" Merged {len(models)} into single model") - - # keras_model = convert_to_gpu_model(keras_model) - - return keras_model - -def get_model_data( - sleap_models: Dict[ModelOutputType,TrainingJob], - output_types: List[ModelOutputType]) -> Dict: - - model_type = output_types[0] - job = sleap_models[model_type] - - # Model input is scaled by to get output - model_properties = dict( - skeleton=job.model.skeletons[0], - scale=job.trainer.scale, - multiscale=job.model.output_scale) - - return model_properties - -def get_model_skeleton(sleap_models, output_types) -> Skeleton: - - skeleton = get_model_data(sleap_models, output_types)["skeleton"] - - if skeleton is None: - logger.warning("Predictor has no skeleton.") - raise ValueError("Predictor has no skeleton.") - - return skeleton - -def load_model_from_job(job: TrainingJob) -> keras.Model: - """Load keras Model from a specific TrainingJob.""" - - # Load model from TrainingJob data - keras_model = keras.models.load_model(job_model_path(job), - custom_objects={"tf": tf}) - - # Rename to prevent layer naming conflict - name_prefix = f"{job.model.output_type}_" - keras_model._name = name_prefix + keras_model.name - for i in range(len(keras_model.layers)): - keras_model.layers[i]._name = name_prefix + keras_model.layers[i].name - - return keras_model - -def job_model_path(job: TrainingJob) -> str: - import os - return os.path.join(job.save_dir, job.best_model_filename) - -def get_available_gpus(): - """ - Get the list of available GPUs - - Returns: - List of available GPU device names - """ - - from tensorflow.python.client import device_lib - local_device_protos = device_lib.list_local_devices() - return [x.name for x in local_device_protos if x.device_type == 'GPU'] - -def convert_to_gpu_model(model: keras.Model) -> keras.Model: - gpu_list = get_available_gpus() - - if len(gpu_list) == 0: - logger.warn('No GPU devices, this is going to be really slow, something is wrong, dont do this!!!') - else: - logger.info(f'Detected {len(gpu_list)} GPU(s) for inference') - - if len(gpu_list) > 1: - model = keras.util.multi_gpu_model(model, gpus=len(gpu_list)) - - return model \ No newline at end of file diff --git a/sleap/nn/model.py b/sleap/nn/model.py index 4c1680031..e9d7b1480 100644 --- a/sleap/nn/model.py +++ b/sleap/nn/model.py @@ -125,6 +125,22 @@ def name(self): """ return self.backbone_name + @property + def down_blocks(self): + """Returns the number of pooling or striding blocks in the backbone. + + This is useful when computing valid dimensions of the input data. + + If the backbone does not provide enough information to infer this, + this is set to 0. + """ + + if hasattr(self.backbone, "down_blocks"): + return self.backbone.down_blocks + + else: + return 0 + @property def output_scale(self): """Calculates output scale relative to input.""" diff --git a/sleap/nn/peakfinding_tf.py b/sleap/nn/peakfinding_tf.py index f19ee605c..a857ea425 100644 --- a/sleap/nn/peakfinding_tf.py +++ b/sleap/nn/peakfinding_tf.py @@ -2,9 +2,9 @@ import time import h5py +import keras import tensorflow as tf -keras = tf.keras import numpy as np from typing import Generator, Tuple @@ -46,33 +46,34 @@ def impeaksnms_tf(I, min_thresh=0.3): return inds, peak_vals -def find_peaks_tf(confmaps, min_thresh=0.3, upsample_factor: int = 1): - n, h, w, c = confmaps.get_shape().as_list() +def find_peaks_tf(confmaps, confmaps_shape, min_thresh=0.3, upsample_factor: int = 1, win_size: int = 5): + # n, h, w, c = confmaps.get_shape().as_list() - unrolled_confmaps = tf.reshape(tf.transpose(confmaps, perm=[0, 3, 1, 2]), [-1, h, w, 1]) # nc, h, w, 1 + h, w, c = confmaps_shape + + unrolled_confmaps = tf.reshape(tf.transpose(confmaps, perm=[0, 3, 1, 2]), [-1, h, w, 1]) # (nc, h, w, 1) peak_inds, peak_vals = impeaksnms_tf(unrolled_confmaps, min_thresh=min_thresh) - channel_sample, y, x, _ = tf.split(peak_inds, 4, axis=1) + channel_sample_ind, y, x, _ = tf.split(peak_inds, 4, axis=1) - channel = tf.floormod(channel_sample, c) - sample = tf.floordiv(channel_sample, c) + channel_ind = tf.floormod(channel_sample_ind, c) + sample_ind = tf.floordiv(channel_sample_ind, c) - peaks = tf.concat([sample, y, x, channel], axis=1) + peaks = tf.concat([sample_ind, y, x, channel_ind], axis=1) # (nc, 4) # If we have run prediction on low res and need to upsample the peaks # to a higher resolution. Compute sub-pixel accurate peaks # from these approximate peaks and return the upsampled sub-pixel peaks. if upsample_factor > 1: - win_size = 5 # Must be odd offset = (win_size - 1) / 2 # Get the boxes coordinates centered on the peaks, normalized to image # coordinates - box_ind = tf.squeeze(tf.cast(channel_sample, tf.int32)) - top_left = (tf.to_float(peaks[:, 1:3]) + - tf.constant([-offset, -offset], dtype='float32')) / (h - 1.0) - bottom_right = (tf.to_float(peaks[:, 1:3]) + tf.constant([offset, offset], dtype='float32')) / (w - 1.0) + box_ind = tf.squeeze(tf.cast(channel_sample_ind, tf.int32)) + top_left = (tf.cast(peaks[:, 1:3], tf.float32) + + tf.constant([-offset, -offset], dtype="float32")) / (h - 1.0) + bottom_right = (tf.cast(peaks[:, 1:3], tf.float32) + tf.constant([offset, offset], dtype="float32")) / (w - 1.0) boxes = tf.concat([top_left, bottom_right], axis=1) small_windows = tf.image.crop_and_resize( @@ -81,17 +82,27 @@ def find_peaks_tf(confmaps, min_thresh=0.3, upsample_factor: int = 1): box_ind, crop_size=[win_size, win_size]) + # Upsample cropped windows windows = tf.image.resize_bicubic( small_windows, - [upsample_factor*win_size, upsample_factor*win_size]) + [upsample_factor * win_size, upsample_factor * win_size]) windows = tf.squeeze(windows) - windows_peaks = find_maxima_tf(windows) - windows_peaks = windows_peaks / win_size - else: - windows_peaks = None - return peaks, peak_vals, windows_peaks + # Find global maximum of each window + windows_peaks = find_maxima_tf(windows) # [row_ind, col_ind] ==> (nc, 2) + + # Adjust back to resolution before upsampling + windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast(upsample_factor, tf.float32) + + # Convert to offsets relative to the original peaks (center of cropped windows) + windows_offsets = windows_peaks - tf.cast(offset, tf.float32) # (nc, 2) + windows_offsets = tf.pad(windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0) # (nc, 4) + + # Apply offsets + peaks = tf.cast(peaks, tf.float32) + windows_offsets + + return peaks, peak_vals # Blurring: # Ref: https://stackoverflow.com/questions/52012657/how-to-make-a-2d-gaussian-filter-in-tensorflow @@ -103,33 +114,34 @@ def gaussian_kernel(size: int, d = tf.distributions.Normal(mean, std) vals = d.prob(tf.range(start = -size, limit = size + 1, dtype = tf.float32)) - gauss_kernel = tf.einsum('i,j->ij', + gauss_kernel = tf.einsum("i,j->ij", vals, vals) return gauss_kernel / tf.reduce_sum(gauss_kernel) -# Now we can do peak finding on the GPU like this: def peak_tf_inference(model, data, - min_thresh: float = 0.3, - gaussian_size: int = 9, - gaussian_sigma: float = 3.0, - upsample_factor: int = 1, - downsample_factor: int = 1, - return_confmaps: bool = False): + confmaps_shape: Tuple[int], + min_thresh: float = 0.3, + gaussian_size: int = 9, + gaussian_sigma: float = 3.0, + upsample_factor: int = 1, + return_confmaps: bool = False, + batch_size: int = 4, + win_size: int = 7): sess = keras.backend.get_session() + # TODO: Unfuck this. confmaps = model.outputs[-1] - - n, h, w, c = confmaps.get_shape().as_list() - - if gaussian_size and upsample_factor == 1: + h, w, c = confmaps_shape + + if gaussian_size > 0 and gaussian_sigma > 0: # Make Gaussian Kernel with desired specs. gauss_kernel = gaussian_kernel(size=gaussian_size, mean=0.0, std=gaussian_sigma) - # Expand dimensions of `gauss_kernel` for `tf.nn.seprable_conv2d` signature. + # Expand dimensions of `gauss_kernel` for `tf.nn.separable_conv2d` signature. gauss_kernel = tf.tile(gauss_kernel[:, :, tf.newaxis, tf.newaxis], [1, 1, c, 1]) # Create a pointwise filter that does nothing, we are using separable convultions to blur @@ -137,91 +149,78 @@ def peak_tf_inference(model, data, pointwise_filter = tf.eye(c, batch_shape=[1, 1]) # Convolve. - blurred_confmaps = tf.nn.separable_conv2d(confmaps, gauss_kernel, pointwise_filter, - strides=[1, 1, 1, 1], padding='SAME') + confmaps = tf.nn.separable_conv2d(confmaps, gauss_kernel, pointwise_filter, + strides=[1, 1, 1, 1], padding="SAME") - inds, peak_vals, windows = find_peaks_tf(blurred_confmaps, min_thresh=min_thresh, - upsample_factor=upsample_factor) - else: - inds, peak_vals, windows = find_peaks_tf(confmaps, min_thresh=min_thresh, - upsample_factor=upsample_factor) + + # Setup peak finding computations. + peaks, peak_vals = find_peaks_tf(confmaps, + confmaps_shape=confmaps_shape, min_thresh=min_thresh, + upsample_factor=upsample_factor, win_size=win_size) # We definitely want to capture the peaks in the output # We will map the tensorflow outputs onto a dict to return - outputs_dict = dict(peaks=inds, peak_vals=peak_vals) - - if upsample_factor > 1: - outputs_dict["windows"] = windows + outputs_dict = dict(peaks=peaks, peak_vals=peak_vals) if return_confmaps: outputs_dict["confmaps"] = confmaps # Convert dict to list of keys and list of tensors (to evaluate) - outputs_keys, outputs_vals = list(outputs_dict.keys()), list(outputs_dict.values()) - - peaks = [] - peak_vals = [] - windows = [] - confmaps = [] + outputs_keys, output_tensors = list(outputs_dict.keys()), list(outputs_dict.values()) - for batch_number, row_offset, data_batch in batch(data, batch_size=2): + # Run the graph and retrieve output arrays. + peaks_arr = [] + peak_vals_arr = [] + confmaps_arr = [] + for batch_number, row_offset, data_batch in batch(data, batch_size=batch_size): # This does the actual evaluation - outputs = sess.run(outputs_vals, feed_dict={ model.input: data_batch }) + outputs_arr = sess.run(output_tensors, feed_dict={model.input: data_batch}) # Convert list of results to dict using saved list of keys - outputs_dict = dict(zip(outputs_keys, outputs)) + outputs_arr_dict = dict(zip(outputs_keys, outputs_arr)) - batch_peaks = outputs_dict["peaks"] + batch_peaks = outputs_arr_dict["peaks"] # First column should match row number in full data matrix, # so we add row offset of batch to row number in batch matrix. - batch_peaks[:,0] += row_offset + batch_peaks[:, 0] += row_offset - peaks.append(batch_peaks) - peak_vals.append(outputs_dict["peak_vals"]) - - if "windows" in outputs_dict: - windows.append(outputs_dict["windows"]) + peaks_arr.append(batch_peaks) + peak_vals_arr.append(outputs_arr_dict["peak_vals"]) if "confmaps" in outputs_dict: - confmaps.append(outputs_dict["confmaps"]) + confmaps.append(outputs_arr_dict["confmaps"]) - peaks = np.concatenate(peaks) - peak_vals = np.concatenate(peak_vals) - confmaps = np.concatenate(confmaps) if len(confmaps) else None + peaks_arr = np.concatenate(peaks_arr, axis=0) + peak_vals_arr = np.concatenate(peak_vals_arr, axis=0) + confmaps_arr = np.concatenate(confmaps_arr, axis=0) if len(confmaps_arr) else None # Extract frame and node index columns - frame_node_idx = peaks[:, [0, 3]] + sample_channel_ind = peaks_arr[:, [0, 3]] # (nc, 2) # Extract X and Y columns - peak_points = peaks[:,[1,2]].astype("float") - - # Add offset from upsampling window peak if upsampling - if upsample_factor > 1 and len(windows): - windows = np.concatenate(windows) - peak_points += windows/upsample_factor - - if downsample_factor > 1: - peak_points /= downsample_factor - - # Swap the X and Y columns (order was [row idx, col idx]) - peak_points = peak_points[:,[1,0]] + peak_points = peaks_arr[:, [2, 1]].astype("float") # [x, y] ==> (nc, 2) # Use indices to convert matrices to lists of lists # (this matches the format of cpu-based peak-finding) - peak_list, peak_val_list = split_matrices_by_double_index(frame_node_idx, peak_points, peak_vals) + peak_list, peak_val_list = split_matrices_by_double_index(sample_channel_ind, peak_points, peak_vals_arr, + n_samples=len(data), n_channels=c) return peak_list, peak_val_list, confmaps -def split_matrices_by_double_index(idxs, *data_list): +def split_matrices_by_double_index(idxs, *data_list, n_samples=None, n_channels=None): """Convert data matrices to lists of lists expected by other functions.""" # Return empty array if there are no idxs if len(idxs) == 0: return [], [] # Determine the list length for major and minor indices - max_idx_vals = np.max(idxs, axis=0).astype("int") + 1 + if n_samples is None: + n_samples = np.max(idxs[:, 0]) + 1 + + if n_channels is None: + n_channels = np.max(idxs[:, 1]) + 1 # We can accept a variable number of data matrices data_matrix_count = len(data_list) @@ -230,18 +229,18 @@ def split_matrices_by_double_index(idxs, *data_list): r = [[] for _ in range(data_matrix_count)] # Loop over major index (frame) - for i in range(max_idx_vals[0]): + for t in range(n_samples): # Empty list for this value of major index # for results from each data matrix major = [[] for _ in range(data_matrix_count)] # Loop over minor index (node) - for j in range(max_idx_vals[1]): + for c in range(n_channels): # Use idxs matrix to determine which rows # to retrieve from each data matrix - mask = np.all((idxs == [i,j]), axis = 1) + mask = np.all((idxs == [t, c]), axis=1) # Get rows from each data matrix for data_matrix_idx, matrix in enumerate(data_list): diff --git a/sleap/nn/peakmatching.py b/sleap/nn/peakmatching.py index 5023e81fd..9487cc32a 100644 --- a/sleap/nn/peakmatching.py +++ b/sleap/nn/peakmatching.py @@ -244,9 +244,9 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx score_to_node_ratio = subset[:,-2] / subset[:,-1] subset = subset[score_to_node_ratio > min_score_to_node_ratio, :] - # apply inverse transform to points + # Apply inverse transform to points to return to full resolution, uncropped image coordinates if candidate.shape[0] > 0: - candidate[...,0:2] = transform.invert(img_idx, candidate[...,0:2]) + candidate[..., 0:2] = transform.invert(img_idx, candidate[..., 0:2]) # Done with all the matching! Gather the data matched_instances_t = [] @@ -267,18 +267,22 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx score=match[-2])) # For centroid crop just return instance closest to centroid - if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: + # if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: - crop_centroid = np.array(((transform.crop_size//2, transform.crop_size//2),)) # center of crop box - crop_centroid = transform.invert(img_idx, crop_centroid) # relative to original image + # crop_centroid = np.array(((transform.crop_size//2, transform.crop_size//2),)) # center of crop box + # crop_centroid = transform.invert(img_idx, crop_centroid) # relative to original image - # sort by distance from crop centroid - matched_instances_t.sort(key=lambda inst: np.linalg.norm(inst.centroid - crop_centroid)) + # # sort by distance from crop centroid + # matched_instances_t.sort(key=lambda inst: np.linalg.norm(inst.centroid - crop_centroid)) - # logger.debug(f"SINGLE_INSTANCE_PER_CROP: crop has {len(matched_instances_t)} instances, filter to 1.") + # # logger.debug(f"SINGLE_INSTANCE_PER_CROP: crop has {len(matched_instances_t)} instances, filter to 1.") - # just use closest - matched_instances_t = matched_instances_t[0:1] + # # just use closest + # matched_instances_t = matched_instances_t[0:1] + + if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: + # Just keep highest scoring instance + matched_instances_t = [matched_instances_t[0]] return matched_instances_t @@ -318,6 +322,7 @@ def match_peaks_paf_par(peaks, peak_vals, pafs, skeleton, """ Parallel version of PAF peak matching """ if pool is None: + import multiprocessing pool = multiprocessing.Pool() futures = [] diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 8c59b1163..93e58ba0a 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -130,8 +130,8 @@ class Trainer: sigma: float = 5.0 instance_crop: bool = False bounding_box_size: int = 0 - min_crop_size: int = 0 - negative_samples: int = 0 + min_crop_size: int = 32 + negative_samples: int = 10 def train(self, model: Model, diff --git a/sleap/nn/transform.py b/sleap/nn/transform.py index 9389d70f5..0f5cd6ab2 100644 --- a/sleap/nn/transform.py +++ b/sleap/nn/transform.py @@ -79,7 +79,7 @@ def _scale(self, imgs, target_size): if (img_h, img_w) != target_size: # build ndarray for new size - scaled_imgs = np.zeros((imgs.shape[0], h, w, imgs.shape[3])) + scaled_imgs = np.zeros((imgs.shape[0], h, w, imgs.shape[3]), dtype=imgs.dtype) for i in range(imgs.shape[0]): # resize using cv2 diff --git a/sleap/training_profiles/default_centroids.json b/sleap/training_profiles/default_centroids.json index 2e0f59b6d..9843f75b0 100644 --- a/sleap/training_profiles/default_centroids.json +++ b/sleap/training_profiles/default_centroids.json @@ -1 +1,48 @@ -{"model": {"output_type": 2, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 16, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 4, "num_epochs": 100, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 0.25, "sigma": 5.0, "instance_crop": false}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file +{ + "model": { + "output_type": 2, + "backbone": { + "down_blocks": 3, + "up_blocks": 3, + "convs_per_depth": 2, + "num_filters": 16, + "kernel_size": 5, + "upsampling_layers": true, + "interp": "bilinear" + }, + "skeletons": null, + "backbone_name": "UNet" + }, + "trainer": { + "val_size": 0.1, + "optimizer": "adam", + "learning_rate": 0.0001, + "amsgrad": true, + "batch_size": 4, + "num_epochs": 100, + "steps_per_epoch": 200, + "shuffle_initially": true, + "shuffle_every_epoch": true, + "augment_rotation": 180, + "augment_scale_min": 1.0, + "augment_scale_max": 1.0, + "save_every_epoch": false, + "save_best_val": true, + "reduce_lr_min_delta": "1e-06", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 5, + "reduce_lr_cooldown": 3, + "reduce_lr_min_lr": "1e-10", + "early_stopping_min_delta": "1e-08", + "early_stopping_patience": 15, + "scale": 0.25, + "sigma": 5.0, + "instance_crop": false + }, + "labels_filename": null, + "run_name": null, + "save_dir": null, + "best_model_filename": null, + "newest_model_filename": null, + "final_model_filename": null +} \ No newline at end of file diff --git a/sleap/training_profiles/default_confmaps.json b/sleap/training_profiles/default_confmaps.json index 4503d7e8b..6d3393f0f 100644 --- a/sleap/training_profiles/default_confmaps.json +++ b/sleap/training_profiles/default_confmaps.json @@ -1 +1,50 @@ -{"model": {"output_type": 0, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 32, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file +{ + "model": { + "output_type": 0, + "backbone": { + "down_blocks": 3, + "up_blocks": 3, + "convs_per_depth": 2, + "num_filters": 32, + "kernel_size": 5, + "upsampling_layers": true, + "interp": "bilinear" + }, + "skeletons": null, + "backbone_name": "UNet" + }, + "trainer": { + "val_size": 0.1, + "optimizer": "adam", + "learning_rate": 0.0001, + "amsgrad": true, + "batch_size": 2, + "num_epochs": 150, + "steps_per_epoch": 200, + "shuffle_initially": true, + "shuffle_every_epoch": true, + "augment_rotation": 180, + "augment_scale_min": 1.0, + "augment_scale_max": 1.0, + "save_every_epoch": false, + "save_best_val": true, + "reduce_lr_min_delta": "1e-06", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 5, + "reduce_lr_cooldown": 3, + "reduce_lr_min_lr": "1e-10", + "early_stopping_min_delta": "1e-08", + "early_stopping_patience": 15, + "scale": 1, + "sigma": 5.0, + "instance_crop": true, + "min_crop_size": 32, + "negative_samples": 10 + }, + "labels_filename": null, + "run_name": null, + "save_dir": null, + "best_model_filename": null, + "newest_model_filename": null, + "final_model_filename": null +} \ No newline at end of file diff --git a/sleap/training_profiles/default_pafs.json b/sleap/training_profiles/default_pafs.json index 5c04a2acc..8cdd66dde 100644 --- a/sleap/training_profiles/default_pafs.json +++ b/sleap/training_profiles/default_pafs.json @@ -1 +1,48 @@ -{"model": {"output_type": 1, "backbone": {"down_blocks": 3, "up_blocks": 3, "upsampling_layers": true, "num_filters": 32, "interp": "bilinear"}, "skeletons": null, "backbone_name": "LeapCNN"}, "trainer": {"val_size": 0.15, "optimizer": "adam", "learning_rate": 5e-5, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 100, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-6, "reduce_lr_factor": 0.5, "reduce_lr_patience": 8, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file +{ + "model": { + "output_type": 1, + "backbone": { + "down_blocks": 3, + "up_blocks": 3, + "upsampling_layers": true, + "num_filters": 32, + "interp": "bilinear" + }, + "skeletons": null, + "backbone_name": "LeapCNN" + }, + "trainer": { + "val_size": 0.15, + "optimizer": "adam", + "learning_rate": "5e-5", + "amsgrad": true, + "batch_size": 2, + "num_epochs": 150, + "steps_per_epoch": 100, + "shuffle_initially": true, + "shuffle_every_epoch": true, + "augment_rotation": 180, + "augment_scale_min": 1.0, + "augment_scale_max": 1.0, + "save_every_epoch": false, + "save_best_val": true, + "reduce_lr_min_delta": "1e-6", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 8, + "reduce_lr_cooldown": 3, + "reduce_lr_min_lr": "1e-10", + "early_stopping_min_delta": "1e-08", + "early_stopping_patience": 15, + "scale": 1, + "sigma": 5.0, + "instance_crop": true, + "min_crop_size": 32, + "negative_samples": 10 + }, + "labels_filename": null, + "run_name": null, + "save_dir": null, + "best_model_filename": null, + "newest_model_filename": null, + "final_model_filename": null +} \ No newline at end of file diff --git a/tests/test_instance.py b/tests/test_instance.py index 82e93abf2..e61d07f11 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -151,7 +151,7 @@ def test_points_array(skeleton): pts = instance1.get_points_array() assert not np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() - pts = instance1.visible_points_array + pts = instance1.points_array assert np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() def test_modifying_skeleton(skeleton): From b4276b3f9d2605e13be3d15bff4f8d3364b9590a Mon Sep 17 00:00:00 2001 From: Talmo Date: Wed, 25 Sep 2019 04:16:21 -0400 Subject: [PATCH 105/176] Active inference changes - Removed disk serialization of predictions (it was being overwritten after each video anyway) - Progress dialog works - Canceling works gracefully - Broken: asynchronous call so that the GUI remains responsive --- sleap/gui/active.py | 79 +++++++++++++++++++++++++++++-------------- sleap/nn/inference.py | 36 ++++++-------------- 2 files changed, 63 insertions(+), 52 deletions(-) diff --git a/sleap/gui/active.py b/sleap/gui/active.py index ba31d8682..7babbec2d 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -732,55 +732,82 @@ def run_active_inference( Number of new frames added to labels. """ from sleap.nn.inference import Predictor + # from multiprocessing import Pool - total_new_lf_count = 0 - timestamp = datetime.now().strftime("%y%m%d_%H%M%S") - inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5") + # total_new_lf_count = 0 + # timestamp = datetime.now().strftime("%y%m%d_%H%M%S") + # inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5") # Create Predictor from the results of training + # pool = Pool(processes=1) predictor = Predictor(training_jobs=training_jobs, with_tracking=with_tracking, - output_path=inference_output_path) + # output_path=inference_output_path, + # pool=pool + ) if gui: # show message while running inference - win = QtWidgets.QProgressDialog() - win.setLabelText(" Running inference on selected frames... ") - win.show() + progress = QtWidgets.QProgressDialog( + f"Running inference on {len(frames_to_predict)} videos...", + "Cancel", + 0, len(frames_to_predict)) + # win.setLabelText(" Running inference on selected frames... ") + progress.show() QtWidgets.QApplication.instance().processEvents() - for video, frames in frames_to_predict.items(): - if len(frames): + new_lfs = [] + for i, (video, frames) in enumerate(frames_to_predict.items()): + QtWidgets.QApplication.instance().processEvents() + if len(frames): # Run inference for desired frames in this video - pool, result = predictor.predict_async( - input_video=video, - frames=frames) + # result = predictor.predict_async( + new_lfs_video = predictor.predict( + input_video=video, frames=frames) + new_lfs.extend(new_lfs_video) - while not result.ready(): - if gui: - QtWidgets.QApplication.instance().processEvents() - result.wait(.01) + if gui: + progress.setValue(i) + if progress.wasCanceled(): + return 0 - if result.successful(): - new_labels_json = result.get() + # while not result.ready(): + # if gui: + # QtWidgets.QApplication.instance().processEvents() + # result.wait(.01) + + # if result.successful(): + # new_labels_json = result.get() # Add new frames to labels # (we're doing this for each video as we go since there was a problem # when we tried to add frames for all videos together.) - new_lf_count = add_frames_from_json(labels, new_labels_json) + # new_lf_count = add_frames_from_json(labels, new_labels_json) - total_new_lf_count += new_lf_count - else: - if gui: - QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() - result.get() + # total_new_lf_count += new_lf_count + # else: + # if gui: + # QtWidgets.QApplication.instance().processEvents() + # QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() + # result.get() + + # predictor.pool.close() + + # Remove any frames without instances + new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) + + # Now add them to labels and merge labeled frames with same video/frame_idx + # labels.extend_from(new_lfs) + labels.extend_from(new_lfs, unify=True) + labels.merge_matching_frames() # close message window if gui: - win.close() + progress.close() - return total_new_lf_count + # return total_new_lf_count + return len(new_lfs) if __name__ == "__main__": import sys diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index bc666774a..9cc6ab14e 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -328,7 +328,7 @@ class Predictor: gpu_peak_finding: bool = True supersample_window_size: int = 7 # must be odd supersample_factor: float = 2 # factor to upsample cropped windows by - overlapping_instances_nms: bool = False # suppress overlapping instances + overlapping_instances_nms: bool = True # suppress overlapping instances def __attrs_post_init__(self): @@ -405,10 +405,12 @@ def predict(self, # Delete the output file if it exists already if os.path.exists(self.output_path): os.unlink(self.output_path) + logger.warning("Deleted existing output: " + self.output_path) # Create output directory if it doesn't exist if not os.path.exists(self.output_path): - os.makedirs(self.output_path) + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + logger.info("Output path: " + self.output_path) # Process chunk-by-chunk! t0_start = time() @@ -450,27 +452,8 @@ def predict(self, imgs_full, frames_idx) else: - # Scale without centroid cropping - # TODO: Move this into the processing below to allow for different input scales by model. - - # Determine scaled image size - # cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] - # input_scale = cm_model.input_scale - # scale_to = (int(vid.height // (1 / input_scale)), int(vid.width // (1 / input_scale))) - - # # if self.resize_hack: - # # TODO: Replace this when model-specific divisibility calculation implemented. - # divisor = 2 ** cm_model.down_blocks - # crop_to = ( - # (scale_to[0] // divisor) * divisor, - # (scale_to[1] // divisor) * divisor) - # Create transform object transform = DataTransform(frame_idxs=frames_idx) - - # Scale if target doesn't match current size - # imgs_full = transform.scale_to(mov_full, target_size=scale_to) - subchunks_to_process = [(imgs_full, transform)] logger.info(" Transformed images [%.1fs]" % (time() - t0)) @@ -561,7 +544,7 @@ def predict(self, # We should save in chunks then combine at the end. labels = Labels(labeled_frames=predicted_frames) if self.output_path is not None: - if self.output_path.endswith('json'): + if self.output_path.endswith("json"): Labels.save_json(labels, filename=self.output_path, compress=True) else: Labels.save_hdf5(labels, filename=self.output_path) @@ -610,13 +593,14 @@ def predict_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: # unstructure input_video since it won't pickle kwargs["input_video"] = Video.cattr().unstructure(kwargs["input_video"]) - pool = Pool(processes=1) - result = pool.apply_async(self.predict, args=args, kwds=kwargs) + if self.pool is None: + self.pool = Pool(processes=1) + result = self.pool.apply_async(self.predict, args=args, kwds=kwargs) # Tell the pool to accept no new tasks - pool.close() + # pool.close() - return pool, result + return result def centroid_crop_inference(self, From 122ddbd7bac833db4893fdeffd56f70fa4b7d895 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 09:15:22 -0400 Subject: [PATCH 106/176] Show table of clean merge results. Clean merge results data format in different format (dict of video -> dict of frame idx -> list of instances). --- sleap/gui/merge.py | 93 ++++++++++++++++++++++++++++++++-------- sleap/instance.py | 4 +- tests/io/test_dataset.py | 2 +- 3 files changed, 79 insertions(+), 20 deletions(-) diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py index 7d0cffd40..96203ad43 100644 --- a/sleap/gui/merge.py +++ b/sleap/gui/merge.py @@ -30,24 +30,33 @@ def __init__(self, merge_frames = 0 for vid_frame_list in merged.values(): # number of frames for this video - merge_frames += len(vid_frame_list) + merge_frames += len(vid_frame_list.keys()) # number of instances across frames for this video - merge_total += sum((map(len, vid_frame_list))) + merge_total += sum((map(len, vid_frame_list.values()))) buttons = self._make_buttons(conflict=self.extra_base) - merged_label = QtWidgets.QLabel(f"Cleanly merged {merge_total} instances across {merge_frames} frames.") - - conflict_text = "There are no conflicts." if not self.extra_base else "Merge conflicts:" - conflict_label = QtWidgets.QLabel(conflict_text) layout = QtWidgets.QVBoxLayout() + + merged_text = f"Cleanly merged {merge_total} instances" + if merge_total: + merged_text += f" across {merge_frames} frames" + merged_text += "." + merged_label = QtWidgets.QLabel(merged_text) layout.addWidget(merged_label) + if merge_total: + merge_table = MergeTable(merged) + layout.addWidget(merge_table) + + conflict_text = "There are no conflicts." if not self.extra_base else "Merge conflicts:" + conflict_label = QtWidgets.QLabel(conflict_text) layout.addWidget(conflict_label) + if self.extra_base: conflict_table = ConflictTable(self.base_labels, self.extra_base, self.extra_new) - layout.addWidget(conflict_table) + layout.addWidget(conflict_table) layout.addWidget(buttons) @@ -107,18 +116,12 @@ def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): if prop == "frame": return self.extra_base[idx].frame_idx if prop == "base": - return self._showInstanceCount(self.extra_base[idx]) + return show_instance_type_counts(self.extra_base[idx]) if prop == "new": - return self._showInstanceCount(self.extra_new[idx]) + return show_instance_type_counts(self.extra_new[idx]) return None - @staticmethod - def _showInstanceCount(instance_list): - prediction_count = len(list(filter(lambda inst: hasattr(inst, "score"), instance_list))) - user_count = len(instance_list) - prediction_count - return f"{prediction_count}/{user_count}" - def rowCount(self, *args): return len(self.extra_base) @@ -133,10 +136,66 @@ def headerData(self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt return section return None +class MergeTable(QtWidgets.QTableView): + def __init__(self, *args, **kwargs): + super(MergeTable, self).__init__() + self.setModel(MergeTableModel(*args, **kwargs)) + +class MergeTableModel(QtCore.QAbstractTableModel): + _props = ["video", "frame", "merged instances"] + + def __init__(self, merged: List[List['Instance']]): + super(MergeTableModel, self).__init__() + self.merged = merged + + self.data_table = [] + for video in self.merged.keys(): + for frame_idx, frame_instance_list in self.merged[video].items(): + self.data_table.append(dict( + filename=video.filename, + frame_idx=frame_idx, + instances=frame_instance_list)) + + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + if role == QtCore.Qt.DisplayRole and index.isValid(): + idx = index.row() + prop = self._props[index.column()] + + if idx < self.rowCount(): + if prop == "video": + return self.data_table[idx]["filename"] + if prop == "frame": + return self.data_table[idx]["frame_idx"] + if prop == "merged instances": + return show_instance_type_counts(self.data_table[idx]["instances"]) + + return None + + def rowCount(self, *args): + return len(self.data_table) + + def columnCount(self, *args): + return len(self._props) + + def headerData(self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole): + if role == QtCore.Qt.DisplayRole: + if orientation == QtCore.Qt.Horizontal: + return self._props[section] + elif orientation == QtCore.Qt.Vertical: + return section + return None + +def show_instance_type_counts(instance_list): + prediction_count = len(list(filter(lambda inst: hasattr(inst, "score"), instance_list))) + user_count = len(instance_list) - prediction_count + return f"{prediction_count}/{user_count}" + if __name__ == "__main__": - file_a = "tests/data/json_format_v1/centered_pair.json" - file_b = "tests/data/json_format_v2/centered_pair_predictions.json" +# file_a = "tests/data/json_format_v1/centered_pair.json" +# file_b = "tests/data/json_format_v2/centered_pair_predictions.json" + file_a = "files/merge/a.h5" + file_b = "files/merge/b.h5" base_labels = Labels.load_file(file_a) new_labels = Labels.load_file(file_b) diff --git a/sleap/instance.py b/sleap/instance.py index cd18b3e1b..aed1a078d 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -949,8 +949,8 @@ def complex_merge_between(cls, base_labels: 'Labels', new_frames: List['LabeledF if merged_instances: if new_frame.video not in merged: - merged[new_frame.video] = [] - merged[new_frame.video].append(merged_instances) + merged[new_frame.video] = dict() + merged[new_frame.video][new_frame.frame_idx] = merged_instances return merged, extra_base, extra_new @classmethod diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 210a432fe..599d0f476 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -312,7 +312,7 @@ def test_complex_merge(): # Check that we have the cleanly merged frame assert dummy_video_a in merged assert len(merged[dummy_video_a]) == 1 # one merged frame - assert len(merged[dummy_video_a][0]) == 2 # with two instances + assert len(merged[dummy_video_a][1]) == 2 # with two instances # Check that labels_a includes redundant and clean assert len(labels_a.labeled_frames) == 2 From 0eb59db5fed622ddb105872cfbe55cde9bb85a90 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 09:35:43 -0400 Subject: [PATCH 107/176] Combobox (menu) for merge methods. --- sleap/gui/merge.py | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py index 96203ad43..22199e8c5 100644 --- a/sleap/gui/merge.py +++ b/sleap/gui/merge.py @@ -11,6 +11,11 @@ from PySide2 import QtWidgets, QtCore +USE_BASE_STRING = "Use base, discard conflicting new instances" +USE_NEW_STRING = "Use new, discard conflicting base instances" +USE_NEITHER_STRING = "Discard all conflicting instances" +CLEAN_STRING = "Accept clean merge" + class MergeDialog(QtWidgets.QDialog): def __init__(self, @@ -34,9 +39,6 @@ def __init__(self, # number of instances across frames for this video merge_total += sum((map(len, vid_frame_list.values()))) - buttons = self._make_buttons(conflict=self.extra_base) - - layout = QtWidgets.QVBoxLayout() merged_text = f"Cleanly merged {merge_total} instances" @@ -58,33 +60,33 @@ def __init__(self, conflict_table = ConflictTable(self.base_labels, self.extra_base, self.extra_new) layout.addWidget(conflict_table) - layout.addWidget(buttons) - - self.setLayout(layout) - - def _make_buttons(self, conflict: bool): - self.use_base_button = None - self.use_new_button = None - self.okay_button = None + self.merge_method = QtWidgets.QComboBox() + if self.extra_base: + self.merge_method.addItem(USE_NEW_STRING) + self.merge_method.addItem(USE_BASE_STRING) + self.merge_method.addItem(USE_NEITHER_STRING) + else: + self.merge_method.addItem(CLEAN_STRING) + layout.addWidget(self.merge_method) buttons = QtWidgets.QDialogButtonBox() - if conflict: - self.use_base_button = buttons.addButton("Use Base", QtWidgets.QDialogButtonBox.YesRole) - self.use_new_button = buttons.addButton("Use New", QtWidgets.QDialogButtonBox.NoRole) - else: - self.okay_button = buttons.addButton(QtWidgets.QDialogButtonBox.Ok) + buttons.addButton("Finish Merge", QtWidgets.QDialogButtonBox.AcceptRole) + buttons.accepted.connect(self.finishMerge) - buttons.clicked.connect(self.finishMerge) + layout.addWidget(buttons) - return buttons + self.setLayout(layout) - def finishMerge(self, button): - if button == self.use_base_button: + def finishMerge(self): + merge_method = self.merge_method.currentText() + if merge_method == USE_BASE_STRING: Labels.finish_complex_merge(self.base_labels, self.extra_base) - elif button == self.use_new_button: + elif merge_method == USE_NEW_STRING: Labels.finish_complex_merge(self.base_labels, self.extra_new) - elif button == self.okay_button: + elif merge_method in (USE_NEITHER_STRING, CLEAN_STRING): Labels.finish_complex_merge(self.base_labels, []) + else: + raise ValueError("No valid merge method selected.") self.accept() From 325852ba607925eb95fc757a2fef81e25fbf401a Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 10:38:17 -0400 Subject: [PATCH 108/176] Use file with user data for demo code. --- sleap/nn/augmentation.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sleap/nn/augmentation.py b/sleap/nn/augmentation.py index 37fcd017e..04654fabc 100644 --- a/sleap/nn/augmentation.py +++ b/sleap/nn/augmentation.py @@ -213,14 +213,9 @@ def demo_augmentation(): from sleap.nn.datagen import generate_training_data from sleap.nn.datagen import generate_confmaps_from_points, generate_pafs_from_points - data_path = "tests/data/json_format_v2/centered_pair_predictions.json" - # data_path = "tests/data/json_format_v2/minimal_instance.json" -# data_path = "tests/data/json_format_v1/test.json" - + data_path = "tests/data/json_format_v1/centered_pair.json" labels = Labels.load_json(data_path) -# labels.labeled_frames = labels.labeled_frames[123:323:10] - # Generate raw training data skeleton = labels.skeletons[0] imgs, points = generate_training_data(labels, params = dict( From 14d57b83dff16a2e44421c744fbfdf4c3365fbdd Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 10:54:56 -0400 Subject: [PATCH 109/176] Add test for sleap.io.visuals. The code to write a video with instances on it uses opencv. --- tests/io/test_visuals.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/io/test_visuals.py diff --git a/tests/io/test_visuals.py b/tests/io/test_visuals.py new file mode 100644 index 000000000..15c55e005 --- /dev/null +++ b/tests/io/test_visuals.py @@ -0,0 +1,11 @@ +import os +from sleap.io.visuals import save_labeled_video + +def test_write_visuals(tmpdir, centered_pair_predictions): + path = os.path.join(tmpdir, 'clip.avi') + save_labeled_video(filename=path, + labels=centered_pair_predictions, + video=centered_pair_predictions.videos[0], + frames=(0,1,2), + fps=15) + assert os.path.exists(path) \ No newline at end of file From cbc8f17411aa99967408f775c661705670670f14 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 10:56:32 -0400 Subject: [PATCH 110/176] For cv2 dependency use opencv-python-headless. imgaug 0.3.0 depends on headless cv package which causes conflict if we also require opencv-python. Everything we use appears to be in the headless package, so we may as well use that. --- .conda/bld.bat | 2 +- environment.yml | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.conda/bld.bat b/.conda/bld.bat index aa46b2eb7..157b97b14 100644 --- a/.conda/bld.bat +++ b/.conda/bld.bat @@ -13,7 +13,7 @@ rem # this out myself, ughhh. set PIP_NO_INDEX=False set PIP_NO_DEPENDENCIES=False set PIP_IGNORE_INSTALLED=False -pip install cattrs==1.0.0rc opencv-python==3.4.1.15 PySide2==5.12.0 imgaug qimage2ndarray==1.8 imgstore +pip install cattrs==1.0.0rc opencv-python-headless==3.4.1.15 PySide2==5.12.0 imgaug qimage2ndarray==1.8 imgstore rem # Use and update environment.yml call to install pip dependencies. This is slick. rem # While environment.yml contains the non pip dependencies, the only thing left diff --git a/environment.yml b/environment.yml index d9d3f0684..779e21182 100644 --- a/environment.yml +++ b/environment.yml @@ -17,7 +17,7 @@ dependencies: - python-rapidjson - pip - pip: - - opencv-python==3.4.1.15 + - opencv-python-headless==3.4.1.15 - PySide2==5.12.0 - imgaug - cattrs==1.0.0rc0 diff --git a/requirements.txt b/requirements.txt index 2d2c82a6d..21763ec20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ tensorflow keras h5py python-rapidjson -opencv-python==3.4.1.15 +opencv-python-headless==3.4.1.15 pandas psutil PySide2 From 06bf53d220794ed27393da805643c58f19cb26e0 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 12:22:02 -0400 Subject: [PATCH 111/176] Remove unused (broken) method. --- sleap/gui/training_editor.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sleap/gui/training_editor.py b/sleap/gui/training_editor.py index 9a1f4e6dc..ae6960acd 100644 --- a/sleap/gui/training_editor.py +++ b/sleap/gui/training_editor.py @@ -80,14 +80,6 @@ def _load_profile(self, profile_filename:str): for name in "datagen,trainer,output".split(","): self.form_widgets[name].set_form_data(job_dict["trainer"]) - def _update_profile(self): - # update training job from params in form - trainer = job.trainer - for key, val in form_data.items(): - # check if form field matches attribute of Trainer object - if key in dir(trainer): - setattr(trainer, key, val) - def _save_as(self): # Show "Save" dialog From 643d8575a0005f0cf81eea6daff3b2513507ec72 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 15:21:38 -0400 Subject: [PATCH 112/176] Coveralls in appveyor script. --- appveyor.yml | 5 +++++ dev_requirements.txt | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index d1133066a..8bc32f422 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -3,6 +3,8 @@ version: '{build}' clone_depth: 5 environment: + COVERALLS_REPO_TOKEN: + secure: VsCyKmdi8x0OFK+Jbzk7ZRAW3EtojWP85TWqWKi+vuGmdiQFX7rLPnuaw3kt++a8 access_token: secure: T7XuBtHDu85Tk/d1AeyfhW3CVyzaoddTWmR4xsPIdQ3di0R6x8ncWqw3KrYXkWJm @@ -68,6 +70,9 @@ test_script: - cmd: where python - cmd: python -m pytest tests/ +after_success: + - coveralls + # here we are going to override common configuration for: diff --git a/dev_requirements.txt b/dev_requirements.txt index f0e899b95..0f005e97b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -3,4 +3,5 @@ pytest-qt pytest-cov ipython sphinx -sphinx_rtd_theme \ No newline at end of file +sphinx_rtd_theme +python-coveralls \ No newline at end of file From bda29a9ba8af09f1a02cedba48ee88bfb834e487 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 15:23:48 -0400 Subject: [PATCH 113/176] Coveralls in appveyor script. --- appveyor.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 8bc32f422..3f19c3e7a 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -70,8 +70,8 @@ test_script: - cmd: where python - cmd: python -m pytest tests/ -after_success: - - coveralls +after_test: + - cmd: coveralls # here we are going to override common configuration for: From 582188d8d27ec712cf7c4a0a15828f4dd8298474 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 15:45:18 -0400 Subject: [PATCH 114/176] pytest-cov in appveyor script. --- appveyor.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 3f19c3e7a..6220862a5 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -68,9 +68,9 @@ build: off test_script: - cmd: activate sleap_appveyor - cmd: where python - - cmd: python -m pytest tests/ + - cmd: pytest --cov sleap -after_test: +on_success: - cmd: coveralls # here we are going to override common configuration From 0f3fb2e0eb38b8f360446ddc155be371e93eee2e Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 15:46:28 -0400 Subject: [PATCH 115/176] pytest-cov in appveyor script. --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 6220862a5..70aa5aede 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -68,7 +68,7 @@ build: off test_script: - cmd: activate sleap_appveyor - cmd: where python - - cmd: pytest --cov sleap + - cmd: pytest --cov=sleap tests/ on_success: - cmd: coveralls From f5dfdd90bdacb7c30e8be19b8f2d0540fd140d80 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 16:40:27 -0400 Subject: [PATCH 116/176] Reformat all code with black. --- sleap/gui/active.py | 342 ++++++---- sleap/gui/app.py | 982 +++++++++++++++++++--------- sleap/gui/dataviews.py | 96 ++- sleap/gui/formbuilder.py | 40 +- sleap/gui/importvideos.py | 186 +++--- sleap/gui/merge.py | 67 +- sleap/gui/multicheck.py | 7 +- sleap/gui/overlays/anchors.py | 33 +- sleap/gui/overlays/base.py | 91 ++- sleap/gui/overlays/confmaps.py | 60 +- sleap/gui/overlays/instance.py | 24 +- sleap/gui/overlays/pafs.py | 177 ++--- sleap/gui/overlays/tracks.py | 195 +++--- sleap/gui/shortcuts.py | 84 ++- sleap/gui/slider.py | 162 +++-- sleap/gui/suggestions.py | 149 +++-- sleap/gui/training_editor.py | 89 ++- sleap/gui/video.py | 340 +++++++--- sleap/info/labels.py | 2 +- sleap/info/metrics.py | 101 ++- sleap/info/summary.py | 11 +- sleap/info/write_tracking_h5.py | 37 +- sleap/instance.py | 302 ++++++--- sleap/io/dataset.py | 839 ++++++++++++++++-------- sleap/io/legacy.py | 94 ++- sleap/io/video.py | 144 ++-- sleap/io/visuals.py | 163 +++-- sleap/nn/architectures/__init__.py | 4 +- sleap/nn/architectures/common.py | 42 +- sleap/nn/architectures/densenet.py | 102 ++- sleap/nn/architectures/hourglass.py | 109 ++- sleap/nn/architectures/leap.py | 55 +- sleap/nn/architectures/resnet.py | 115 ++-- sleap/nn/architectures/unet.py | 79 ++- sleap/nn/augmentation.py | 82 ++- sleap/nn/datagen.py | 382 +++++++---- sleap/nn/inference.py | 405 ++++++++---- sleap/nn/model.py | 47 +- sleap/nn/monitor.py | 52 +- sleap/nn/peakfinding.py | 22 +- sleap/nn/peakfinding_tf.py | 111 ++-- sleap/nn/peakmatching.py | 283 +++++--- sleap/nn/tracking.py | 194 ++++-- sleap/nn/training.py | 418 ++++++++---- sleap/nn/transform.py | 23 +- sleap/nn/util.py | 23 +- sleap/rangelist.py | 23 +- sleap/skeleton.py | 208 ++++-- sleap/util.py | 55 +- 49 files changed, 4988 insertions(+), 2663 deletions(-) diff --git a/sleap/gui/active.py b/sleap/gui/active.py index 7babbec2d..ffb0d1074 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -16,15 +16,20 @@ from PySide2 import QtWidgets, QtCore + class ActiveLearningDialog(QtWidgets.QDialog): learningFinished = QtCore.Signal() - def __init__(self, - labels_filename: str, labels: Labels, - mode: str="expert", - only_predict: bool=False, - *args, **kwargs): + def __init__( + self, + labels_filename: str, + labels: Labels, + mode: str = "expert", + only_predict: bool = False, + *args, + **kwargs, + ): super(ActiveLearningDialog, self).__init__(*args, **kwargs) @@ -35,27 +40,37 @@ def __init__(self, print(f"Number of frames to train on: {len(labels.user_labeled_frames)}") - title = dict(learning="Active Learning", - inference="Inference", - expert="Inference Pipeline", - ) + title = dict( + learning="Active Learning", + inference="Inference", + expert="Inference Pipeline", + ) - learning_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/active.yaml") + learning_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/active.yaml" + ) self.form_widget = YamlFormWidget( - yaml_file=learning_yaml, - which_form=self.mode, - title=title[self.mode] + " Settings") + yaml_file=learning_yaml, + which_form=self.mode, + title=title[self.mode] + " Settings", + ) # form ui self.training_profile_widgets = dict() if "conf_job" in self.form_widget.fields: - self.training_profile_widgets[ModelOutputType.CONFIDENCE_MAP] = self.form_widget.fields["conf_job"] + self.training_profile_widgets[ + ModelOutputType.CONFIDENCE_MAP + ] = self.form_widget.fields["conf_job"] if "paf_job" in self.form_widget.fields: - self.training_profile_widgets[ModelOutputType.PART_AFFINITY_FIELD] = self.form_widget.fields["paf_job"] + self.training_profile_widgets[ + ModelOutputType.PART_AFFINITY_FIELD + ] = self.form_widget.fields["paf_job"] if "centroid_job" in self.form_widget.fields: - self.training_profile_widgets[ModelOutputType.CENTROIDS] = self.form_widget.fields["centroid_job"] + self.training_profile_widgets[ + ModelOutputType.CENTROIDS + ] = self.form_widget.fields["centroid_job"] self._rebuild_job_options() self._update_job_menus(init=True) @@ -63,8 +78,8 @@ def __init__(self, buttons = QtWidgets.QDialogButtonBox() self.cancel_button = buttons.addButton(QtWidgets.QDialogButtonBox.Cancel) self.run_button = buttons.addButton( - "Run "+title[self.mode], - QtWidgets.QDialogButtonBox.AcceptRole) + "Run " + title[self.mode], QtWidgets.QDialogButtonBox.AcceptRole + ) self.status_message = QtWidgets.QLabel("hi!") @@ -84,21 +99,29 @@ def __init__(self, # connect actions to buttons def edit_conf_profile(): - self.view_profile(self.form_widget["conf_job"], - model_type=ModelOutputType.CONFIDENCE_MAP) + self.view_profile( + self.form_widget["conf_job"], model_type=ModelOutputType.CONFIDENCE_MAP + ) + def edit_paf_profile(): - self.view_profile(self.form_widget["paf_job"], - model_type=ModelOutputType.PART_AFFINITY_FIELD) + self.view_profile( + self.form_widget["paf_job"], + model_type=ModelOutputType.PART_AFFINITY_FIELD, + ) + def edit_cent_profile(): - self.view_profile(self.form_widget["centroid_job"], - model_type=ModelOutputType.CENTROIDS) + self.view_profile( + self.form_widget["centroid_job"], model_type=ModelOutputType.CENTROIDS + ) if "_view_conf" in self.form_widget.buttons: self.form_widget.buttons["_view_conf"].clicked.connect(edit_conf_profile) if "_view_paf" in self.form_widget.buttons: self.form_widget.buttons["_view_paf"].clicked.connect(edit_paf_profile) if "_view_centoids" in self.form_widget.buttons: - self.form_widget.buttons["_view_centoids"].clicked.connect(edit_cent_profile) + self.form_widget.buttons["_view_centoids"].clicked.connect( + edit_cent_profile + ) if "_view_datagen" in self.form_widget.buttons: self.form_widget.buttons["_view_datagen"].clicked.connect(self.view_datagen) @@ -111,7 +134,9 @@ def edit_cent_profile(): def _rebuild_job_options(self): # load list of job profiles from directory - profile_dir = resource_filename(Requirement.parse("sleap"), "sleap/training_profiles") + profile_dir = resource_filename( + Requirement.parse("sleap"), "sleap/training_profiles" + ) labels_dir = os.path.join(os.path.dirname(self.labels_filename), "models") self.job_options = dict() @@ -127,7 +152,9 @@ def _update_job_menus(self, init=False): if model_type not in self.job_options: self.job_options[model_type] = [] if init: - field.currentIndexChanged.connect(lambda idx, mt=model_type: self.select_job(mt, idx)) + field.currentIndexChanged.connect( + lambda idx, mt=model_type: self.select_job(mt, idx) + ) else: # block signals so we can update combobox without overwriting # any user data with the defaults from the profile @@ -148,7 +175,7 @@ def frame_selection(self, frame_selection): prediction_options = [] def count_total_frames(videos_frames): - return reduce(lambda x,y:x+y, map(len, videos_frames.values())) + return reduce(lambda x, y: x + y, map(len, videos_frames.values())) # Determine which options are available given _frame_selection @@ -177,7 +204,9 @@ def count_total_frames(videos_frames): prediction_options.append(f"entire video ({video_length} frames)") - self.form_widget.fields["_predict_frames"].set_options(prediction_options, default_option) + self.form_widget.fields["_predict_frames"].set_options( + prediction_options, default_option + ) def show(self): super(ActiveLearningDialog, self).show() @@ -211,8 +240,9 @@ def update_gui(self): self.form_widget.fields["instance_crop"].setEnabled(True) error_messages = [] - if form_data.get("_use_trained_confmaps", False) and \ - form_data.get("_use_trained_pafs", False): + if form_data.get("_use_trained_confmaps", False) and form_data.get( + "_use_trained_pafs", False + ): # make sure trained models are compatible conf_job, _ = self._get_current_job(ModelOutputType.CONFIDENCE_MAP) paf_job, _ = self._get_current_job(ModelOutputType.PART_AFFINITY_FIELD) @@ -221,19 +251,30 @@ def update_gui(self): if conf_job is not None and paf_job is not None: if conf_job.trainer.scale != paf_job.trainer.scale: can_run = False - error_messages.append(f"training image scale for confmaps ({conf_job.trainer.scale}) does not match pafs ({paf_job.trainer.scale})") + error_messages.append( + f"training image scale for confmaps ({conf_job.trainer.scale}) does not match pafs ({paf_job.trainer.scale})" + ) if conf_job.trainer.instance_crop != paf_job.trainer.instance_crop: can_run = False - crop_model_name = "confmaps" if conf_job.trainer.instance_crop else "pafs" - error_messages.append(f"exactly one model ({crop_model_name}) was trained on crops") + crop_model_name = ( + "confmaps" if conf_job.trainer.instance_crop else "pafs" + ) + error_messages.append( + f"exactly one model ({crop_model_name}) was trained on crops" + ) if use_centroids and not conf_job.trainer.instance_crop: can_run = False - error_messages.append(f"models used with centroids must be trained on cropped images") + error_messages.append( + f"models used with centroids must be trained on cropped images" + ) message = "" if not can_run: - message = "Unable to run with selected models:\n- " + \ - ";\n- ".join(error_messages) + "." + message = ( + "Unable to run with selected models:\n- " + + ";\n- ".join(error_messages) + + "." + ) self.status_message.setText(message) self.run_button.setEnabled(can_run) @@ -280,7 +321,7 @@ def _get_current_training_jobs(self): form_data = self.form_widget.get_form_data() training_jobs = dict() - default_use_trained = (self.mode == "inference") + default_use_trained = self.mode == "inference" for model_type in self._get_model_types_to_use(): job, _ = self._get_current_job(model_type) @@ -337,19 +378,25 @@ def run(self): # Run active learning pipeline using the TrainingJobs new_counts = run_active_learning_pipeline( - labels_filename = self.labels_filename, - labels = self.labels, - training_jobs = training_jobs, - frames_to_predict = frames_to_predict, - with_tracking = with_tracking) + labels_filename=self.labels_filename, + labels=self.labels, + training_jobs=training_jobs, + frames_to_predict=frames_to_predict, + with_tracking=with_tracking, + ) self.learningFinished.emit() - QtWidgets.QMessageBox(text=f"Active learning has finished. Instances were predicted on {new_counts} frames.").exec_() + QtWidgets.QMessageBox( + text=f"Active learning has finished. Instances were predicted on {new_counts} frames." + ).exec_() def view_datagen(self): - from sleap.nn.datagen import generate_training_data, \ - generate_confmaps_from_points, generate_pafs_from_points + from sleap.nn.datagen import ( + generate_training_data, + generate_confmaps_from_points, + generate_pafs_from_points, + ) from sleap.io.video import Video from sleap.gui.overlays.confmaps import demo_confmaps from sleap.gui.overlays.pafs import demo_pafs @@ -367,19 +414,23 @@ def view_datagen(self): negative_samples = form_data.get("negative_samples", 0) imgs, points = generate_training_data( - self.labels, - params = dict( - frame_limit = 10, - scale = scale, - instance_crop = instance_crop, - min_crop_size = min_crop_size, - negative_samples = negative_samples)) + self.labels, + params=dict( + frame_limit=10, + scale=scale, + instance_crop=instance_crop, + min_crop_size=min_crop_size, + negative_samples=negative_samples, + ), + ) skeleton = self.labels.skeletons[0] img_shape = (imgs.shape[1], imgs.shape[2]) vid = Video.from_numpy(imgs * 255) - confmaps = generate_confmaps_from_points(points, skeleton, img_shape, sigma=sigma_confmaps) + confmaps = generate_confmaps_from_points( + points, skeleton, img_shape, sigma=sigma_confmaps + ) conf_win = demo_confmaps(confmaps, vid) conf_win.activateWindow() conf_win.move(200, 200) @@ -387,7 +438,7 @@ def view_datagen(self): pafs = generate_pafs_from_points(points, skeleton, img_shape, sigma=sigma_pafs) paf_win = demo_pafs(pafs, vid) paf_win.activateWindow() - paf_win.move(220+conf_win.rect().width(), 200) + paf_win.move(220 + conf_win.rect().width(), 200) # FIXME: hide dialog so use can see other windows # can we show these windows without closing dialog? @@ -412,14 +463,16 @@ def option_list_from_jobs(self, model_type): def add_job_file(self, model_type): filename, _ = QtWidgets.QFileDialog.getOpenFileName( - None, dir=None, - caption="Select training profile...", - filter="TrainingJob JSON (*.json)") + None, + dir=None, + caption="Select training profile...", + filter="TrainingJob JSON (*.json)", + ) self._add_job_file_to_list(filename, model_type) field = self.training_profile_widgets[model_type] # if we didn't successfully select a new file, then clear selection - if field.currentIndex() == field.count()-1: # subtract 1 for separator + if field.currentIndex() == field.count() - 1: # subtract 1 for separator field.setCurrentIndex(-1) def _add_job_file_to_list(self, filename, model_type): @@ -429,7 +482,9 @@ def _add_job_file_to_list(self, filename, model_type): job = TrainingJob.load_json(filename) except: # but do raise any other type of error - QtWidgets.QMessageBox(text=f"Unable to load a training profile from {filename}.").exec_() + QtWidgets.QMessageBox( + text=f"Unable to load a training profile from {filename}." + ).exec_() raise else: # we loaded the json as a TrainingJob, so see what type of model it's for @@ -441,18 +496,25 @@ def _add_job_file_to_list(self, filename, model_type): # update ui list if model_type in self.training_profile_widgets: field = self.training_profile_widgets[model_type] - field.set_options(self.option_list_from_jobs(model_type), filename) + field.set_options( + self.option_list_from_jobs(model_type), filename + ) else: - QtWidgets.QMessageBox(text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}.").exec_() + QtWidgets.QMessageBox( + text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}." + ).exec_() def select_job(self, model_type, idx): jobs = self.job_options[model_type] - if idx == -1: return + if idx == -1: + return if idx < len(jobs): name, job = jobs[idx] training_params = cattr.unstructure(job.trainer) - training_params_specific = {f"{key}_{str(model_type)}":val for key,val in training_params.items()} + training_params_specific = { + f"{key}_{str(model_type)}": val for key, val in training_params.items() + } # confmap and paf models should share some params shown in dialog (e.g. scale) # but centroids does not, so just set any centroid_foo fields from its profile if model_type in [ModelOutputType.CENTROIDS]: @@ -485,45 +547,40 @@ def make_default_training_jobs(): models = dict() models[ModelOutputType.CONFIDENCE_MAP] = Model( - output_type=ModelOutputType.CONFIDENCE_MAP, - backbone=unet.UNet(num_filters=32)) + output_type=ModelOutputType.CONFIDENCE_MAP, backbone=unet.UNet(num_filters=32) + ) models[ModelOutputType.PART_AFFINITY_FIELD] = Model( - output_type=ModelOutputType.PART_AFFINITY_FIELD, - backbone=leap.LeapCNN(num_filters=64)) + output_type=ModelOutputType.PART_AFFINITY_FIELD, + backbone=leap.LeapCNN(num_filters=64), + ) # Build Trainers defaults = dict() defaults["shared"] = dict( - instance_crop = True, - val_size = 0.1, - augment_rotation=180, - batch_size=4, - learning_rate = 1e-4, - reduce_lr_factor=0.5, - reduce_lr_cooldown=3, - reduce_lr_min_delta=1e-6, - reduce_lr_min_lr = 1e-10, - amsgrad = True, - shuffle_every_epoch=True, - save_every_epoch = False, -# val_batches_per_epoch = 10, -# upsampling_layers = True, -# depth = 3, + instance_crop=True, + val_size=0.1, + augment_rotation=180, + batch_size=4, + learning_rate=1e-4, + reduce_lr_factor=0.5, + reduce_lr_cooldown=3, + reduce_lr_min_delta=1e-6, + reduce_lr_min_lr=1e-10, + amsgrad=True, + shuffle_every_epoch=True, + save_every_epoch=False, + # val_batches_per_epoch = 10, + # upsampling_layers = True, + # depth = 3, ) defaults[ModelOutputType.CONFIDENCE_MAP] = dict( - **defaults["shared"], - num_epochs=100, - steps_per_epoch=200, - reduce_lr_patience=5, - ) + **defaults["shared"], num_epochs=100, steps_per_epoch=200, reduce_lr_patience=5 + ) defaults[ModelOutputType.PART_AFFINITY_FIELD] = dict( - **defaults["shared"], - num_epochs=75, - steps_per_epoch = 100, - reduce_lr_patience=8, - ) + **defaults["shared"], num_epochs=75, steps_per_epoch=100, reduce_lr_patience=8 + ) trainers = dict() for type in models.keys(): @@ -537,6 +594,7 @@ def make_default_training_jobs(): return training_jobs + def find_saved_jobs(job_dir, jobs=None): """Find all the TrainingJob json files in a given directory. @@ -574,6 +632,7 @@ def find_saved_jobs(job_dir, jobs=None): return jobs + def add_frames_from_json(labels: Labels, new_labels_json: str): # Deserialize the new frames, matching to the existing videos/skeletons if possible new_lfs = Labels.from_json(new_labels_json, match_to=labels).labeled_frames @@ -587,12 +646,14 @@ def add_frames_from_json(labels: Labels, new_labels_json: str): return len(new_lfs) + def run_active_learning_pipeline( - labels_filename: str, - labels: Labels, - training_jobs: Dict['ModelOutputType', 'TrainingJob']=None, - frames_to_predict: Dict[Video, List[int]]=None, - with_tracking: bool=False) -> int: + labels_filename: str, + labels: Labels, + training_jobs: Dict["ModelOutputType", "TrainingJob"] = None, + frames_to_predict: Dict[Video, List[int]] = None, + with_tracking: bool = False, +) -> int: """Run training (as needed) and inference. Args: @@ -628,16 +689,19 @@ def run_active_learning_pipeline( return 0 # Run the Predictor for suggested frames - new_labeled_frame_count = \ - run_active_inference(labels, trained_jobs, save_dir, frames_to_predict, with_tracking) + new_labeled_frame_count = run_active_inference( + labels, trained_jobs, save_dir, frames_to_predict, with_tracking + ) return new_labeled_frame_count + def run_active_training( - labels: Labels, - training_jobs: Dict['ModelOutputType', 'TrainingJob'], - save_dir:str, - gui:bool = True) -> Dict['ModelOutputType', 'TrainingJob']: + labels: Labels, + training_jobs: Dict["ModelOutputType", "TrainingJob"], + save_dir: str, + gui: bool = True, +) -> Dict["ModelOutputType", "TrainingJob"]: """ Run training for each training job. @@ -678,14 +742,15 @@ def run_active_training( # Start training in separate process # This makes it easier to ensure that tensorflow released memory when done - pool, result = job.trainer.train_async(model=job.model, labels=labels, - save_dir=save_dir) + pool, result = job.trainer.train_async( + model=job.model, labels=labels, save_dir=save_dir + ) # Wait for training results while not result.ready(): if gui: QtWidgets.QApplication.instance().processEvents() - result.wait(.01) + result.wait(0.01) if result.successful(): # get the path to the resulting TrainingJob file @@ -694,7 +759,9 @@ def run_active_training( else: if gui: win.close() - QtWidgets.QMessageBox(text=f"An error occured while training {str(model_type)}. Your command line terminal may have more information about the error.").exec_() + QtWidgets.QMessageBox( + text=f"An error occured while training {str(model_type)}. Your command line terminal may have more information about the error." + ).exec_() trained_jobs[model_type] = None result.get() @@ -710,13 +777,15 @@ def run_active_training( return trained_jobs + def run_active_inference( - labels: Labels, - training_jobs: Dict['ModelOutputType', 'TrainingJob'], - save_dir:str, - frames_to_predict: Dict[Video, List[int]], - with_tracking: bool, - gui: bool=True) -> int: + labels: Labels, + training_jobs: Dict["ModelOutputType", "TrainingJob"], + save_dir: str, + frames_to_predict: Dict[Video, List[int]], + with_tracking: bool, + gui: bool = True, +) -> int: """Run inference on specified frames using models from training_jobs. Args: @@ -732,6 +801,7 @@ def run_active_inference( Number of new frames added to labels. """ from sleap.nn.inference import Predictor + # from multiprocessing import Pool # total_new_lf_count = 0 @@ -740,31 +810,32 @@ def run_active_inference( # Create Predictor from the results of training # pool = Pool(processes=1) - predictor = Predictor(training_jobs=training_jobs, - with_tracking=with_tracking, - # output_path=inference_output_path, - # pool=pool - ) + predictor = Predictor( + training_jobs=training_jobs, + with_tracking=with_tracking, + # output_path=inference_output_path, + # pool=pool + ) if gui: # show message while running inference progress = QtWidgets.QProgressDialog( f"Running inference on {len(frames_to_predict)} videos...", "Cancel", - 0, len(frames_to_predict)) + 0, + len(frames_to_predict), + ) # win.setLabelText(" Running inference on selected frames... ") progress.show() QtWidgets.QApplication.instance().processEvents() - new_lfs = [] for i, (video, frames) in enumerate(frames_to_predict.items()): QtWidgets.QApplication.instance().processEvents() if len(frames): # Run inference for desired frames in this video # result = predictor.predict_async( - new_lfs_video = predictor.predict( - input_video=video, frames=frames) + new_lfs_video = predictor.predict(input_video=video, frames=frames) new_lfs.extend(new_lfs_video) if gui: @@ -778,19 +849,19 @@ def run_active_inference( # result.wait(.01) # if result.successful(): - # new_labels_json = result.get() + # new_labels_json = result.get() - # Add new frames to labels - # (we're doing this for each video as we go since there was a problem - # when we tried to add frames for all videos together.) - # new_lf_count = add_frames_from_json(labels, new_labels_json) + # Add new frames to labels + # (we're doing this for each video as we go since there was a problem + # when we tried to add frames for all videos together.) + # new_lf_count = add_frames_from_json(labels, new_labels_json) - # total_new_lf_count += new_lf_count + # total_new_lf_count += new_lf_count # else: - # if gui: - # QtWidgets.QApplication.instance().processEvents() - # QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() - # result.get() + # if gui: + # QtWidgets.QApplication.instance().processEvents() + # QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() + # result.get() # predictor.pool.close() @@ -809,15 +880,16 @@ def run_active_inference( # return total_new_lf_count return len(new_lfs) + if __name__ == "__main__": import sys -# labels_filename = "/Volumes/fileset-mmurthy/nat/shruthi/labels-mac.json" + # labels_filename = "/Volumes/fileset-mmurthy/nat/shruthi/labels-mac.json" labels_filename = sys.argv[1] labels = Labels.load_json(labels_filename) app = QtWidgets.QApplication() - win = ActiveLearningDialog(labels=labels,labels_filename=labels_filename) + win = ActiveLearningDialog(labels=labels, labels_filename=labels_filename) win.show() app.exec_() diff --git a/sleap/gui/app.py b/sleap/gui/app.py index e1cb090a2..d63f15b51 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -5,7 +5,15 @@ from PySide2.QtWidgets import QApplication, QMainWindow, QWidget, QDockWidget from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout -from PySide2.QtWidgets import QLabel, QPushButton, QLineEdit, QSpinBox, QDoubleSpinBox, QComboBox, QCheckBox +from PySide2.QtWidgets import ( + QLabel, + QPushButton, + QLineEdit, + QSpinBox, + QDoubleSpinBox, + QComboBox, + QCheckBox, +) from PySide2.QtWidgets import QTableWidget, QTableView, QTableWidgetItem from PySide2.QtWidgets import QMenu, QAction from PySide2.QtWidgets import QFileDialog, QMessageBox @@ -29,20 +37,31 @@ from sleap.io.dataset import Labels from sleap.info.summary import Summary from sleap.gui.video import QtVideoPlayer -from sleap.gui.dataviews import VideosTable, SkeletonNodesTable, SkeletonEdgesTable, \ - LabeledFrameTable, SkeletonNodeModel, SuggestionsTable +from sleap.gui.dataviews import ( + VideosTable, + SkeletonNodesTable, + SkeletonEdgesTable, + LabeledFrameTable, + SkeletonNodeModel, + SuggestionsTable, +) from sleap.gui.importvideos import ImportVideos from sleap.gui.formbuilder import YamlFormWidget from sleap.gui.merge import MergeDialog from sleap.gui.shortcuts import Shortcuts, ShortcutDialog from sleap.gui.suggestions import VideoFrameSuggestions -from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay, TrackListOverlay +from sleap.gui.overlays.tracks import ( + TrackColorManager, + TrackTrailOverlay, + TrackListOverlay, +) from sleap.gui.overlays.instance import InstanceOverlay from sleap.gui.overlays.anchors import NegativeAnchorOverlay OPEN_IN_NEW = True + class MainWindow(QMainWindow): labels: Labels skeleton: Skeleton @@ -87,7 +106,7 @@ def __init__(self, data_path=None, video=None, import_data=None, *args, **kwargs def event(self, e): if e.type() == QEvent.StatusTip: - if e.tip() == '': + if e.tip() == "": return True return super().event(e) @@ -112,8 +131,10 @@ def changestack_end_atomic(self): def changestack_has_changes(self) -> bool: # True iff there are no unsaved changed - if len(self._change_stack) == 0: return False - if self._change_stack[-1] == "SAVE": return False + if len(self._change_stack) == 0: + return False + if self._change_stack[-1] == "SAVE": + return False return True @property @@ -123,7 +144,8 @@ def filename(self): @filename.setter def filename(self, x): self._filename = x - if x is not None: self.setWindowTitle(x) + if x is not None: + self.setWindowTitle(x) def initialize_gui(self): @@ -132,13 +154,15 @@ def initialize_gui(self): ####### Video player ####### self.player = QtVideoPlayer(color_manager=self._color_manager) self.player.changedPlot.connect(self.newFrame) - self.player.changedData.connect(lambda inst: self.changestack_push("viewer change")) + self.player.changedData.connect( + lambda inst: self.changestack_push("viewer change") + ) self.player.view.instanceDoubleClicked.connect(self.doubleClickInstance) self.player.seekbar.selectionChanged.connect(lambda: self.updateStatusMessage()) self.setCentralWidget(self.player) ####### Status bar ####### - self.statusBar() # Initialize status bar + self.statusBar() # Initialize status bar self.load_overlays() @@ -147,54 +171,104 @@ def initialize_gui(self): ### File Menu ### fileMenu = self.menuBar().addMenu("File") - self._menu_actions["new"] = fileMenu.addAction("New Project", self.newProject, shortcuts["new"]) - self._menu_actions["open"] = fileMenu.addAction("Open Project...", self.openProject, shortcuts["open"]) - self._menu_actions["import predictions"] = fileMenu.addAction("Import Labels...", self.importPredictions) + self._menu_actions["new"] = fileMenu.addAction( + "New Project", self.newProject, shortcuts["new"] + ) + self._menu_actions["open"] = fileMenu.addAction( + "Open Project...", self.openProject, shortcuts["open"] + ) + self._menu_actions["import predictions"] = fileMenu.addAction( + "Import Labels...", self.importPredictions + ) fileMenu.addSeparator() - self._menu_actions["add videos"] = fileMenu.addAction("Add Videos...", self.addVideo, shortcuts["add videos"]) + self._menu_actions["add videos"] = fileMenu.addAction( + "Add Videos...", self.addVideo, shortcuts["add videos"] + ) fileMenu.addSeparator() - self._menu_actions["save"] = fileMenu.addAction("Save", self.saveProject, shortcuts["save"]) - self._menu_actions["save as"] = fileMenu.addAction("Save As...", self.saveProjectAs, shortcuts["save as"]) + self._menu_actions["save"] = fileMenu.addAction( + "Save", self.saveProject, shortcuts["save"] + ) + self._menu_actions["save as"] = fileMenu.addAction( + "Save As...", self.saveProjectAs, shortcuts["save as"] + ) fileMenu.addSeparator() - self._menu_actions["close"] = fileMenu.addAction("Quit", self.close, shortcuts["close"]) + self._menu_actions["close"] = fileMenu.addAction( + "Quit", self.close, shortcuts["close"] + ) ### Go Menu ### goMenu = self.menuBar().addMenu("Go") - self._menu_actions["goto next labeled"] = goMenu.addAction("Next Labeled Frame", self.nextLabeledFrame, shortcuts["goto next labeled"]) - self._menu_actions["goto prev labeled"] = goMenu.addAction("Previous Labeled Frame", self.previousLabeledFrame, shortcuts["goto prev labeled"]) - - self._menu_actions["goto next user"] = goMenu.addAction("Next User Labeled Frame", self.nextUserLabeledFrame, shortcuts["goto next user"]) - - self._menu_actions["goto next suggestion"] = goMenu.addAction("Next Suggestion", self.nextSuggestedFrame, shortcuts["goto next suggestion"]) - self._menu_actions["goto prev suggestion"] = goMenu.addAction("Previous Suggestion", lambda:self.nextSuggestedFrame(-1), shortcuts["goto prev suggestion"]) - - self._menu_actions["goto next track spawn"] = goMenu.addAction("Next Track Spawn Frame", self.nextTrackFrame, shortcuts["goto next track spawn"]) + self._menu_actions["goto next labeled"] = goMenu.addAction( + "Next Labeled Frame", self.nextLabeledFrame, shortcuts["goto next labeled"] + ) + self._menu_actions["goto prev labeled"] = goMenu.addAction( + "Previous Labeled Frame", + self.previousLabeledFrame, + shortcuts["goto prev labeled"], + ) + + self._menu_actions["goto next user"] = goMenu.addAction( + "Next User Labeled Frame", + self.nextUserLabeledFrame, + shortcuts["goto next user"], + ) + + self._menu_actions["goto next suggestion"] = goMenu.addAction( + "Next Suggestion", + self.nextSuggestedFrame, + shortcuts["goto next suggestion"], + ) + self._menu_actions["goto prev suggestion"] = goMenu.addAction( + "Previous Suggestion", + lambda: self.nextSuggestedFrame(-1), + shortcuts["goto prev suggestion"], + ) + + self._menu_actions["goto next track spawn"] = goMenu.addAction( + "Next Track Spawn Frame", + self.nextTrackFrame, + shortcuts["goto next track spawn"], + ) goMenu.addSeparator() - self._menu_actions["next video"] = goMenu.addAction("Next Video", self.nextVideo, shortcuts["next video"]) - self._menu_actions["prev video"] = goMenu.addAction("Previous Video", self.previousVideo, shortcuts["prev video"]) + self._menu_actions["next video"] = goMenu.addAction( + "Next Video", self.nextVideo, shortcuts["next video"] + ) + self._menu_actions["prev video"] = goMenu.addAction( + "Previous Video", self.previousVideo, shortcuts["prev video"] + ) goMenu.addSeparator() - self._menu_actions["goto frame"] = goMenu.addAction("Go to Frame...", self.gotoFrame, shortcuts["goto frame"]) - self._menu_actions["mark frame"] = goMenu.addAction("Mark Frame", self.markFrame, shortcuts["mark frame"]) - self._menu_actions["goto marked"] = goMenu.addAction("Go to Marked Frame", self.goMarkedFrame, shortcuts["goto marked"]) - + self._menu_actions["goto frame"] = goMenu.addAction( + "Go to Frame...", self.gotoFrame, shortcuts["goto frame"] + ) + self._menu_actions["mark frame"] = goMenu.addAction( + "Mark Frame", self.markFrame, shortcuts["mark frame"] + ) + self._menu_actions["goto marked"] = goMenu.addAction( + "Go to Marked Frame", self.goMarkedFrame, shortcuts["goto marked"] + ) ### View Menu ### viewMenu = self.menuBar().addMenu("View") viewMenu.addSeparator() - self._menu_actions["color predicted"] = viewMenu.addAction("Color Predicted Instances", self.toggleColorPredicted, shortcuts["color predicted"]) + self._menu_actions["color predicted"] = viewMenu.addAction( + "Color Predicted Instances", + self.toggleColorPredicted, + shortcuts["color predicted"], + ) self.paletteMenu = viewMenu.addMenu("Color Palette") for palette_name in self._color_manager.palette_names: - menu_item = self.paletteMenu.addAction(f"{palette_name}", - lambda x=palette_name: self.setPalette(x)) + menu_item = self.paletteMenu.addAction( + f"{palette_name}", lambda x=palette_name: self.setPalette(x) + ) menu_item.setCheckable(True) self.setPalette("standard") @@ -202,85 +276,149 @@ def initialize_gui(self): self.seekbarHeaderMenu = viewMenu.addMenu("Seekbar Header") headers = ( - "None", - "Point Displacement (sum)", - "Point Displacement (max)", - "Instance Score (sum)", - "Instance Score (min)", - "Point Score (sum)", - "Point Score (min)", - "Number of predicted points" - ) + "None", + "Point Displacement (sum)", + "Point Displacement (max)", + "Instance Score (sum)", + "Instance Score (min)", + "Point Score (sum)", + "Point Score (min)", + "Number of predicted points", + ) for header in headers: - menu_item = self.seekbarHeaderMenu.addAction(header, - lambda x=header: self.setSeekbarHeader(x)) + menu_item = self.seekbarHeaderMenu.addAction( + header, lambda x=header: self.setSeekbarHeader(x) + ) menu_item.setCheckable(True) self.setSeekbarHeader("None") viewMenu.addSeparator() - self._menu_actions["show labels"] = viewMenu.addAction("Show Node Names", self.toggleLabels, shortcuts["show labels"]) - self._menu_actions["show edges"] = viewMenu.addAction("Show Edges", self.toggleEdges, shortcuts["show edges"]) - self._menu_actions["show trails"] = viewMenu.addAction("Show Trails", self.toggleTrails, shortcuts["show trails"]) + self._menu_actions["show labels"] = viewMenu.addAction( + "Show Node Names", self.toggleLabels, shortcuts["show labels"] + ) + self._menu_actions["show edges"] = viewMenu.addAction( + "Show Edges", self.toggleEdges, shortcuts["show edges"] + ) + self._menu_actions["show trails"] = viewMenu.addAction( + "Show Trails", self.toggleTrails, shortcuts["show trails"] + ) self.trailLengthMenu = viewMenu.addMenu("Trail Length") for length_option in (4, 10, 20): - menu_item = self.trailLengthMenu.addAction(f"{length_option}", - lambda x=length_option: self.setTrailLength(x)) + menu_item = self.trailLengthMenu.addAction( + f"{length_option}", lambda x=length_option: self.setTrailLength(x) + ) menu_item.setCheckable(True) viewMenu.addSeparator() - self._menu_actions["fit"] = viewMenu.addAction("Fit Instances to View", self.toggleAutoZoom, shortcuts["fit"]) + self._menu_actions["fit"] = viewMenu.addAction( + "Fit Instances to View", self.toggleAutoZoom, shortcuts["fit"] + ) viewMenu.addSeparator() # set menu checkmarks - self._menu_actions["show labels"].setCheckable(True); self._menu_actions["show labels"].setChecked(self._show_labels) - self._menu_actions["show edges"].setCheckable(True); self._menu_actions["show edges"].setChecked(self._show_edges) - self._menu_actions["show trails"].setCheckable(True); self._menu_actions["show trails"].setChecked(self.overlays["trails"].show) - self._menu_actions["color predicted"].setCheckable(True); self._menu_actions["color predicted"].setChecked(self.overlays["instance"].color_predicted) + self._menu_actions["show labels"].setCheckable(True) + self._menu_actions["show labels"].setChecked(self._show_labels) + self._menu_actions["show edges"].setCheckable(True) + self._menu_actions["show edges"].setChecked(self._show_edges) + self._menu_actions["show trails"].setCheckable(True) + self._menu_actions["show trails"].setChecked(self.overlays["trails"].show) + self._menu_actions["color predicted"].setCheckable(True) + self._menu_actions["color predicted"].setChecked( + self.overlays["instance"].color_predicted + ) self._menu_actions["fit"].setCheckable(True) ### Label Menu ### labelMenu = self.menuBar().addMenu("Labels") - self._menu_actions["add instance"] = labelMenu.addAction("Add Instance", self.newInstance, shortcuts["add instance"]) - self._menu_actions["delete instance"] = labelMenu.addAction("Delete Instance", self.deleteSelectedInstance, shortcuts["delete instance"]) + self._menu_actions["add instance"] = labelMenu.addAction( + "Add Instance", self.newInstance, shortcuts["add instance"] + ) + self._menu_actions["delete instance"] = labelMenu.addAction( + "Delete Instance", self.deleteSelectedInstance, shortcuts["delete instance"] + ) labelMenu.addSeparator() self.track_menu = labelMenu.addMenu("Set Instance Track") - self._menu_actions["transpose"] = labelMenu.addAction("Transpose Instance Tracks", self.transposeInstance, shortcuts["transpose"]) - self._menu_actions["delete track"] = labelMenu.addAction("Delete Instance and Track", self.deleteSelectedInstanceTrack, shortcuts["delete track"]) + self._menu_actions["transpose"] = labelMenu.addAction( + "Transpose Instance Tracks", self.transposeInstance, shortcuts["transpose"] + ) + self._menu_actions["delete track"] = labelMenu.addAction( + "Delete Instance and Track", + self.deleteSelectedInstanceTrack, + shortcuts["delete track"], + ) labelMenu.addSeparator() - self._menu_actions["select next"] = labelMenu.addAction("Select Next Instance", self.player.view.nextSelection, shortcuts["select next"]) - self._menu_actions["clear selection"] = labelMenu.addAction("Clear Selection", self.player.view.clearSelection, shortcuts["clear selection"]) + self._menu_actions["select next"] = labelMenu.addAction( + "Select Next Instance", + self.player.view.nextSelection, + shortcuts["select next"], + ) + self._menu_actions["clear selection"] = labelMenu.addAction( + "Clear Selection", + self.player.view.clearSelection, + shortcuts["clear selection"], + ) labelMenu.addSeparator() ### Predict Menu ### predictionMenu = self.menuBar().addMenu("Predict") - self._menu_actions["active learning"] = predictionMenu.addAction("Run Active Learning...", self.runActiveLearning, shortcuts["learning"]) - self._menu_actions["inference"] = predictionMenu.addAction("Run Inference...", self.runInference) - self._menu_actions["learning expert"] = predictionMenu.addAction("Expert Controls...", self.runLearningExpert) + self._menu_actions["active learning"] = predictionMenu.addAction( + "Run Active Learning...", self.runActiveLearning, shortcuts["learning"] + ) + self._menu_actions["inference"] = predictionMenu.addAction( + "Run Inference...", self.runInference + ) + self._menu_actions["learning expert"] = predictionMenu.addAction( + "Expert Controls...", self.runLearningExpert + ) predictionMenu.addSeparator() - self._menu_actions["negative sample"] = predictionMenu.addAction("Mark Negative Training Sample...", self.markNegativeAnchor) - self._menu_actions["clear negative samples"] = predictionMenu.addAction("Clear Current Frame Negative Samples", self.clearFrameNegativeAnchors) + self._menu_actions["negative sample"] = predictionMenu.addAction( + "Mark Negative Training Sample...", self.markNegativeAnchor + ) + self._menu_actions["clear negative samples"] = predictionMenu.addAction( + "Clear Current Frame Negative Samples", self.clearFrameNegativeAnchors + ) predictionMenu.addSeparator() - self._menu_actions["visualize models"] = predictionMenu.addAction("Visualize Model Outputs...", self.visualizeOutputs) + self._menu_actions["visualize models"] = predictionMenu.addAction( + "Visualize Model Outputs...", self.visualizeOutputs + ) predictionMenu.addSeparator() - self._menu_actions["remove predictions"] = predictionMenu.addAction("Delete All Predictions...", self.deletePredictions) - self._menu_actions["remove clip predictions"] = predictionMenu.addAction("Delete Predictions from Clip...", self.deleteClipPredictions, shortcuts["delete clip"]) - self._menu_actions["remove area predictions"] = predictionMenu.addAction("Delete Predictions from Area...", self.deleteAreaPredictions, shortcuts["delete area"]) - self._menu_actions["remove score predictions"] = predictionMenu.addAction("Delete Predictions with Low Score...", self.deleteLowScorePredictions) - self._menu_actions["remove frame limit predictions"] = predictionMenu.addAction("Delete Predictions beyond Frame Limit...", self.deleteFrameLimitPredictions) + self._menu_actions["remove predictions"] = predictionMenu.addAction( + "Delete All Predictions...", self.deletePredictions + ) + self._menu_actions["remove clip predictions"] = predictionMenu.addAction( + "Delete Predictions from Clip...", + self.deleteClipPredictions, + shortcuts["delete clip"], + ) + self._menu_actions["remove area predictions"] = predictionMenu.addAction( + "Delete Predictions from Area...", + self.deleteAreaPredictions, + shortcuts["delete area"], + ) + self._menu_actions["remove score predictions"] = predictionMenu.addAction( + "Delete Predictions with Low Score...", self.deleteLowScorePredictions + ) + self._menu_actions["remove frame limit predictions"] = predictionMenu.addAction( + "Delete Predictions beyond Frame Limit...", self.deleteFrameLimitPredictions + ) predictionMenu.addSeparator() - self._menu_actions["export frames"] = predictionMenu.addAction("Export Training Package...", self.exportLabeledFrames) - self._menu_actions["export clip"] = predictionMenu.addAction("Export Labeled Clip...", self.exportLabeledClip, shortcuts["export clip"]) + self._menu_actions["export frames"] = predictionMenu.addAction( + "Export Training Package...", self.exportLabeledFrames + ) + self._menu_actions["export clip"] = predictionMenu.addAction( + "Export Labeled Clip...", self.exportLabeledClip, shortcuts["export clip"] + ) ############ @@ -311,20 +449,26 @@ def _make_dock(name, widgets=[], tab_with=None): videos_layout.addWidget(self.videosTable) hb = QHBoxLayout() btn = QPushButton("Show video") - btn.clicked.connect(self.activateSelectedVideo); hb.addWidget(btn) + btn.clicked.connect(self.activateSelectedVideo) + hb.addWidget(btn) self._buttons["show video"] = btn btn = QPushButton("Add videos") - btn.clicked.connect(self.addVideo); hb.addWidget(btn) + btn.clicked.connect(self.addVideo) + hb.addWidget(btn) btn = QPushButton("Remove video") - btn.clicked.connect(self.removeVideo); hb.addWidget(btn) + btn.clicked.connect(self.removeVideo) + hb.addWidget(btn) self._buttons["remove video"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) videos_layout.addWidget(hbw) self.videosTable.doubleClicked.connect(self.activateSelectedVideo) ####### Skeleton ####### - skeleton_layout = _make_dock("Skeleton", tab_with=videos_layout.parent().parent()) + skeleton_layout = _make_dock( + "Skeleton", tab_with=videos_layout.parent().parent() + ) gb = QGroupBox("Nodes") vb = QVBoxLayout() @@ -332,11 +476,14 @@ def _make_dock(name, widgets=[], tab_with=None): vb.addWidget(self.skeletonNodesTable) hb = QHBoxLayout() btn = QPushButton("New node") - btn.clicked.connect(self.newNode); hb.addWidget(btn) + btn.clicked.connect(self.newNode) + hb.addWidget(btn) btn = QPushButton("Delete node") - btn.clicked.connect(self.deleteNode); hb.addWidget(btn) + btn.clicked.connect(self.deleteNode) + hb.addWidget(btn) self._buttons["delete node"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) vb.addWidget(hbw) gb.setLayout(vb) skeleton_layout.addWidget(gb) @@ -346,30 +493,43 @@ def _make_dock(name, widgets=[], tab_with=None): self.skeletonEdgesTable = SkeletonEdgesTable(self.skeleton) vb.addWidget(self.skeletonEdgesTable) hb = QHBoxLayout() - self.skeletonEdgesSrc = QComboBox(); self.skeletonEdgesSrc.setEditable(False); self.skeletonEdgesSrc.currentIndexChanged.connect(self.selectSkeletonEdgeSrc) + self.skeletonEdgesSrc = QComboBox() + self.skeletonEdgesSrc.setEditable(False) + self.skeletonEdgesSrc.currentIndexChanged.connect(self.selectSkeletonEdgeSrc) self.skeletonEdgesSrc.setModel(SkeletonNodeModel(self.skeleton)) hb.addWidget(self.skeletonEdgesSrc) hb.addWidget(QLabel("to")) - self.skeletonEdgesDst = QComboBox(); self.skeletonEdgesDst.setEditable(False) + self.skeletonEdgesDst = QComboBox() + self.skeletonEdgesDst.setEditable(False) hb.addWidget(self.skeletonEdgesDst) - self.skeletonEdgesDst.setModel(SkeletonNodeModel(self.skeleton, lambda: self.skeletonEdgesSrc.currentText())) + self.skeletonEdgesDst.setModel( + SkeletonNodeModel( + self.skeleton, lambda: self.skeletonEdgesSrc.currentText() + ) + ) btn = QPushButton("Add edge") - btn.clicked.connect(self.newEdge); hb.addWidget(btn) + btn.clicked.connect(self.newEdge) + hb.addWidget(btn) self._buttons["add edge"] = btn btn = QPushButton("Delete edge") - btn.clicked.connect(self.deleteEdge); hb.addWidget(btn) + btn.clicked.connect(self.deleteEdge) + hb.addWidget(btn) self._buttons["delete edge"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) vb.addWidget(hbw) gb.setLayout(vb) skeleton_layout.addWidget(gb) hb = QHBoxLayout() btn = QPushButton("Load Skeleton") - btn.clicked.connect(self.openSkeleton); hb.addWidget(btn) + btn.clicked.connect(self.openSkeleton) + hb.addWidget(btn) btn = QPushButton("Save Skeleton") - btn.clicked.connect(self.saveSkeleton); hb.addWidget(btn) - hbw = QWidget(); hbw.setLayout(hb) + btn.clicked.connect(self.saveSkeleton) + hb.addWidget(btn) + hbw = QWidget() + hbw.setLayout(hb) skeleton_layout.addWidget(hbw) # update edge UI when change to nodes @@ -382,17 +542,21 @@ def _make_dock(name, widgets=[], tab_with=None): instances_layout.addWidget(self.instancesTable) hb = QHBoxLayout() btn = QPushButton("New instance") - btn.clicked.connect(lambda x: self.newInstance()); hb.addWidget(btn) + btn.clicked.connect(lambda x: self.newInstance()) + hb.addWidget(btn) btn = QPushButton("Delete instance") - btn.clicked.connect(self.deleteSelectedInstance); hb.addWidget(btn) + btn.clicked.connect(self.deleteSelectedInstance) + hb.addWidget(btn) self._buttons["delete instance"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) instances_layout.addWidget(hbw) def update_instance_table_selection(): inst_selected = self.player.view.getSelectionInstance() - if not inst_selected: return + if not inst_selected: + return idx = -1 if inst_selected in self.labeled_frame.instances_to_show: @@ -401,7 +565,9 @@ def update_instance_table_selection(): table_row_idx = self.instancesTable.model().createIndex(idx, 0) self.instancesTable.setCurrentIndex(table_row_idx) - self.instancesTable.selectionChangedSignal.connect(lambda inst: self.player.view.selectInstance(inst, signal=False)) + self.instancesTable.selectionChangedSignal.connect( + lambda inst: self.player.view.selectInstance(inst, signal=False) + ) self.player.view.updatedSelection.connect(update_instance_table_selection) # update track UI when change to track name @@ -415,20 +581,31 @@ def update_instance_table_selection(): hb = QHBoxLayout() btn = QPushButton("Prev") - btn.clicked.connect(lambda:self.nextSuggestedFrame(-1)); hb.addWidget(btn) + btn.clicked.connect(lambda: self.nextSuggestedFrame(-1)) + hb.addWidget(btn) self.suggested_count_label = QLabel() hb.addWidget(self.suggested_count_label) btn = QPushButton("Next") - btn.clicked.connect(lambda:self.nextSuggestedFrame()); hb.addWidget(btn) - hbw = QWidget(); hbw.setLayout(hb) + btn.clicked.connect(lambda: self.nextSuggestedFrame()) + hb.addWidget(btn) + hbw = QWidget() + hbw.setLayout(hb) suggestions_layout.addWidget(hbw) - suggestions_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/suggestions.yaml") - form_wid = YamlFormWidget(yaml_file=suggestions_yaml, title="Generate Suggestions") + suggestions_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/suggestions.yaml" + ) + form_wid = YamlFormWidget( + yaml_file=suggestions_yaml, title="Generate Suggestions" + ) form_wid.mainAction.connect(self.generateSuggestions) suggestions_layout.addWidget(form_wid) - self.suggestionsTable.doubleClicked.connect(lambda table_idx: self.gotoVideoAndFrame(*self.labels.get_suggestions()[table_idx.row()])) + self.suggestionsTable.doubleClicked.connect( + lambda table_idx: self.gotoVideoAndFrame( + *self.labels.get_suggestions()[table_idx.row()] + ) + ) # # Set timer to update state of gui at regular intervals @@ -440,35 +617,40 @@ def update_instance_table_selection(): def load_overlays(self): self.overlays["track_labels"] = TrackListOverlay( - labels = self.labels, - view = self.player.view, - color_manager = self._color_manager) + labels=self.labels, view=self.player.view, color_manager=self._color_manager + ) self.overlays["negative"] = NegativeAnchorOverlay( - labels = self.labels, - scene = self.player.view.scene) + labels=self.labels, scene=self.player.view.scene + ) self.overlays["trails"] = TrackTrailOverlay( - labels = self.labels, - scene = self.player.view.scene, - color_manager = self._color_manager) + labels=self.labels, + scene=self.player.view.scene, + color_manager=self._color_manager, + ) self.overlays["instance"] = InstanceOverlay( - labels = self.labels, - player = self.player, - color_manager = self._color_manager) + labels=self.labels, player=self.player, color_manager=self._color_manager + ) def update_gui_state(self): - has_selected_instance = (self.player.view.getSelection() is not None) + has_selected_instance = self.player.view.getSelection() is not None has_unsaved_changes = self.changestack_has_changes() - has_multiple_videos = (self.labels is not None and len(self.labels.videos) > 1) - has_labeled_frames = self.labels is not None and any((lf.video == self.video for lf in self.labels)) + has_multiple_videos = self.labels is not None and len(self.labels.videos) > 1 + has_labeled_frames = self.labels is not None and any( + (lf.video == self.video for lf in self.labels) + ) has_suggestions = self.labels is not None and (len(self.labels.suggestions) > 0) has_tracks = self.labels is not None and (len(self.labels.tracks) > 0) - has_multiple_instances = (self.labeled_frame is not None and len(self.labeled_frame.instances) > 1) + has_multiple_instances = ( + self.labeled_frame is not None and len(self.labeled_frame.instances) > 1 + ) # todo: exclude predicted instances from count - has_nodes_selected = (self.skeletonEdgesSrc.currentIndex() > -1 and - self.skeletonEdgesDst.currentIndex() > -1) + has_nodes_selected = ( + self.skeletonEdgesSrc.currentIndex() > -1 + and self.skeletonEdgesDst.currentIndex() > -1 + ) control_key_down = QApplication.queryKeyboardModifiers() == Qt.ControlModifier # Update menus @@ -495,20 +677,32 @@ def update_gui_state(self): # Update buttons self._buttons["add edge"].setEnabled(has_nodes_selected) - self._buttons["delete edge"].setEnabled(self.skeletonEdgesTable.currentIndex().isValid()) - self._buttons["delete node"].setEnabled(self.skeletonNodesTable.currentIndex().isValid()) - self._buttons["show video"].setEnabled(self.videosTable.currentIndex().isValid()) - self._buttons["remove video"].setEnabled(self.videosTable.currentIndex().isValid()) - self._buttons["delete instance"].setEnabled(self.instancesTable.currentIndex().isValid()) + self._buttons["delete edge"].setEnabled( + self.skeletonEdgesTable.currentIndex().isValid() + ) + self._buttons["delete node"].setEnabled( + self.skeletonNodesTable.currentIndex().isValid() + ) + self._buttons["show video"].setEnabled( + self.videosTable.currentIndex().isValid() + ) + self._buttons["remove video"].setEnabled( + self.videosTable.currentIndex().isValid() + ) + self._buttons["delete instance"].setEnabled( + self.instancesTable.currentIndex().isValid() + ) # Update overlays - self.overlays["track_labels"].visible = control_key_down and has_selected_instance + self.overlays["track_labels"].visible = ( + control_key_down and has_selected_instance + ) def update_data_views(self, *update): update = update or ("video", "skeleton", "labels", "frame", "suggestions") if len(self.skeleton.nodes) == 0 and len(self.labels.skeletons): - self.skeleton = self.labels.skeletons[0] + self.skeleton = self.labels.skeletons[0] if "video" in update: self.videosTable.model().videos = self.labels.videos @@ -533,15 +727,20 @@ def update_data_views(self, *update): suggestion_status_text = "" suggestion_list = self.labels.get_suggestions() if len(suggestion_list): - suggestion_label_counts = [self.labels.instance_count(video, frame_idx) - for (video, frame_idx) in suggestion_list] + suggestion_label_counts = [ + self.labels.instance_count(video, frame_idx) + for (video, frame_idx) in suggestion_list + ] labeled_count = len(suggestion_list) - suggestion_label_counts.count(0) - suggestion_status_text = f"{labeled_count}/{len(suggestion_list)} labeled" + suggestion_status_text = ( + f"{labeled_count}/{len(suggestion_list)} labeled" + ) self.suggested_count_label.setText(suggestion_status_text) def plotFrame(self, *args, **kwargs): """Wrap call to player.plot so we can redraw/update things.""" - if self.video is None: return + if self.video is None: + return self.player.plot(*args, **kwargs) self.player.showLabels(self._show_labels) @@ -552,10 +751,12 @@ def plotFrame(self, *args, **kwargs): def importData(self, filename=None, do_load=True): show_msg = False - if len(filename) == 0: return + if len(filename) == 0: + return gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)]) + search_paths=[os.path.dirname(filename)] + ) has_loaded = False labels = None @@ -586,7 +787,9 @@ def importData(self, filename=None, do_load=True): self.setTrailLength(self.overlays["trails"].trail_length) if show_msg: - msgBox = QMessageBox(text=f"Imported {len(self.labels)} labeled frames.") + msgBox = QMessageBox( + text=f"Imported {len(self.labels)} labeled frames." + ) msgBox.exec_() if len(self.labels.skeletons): @@ -611,13 +814,16 @@ def updateTrackMenu(self): key_command = "" if self.labels.tracks.index(track) < 9: key_command = Qt.CTRL + Qt.Key_0 + self.labels.tracks.index(track) + 1 - self.track_menu.addAction(f"{track.name}", lambda x=track:self.setInstanceTrack(x), key_command) + self.track_menu.addAction( + f"{track.name}", lambda x=track: self.setInstanceTrack(x), key_command + ) self.track_menu.addAction("New Track", self.addTrack, Qt.CTRL + Qt.Key_0) def activateSelectedVideo(self, x): # Get selected video idx = self.videosTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return self.loadVideo(self.labels.videos[idx.row()], idx.row()) def addVideo(self, filename=None): @@ -638,7 +844,7 @@ def addVideo(self, filename=None): # Load if no video currently loaded if self.video is None: - self.loadVideo(video, len(self.labels.videos)-1) + self.loadVideo(video, len(self.labels.videos) - 1) # Update data model/view self.update_data_views("video") @@ -646,7 +852,8 @@ def addVideo(self, filename=None): def removeVideo(self): # Get selected video idx = self.videosTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return video = self.labels.videos[idx.row()] # Count labeled frames for this video @@ -654,7 +861,13 @@ def removeVideo(self): # Warn if there are labels that will be deleted if n > 0: - response = QMessageBox.critical(self, "Removing video with labels", f"{n} labeled frames in this video will be deleted, are you sure you want to remove this video?", QMessageBox.Yes, QMessageBox.No) + response = QMessageBox.critical( + self, + "Removing video with labels", + f"{n} labeled frames in this video will be deleted, are you sure you want to remove this video?", + QMessageBox.Yes, + QMessageBox.No, + ) if response == QMessageBox.No: return @@ -674,13 +887,15 @@ def removeVideo(self): new_idx = min(idx.row(), len(self.labels.videos) - 1) self.loadVideo(self.labels.videos[new_idx], new_idx) - def loadVideo(self, video:Video, video_idx: int = None): + def loadVideo(self, video: Video, video_idx: int = None): # Clear video frame mark self.mark_idx = None # Update current video instance self.video = video - self.video_idx = video_idx if video_idx is not None else self.labels.videos.index(video) + self.video_idx = ( + video_idx if video_idx is not None else self.labels.videos.index(video) + ) # Load video in player widget self.player.load_video(self.video) @@ -695,9 +910,12 @@ def loadVideo(self, video:Video, video_idx: int = None): def openSkeleton(self): filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] - filename, selected_filter = QFileDialog.getOpenFileName(self, dir=None, caption="Open skeleton...", filter=";;".join(filters)) + filename, selected_filter = QFileDialog.getOpenFileName( + self, dir=None, caption="Open skeleton...", filter=";;".join(filters) + ) - if len(filename) == 0: return + if len(filename) == 0: + return if filename.endswith(".json"): self.skeleton = Skeleton.load_json(filename) @@ -716,9 +934,12 @@ def openSkeleton(self): def saveSkeleton(self): default_name = "skeleton.json" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] - filename, selected_filter = QFileDialog.getSaveFileName(self, caption="Save As...", dir=default_name, filter=";;".join(filters)) + filename, selected_filter = QFileDialog.getSaveFileName( + self, caption="Save As...", dir=default_name, filter=";;".join(filters) + ) - if len(filename) == 0: return + if len(filename) == 0: + return if filename.endswith(".json"): self.skeleton.save_json(filename) @@ -745,7 +966,8 @@ def newNode(self): def deleteNode(self): # Get selected node idx = self.skeletonNodesTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return node = self.skeleton.nodes[idx.row()] # Remove @@ -785,13 +1007,13 @@ def newEdge(self): self.plotFrame() - def deleteEdge(self): # TODO: Move this to unified data model # Get selected edge idx = self.skeletonEdgesTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return edge = self.skeleton.edges[idx.row()] # Delete edge @@ -816,7 +1038,7 @@ def setSeekbarHeader(self, graph_name): "Point Score (sum)": data_obj.get_point_score_series, "Point Score (min)": data_obj.get_point_score_series, "Number of predicted points": data_obj.get_point_count_series, - } + } self._menu_check_single(self.seekbarHeaderMenu, graph_name) @@ -837,9 +1059,8 @@ def generateSuggestions(self, params): new_suggestions = dict() for video in self.labels.videos: new_suggestions[video] = VideoFrameSuggestions.suggest( - video=video, - labels=self.labels, - params=params) + video=video, labels=self.labels, params=params + ) self.labels.set_suggestions(new_suggestions) @@ -847,25 +1068,32 @@ def generateSuggestions(self, params): self.updateSeekbarMarks() def _frames_for_prediction(self): - - def remove_user_labeled(video, frames, user_labeled_frames=self.labels.user_labeled_frames): - if len(frames) == 0: return frames - video_user_labeled_frame_idxs = [lf.frame_idx for lf in user_labeled_frames - if lf.video == video] + def remove_user_labeled( + video, frames, user_labeled_frames=self.labels.user_labeled_frames + ): + if len(frames) == 0: + return frames + video_user_labeled_frame_idxs = [ + lf.frame_idx for lf in user_labeled_frames if lf.video == video + ] return list(set(frames) - set(video_user_labeled_frame_idxs)) selection = dict() selection["frame"] = {self.video: [self.player.frame_idx]} - selection["clip"] = {self.video: list(range(*self.player.seekbar.getSelection()))} + selection["clip"] = { + self.video: list(range(*self.player.seekbar.getSelection())) + } selection["video"] = {self.video: list(range(self.video.num_frames))} selection["suggestions"] = { - video:remove_user_labeled(video, self.labels.get_video_suggestions(video)) - for video in self.labels.videos} + video: remove_user_labeled(video, self.labels.get_video_suggestions(video)) + for video in self.labels.videos + } selection["random"] = { video: remove_user_labeled(video, VideoFrameSuggestions.random(video=video)) - for video in self.labels.videos} + for video in self.labels.videos + } return selection @@ -873,11 +1101,15 @@ def _show_learning_window(self, mode): from sleap.gui.active import ActiveLearningDialog if "inference" in self.overlays: - QMessageBox(text=f"In order to use this function you must first quit and re-open sLEAP to release resources used by visualizing model outputs.").exec_() - return + QMessageBox( + text=f"In order to use this function you must first quit and re-open sLEAP to release resources used by visualizing model outputs." + ).exec_() + return if self._child_windows.get(mode, None) is None: - self._child_windows[mode] = ActiveLearningDialog(self.filename, self.labels, mode) + self._child_windows[mode] = ActiveLearningDialog( + self.filename, self.labels, mode + ) self._child_windows[mode].learningFinished.connect(self.learningFinished) self._child_windows[mode].frame_selection = self._frames_for_prediction() @@ -908,15 +1140,22 @@ def visualizeOutputs(self): models_dir = os.path.join(os.path.dirname(self.filename), "models/") # Show dialog - filename, selected_filter = QFileDialog.getOpenFileName(self, dir=models_dir, caption="Import model outputs...", filter=";;".join(filters)) - - if len(filename) == 0: return + filename, selected_filter = QFileDialog.getOpenFileName( + self, + dir=models_dir, + caption="Import model outputs...", + filter=";;".join(filters), + ) + + if len(filename) == 0: + return if selected_filter == filters[0]: # Model as overlay datasource # This will show live inference results from sleap.gui.overlays.base import DataOverlay + overlay = DataOverlay.from_model(filename, self.video, player=self.player) self.overlays["inference"] = overlay @@ -930,27 +1169,42 @@ def visualizeOutputs(self): if show_confmaps: from sleap.gui.overlays.confmaps import ConfmapOverlay + confmap_overlay = ConfmapOverlay.from_h5(filename, player=self.player) - self.player.changedPlot.connect(lambda parent, idx: confmap_overlay.add_to_scene(None, idx)) + self.player.changedPlot.connect( + lambda parent, idx: confmap_overlay.add_to_scene(None, idx) + ) if show_pafs: from sleap.gui.overlays.pafs import PafOverlay + paf_overlay = PafOverlay.from_h5(filename, player=self.player) - self.player.changedPlot.connect(lambda parent, idx: paf_overlay.add_to_scene(None, idx)) + self.player.changedPlot.connect( + lambda parent, idx: paf_overlay.add_to_scene(None, idx) + ) self.plotFrame() def deletePredictions(self): - predicted_instances = [(lf, inst) for lf in self.labels for inst in lf if type(inst) == PredictedInstance] - - resp = QMessageBox.critical(self, - "Removing predicted instances", - f"There are {len(predicted_instances)} predicted instances. " - "Are you sure you want to delete these?", - QMessageBox.Yes, QMessageBox.No) - - if resp == QMessageBox.No: return + predicted_instances = [ + (lf, inst) + for lf in self.labels + for inst in lf + if type(inst) == PredictedInstance + ] + + resp = QMessageBox.critical( + self, + "Removing predicted instances", + f"There are {len(predicted_instances)} predicted instances. " + "Are you sure you want to delete these?", + QMessageBox.Yes, + QMessageBox.No, + ) + + if resp == QMessageBox.No: + return for lf, inst in predicted_instances: self.labels.remove_instance(lf, inst) @@ -961,10 +1215,14 @@ def deletePredictions(self): def deleteClipPredictions(self): - predicted_instances = [(lf, inst) - for lf in self.labels.find(self.video, frame_idx = range(*self.player.seekbar.getSelection())) - for inst in lf - if type(inst) == PredictedInstance] + predicted_instances = [ + (lf, inst) + for lf in self.labels.find( + self.video, frame_idx=range(*self.player.seekbar.getSelection()) + ) + for inst in lf + if type(inst) == PredictedInstance + ] # If user selected an instance, then only delete for that track. selected_inst = self.player.view.getSelectionInstance() @@ -976,15 +1234,21 @@ def deleteClipPredictions(self): predicted_instances = [(self.labeled_frame, selected_inst)] else: # Filter by track - predicted_instances = list(filter(lambda x: x[1].track == track, predicted_instances)) - - resp = QMessageBox.critical(self, - "Removing predicted instances", - f"There are {len(predicted_instances)} predicted instances. " - "Are you sure you want to delete these?", - QMessageBox.Yes, QMessageBox.No) - - if resp == QMessageBox.No: return + predicted_instances = list( + filter(lambda x: x[1].track == track, predicted_instances) + ) + + resp = QMessageBox.critical( + self, + "Removing predicted instances", + f"There are {len(predicted_instances)} predicted instances. " + "Are you sure you want to delete these?", + QMessageBox.Yes, + QMessageBox.No, + ) + + if resp == QMessageBox.No: + return # Delete the instances for lf, inst in predicted_instances: @@ -1002,7 +1266,8 @@ def delete_area_callback(x0, y0, x1, y1): self.updateStatusMessage() # Make sure there was an area selected - if x0==x1 or y0==y1: return + if x0 == x1 or y0 == y1: + return min_corner = (x0, y0) max_corner = (x1, y1) @@ -1016,49 +1281,54 @@ def is_bounded(inst): return is_gt_min and is_lt_max # Find all instances contained in selected area - predicted_instances = [(lf, inst) for lf in self.labels.find(self.video) - for inst in lf - if type(inst) == PredictedInstance - and is_bounded(inst)] + predicted_instances = [ + (lf, inst) + for lf in self.labels.find(self.video) + for inst in lf + if type(inst) == PredictedInstance and is_bounded(inst) + ] self._delete_confirm(predicted_instances) # Prompt the user to select area - self.updateStatusMessage(f"Please select the area from which to remove instances. This will be applied to all frames.") + self.updateStatusMessage( + f"Please select the area from which to remove instances. This will be applied to all frames." + ) self.player.onAreaSelection(delete_area_callback) def deleteLowScorePredictions(self): score_thresh, okay = QtWidgets.QInputDialog.getDouble( - self, - "Delete Instances with Low Score...", - "Score Below:", - 1, - 0, 100) + self, "Delete Instances with Low Score...", "Score Below:", 1, 0, 100 + ) if okay: # Find all instances contained in selected area - predicted_instances = [(lf, inst) for lf in self.labels.find(self.video) - for inst in lf - if type(inst) == PredictedInstance - and inst.score < score_thresh] + predicted_instances = [ + (lf, inst) + for lf in self.labels.find(self.video) + for inst in lf + if type(inst) == PredictedInstance and inst.score < score_thresh + ] self._delete_confirm(predicted_instances) def deleteFrameLimitPredictions(self): count_thresh, okay = QtWidgets.QInputDialog.getInt( - self, - "Limit Instances in Frame...", - "Maximum instances in a frame:", - 3, - 1, 100) + self, + "Limit Instances in Frame...", + "Maximum instances in a frame:", + 3, + 1, + 100, + ) if okay: predicted_instances = [] # Find all instances contained in selected area for lf in self.labels.find(self.video): if len(lf.instances) > count_thresh: # Get all but the count_thresh many instances with the highest score - extra_instances = sorted(lf.instances, - key=operator.attrgetter('score') - )[:-count_thresh] + extra_instances = sorted( + lf.instances, key=operator.attrgetter("score") + )[:-count_thresh] predicted_instances.extend([(lf, inst) for inst in extra_instances]) self._delete_confirm(predicted_instances) @@ -1066,13 +1336,17 @@ def deleteFrameLimitPredictions(self): def _delete_confirm(self, lf_inst_list): # Confirm that we want to delete - resp = QMessageBox.critical(self, - "Removing predicted instances", - f"There are {len(lf_inst_list)} predicted instances that would be deleted. " - "Are you sure you want to delete these?", - QMessageBox.Yes, QMessageBox.No) - - if resp == QMessageBox.No: return + resp = QMessageBox.critical( + self, + "Removing predicted instances", + f"There are {len(lf_inst_list)} predicted instances that would be deleted. " + "Are you sure you want to delete these?", + QMessageBox.Yes, + QMessageBox.No, + ) + + if resp == QMessageBox.No: + return # Delete the instances for lf, inst in lf_inst_list: @@ -1101,20 +1375,22 @@ def clearFrameNegativeAnchors(self): def importPredictions(self): filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"] - filenames, selected_filter = QFileDialog.getOpenFileNames(self, dir=None, caption="Import labeled data...", filter=";;".join(filters)) + filenames, selected_filter = QFileDialog.getOpenFileNames( + self, dir=None, caption="Import labeled data...", filter=";;".join(filters) + ) - if len(filenames) == 0: return + if len(filenames) == 0: + return for filename in filenames: gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)]) + search_paths=[os.path.dirname(filename)] + ) - new_labels = Labels.load_file( - filename, - video_callback=gui_video_callback) + new_labels = Labels.load_file(filename, video_callback=gui_video_callback) # Merging data is handled by MergeDialog - MergeDialog(base_labels = self.labels, new_labels = new_labels).exec_() + MergeDialog(base_labels=self.labels, new_labels=new_labels).exec_() # update display/ui self.plotFrame() @@ -1125,20 +1401,28 @@ def importPredictions(self): def doubleClickInstance(self, instance): # When a predicted instance is double-clicked, add a new instance if hasattr(instance, "score"): - self.newInstance(copy_instance = instance) + self.newInstance(copy_instance=instance) # When a regular instance is double-clicked, add any missing points else: # the rect that's currently visibile in the window view - in_view_rect = self.player.view.mapToScene(self.player.view.rect()).boundingRect() + in_view_rect = self.player.view.mapToScene( + self.player.view.rect() + ).boundingRect() for node in self.skeleton.nodes: if node.name not in instance.node_names or instance[node].isnan(): # pick random points within currently zoomed view - x = in_view_rect.x() + (in_view_rect.width() * 0.1) \ + x = ( + in_view_rect.x() + + (in_view_rect.width() * 0.1) + (np.random.rand() * in_view_rect.width() * 0.8) - y = in_view_rect.y() + (in_view_rect.height() * 0.1) \ + ) + y = ( + in_view_rect.y() + + (in_view_rect.height() * 0.1) + (np.random.rand() * in_view_rect.height() * 0.8) + ) # set point for node instance[node] = Point(x=x, y=y, visible=False) @@ -1174,7 +1458,9 @@ def newInstance(self, copy_instance=None): prev_idx = self.previousLabeledFrameIndex() if prev_idx is not None: - prev_instances = self.labels.find(self.video, prev_idx, return_new=True)[0].instances + prev_instances = self.labels.find( + self.video, prev_idx, return_new=True + )[0].instances if len(prev_instances) > len(self.labeled_frame.instances): # If more instances in previous frame than current, then use the # first unmatched instance. @@ -1195,26 +1481,41 @@ def newInstance(self, copy_instance=None): new_instance = Instance(skeleton=self.skeleton, from_predicted=from_predicted) # Get the rect that's currently visibile in the window view - in_view_rect = self.player.view.mapToScene(self.player.view.rect()).boundingRect() + in_view_rect = self.player.view.mapToScene( + self.player.view.rect() + ).boundingRect() # go through each node in skeleton for node in self.skeleton.node_names: # if we're copying from a skeleton that has this node - if copy_instance is not None and node in copy_instance and not copy_instance[node].isnan(): + if ( + copy_instance is not None + and node in copy_instance + and not copy_instance[node].isnan() + ): # just copy x, y, and visible # we don't want to copy a PredictedPoint or score attribute new_instance[node] = Point( - x=copy_instance[node].x, - y=copy_instance[node].y, - visible=copy_instance[node].visible) + x=copy_instance[node].x, + y=copy_instance[node].y, + visible=copy_instance[node].visible, + ) else: # pick random points within currently zoomed view - x = in_view_rect.x() + (in_view_rect.width() * 0.1) \ + x = ( + in_view_rect.x() + + (in_view_rect.width() * 0.1) + (np.random.rand() * in_view_rect.width() * 0.8) - y = in_view_rect.y() + (in_view_rect.height() * 0.1) \ + ) + y = ( + in_view_rect.y() + + (in_view_rect.height() * 0.1) + (np.random.rand() * in_view_rect.height() * 0.8) + ) # mark the node as not "visible" if we're copying from a predicted instance without this node - is_visible = copy_instance is None or not hasattr(copy_instance, "score") + is_visible = copy_instance is None or not hasattr( + copy_instance, "score" + ) # set point for node new_instance[node] = Point(x=x, y=y, visible=is_visible) @@ -1237,7 +1538,8 @@ def newInstance(self, copy_instance=None): def deleteSelectedInstance(self): selected_inst = self.player.view.getSelectionInstance() - if selected_inst is None: return + if selected_inst is None: + return self.labels.remove_instance(self.labeled_frame, selected_inst) self.changestack_push("delete instance") @@ -1247,7 +1549,8 @@ def deleteSelectedInstance(self): def deleteSelectedInstanceTrack(self): selected_inst = self.player.view.getSelectionInstance() - if selected_inst is None: return + if selected_inst is None: + return # to do: range of frames? @@ -1267,9 +1570,9 @@ def deleteSelectedInstanceTrack(self): self.updateSeekbarMarks() def addTrack(self): - track_numbers_used = [int(track.name) - for track in self.labels.tracks - if track.name.isnumeric()] + track_numbers_used = [ + int(track.name) for track in self.labels.tracks if track.name.isnumeric() + ] next_number = max(track_numbers_used, default=0) + 1 new_track = Track(spawned_on=self.player.frame_idx, name=next_number) @@ -1286,7 +1589,8 @@ def addTrack(self): def setInstanceTrack(self, new_track): vis_idx = self.player.view.getSelection() - if vis_idx is None: return + if vis_idx is None: + return selected_instance = self.labeled_frame.instances_to_show[vis_idx] idx = self.labeled_frame.index(selected_instance) @@ -1298,13 +1602,16 @@ def setInstanceTrack(self, new_track): if old_track is None: # Move anything already in the new track out of it new_track_instances = self.labels.find_track_instances( - video = self.video, - track = new_track, - frame_range = (self.player.frame_idx, self.player.frame_idx+1)) + video=self.video, + track=new_track, + frame_range=(self.player.frame_idx, self.player.frame_idx + 1), + ) for instance in new_track_instances: instance.track = None # Move selected instance into new track - self.labels.track_set_instance(self.labeled_frame, selected_instance, new_track) + self.labels.track_set_instance( + self.labeled_frame, selected_instance, new_track + ) # When the instance does already have a track, then we want to update # the track for a range of frames. @@ -1336,24 +1643,27 @@ def transposeInstance(self): # as the second instance in some other frame. # For the present, we can only "transpose" if there are multiple instances. - if len(self.labeled_frame.instances) < 2: return + if len(self.labeled_frame.instances) < 2: + return # If there are just two instances, transpose them. if len(self.labeled_frame.instances) == 2: - self._transpose_instances((0,1)) + self._transpose_instances((0, 1)) # If there are more than two, then we need the user to select the instances. else: - self.player.onSequenceSelect(seq_len = 2, - on_success = self._transpose_instances, - on_each = self._transpose_message, - on_failure = lambda x:self.updateStatusMessage() - ) - - def _transpose_message(self, instance_ids:list): + self.player.onSequenceSelect( + seq_len=2, + on_success=self._transpose_instances, + on_each=self._transpose_message, + on_failure=lambda x: self.updateStatusMessage(), + ) + + def _transpose_message(self, instance_ids: list): word = "next" if len(instance_ids) else "first" self.updateStatusMessage(f"Please select the {word} instance to transpose...") - def _transpose_instances(self, instance_ids:list): - if len(instance_ids) != 2: return + def _transpose_instances(self, instance_ids: list): + if len(instance_ids) != 2: + return idx_0 = instance_ids[0] idx_1 = instance_ids[1] @@ -1379,10 +1689,18 @@ def newProject(self): window.showMaximized() def openProject(self, first_open=False): - filters = ["JSON labels (*.json *.json.zip)", "HDF5 dataset (*.h5 *.hdf5)", "Matlab dataset (*.mat)", "DeepLabCut csv (*.csv)"] - filename, selected_filter = QFileDialog.getOpenFileName(self, dir=None, caption="Import labeled data...", filter=";;".join(filters)) - - if len(filename) == 0: return + filters = [ + "JSON labels (*.json *.json.zip)", + "HDF5 dataset (*.h5 *.hdf5)", + "Matlab dataset (*.mat)", + "DeepLabCut csv (*.csv)", + ] + filename, selected_filter = QFileDialog.getOpenFileName( + self, dir=None, caption="Import labeled data...", filter=";;".join(filters) + ) + + if len(filename) == 0: + return if OPEN_IN_NEW and not first_open: new_window = MainWindow() @@ -1405,13 +1723,17 @@ def saveProjectAs(self): p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem} copy{p.suffix}")) - filters = ["JSON labels (*.json)", "Compressed JSON (*.zip)", "HDF5 dataset (*.h5)"] - filename, selected_filter = QFileDialog.getSaveFileName(self, - caption="Save As...", - dir=default_name, - filter=";;".join(filters)) + filters = [ + "JSON labels (*.json)", + "Compressed JSON (*.zip)", + "HDF5 dataset (*.h5)", + ] + filename, selected_filter = QFileDialog.getSaveFileName( + self, caption="Save As...", dir=default_name, filter=";;".join(filters) + ) - if len(filename) == 0: return + if len(filename) == 0: + return if self._trySave(filename): # If save was successful @@ -1420,7 +1742,7 @@ def saveProjectAs(self): def _trySave(self, filename): success = False try: - Labels.save_file(labels = self.labels, filename = filename) + Labels.save_file(labels=self.labels, filename=filename) success = True # Mark savepoint in change stack self.changestack_savepoint() @@ -1443,7 +1765,9 @@ def closeEvent(self, event): msgBox = QMessageBox() msgBox.setText("Do you want to save the changes to this project?") msgBox.setInformativeText("If you don't save, your changes will be lost.") - msgBox.setStandardButtons(QMessageBox.Save | QMessageBox.Discard | QMessageBox.Cancel) + msgBox.setStandardButtons( + QMessageBox.Save | QMessageBox.Discard | QMessageBox.Cancel + ) msgBox.setDefaultButton(QMessageBox.Save) ret_val = msgBox.exec_() @@ -1461,24 +1785,26 @@ def closeEvent(self, event): event.accept() def nextVideo(self): - new_idx = self.video_idx+1 + new_idx = self.video_idx + 1 new_idx = 0 if new_idx >= len(self.labels.videos) else new_idx self.loadVideo(self.labels.videos[new_idx], new_idx) def previousVideo(self): - new_idx = self.video_idx-1 - new_idx = len(self.labels.videos)-1 if new_idx < 0 else new_idx + new_idx = self.video_idx - 1 + new_idx = len(self.labels.videos) - 1 if new_idx < 0 else new_idx self.loadVideo(self.labels.videos[new_idx], new_idx) def gotoFrame(self): frame_number, okay = QtWidgets.QInputDialog.getInt( - self, - "Go To Frame...", - "Frame Number:", - self.player.frame_idx+1, - 1, self.video.frames) + self, + "Go To Frame...", + "Frame Number:", + self.player.frame_idx + 1, + 1, + self.video.frames, + ) if okay: - self.plotFrame(frame_number-1) + self.plotFrame(frame_number - 1) def markFrame(self): self.mark_idx = self.player.frame_idx @@ -1488,32 +1814,45 @@ def goMarkedFrame(self): def exportLabeledClip(self): from sleap.io.visuals import save_labeled_video + if self.player.seekbar.hasSelection(): fps, okay = QtWidgets.QInputDialog.getInt( - self, - "Frames per second", - "Frames per second:", - getattr(self.video, "fps", 30), - 1, 300) - if not okay: return + self, + "Frames per second", + "Frames per second:", + getattr(self.video, "fps", 30), + 1, + 300, + ) + if not okay: + return - filename, _ = QFileDialog.getSaveFileName(self, caption="Save Video As...", dir=self.filename + ".avi", filter="AVI Video (*.avi)") + filename, _ = QFileDialog.getSaveFileName( + self, + caption="Save Video As...", + dir=self.filename + ".avi", + filter="AVI Video (*.avi)", + ) - if len(filename) == 0: return + if len(filename) == 0: + return save_labeled_video( - labels=self.labels, - video=self.video, - filename=filename, - frames=list(range(*self.player.seekbar.getSelection())), - fps=fps, - gui_progress=True - ) + labels=self.labels, + video=self.video, + filename=filename, + frames=list(range(*self.player.seekbar.getSelection())), + fps=fps, + gui_progress=True, + ) def exportLabeledFrames(self): - filename, _ = QFileDialog.getSaveFileName(self, caption="Save Labeled Frames As...", dir=self.filename) - if len(filename) == 0: return + filename, _ = QFileDialog.getSaveFileName( + self, caption="Save Labeled Frames As...", dir=self.filename + ) + if len(filename) == 0: + return Labels.save_json(self.labels, filename, save_frame_data=True) def previousLabeledFrameIndex(self): @@ -1559,21 +1898,28 @@ def nextUserLabeledFrame(self): self.plotFrame(next_idx) def nextSuggestedFrame(self, seek_direction=1): - next_video, next_frame = self.labels.get_next_suggestion(self.video, self.player.frame_idx, seek_direction) + next_video, next_frame = self.labels.get_next_suggestion( + self.video, self.player.frame_idx, seek_direction + ) if next_video is not None: self.gotoVideoAndFrame(next_video, next_frame) if next_frame is not None: - selection_idx = self.labels.get_suggestions().index((next_video, next_frame)) + selection_idx = self.labels.get_suggestions().index( + (next_video, next_frame) + ) self.suggestionsTable.selectRow(selection_idx) def nextTrackFrame(self): cur_idx = self.player.frame_idx track_ranges = self.labels.get_track_occupany(self.video) - next_idx = min([track_range.start - for track_range in track_ranges.values() - if track_range.start is not None - and track_range.start > cur_idx], - default=-1) + next_idx = min( + [ + track_range.start + for track_range in track_ranges.values() + if track_range.start is not None and track_range.start > cur_idx + ], + default=-1, + ) if next_idx > -1: self.plotFrame(next_idx) @@ -1602,12 +1948,14 @@ def setTrailLength(self, trail_length): self.overlays["trails"].trail_length = trail_length self._menu_check_single(self.trailLengthMenu, trail_length) - if self.video is not None: self.plotFrame() + if self.video is not None: + self.plotFrame() def setPalette(self, palette): self._color_manager.set_palette(palette) self._menu_check_single(self.paletteMenu, palette) - if self.video is not None: self.plotFrame() + if self.video is not None: + self.plotFrame() self.updateSeekbarMarks() def _menu_check_single(self, menu, item_text): @@ -1618,8 +1966,12 @@ def _menu_check_single(self, menu, item_text): menu_item.setChecked(False) def toggleColorPredicted(self): - self.overlays["instance"].color_predicted = not self.overlays["instance"].color_predicted - self._menu_actions["color predicted"].setChecked(self.overlays["instance"].color_predicted) + self.overlays["instance"].color_predicted = not self.overlays[ + "instance" + ].color_predicted + self._menu_actions["color predicted"].setChecked( + self.overlays["instance"].color_predicted + ) self.plotFrame() def toggleAutoZoom(self): @@ -1659,7 +2011,7 @@ def newFrame(self, player, frame_idx, selected_inst): # Trigger event after the overlays have been added player.view.updatedViewer.emit() - def updateStatusMessage(self, message = None): + def updateStatusMessage(self, message=None): if message is None: message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" if self.player.seekbar.hasSelection(): @@ -1671,7 +2023,9 @@ def updateStatusMessage(self, message = None): message += f" Labeled Frames: " if self.video is not None: - message += f"{len(self.labels.get_video_user_labeled_frames(self.video))}" + message += ( + f"{len(self.labels.get_video_user_labeled_frames(self.video))}" + ) if len(self.labels.videos) > 1: message += " in video, " if len(self.labels.videos) > 1: @@ -1679,6 +2033,7 @@ def updateStatusMessage(self, message = None): self.statusBar().showMessage(message) + def main(*args, **kwargs): app = QApplication([]) app.setApplicationName("sLEAP Label") @@ -1691,6 +2046,7 @@ def main(*args, **kwargs): app.exec_() + if __name__ == "__main__": kwargs = dict() diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index ebd40fc14..fe8613587 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -5,8 +5,21 @@ from PySide2.QtWidgets import QApplication, QMainWindow, QWidget, QDockWidget from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout -from PySide2.QtWidgets import QLabel, QPushButton, QLineEdit, QSpinBox, QDoubleSpinBox, QComboBox, QCheckBox -from PySide2.QtWidgets import QTableWidget, QTableView, QTableWidgetItem, QAbstractItemView +from PySide2.QtWidgets import ( + QLabel, + QPushButton, + QLineEdit, + QSpinBox, + QDoubleSpinBox, + QComboBox, + QCheckBox, +) +from PySide2.QtWidgets import ( + QTableWidget, + QTableView, + QTableWidgetItem, + QAbstractItemView, +) from PySide2.QtWidgets import QTreeView, QTreeWidget, QTreeWidgetItem from PySide2.QtWidgets import QMenu, QAction from PySide2.QtWidgets import QFileDialog, QMessageBox @@ -28,14 +41,16 @@ class VideosTable(QTableView): """Table view widget backed by a custom data model for displaying lists of Video instances. """ + def __init__(self, videos: list = []): super(VideosTable, self).__init__() self.setModel(VideosTableModel(videos)) self.setSelectionBehavior(QAbstractItemView.SelectRows) self.setSelectionMode(QAbstractItemView.SingleSelection) + class VideosTableModel(QtCore.QAbstractTableModel): - _props = ["filename", "frames", "height", "width", "channels",] + _props = ["filename", "frames", "height", "width", "channels"] def __init__(self, videos: list): super(VideosTableModel, self).__init__() @@ -51,11 +66,12 @@ def videos(self, val): self._cache = [] for video in val: row_data = dict( - filename=video.filename, - frames=video.frames, - height=video.height, - width=video.width, - channels=video.channels) + filename=video.filename, + frames=video.frames, + height=video.height, + width=video.width, + channels=video.channels, + ) self._cache.append(row_data) self.endResetModel() @@ -78,7 +94,9 @@ def rowCount(self, parent): def columnCount(self, parent): return len(VideosTableModel._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + ): if role == Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -94,12 +112,14 @@ def flags(self, index: QtCore.QModelIndex): class SkeletonNodesTable(QTableView): """Table view widget backed by a custom data model for displaying and editing Skeleton nodes. """ + def __init__(self, skeleton: Skeleton): super(SkeletonNodesTable, self).__init__() self.setModel(SkeletonNodesTableModel(skeleton)) self.setSelectionBehavior(QAbstractItemView.SelectRows) self.setSelectionMode(QAbstractItemView.SingleSelection) + class SkeletonNodesTableModel(QtCore.QAbstractTableModel): _props = ["name", "symmetry"] @@ -121,7 +141,9 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): if role == Qt.DisplayRole and index.isValid(): node_idx = index.row() prop = self._props[index.column()] - node = self.skeleton.nodes[node_idx] # FIXME? can we assume order is stable? + node = self.skeleton.nodes[ + node_idx + ] # FIXME? can we assume order is stable? node_name = node.name if prop == "name": @@ -137,7 +159,9 @@ def rowCount(self, parent): def columnCount(self, parent): return len(SkeletonNodesTableModel._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + ): if role == Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -156,12 +180,14 @@ def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): if len(value) > 0: self._skeleton.relabel_node(node_name, value) # else: - # self._skeleton.delete_node(node_name) + # self._skeleton.delete_node(node_name) elif prop == "symmetry": if len(value) > 0: self._skeleton.add_symmetry(node_name, value) else: - self._skeleton.delete_symmetry(node_name, self._skeleton.get_symmetry(node_name)) + self._skeleton.delete_symmetry( + node_name, self._skeleton.get_symmetry(node_name) + ) # send signal that data has changed self.dataChanged.emit(index, index) @@ -180,12 +206,14 @@ def flags(self, index: QtCore.QModelIndex): class SkeletonEdgesTable(QTableView): """Table view widget backed by a custom data model for displaying and editing Skeleton edges. """ + def __init__(self, skeleton: Skeleton): super(SkeletonEdgesTable, self).__init__() self.setModel(SkeletonEdgesTableModel(skeleton)) self.setSelectionBehavior(QAbstractItemView.SelectRows) self.setSelectionMode(QAbstractItemView.SingleSelection) + class SkeletonEdgesTableModel(QtCore.QAbstractTableModel): _props = ["source", "destination"] @@ -222,7 +250,9 @@ def rowCount(self, parent): def columnCount(self, parent): return len(SkeletonNodesTableModel._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + ): if role == Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -235,8 +265,6 @@ def flags(self, index: QtCore.QModelIndex): return Qt.ItemIsEnabled | Qt.ItemIsSelectable - - class LabeledFrameTable(QTableView): """Table view widget backed by a custom data model for displaying lists of Video instances. """ @@ -330,12 +358,18 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): return None def rowCount(self, parent): - return len(self.labeled_frame.instances_to_show) if self.labeled_frame is not None else 0 + return ( + len(self.labeled_frame.instances_to_show) + if self.labeled_frame is not None + else 0 + ) def columnCount(self, parent): return len(LabeledFrameTableModel._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + ): if role == Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -372,7 +406,6 @@ def flags(self, index: QtCore.QModelIndex): class SkeletonNodeModel(QtCore.QStringListModel): - def __init__(self, skeleton: Skeleton, src_node: Callable = None): super(SkeletonNodeModel, self).__init__() self._src_node = src_node @@ -434,6 +467,7 @@ def flags(self, index: QtCore.QModelIndex): class SuggestionsTable(QTableView): """Table view widget backed by a custom data model for displaying lists of Video instances. """ + def __init__(self, labels): super(SuggestionsTable, self).__init__() self.setModel(SuggestionsTableModel(labels)) @@ -441,8 +475,9 @@ def __init__(self, labels): self.setSelectionMode(QAbstractItemView.SingleSelection) self.setSortingEnabled(True) + class SuggestionsTableModel(QtCore.QAbstractTableModel): - _props = ["video", "frame", "labeled", "mean score",] + _props = ["video", "frame", "labeled", "mean score"] def __init__(self, labels): super(SuggestionsTableModel, self).__init__() @@ -473,7 +508,7 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): if prop == "video": return f"{self.labels.videos.index(video)}: {os.path.basename(video.filename)}" elif prop == "frame": - return int(frame_idx) + 1 # start at frame 1 rather than 0 + return int(frame_idx) + 1 # start at frame 1 rather than 0 elif prop == "labeled": # show how many labeled instances are in this frame val = self._labels.instance_count(video, frame_idx) @@ -485,7 +520,12 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): return None def _getScore(self, video, frame_idx): - scores = [inst.score for lf in self.labels.find(video, frame_idx) for inst in lf if hasattr(inst, "score")] + scores = [ + inst.score + for lf in self.labels.find(video, frame_idx) + for inst in lf + if hasattr(inst, "score") + ] return sum(scores) / len(scores) def sort(self, column_idx: int, order: Qt.SortOrder): @@ -497,7 +537,7 @@ def sort(self, column_idx: int, order: Qt.SortOrder): elif prop == "mean score": sort_function = lambda s: self._getScore(*s) - reverse = (order == Qt.SortOrder.DescendingOrder) + reverse = order == Qt.SortOrder.DescendingOrder self.beginResetModel() self._suggestions_list.sort(key=sort_function, reverse=reverse) @@ -509,7 +549,9 @@ def rowCount(self, *args): def columnCount(self, *args): return len(self._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + ): if role == Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -524,7 +566,9 @@ def flags(self, index: QtCore.QModelIndex): if __name__ == "__main__": - labels = Labels.load_json("tests/data/json_format_v2/centered_pair_predictions.json") + labels = Labels.load_json( + "tests/data/json_format_v2/centered_pair_predictions.json" + ) skeleton = labels.labels[0].instances[0].skeleton Labels.save_json(labels, "test.json") @@ -538,4 +582,4 @@ def flags(self, index: QtCore.QModelIndex): table = LabeledFrameTable(labels.labels[0], labels) table.show() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/formbuilder.py b/sleap/gui/formbuilder.py index 0ae8bcd46..875996e6a 100644 --- a/sleap/gui/formbuilder.py +++ b/sleap/gui/formbuilder.py @@ -13,6 +13,7 @@ from PySide2 import QtWidgets, QtCore + class YamlFormWidget(QtWidgets.QGroupBox): """ Custom QWidget which creates form based on yaml file. @@ -25,10 +26,10 @@ class YamlFormWidget(QtWidgets.QGroupBox): mainAction = QtCore.Signal(dict) valueChanged = QtCore.Signal() - def __init__(self, yaml_file, which_form: str="main", *args, **kwargs): + def __init__(self, yaml_file, which_form: str = "main", *args, **kwargs): super(YamlFormWidget, self).__init__(*args, **kwargs) - with open(yaml_file, 'r') as form_yaml: + with open(yaml_file, "r") as form_yaml: items_to_create = yaml.load(form_yaml, Loader=yaml.SafeLoader) self.which_form = which_form @@ -71,6 +72,7 @@ def trigger_main_action(self): """Emit mainAction signal with form data.""" self.mainAction.emit(self.get_form_data()) + class FormBuilderLayout(QtWidgets.QFormLayout): """ Custom QFormLayout which populates itself from list of form fields. @@ -96,10 +98,12 @@ def get_form_data(self) -> dict: Dict with key:value for each user-editable widget in layout """ widgets = self.fields.values() - data = {w.objectName(): self.get_widget_value(w) - for w in widgets - if len(w.objectName()) - and type(w) not in (QtWidgets.QLabel, QtWidgets.QPushButton)} + data = { + w.objectName(): self.get_widget_value(w) + for w in widgets + if len(w.objectName()) + and type(w) not in (QtWidgets.QLabel, QtWidgets.QPushButton) + } stacked_data = [w.get_data() for w in widgets if type(w) == StackBuilderWidget] for stack in stacked_data: data.update(stack) @@ -118,7 +122,8 @@ def set_form_data(self, data: dict): self.set_widget_value(widgets[name], val) else: pass -# print(f"no {name} widget found") + + # print(f"no {name} widget found") @staticmethod def set_widget_value(widget, val): @@ -260,21 +265,27 @@ def build_form(self, items_to_create): self.addRow("", self._make_file_button(item, field)) def _make_file_button(self, item, field): - file_button = QtWidgets.QPushButton("Select "+item["label"]) + file_button = QtWidgets.QPushButton("Select " + item["label"]) if item["type"].split("_")[-1] == "open": # Define function for button to trigger def select_file(*args, x=field): filter = item.get("filter", "Any File (*.*)") - filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, directory=None, caption="Open File", filter=filter) - if len(filename): x.setText(filename) + filename, _ = QtWidgets.QFileDialog.getOpenFileName( + None, directory=None, caption="Open File", filter=filter + ) + if len(filename): + x.setText(filename) self.valueChanged.emit() elif item["type"].split("_")[-1] == "dir": # Define function for button to trigger def select_file(*args, x=field): - filename = QtWidgets.QFileDialog.getExistingDirectory(None, directory=None, caption="Open File") - if len(filename): x.setText(filename) + filename = QtWidgets.QFileDialog.getExistingDirectory( + None, directory=None, caption="Open File" + ) + if len(filename): + x.setText(filename) self.valueChanged.emit() else: @@ -283,6 +294,7 @@ def select_file(*args, x=field): file_button.clicked.connect(select_file) return file_button + class StackBuilderWidget(QtWidgets.QWidget): def __init__(self, stack_data, *args, **kwargs): super(StackBuilderWidget, self).__init__(*args, **kwargs) @@ -291,7 +303,9 @@ def __init__(self, stack_data, *args, **kwargs): self.combo_box = QtWidgets.QComboBox() self.stacked_widget = ResizingStackedWidget() - self.combo_box.activated.connect(lambda x: self.stacked_widget.setCurrentIndex(x)) + self.combo_box.activated.connect( + lambda x: self.stacked_widget.setCurrentIndex(x) + ) self.page_layouts = dict() diff --git a/sleap/gui/importvideos.py b/sleap/gui/importvideos.py index dce9cfa11..3cbd145e2 100644 --- a/sleap/gui/importvideos.py +++ b/sleap/gui/importvideos.py @@ -22,7 +22,14 @@ from PySide2.QtCore import Qt, QRectF, Signal from PySide2.QtWidgets import QApplication, QLayout, QVBoxLayout, QHBoxLayout, QFrame from PySide2.QtWidgets import QFileDialog, QDialog, QWidget, QLabel, QScrollArea -from PySide2.QtWidgets import QPushButton, QButtonGroup, QRadioButton, QCheckBox, QComboBox, QStackedWidget +from PySide2.QtWidgets import ( + QPushButton, + QButtonGroup, + QRadioButton, + QCheckBox, + QComboBox, + QStackedWidget, +) from sleap.gui.video import GraphicsView from sleap.io.video import Video, HDF5Video, MediaVideo @@ -30,12 +37,13 @@ import h5py import qimage2ndarray + class ImportVideos: """Class to handle video importing UI.""" - + def __init__(self): self.result = [] - + def go(self): """Runs the import UI. @@ -48,20 +56,21 @@ def go(self): List with dict of the parameters for each file to import. """ dialog = QFileDialog() - #dialog.setOption(QFileDialog.Option.DontUseNativeDialogs, True) + # dialog.setOption(QFileDialog.Option.DontUseNativeDialogs, True) file_names, filter = dialog.getOpenFileNames( - None, - "Select videos to import...", # dialogue title - ".", # initial path - "Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)", # filters - #options=QFileDialog.DontUseNativeDialog - ) + None, + "Select videos to import...", # dialogue title + ".", # initial path + "Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)", # filters + # options=QFileDialog.DontUseNativeDialog + ) if len(file_names) > 0: importer = ImportParamDialog(file_names) - importer.accepted.connect(lambda:importer.get_data(self.result)) + importer.accepted.connect(lambda: importer.get_data(self.result)) importer.exec_() return self.result + class ImportParamDialog(QDialog): """Dialog for selecting parameters with preview when importing video. @@ -69,13 +78,13 @@ class ImportParamDialog(QDialog): file_names (list): List of files we want to import. """ - def __init__(self, file_names:list, *args, **kwargs): + def __init__(self, file_names: list, *args, **kwargs): super(ImportParamDialog, self).__init__(*args, **kwargs) - + self.import_widgets = [] - + self.setWindowTitle("Video Import Options") - + self.import_types = [ { "video_type": "hdf5", @@ -86,55 +95,52 @@ def __init__(self, file_names:list, *args, **kwargs): "name": "dataset", "type": "function_menu", "options": "_get_h5_dataset_options", - "required": True + "required": True, }, { "name": "input_format", "type": "radio", - "options": "channels_first,channels_last" - } - ] + "options": "channels_first,channels_last", + }, + ], }, { "video_type": "mp4", "match": "mp4,avi", "video_class": Video.from_media, - "params": [ - { - "name": "grayscale", - "type": "check" - } - ] + "params": [{"name": "grayscale", "type": "check"}], }, { "video_type": "numpy", "match": "npy", "video_class": Video.from_numpy, - "params": [] + "params": [], }, { "video_type": "imgstore", "match": "json", "video_class": Video.from_filename, - "params": [] - } + "params": [], + }, ] - + outer_layout = QVBoxLayout() - + scroll_widget = QScrollArea() - #scroll_widget.setWidgetResizable(False) + # scroll_widget.setWidgetResizable(False) scroll_widget.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) scroll_widget.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - + scroll_items_widget = QWidget() scroll_layout = QVBoxLayout() for file_name in file_names: if file_name: this_type = None for import_type in self.import_types: - if import_type.get("match",None) is not None: - if file_name.lower().endswith(tuple(import_type["match"].split(","))): + if import_type.get("match", None) is not None: + if file_name.lower().endswith( + tuple(import_type["match"].split(",")) + ): this_type = import_type break if this_type is not None: @@ -146,7 +152,7 @@ def __init__(self, file_names:list, *args, **kwargs): scroll_items_widget.setLayout(scroll_layout) scroll_widget.setWidget(scroll_items_widget) outer_layout.addWidget(scroll_widget) - + button_layout = QHBoxLayout() cancel_button = QPushButton("Cancel") import_button = QPushButton("Import") @@ -154,15 +160,15 @@ def __init__(self, file_names:list, *args, **kwargs): button_layout.addStretch() button_layout.addWidget(cancel_button) button_layout.addWidget(import_button) - + outer_layout.addLayout(button_layout) - + self.setLayout(outer_layout) - + import_button.clicked.connect(self.accept) cancel_button.clicked.connect(self.reject) - def get_data(self, import_result = None): + def get_data(self, import_result=None): """Method to get results from import. Args: @@ -187,6 +193,7 @@ def paint(self, painter, option, widget=None): """Method required by Qt.""" pass + class ImportItemWidget(QFrame): """Widget for selecting parameters with preview when importing video. @@ -194,37 +201,39 @@ class ImportItemWidget(QFrame): file_path (str): Full path to selected video file. import_type (dict): Data about user-selectable import parameters. """ - + def __init__(self, file_path: str, import_type: dict, *args, **kwargs): super(ImportItemWidget, self).__init__(*args, **kwargs) - + self.file_path = file_path self.import_type = import_type self.video = None - + import_item_layout = QVBoxLayout() - + self.enabled_checkbox_widget = QCheckBox(self.file_path) self.enabled_checkbox_widget.setChecked(True) import_item_layout.addWidget(self.enabled_checkbox_widget) - - #import_item_layout.addWidget(QLabel(self.file_path)) + + # import_item_layout.addWidget(QLabel(self.file_path)) inner_layout = QHBoxLayout() - self.options_widget = ImportParamWidget(parent=self, file_path = self.file_path, import_type = self.import_type) + self.options_widget = ImportParamWidget( + parent=self, file_path=self.file_path, import_type=self.import_type + ) self.preview_widget = VideoPreviewWidget(parent=self) self.preview_widget.setFixedSize(200, 200) - + self.enabled_checkbox_widget.stateChanged.connect( - lambda state:self.options_widget.setEnabled(state == Qt.Checked) + lambda state: self.options_widget.setEnabled(state == Qt.Checked) ) - + inner_layout.addWidget(self.options_widget) inner_layout.addWidget(self.preview_widget) import_item_layout.addLayout(inner_layout) self.setLayout(import_item_layout) - + self.setFrameStyle(QFrame.Panel) - + self.options_widget.changed.connect(self.update_video) self.update_video(initial=True) @@ -245,12 +254,12 @@ def get_data(self) -> dict: Returns: Dict with data for this video. """ - + video_data = { - "params": self.options_widget.get_values(), - "video_type": self.import_type["video_type"], - "video_class": self.import_type["video_class"], - } + "params": self.options_widget.get_values(), + "video_type": self.import_type["video_type"], + "video_class": self.import_type["video_class"], + } return video_data def update_video(self, initial=False): @@ -289,6 +298,7 @@ def paint(self, painter, option, widget=None): """Method required by Qt.""" pass + class ImportParamWidget(QWidget): """Widget for allowing user to select video parameters. @@ -302,29 +312,29 @@ class ImportParamWidget(QWidget): changed = Signal() - def __init__(self, file_path:str, import_type:dict, *args, **kwargs): + def __init__(self, file_path: str, import_type: dict, *args, **kwargs): super(ImportParamWidget, self).__init__(*args, **kwargs) - + self.file_path = file_path self.import_type = import_type self.widget_elements = {} self.video_params = {} - + option_layout = self.make_layout() - #self.changed.connect( lambda: print(self.get_values()) ) - + # self.changed.connect( lambda: print(self.get_values()) ) + self.setLayout(option_layout) - + def make_layout(self) -> QLayout: """Builds the layout of widgets for user-selected import parameters.""" - + param_list = self.import_type["params"] widget_layout = QVBoxLayout() widget_elements = dict() for param_item in param_list: name = param_item["name"] type = param_item["type"] - options = param_item.get("options",None) + options = param_item.get("options", None) if type == "radio": radio_group = QButtonGroup(parent=self) option_list = options.split(",") @@ -335,11 +345,11 @@ def make_layout(self) -> QLayout: btn_widget.setChecked(True) widget_layout.addWidget(btn_widget) radio_group.addButton(btn_widget) - radio_group.buttonToggled.connect(lambda:self.changed.emit()) + radio_group.buttonToggled.connect(lambda: self.changed.emit()) widget_elements[name] = radio_group elif type == "check": check_widget = QCheckBox(name) - check_widget.stateChanged.connect(lambda:self.changed.emit()) + check_widget.stateChanged.connect(lambda: self.changed.emit()) widget_layout.addWidget(check_widget) widget_elements[name] = check_widget elif type == "function_menu": @@ -348,12 +358,12 @@ def make_layout(self) -> QLayout: option_list = getattr(self, options)() for option in option_list: list_widget.addItem(option) - list_widget.currentIndexChanged.connect(lambda:self.changed.emit()) + list_widget.currentIndexChanged.connect(lambda: self.changed.emit()) widget_layout.addWidget(list_widget) widget_elements[name] = list_widget self.widget_elements = widget_elements return widget_layout - + def get_values(self, only_required=False): """Method to get current user-selected values for import parameters. @@ -394,10 +404,10 @@ def set_values_from_video(self, video): for param in param_list: name = param["name"] type = param["type"] - print(name,type) + print(name, type) if hasattr(video, name): val = getattr(video, name) - print(name,val) + print(name, val) widget = self.widget_elements[name] if hasattr(widget, "isChecked"): widget.setChecked(val) @@ -421,12 +431,12 @@ def _get_h5_dataset_options(self) -> list: This is used to populate the "function_menu"-type param. """ try: - with h5py.File(self.file_path,"r") as f: - options = self._find_h5_datasets("",f) + with h5py.File(self.file_path, "r") as f: + options = self._find_h5_datasets("", f) except Exception as e: options = [] return options - + def _find_h5_datasets(self, data_path, data_object) -> list: """Recursively find datasets in hdf5 file.""" options = [] @@ -435,7 +445,9 @@ def _find_h5_datasets(self, data_path, data_object) -> list: if len(data_object[key].shape) == 4: options.append(data_path + "/" + key) elif isinstance(data_object[key], h5py._hl.group.Group): - options.extend(self._find_h5_datasets(data_path + "/" + key, data_object[key])) + options.extend( + self._find_h5_datasets(data_path + "/" + key, data_object[key]) + ) return options def boundingRect(self) -> QRectF: @@ -471,28 +483,33 @@ def __init__(self, video: Video = None, *args, **kwargs): self.layout.addWidget(self.video_label) self.setLayout(self.layout) self.view.show() - + if video is not None: self.load_video(video) - + def clear_video(self): """Clear the video preview.""" self.view.clear() - + def load_video(self, video: Video, initial_frame=0, plot=True): """Load the video preview and display label text.""" self.video = video self.frame_idx = initial_frame - label = "(%d, %d), %d f, %d c" % (self.video.width, self.video.height, self.video.frames, self.video.channels) + label = "(%d, %d), %d f, %d c" % ( + self.video.width, + self.video.height, + self.video.frames, + self.video.channels, + ) self.video_label.setText(label) if plot: self.plot(initial_frame) - + def plot(self, idx=0): """Show the video preview.""" if self.video is None: return - + # Get image data frame = self.video.get_frame(idx) # Clear existing objects @@ -514,9 +531,12 @@ def paint(self, painter, option, widget=None): if __name__ == "__main__": app = QApplication([]) - + import_list = ImportVideos().go() - + for import_item in import_list: vid = import_item["video_class"](**import_item["params"]) - print("Imported video data: (%d, %d), %d f, %d c" % (vid.width, vid.height, vid.frames, vid.channels)) + print( + "Imported video data: (%d, %d), %d f, %d c" + % (vid.width, vid.height, vid.frames, vid.channels) + ) diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py index 22199e8c5..8aa687241 100644 --- a/sleap/gui/merge.py +++ b/sleap/gui/merge.py @@ -16,20 +16,18 @@ USE_NEITHER_STRING = "Discard all conflicting instances" CLEAN_STRING = "Accept clean merge" -class MergeDialog(QtWidgets.QDialog): - def __init__(self, - base_labels: Labels, - new_labels: Labels, - *args, **kwargs): +class MergeDialog(QtWidgets.QDialog): + def __init__(self, base_labels: Labels, new_labels: Labels, *args, **kwargs): super(MergeDialog, self).__init__(*args, **kwargs) self.base_labels = base_labels self.new_labels = new_labels - merged, self.extra_base, self.extra_new = \ - Labels.complex_merge_between(self.base_labels, self.new_labels) + merged, self.extra_base, self.extra_new = Labels.complex_merge_between( + self.base_labels, self.new_labels + ) merge_total = 0 merge_frames = 0 @@ -52,12 +50,16 @@ def __init__(self, merge_table = MergeTable(merged) layout.addWidget(merge_table) - conflict_text = "There are no conflicts." if not self.extra_base else "Merge conflicts:" + conflict_text = ( + "There are no conflicts." if not self.extra_base else "Merge conflicts:" + ) conflict_label = QtWidgets.QLabel(conflict_text) layout.addWidget(conflict_label) if self.extra_base: - conflict_table = ConflictTable(self.base_labels, self.extra_base, self.extra_new) + conflict_table = ConflictTable( + self.base_labels, self.extra_base, self.extra_new + ) layout.addWidget(conflict_table) self.merge_method = QtWidgets.QComboBox() @@ -90,18 +92,22 @@ def finishMerge(self): self.accept() + class ConflictTable(QtWidgets.QTableView): def __init__(self, *args, **kwargs): super(ConflictTable, self).__init__() self.setModel(ConflictTableModel(*args, **kwargs)) + class ConflictTableModel(QtCore.QAbstractTableModel): _props = ["video", "frame", "base", "new"] - def __init__(self, - base_labels: Labels, - extra_base: List[LabeledFrame], - extra_new: List[LabeledFrame]): + def __init__( + self, + base_labels: Labels, + extra_base: List[LabeledFrame], + extra_new: List[LabeledFrame], + ): super(ConflictTableModel, self).__init__() self.base_labels = base_labels self.extra_base = extra_base @@ -130,7 +136,9 @@ def rowCount(self, *args): def columnCount(self, *args): return len(self._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole): + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole + ): if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -138,25 +146,30 @@ def headerData(self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt return section return None + class MergeTable(QtWidgets.QTableView): def __init__(self, *args, **kwargs): super(MergeTable, self).__init__() self.setModel(MergeTableModel(*args, **kwargs)) + class MergeTableModel(QtCore.QAbstractTableModel): _props = ["video", "frame", "merged instances"] - def __init__(self, merged: List[List['Instance']]): + def __init__(self, merged: List[List["Instance"]]): super(MergeTableModel, self).__init__() self.merged = merged self.data_table = [] for video in self.merged.keys(): for frame_idx, frame_instance_list in self.merged[video].items(): - self.data_table.append(dict( - filename=video.filename, - frame_idx=frame_idx, - instances=frame_instance_list)) + self.data_table.append( + dict( + filename=video.filename, + frame_idx=frame_idx, + instances=frame_instance_list, + ) + ) def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): if role == QtCore.Qt.DisplayRole and index.isValid(): @@ -179,7 +192,9 @@ def rowCount(self, *args): def columnCount(self, *args): return len(self._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole): + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole + ): if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -187,15 +202,19 @@ def headerData(self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt return section return None + def show_instance_type_counts(instance_list): - prediction_count = len(list(filter(lambda inst: hasattr(inst, "score"), instance_list))) + prediction_count = len( + list(filter(lambda inst: hasattr(inst, "score"), instance_list)) + ) user_count = len(instance_list) - prediction_count return f"{prediction_count}/{user_count}" + if __name__ == "__main__": -# file_a = "tests/data/json_format_v1/centered_pair.json" -# file_b = "tests/data/json_format_v2/centered_pair_predictions.json" + # file_a = "tests/data/json_format_v1/centered_pair.json" + # file_b = "tests/data/json_format_v2/centered_pair_predictions.json" file_a = "files/merge/a.h5" file_b = "files/merge/b.h5" @@ -205,4 +224,4 @@ def show_instance_type_counts(instance_list): app = QtWidgets.QApplication() win = MergeDialog(base_labels, new_labels) win.show() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/multicheck.py b/sleap/gui/multicheck.py index 082df6911..60d34db9d 100644 --- a/sleap/gui/multicheck.py +++ b/sleap/gui/multicheck.py @@ -9,6 +9,7 @@ from PySide2.QtCore import QRectF, Signal from PySide2.QtWidgets import QGridLayout, QGroupBox, QButtonGroup, QCheckBox + class MultiCheckWidget(QGroupBox): """Qt Widget to show multiple checkboxes for selecting from a sequence of numbers. @@ -25,7 +26,7 @@ def __init__(self, *args, count, title="", selected=None, default=False, **kwarg # QButtonGroup is the logical container # it allows us to get list of checked boxes more easily self.check_group = QButtonGroup() - self.check_group.setExclusive(False) # more than one can be checked + self.check_group.setExclusive(False) # more than one can be checked if title != "": self.setTitle(title) @@ -39,11 +40,11 @@ def __init__(self, *args, count, title="", selected=None, default=False, **kwarg check_layout = QGridLayout() self.setLayout(check_layout) for i in range(count): - check = QCheckBox("%d"%(i)) + check = QCheckBox("%d" % (i)) # call signal/slot on self when one of the checkboxes is changed check.stateChanged.connect(lambda e: self.selectionChanged.emit()) self.check_group.addButton(check, i) - check_layout.addWidget(check, i//8, i%8) + check_layout.addWidget(check, i // 8, i % 8) self.setSelected(selected) """ diff --git a/sleap/gui/overlays/anchors.py b/sleap/gui/overlays/anchors.py index 04a2306f3..3025f0e3f 100644 --- a/sleap/gui/overlays/anchors.py +++ b/sleap/gui/overlays/anchors.py @@ -5,23 +5,38 @@ from sleap.gui.video import QtVideoPlayer from sleap.io.dataset import Labels + @attr.s(auto_attribs=True) class NegativeAnchorOverlay: - labels: Labels=None - scene: QtWidgets.QGraphicsScene=None + labels: Labels = None + scene: QtWidgets.QGraphicsScene = None pen = QtGui.QPen(QtGui.QColor("red")) - line_len: int=3 + line_len: int = 3 def add_to_scene(self, video, frame_idx): - if self.labels is None: return - if video not in self.labels.negative_anchors: return - + if self.labels is None: + return + if video not in self.labels.negative_anchors: + return + anchors = self.labels.negative_anchors[video] for idx, x, y in anchors: if frame_idx == idx: - self._add(x,y) + self._add(x, y) def _add(self, x, y): - self.scene.addLine(x-self.line_len, y-self.line_len, x+self.line_len, y+self.line_len, self.pen) - self.scene.addLine(x+self.line_len, y-self.line_len, x-self.line_len, y+self.line_len, self.pen) \ No newline at end of file + self.scene.addLine( + x - self.line_len, + y - self.line_len, + x + self.line_len, + y + self.line_len, + self.pen, + ) + self.scene.addLine( + x + self.line_len, + y - self.line_len, + x - self.line_len, + y + self.line_len, + self.pen, + ) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 78a1672b8..e521a2df7 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -10,34 +10,38 @@ from sleap.gui.video import QtVideoPlayer from sleap.nn.transform import DataTransform + class HDF5Data(HDF5Video): def __getitem__(self, i): """Get data for frame i from `HDF5Video` object.""" x = self.get_frame(i) - return np.clip(x,0,1) + return np.clip(x, 0, 1) + @attr.s(auto_attribs=True) class ModelData: # TODO: Unify this class with inference.Predictor or InferenceModel - model: 'keras.Model' + model: "keras.Model" video: Video - do_rescale: bool=False - output_scale: float=1.0 - adjust_vals: bool=True + do_rescale: bool = False + output_scale: float = 1.0 + adjust_vals: bool = True def __getitem__(self, i): """Data data for frame i from predictor.""" frame_img = self.video[i] # Trim to size that works for model - frame_img = frame_img[:, :self.video.height//8*8, :self.video.width//8*8, :] + frame_img = frame_img[ + :, : self.video.height // 8 * 8, : self.video.width // 8 * 8, : + ] inference_transform = DataTransform() if self.do_rescale: # Scale input image if model trained on scaled images frame_img = inference_transform.scale_to( - imgs=frame_img, - target_size=self.model.input_shape[1:3]) + imgs=frame_img, target_size=self.model.input_shape[1:3] + ) # Get predictions frame_result = self.model.predict(frame_img.astype("float32") / 255) @@ -55,13 +59,14 @@ def __getitem__(self, i): # even though this model may not give us adequate predictions. max_val = np.max(frame_result) if max_val < 1: - frame_result = frame_result/np.max(frame_result) + frame_result = frame_result / np.max(frame_result) # Clip values to ensure that they're within [0, 1] frame_result = np.clip(frame_result, 0, 1) return frame_result + @attr.s(auto_attribs=True) class DataOverlay: @@ -71,12 +76,17 @@ class DataOverlay: transform: DataTransform = None def add_to_scene(self, video, frame_idx): - if self.data is None: return + if self.data is None: + return # Check if video matches video for ModelData object if hasattr(self.data, "video") and self.data.video != video: video_shape = (video.height, video.width, video.channels) - prior_shape = (self.data.video.height, self.data.video.width, self.data.video.channels) + prior_shape = ( + self.data.video.height, + self.data.video.width, + self.data.video.channels, + ) # Check if the videos are both compatible with the loaded model if video_shape == prior_shape: # Shapes match so we can apply model to this video @@ -91,8 +101,11 @@ def add_to_scene(self, video, frame_idx): else: # If data indices are different than frame indices, use data # index; otherwise just use frame index. - idxs = self.transform.get_data_idxs(frame_idx) \ - if self.transform.frame_idxs else [frame_idx] + idxs = ( + self.transform.get_data_idxs(frame_idx) + if self.transform.frame_idxs + else [frame_idx] + ) # Loop over indices, in case there's more than one for frame for idx in idxs: @@ -102,12 +115,17 @@ def add_to_scene(self, video, frame_idx): x, y = 0, 0 overlay_object = self.overlay_class( - self.data[idx], - scale=self.transform.scale) + self.data[idx], scale=self.transform.scale + ) - self._add(self.player.view.scene, overlay_object, (x,y)) + self._add(self.player.view.scene, overlay_object, (x, y)) - def _add(self, to: QtWidgets.QGraphicsScene, what: QtWidgets.QGraphicsObject, where: tuple=(0,0)): + def _add( + self, + to: QtWidgets.QGraphicsScene, + what: QtWidgets.QGraphicsObject, + where: tuple = (0, 0), + ): to.addItem(what) what.setPos(*where) @@ -115,13 +133,15 @@ def _add(self, to: QtWidgets.QGraphicsScene, what: QtWidgets.QGraphicsObject, wh def from_h5(cls, filename, dataset, input_format="channels_last", **kwargs): import h5py as h5 - with h5.File(filename, 'r') as f: + with h5.File(filename, "r") as f: frame_idxs = np.asarray(f["frame_idxs"], dtype="int") bounding_boxes = np.asarray(f["bounds"]) transform = DataTransform(frame_idxs=frame_idxs, bounding_boxes=bounding_boxes) - data_object = HDF5Data(filename, dataset, input_format=input_format, convert_range=False) + data_object = HDF5Data( + filename, dataset, input_format=input_format, convert_range=False + ) return cls(data=data_object, transform=transform, **kwargs) @@ -135,17 +155,19 @@ def from_model(cls, filename, video, **kwargs): trainingjob = TrainingJob.load_json(filename) - input_size = (video.height//8*8, video.width//8*8, video.channels) + input_size = (video.height // 8 * 8, video.width // 8 * 8, video.channels) model_output_type = trainingjob.model.output_type model = load_model( - sleap_models={model_output_type:trainingjob}, - input_size=input_size, - output_types=[model_output_type]) + sleap_models={model_output_type: trainingjob}, + input_size=input_size, + output_types=[model_output_type], + ) model_data = get_model_data( - sleap_models={model_output_type:trainingjob}, - output_types=[model_output_type]) + sleap_models={model_output_type: trainingjob}, + output_types=[model_output_type], + ) # Here we determine if the input should be scaled. If so, then # the output of the model will also be rescaled accordingly. @@ -153,8 +175,8 @@ def from_model(cls, filename, video, **kwargs): do_rescale = model_data["scale"] < 1 # Determine how the output from the model should be scaled - img_output_scale = 1.0 # image rescaling - obj_output_scale = 1.0 # scale to pass to overlay object + img_output_scale = 1.0 # image rescaling + obj_output_scale = 1.0 # scale to pass to overlay object if model_output_type == ModelOutputType.PART_AFFINITY_FIELD: obj_output_scale = model_data["multiscale"] @@ -163,7 +185,9 @@ def from_model(cls, filename, video, **kwargs): # Construct the ModelData object that runs inference - data_object = ModelData(model, video, do_rescale=do_rescale, output_scale=img_output_scale) + data_object = ModelData( + model, video, do_rescale=do_rescale, output_scale=img_output_scale + ) # Determine whether to use confmap or paf overlay @@ -184,10 +208,9 @@ def from_model(cls, filename, video, **kwargs): transform = DataTransform(scale=obj_output_scale) return cls( - data=data_object, - transform=transform, - overlay_class=overlay_class, - **kwargs) + data=data_object, transform=transform, overlay_class=overlay_class, **kwargs + ) + h5_colors = [ [204, 81, 81], @@ -239,5 +262,5 @@ def from_model(cls, filename, video, **kwargs): [81, 204, 181], [51, 127, 113], [81, 181, 204], - [51, 113, 127] -] \ No newline at end of file + [51, 113, 127], +] diff --git a/sleap/gui/overlays/confmaps.py b/sleap/gui/overlays/confmaps.py index 5629cbd40..49c06c149 100644 --- a/sleap/gui/overlays/confmaps.py +++ b/sleap/gui/overlays/confmaps.py @@ -17,15 +17,20 @@ from sleap.gui.video import QtVideoPlayer from sleap.gui.overlays.base import DataOverlay, h5_colors -class ConfmapOverlay(DataOverlay): +class ConfmapOverlay(DataOverlay): @classmethod def from_h5(cls, filename, input_format="channels_last", **kwargs): - return DataOverlay.from_h5(filename, "/confmaps", input_format, overlay_class=ConfMapsPlot, **kwargs) + return DataOverlay.from_h5( + filename, "/confmaps", input_format, overlay_class=ConfMapsPlot, **kwargs + ) @classmethod def from_model(cls, filename, video, **kwargs): - return DataOverlay.from_model(filename, video, overlay_class=ConfMapsPlot, **kwargs) + return DataOverlay.from_model( + filename, video, overlay_class=ConfMapsPlot, **kwargs + ) + class ConfMapsPlot(QtWidgets.QGraphicsObject): """QGraphicsObject to display multiple confidence maps in a QGraphicsView. @@ -42,7 +47,9 @@ class ConfMapsPlot(QtWidgets.QGraphicsObject): When initialized, creates one child ConfMapPlot item for each channel. """ - def __init__(self, frame: np.array = None, show=None, show_box=True, *args, **kwargs): + def __init__( + self, frame: np.array = None, show=None, show_box=True, *args, **kwargs + ): super(ConfMapsPlot, self).__init__(*args, **kwargs) self.frame = frame self.show_box = show_box @@ -50,16 +57,18 @@ def __init__(self, frame: np.array = None, show=None, show_box=True, *args, **kw self.rect = QtCore.QRectF(0, 0, self.frame.shape[1], self.frame.shape[0]) if self.show_box: - QtWidgets.QGraphicsRectItem(self.rect, parent=self).setPen(QtGui.QPen("yellow")) + QtWidgets.QGraphicsRectItem(self.rect, parent=self).setPen( + QtGui.QPen("yellow") + ) for channel in range(self.frame.shape[2]): if show is None or channel in show: color_map = h5_colors[channel % len(h5_colors)] # Add QGraphicsPixmapItem as child object - ConfMapPlot(confmap=self.frame[..., channel], - color=color_map, - parent=self) + ConfMapPlot( + confmap=self.frame[..., channel], color=color_map, parent=self + ) def boundingRect(self) -> QtCore.QRectF: """Method required by Qt. @@ -71,6 +80,7 @@ def paint(self, painter, option, widget=None): """ pass + class ConfMapPlot(QtWidgets.QGraphicsPixmapItem): """QGraphicsPixmapItem object for drawing single channel of confidence map. @@ -85,7 +95,9 @@ class ConfMapPlot(QtWidgets.QGraphicsPixmapItem): In most cases this should only be called by ConfMapsPlot. """ - def __init__(self, confmap: np.array = None, color=[255, 255, 255], *args, **kwargs): + def __init__( + self, confmap: np.array = None, color=[255, 255, 255], *args, **kwargs + ): super(ConfMapPlot, self).__init__(*args, **kwargs) self.color_map = color @@ -108,16 +120,16 @@ def get_conf_image(self) -> QtGui.QImage: frame = self.confmap # Colorize single-channel overlap - if np.ptp(frame) <= 1.: + if np.ptp(frame) <= 1.0: frame_a = (frame * 255).astype(np.uint8) frame_r = (frame * self.color_map[0]).astype(np.uint8) frame_g = (frame * self.color_map[1]).astype(np.uint8) frame_b = (frame * self.color_map[2]).astype(np.uint8) else: frame_a = (frame).astype(np.uint8) - frame_r = (frame * (self.color_map[0]/255.)).astype(np.uint8) - frame_g = (frame * (self.color_map[1]/255.)).astype(np.uint8) - frame_b = (frame * (self.color_map[2]/255.)).astype(np.uint8) + frame_r = (frame * (self.color_map[0] / 255.0)).astype(np.uint8) + frame_g = (frame * (self.color_map[1] / 255.0)).astype(np.uint8) + frame_b = (frame * (self.color_map[2] / 255.0)).astype(np.uint8) frame_composite = np.dstack((frame_r, frame_g, frame_b, frame_a)) @@ -126,20 +138,25 @@ def get_conf_image(self) -> QtGui.QImage: return image + def show_confmaps_from_h5(filename, input_format="channels_last", standalone=False): video = HDF5Video(filename, "/box", input_format=input_format) - conf_data = HDF5Video(filename, "/confmaps", input_format=input_format, convert_range=False) + conf_data = HDF5Video( + filename, "/confmaps", input_format=input_format, convert_range=False + ) - confmaps_ = [np.clip(conf_data.get_frame(i),0,1) for i in range(conf_data.frames)] + confmaps_ = [np.clip(conf_data.get_frame(i), 0, 1) for i in range(conf_data.frames)] confmaps = np.stack(confmaps_) return demo_confmaps(confmaps=confmaps, video=video, standalone=standalone) + def demo_confmaps(confmaps, video, standalone=False, callback=None): from PySide2 import QtWidgets from sleap.gui.video import QtVideoPlayer - if standalone: app = QtWidgets.QApplication([]) + if standalone: + app = QtWidgets.QApplication([]) win = QtVideoPlayer(video=video) win.setWindowTitle("confmaps") @@ -147,21 +164,24 @@ def demo_confmaps(confmaps, video, standalone=False, callback=None): def plot_confmaps(parent, item_idx): if parent.frame_idx < confmaps.shape[0]: - frame_conf_map = ConfMapsPlot(confmaps[parent.frame_idx,...]) + frame_conf_map = ConfMapsPlot(confmaps[parent.frame_idx, ...]) win.view.scene.addItem(frame_conf_map) win.changedPlot.connect(plot_confmaps) - if callback: win.changedPlot.connect(callback) + if callback: + win.changedPlot.connect(callback) win.plot() - if standalone: app.exec_() + if standalone: + app.exec_() return win + if __name__ == "__main__": data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" show_confmaps_from_h5(data_path, input_format="channels_first", standalone=True) # data_path = "/Users/tabris/Documents/predictions.h5" -# show_confmaps_from_h5(data_path, input_format="channels_last", standalone=True) \ No newline at end of file +# show_confmaps_from_h5(data_path, input_format="channels_last", standalone=True) diff --git a/sleap/gui/overlays/instance.py b/sleap/gui/overlays/instance.py index c99322a3f..14dde6bdd 100644 --- a/sleap/gui/overlays/instance.py +++ b/sleap/gui/overlays/instance.py @@ -6,16 +6,18 @@ from sleap.io.dataset import Labels from sleap.gui.overlays.tracks import TrackColorManager + @attr.s(auto_attribs=True) class InstanceOverlay: - labels: Labels=None - player: QtVideoPlayer=None - color_manager: TrackColorManager=TrackColorManager(labels) - color_predicted: bool=False + labels: Labels = None + player: QtVideoPlayer = None + color_manager: TrackColorManager = TrackColorManager(labels) + color_predicted: bool = False def add_to_scene(self, video, frame_idx): - if self.labels is None: return + if self.labels is None: + return lf = self.labels.find(video, frame_idx, return_new=True)[0] @@ -29,9 +31,11 @@ def add_to_scene(self, video, frame_idx): pseudo_track = len(self.labels.tracks) + count_no_track count_no_track += 1 - is_predicted = hasattr(instance,"score") + is_predicted = hasattr(instance, "score") - self.player.addInstance(instance=instance, - color=self.color_manager.get_color(pseudo_track), - predicted=is_predicted, - color_predicted=self.color_predicted) \ No newline at end of file + self.player.addInstance( + instance=instance, + color=self.color_manager.get_color(pseudo_track), + predicted=is_predicted, + color_predicted=self.color_predicted, + ) diff --git a/sleap/gui/overlays/pafs.py b/sleap/gui/overlays/pafs.py index a145483e6..4ff930886 100644 --- a/sleap/gui/overlays/pafs.py +++ b/sleap/gui/overlays/pafs.py @@ -9,11 +9,14 @@ from sleap.gui.overlays.base import DataOverlay, h5_colors -class PafOverlay(DataOverlay): +class PafOverlay(DataOverlay): @classmethod def from_h5(cls, filename, input_format="channels_last", **kwargs): - return DataOverlay.from_h5(filename, "/pafs", input_format, overlay_class=MultiQuiverPlot, **kwargs) + return DataOverlay.from_h5( + filename, "/pafs", input_format, overlay_class=MultiQuiverPlot, **kwargs + ) + class MultiQuiverPlot(QtWidgets.QGraphicsObject): """QtWidgets.QGraphicsObject to display multiple quiver plots in a QtWidgets.QGraphicsView. @@ -33,12 +36,15 @@ class MultiQuiverPlot(QtWidgets.QGraphicsObject): When initialized, creates one child QuiverPlot item for each channel. """ - def __init__(self, - frame: np.array = None, - show: list = None, - decimation: int = 5, - scale: float = 1.0, - *args, **kwargs): + def __init__( + self, + frame: np.array = None, + show: list = None, + decimation: int = 5, + scale: float = 1.0, + *args, + **kwargs, + ): super(MultiQuiverPlot, self).__init__(*args, **kwargs) self.frame = frame self.affinity_field = [] @@ -47,23 +53,23 @@ def __init__(self, # if data range is outside [-1, 1], assume it's [-255, 255] and scale if np.ptp(self.frame) > 4: - self.frame = self.frame.astype(np.float64)/255 + self.frame = self.frame.astype(np.float64) / 255 if show is None: - self.show_list = range(self.frame.shape[2]//2) + self.show_list = range(self.frame.shape[2] // 2) else: self.show_list = show for channel in self.show_list: - if channel < self.frame.shape[-1]//2: + if channel < self.frame.shape[-1] // 2: color_map = h5_colors[channel % len(h5_colors)] aff_field_item = QuiverPlot( - field_x=self.frame[..., channel*2], - field_y=self.frame[..., channel*2+1], + field_x=self.frame[..., channel * 2], + field_y=self.frame[..., channel * 2 + 1], color=color_map, decimation=self.decimation, scale=self.scale, - parent=self - ) + parent=self, + ) self.affinity_field.append(aff_field_item) def boundingRect(self) -> QtCore.QRectF: @@ -76,6 +82,7 @@ def paint(self, painter, option, widget=None): """ pass + class QuiverPlot(QtWidgets.QGraphicsObject): """QtWidgets.QGraphicsObject for drawing single quiver plot. @@ -89,20 +96,23 @@ class QuiverPlot(QtWidgets.QGraphicsObject): None. """ - def __init__(self, - field_x: np.array = None, - field_y: np.array = None, - color=[255, 255, 255], - decimation=1, - scale=1, - *args, **kwargs): + def __init__( + self, + field_x: np.array = None, + field_y: np.array = None, + color=[255, 255, 255], + decimation=1, + scale=1, + *args, + **kwargs, + ): super(QuiverPlot, self).__init__(*args, **kwargs) self.field_x, self.field_y = None, None self.color = color self.decimation = decimation self.scale = scale - pen_width = min(4, max(.1, math.log(self.decimation, 20))) + pen_width = min(4, max(0.1, math.log(self.decimation, 20))) self.pen = QtGui.QPen(QtGui.QColor(*self.color), pen_width) self.points = [] self.rect = QtCore.QRectF() @@ -111,7 +121,7 @@ def __init__(self, self.field_x, self.field_y = field_x, field_y h, w = self.field_x.shape - h, w = int(h/self.scale), int(w/self.scale) + h, w = int(h / self.scale), int(w / self.scale) self.rect = QtCore.QRectF(0, 0, w, h) @@ -121,63 +131,66 @@ def _add_arrows(self, min_length=0.01): points = [] if self.field_x is not None and self.field_y is not None: - raw_delta_yx = np.stack((self.field_y,self.field_x),axis=-1) + raw_delta_yx = np.stack((self.field_y, self.field_x), axis=-1) - dim_0 = self.field_x.shape[0]//self.decimation*self.decimation - dim_1 = self.field_x.shape[1]//self.decimation*self.decimation + dim_0 = self.field_x.shape[0] // self.decimation * self.decimation + dim_1 = self.field_x.shape[1] // self.decimation * self.decimation - grid = np.mgrid[0:dim_0:self.decimation, 0:dim_1:self.decimation] - loc_yx = np.moveaxis(grid,0,-1) + grid = np.mgrid[0 : dim_0 : self.decimation, 0 : dim_1 : self.decimation] + loc_yx = np.moveaxis(grid, 0, -1) # Adjust by scaling factor - loc_yx = loc_yx * (1/self.scale) + loc_yx = loc_yx * (1 / self.scale) if self.decimation > 1: delta_yx = self._decimate(raw_delta_yx, self.decimation) # Shift locations to midpoint of decimation square - loc_yx += self.decimation//2 + loc_yx += self.decimation // 2 else: delta_yx = raw_delta_yx # Split into x,y matrices - loc_y, loc_x = loc_yx[...,0], loc_yx[...,1] - delta_y, delta_x = delta_yx[...,0], delta_yx[...,1] + loc_y, loc_x = loc_yx[..., 0], loc_yx[..., 1] + delta_y, delta_x = delta_yx[..., 0], delta_yx[..., 1] # Determine vector endpoint - x2 = delta_x*self.decimation + loc_x - y2 = delta_y*self.decimation + loc_y - line_length = (delta_x**2 + delta_y**2)**.5 + x2 = delta_x * self.decimation + loc_x + y2 = delta_y * self.decimation + loc_y + line_length = (delta_x ** 2 + delta_y ** 2) ** 0.5 # Determine points for arrow arrow_head_size = line_length / 4 - u_dx = np.divide(delta_x, line_length, out=np.zeros_like(delta_x), where=line_length!=0) - u_dy = np.divide(delta_y, line_length, out=np.zeros_like(delta_y), where=line_length!=0) - p1_x = x2 - u_dx*arrow_head_size - u_dy*arrow_head_size - p1_y = y2 - u_dy*arrow_head_size + u_dx*arrow_head_size + u_dx = np.divide( + delta_x, line_length, out=np.zeros_like(delta_x), where=line_length != 0 + ) + u_dy = np.divide( + delta_y, line_length, out=np.zeros_like(delta_y), where=line_length != 0 + ) + p1_x = x2 - u_dx * arrow_head_size - u_dy * arrow_head_size + p1_y = y2 - u_dy * arrow_head_size + u_dx * arrow_head_size - p2_x = x2 - u_dx*arrow_head_size + u_dy*arrow_head_size - p2_y = y2 - u_dy*arrow_head_size - u_dx*arrow_head_size + p2_x = x2 - u_dx * arrow_head_size + u_dy * arrow_head_size + p2_y = y2 - u_dy * arrow_head_size - u_dx * arrow_head_size # Build list of QPointF objects for faster drawing y_x_pairs = itertools.product( - range(delta_yx.shape[0]), - range(delta_yx.shape[1]) - ) + range(delta_yx.shape[0]), range(delta_yx.shape[1]) + ) for y, x in y_x_pairs: - x1, y1 = loc_x[y,x], loc_y[y,x] + x1, y1 = loc_x[y, x], loc_y[y, x] - if line_length[y,x] > min_length: + if line_length[y, x] > min_length: points.append((x1, y1)) - points.append((x2[y,x],y2[y,x])) - points.append((p1_x[y,x],p1_y[y,x])) - points.append((x2[y,x],y2[y,x])) - points.append((p2_x[y,x],p2_y[y,x])) - points.append((x2[y,x],y2[y,x])) - self.points = list(itertools.starmap(QtCore.QPointF,points)) - - def _decimate(self, image:np.array, box:int): + points.append((x2[y, x], y2[y, x])) + points.append((p1_x[y, x], p1_y[y, x])) + points.append((x2[y, x], y2[y, x])) + points.append((p2_x[y, x], p2_y[y, x])) + points.append((x2[y, x], y2[y, x])) + self.points = list(itertools.starmap(QtCore.QPointF, points)) + + def _decimate(self, image: np.array, box: int): height = width = box # Source: https://stackoverflow.com/questions/48482317/slice-an-image-into-tiles-using-numpy _nrows, _ncols, depth = image.shape @@ -188,21 +201,21 @@ def _decimate(self, image:np.array, box:int): ncols, _n = divmod(_ncols, width) if _m != 0 or _n != 0: # if we can't tile whole image, forget about bottom/right edges - image = image[:(nrows+1)*box,:(ncols+1)*box] + image = image[: (nrows + 1) * box, : (ncols + 1) * box] - tiles = np.lib.stride_tricks.as_strided( + tiles = np.lib.stride_tricks.as_strided( np.ravel(image), shape=(nrows, ncols, height, width, depth), strides=(height * _strides[0], width * _strides[1], *_strides), - writeable=False + writeable=False, ) # Since strides accesses the ndarray by memory, we need to swap axes if # the array is stored column-major (Fortran), which it is from h5py. if _strides[0] < _strides[1]: - tiles = np.swapaxes(tiles,0,1) + tiles = np.swapaxes(tiles, 0, 1) - return np.mean(tiles, axis=(2,3)) + return np.mean(tiles, axis=(2, 3)) def boundingRect(self) -> QtCore.QRectF: """Method called by Qt in order to determine whether object is in visible frame.""" @@ -215,19 +228,24 @@ def paint(self, painter, option, widget=None): painter.drawLines(self.points) pass + def show_pafs_from_h5(filename, input_format="channels_last", standalone=False): video = HDF5Video(filename, "/box", input_format=input_format) - paf_data = HDF5Video(filename, "/pafs", input_format=input_format, convert_range=False) + paf_data = HDF5Video( + filename, "/pafs", input_format=input_format, convert_range=False + ) pafs_ = [paf_data.get_frame(i) for i in range(paf_data.frames)] pafs = np.stack(pafs_) return demo_pafs(pafs, video, standalone=standalone) + def demo_pafs(pafs, video, decimation=4, standalone=False): from sleap.gui.video import QtVideoPlayer - if standalone: app = QtWidgets.QApplication([]) + if standalone: + app = QtWidgets.QApplication([]) win = QtVideoPlayer(video=video) win.setWindowTitle("pafs") @@ -246,47 +264,54 @@ def plot_fields(parent, i): if parent.frame_idx < pafs.shape[0]: frame_pafs = pafs[parent.frame_idx, ...] decimation = decimation_size_bar.value() - aff_fields_item = MultiQuiverPlot(frame_pafs, show=None, decimation=decimation) + aff_fields_item = MultiQuiverPlot( + frame_pafs, show=None, decimation=decimation + ) win.view.scene.addItem(aff_fields_item) win.changedPlot.connect(plot_fields) win.plot() - if standalone: app.exec_() + if standalone: + app.exec_() return win + if __name__ == "__main__": from video import * - #data_path = "training.scale=1.00,sigma=5.h5" + # data_path = "training.scale=1.00,sigma=5.h5" data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" - input_format="channels_first" + input_format = "channels_first" data_path = "/Volumes/fileset-mmurthy/nat/nyu-mouse/predict.h5" input_format = "channels_last" show_pafs_from_h5(data_path, input_format=input_format, standalone=True) + def foo(): vid = HDF5Video(data_path, "/box", input_format=input_format) - overlay_data = HDF5Video(data_path, "/pafs", input_format=input_format, convert_range=False) - print(f"{overlay_data.frames}, {overlay_data.height}, {overlay_data.width}, {overlay_data.channels}") + overlay_data = HDF5Video( + data_path, "/pafs", input_format=input_format, convert_range=False + ) + print( + f"{overlay_data.frames}, {overlay_data.height}, {overlay_data.width}, {overlay_data.channels}" + ) app = QtWidgets.QApplication([]) window = QtVideoPlayer(video=vid) - field_count = overlay_data.get_frame(1).shape[-1]//2 - 1 + field_count = overlay_data.get_frame(1).shape[-1] // 2 - 1 # show the first, middle, and last fields - show_fields = [0, field_count//2, field_count] + show_fields = [0, field_count // 2, field_count] field_check_groupbox = MultiCheckWidget( - count=field_count, - selected=show_fields, - title="Affinity Field Channel" - ) + count=field_count, selected=show_fields, title="Affinity Field Channel" + ) field_check_groupbox.selectionChanged.connect(window.plot) window.layout.addWidget(field_check_groupbox) @@ -301,7 +326,7 @@ def foo(): decimation_size_bar.setEnabled(True) window.layout.addWidget(decimation_size_bar) - def plot_fields(parent,i): + def plot_fields(parent, i): # build list of checked boxes to determine which affinity fields to show selected = field_check_groupbox.getSelected() # get decimation size from slider @@ -317,4 +342,4 @@ def plot_fields(parent,i): window.show() window.plot() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index b698e3d0d..ba2e537ea 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -9,6 +9,7 @@ from PySide2 import QtCore, QtWidgets, QtGui + class TrackColorManager: """Class to determine color to use for track. The color depends on the order of the tracks in `Labels` object, so we need to initialize with `Labels`. @@ -17,7 +18,7 @@ class TrackColorManager: labels: `Labels` object which contains the tracks for which we want colors """ - def __init__(self, labels: Labels=None, palette="standard"): + def __init__(self, labels: Labels = None, palette="standard"): self.labels = labels # alphabet @@ -28,81 +29,78 @@ def __init__(self, labels: Labels=None, palette="standard"): # http://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12 self._palettes = { - "standard" : [ - [0, 114, 189], - [217, 83, 25], - [237, 177, 32], - [126, 47, 142], - [119, 172, 48], - [77, 190, 238], - [162, 20, 47], - ], - "five+" : [ - [228,26,28], - [55,126,184], - [77,175,74], - [152,78,163], - [255,127,0], - ], - "solarized" : [ - [181, 137, 0], - [203, 75, 22], - [220, 50, 47], - [211, 54, 130], - [108, 113, 196], - [38, 139, 210], - [42, 161, 152], - [133, 153, 0], - ], - "alphabet" : [ - [240,163,255], - [0,117,220], - [153,63,0], - [76,0,92], - [25,25,25], - [0,92,49], - [43,206,72], - [255,204,153], - [128,128,128], - [148,255,181], - [143,124,0], - [157,204,0], - [194,0,136], - [0,51,128], - [255,164,5], - [255,168,187], - [66,102,0], - [255,0,16], - [94,241,242], - [0,153,143], - [224,255,102], - [116,10,255], - [153,0,0], - [255,255,128], - [255,255,0], - [255,80,5], - ], - "twelve" : [ - [31,120,180], - [51,160,44], - [227,26,28], - [255,127,0], - [106,61,154], - [177,89,40], - [166,206,227], - [178,223,138], - [251,154,153], - [253,191,111], - [202,178,214], - [255,255,153], - ] + "standard": [ + [0, 114, 189], + [217, 83, 25], + [237, 177, 32], + [126, 47, 142], + [119, 172, 48], + [77, 190, 238], + [162, 20, 47], + ], + "five+": [ + [228, 26, 28], + [55, 126, 184], + [77, 175, 74], + [152, 78, 163], + [255, 127, 0], + ], + "solarized": [ + [181, 137, 0], + [203, 75, 22], + [220, 50, 47], + [211, 54, 130], + [108, 113, 196], + [38, 139, 210], + [42, 161, 152], + [133, 153, 0], + ], + "alphabet": [ + [240, 163, 255], + [0, 117, 220], + [153, 63, 0], + [76, 0, 92], + [25, 25, 25], + [0, 92, 49], + [43, 206, 72], + [255, 204, 153], + [128, 128, 128], + [148, 255, 181], + [143, 124, 0], + [157, 204, 0], + [194, 0, 136], + [0, 51, 128], + [255, 164, 5], + [255, 168, 187], + [66, 102, 0], + [255, 0, 16], + [94, 241, 242], + [0, 153, 143], + [224, 255, 102], + [116, 10, 255], + [153, 0, 0], + [255, 255, 128], + [255, 255, 0], + [255, 80, 5], + ], + "twelve": [ + [31, 120, 180], + [51, 160, 44], + [227, 26, 28], + [255, 127, 0], + [106, 61, 154], + [177, 89, 40], + [166, 206, 227], + [178, 223, 138], + [251, 154, 153], + [253, 191, 111], + [202, 178, 214], + [255, 255, 153], + ], } self.mode = "cycle" - self._modes = dict( - cycle=lambda i, c: i%c, - clip=lambda i, c: min(i,c-1), - ) + self._modes = dict(cycle=lambda i, c: i % c, clip=lambda i, c: min(i, c - 1)) self.set_palette(palette) @@ -137,11 +135,14 @@ def get_color(self, track: Union[Track, int]): Returns: (r, g, b)-tuple """ - track_idx = self.labels.tracks.index(track) if isinstance(track, Track) else track + track_idx = ( + self.labels.tracks.index(track) if isinstance(track, Track) else track + ) color_idx = self._modes[self.mode](track_idx, len(self._color_map)) color = self._color_map[color_idx] return color + @attr.s(auto_attribs=True) class TrackTrailOverlay: """Class to show track trails. You initialize this object with both its data source @@ -158,11 +159,11 @@ class TrackTrailOverlay: to plot the trails in scene. """ - labels: Labels=None - scene: QtWidgets.QGraphicsScene=None - color_manager: TrackColorManager=TrackColorManager(labels) - trail_length: int=4 - show: bool=False + labels: Labels = None + scene: QtWidgets.QGraphicsScene = None + color_manager: TrackColorManager = TrackColorManager(labels) + trail_length: int = 4 + show: bool = False def get_track_trails(self, frame_selection, track: Track): """Get data needed to draw track trail. @@ -203,15 +204,17 @@ def get_track_trails(self, frame_selection, track: Track): def get_frame_selection(self, video: Video, frame_idx: int): """Return list of `LabeledFrame`s to include in trail for specified frame.""" - frame_selection = self.labels.find(video, range(0, frame_idx+1)) + frame_selection = self.labels.find(video, range(0, frame_idx + 1)) frame_selection.sort(key=lambda x: x.frame_idx) - return frame_selection[-self.trail_length:] + return frame_selection[-self.trail_length :] def get_tracks_in_frame(self, video: Video, frame_idx: int): """Return list of tracks that have instance in specified frame.""" - tracks_in_frame = [inst.track for lf in self.labels.find(video, frame_idx) for inst in lf] + tracks_in_frame = [ + inst.track for lf in self.labels.find(video, frame_idx) for inst in lf + ] return tracks_in_frame def add_to_scene(self, video: Video, frame_idx: int): @@ -222,7 +225,8 @@ def add_to_scene(self, video: Video, frame_idx: int): frame_idx: index of the frame to which the trail is attached """ - if not self.show: return + if not self.show: + return frame_selection = self.get_frame_selection(video, frame_idx) tracks_in_frame = self.get_tracks_in_frame(video, frame_idx) @@ -236,14 +240,14 @@ def add_to_scene(self, video: Video, frame_idx: int): pen.setCosmetic(True) for trail in trails: - half = len(trail)//2 + half = len(trail) // 2 color.setAlphaF(1) pen.setColor(color) polygon = self.map_to_qt_polygon(trail[:half]) self.scene.addPolygon(polygon, pen) - color.setAlphaF(.5) + color.setAlphaF(0.5) pen.setColor(color) polygon = self.map_to_qt_polygon(trail[half:]) self.scene.addPolygon(polygon, pen) @@ -259,9 +263,9 @@ class TrackListOverlay: """Class to show track number and names in overlay. """ - labels: Labels=None - view: QtWidgets.QGraphicsView=None - color_manager: TrackColorManager=TrackColorManager(labels) + labels: Labels = None + view: QtWidgets.QGraphicsView = None + color_manager: TrackColorManager = TrackColorManager(labels) text_box = None def add_to_scene(self, video: Video, frame_idx: int): @@ -271,9 +275,10 @@ def add_to_scene(self, video: Video, frame_idx: int): num_to_show = min(9, len(self.labels.tracks)) for i, track in enumerate(self.labels.tracks[:num_to_show]): - idx = i+1 + idx = i + 1 - if html: html += "
" + if html: + html += "
" color = self.color_manager.get_color(track) html_color = f"#{color[0]:02X}{color[1]:02X}{color[2]:02X}" track_text = f"{track.name}" @@ -284,7 +289,7 @@ def add_to_scene(self, video: Video, frame_idx: int): text_box = QtTextWithBackground() text_box.setDefaultTextColor(QtGui.QColor("white")) text_box.setHtml(html) - text_box.setOpacity(.7) + text_box.setOpacity(0.7) self.text_box = text_box self.visible = False @@ -293,16 +298,18 @@ def add_to_scene(self, video: Video, frame_idx: int): @property def visible(self): - if self.text_box is None: return False + if self.text_box is None: + return False return self.text_box.isVisible() @visible.setter def visible(self, val): - if self.text_box is None: return + if self.text_box is None: + return if val: pos = self.view.mapToScene(10, 10) if pos.x() > 0: self.text_box.setPos(pos) else: self.text_box.setPos(10, 10) - self.text_box.setVisible(val) \ No newline at end of file + self.text_box.setVisible(val) diff --git a/sleap/gui/shortcuts.py b/sleap/gui/shortcuts.py index 0bf7b797d..3b155e732 100644 --- a/sleap/gui/shortcuts.py +++ b/sleap/gui/shortcuts.py @@ -7,6 +7,7 @@ from pkg_resources import Requirement, resource_filename + class ShortcutDialog(QtWidgets.QDialog): _column_len = 13 @@ -29,16 +30,22 @@ def load_shortcuts(self): self.shortcuts = Shortcuts() def make_form(self): - self.key_widgets = dict() # dict to store QKeySequenceEdit widgets + self.key_widgets = dict() # dict to store QKeySequenceEdit widgets layout = QtWidgets.QVBoxLayout() layout.addWidget(self.make_shortcuts_widget()) - layout.addWidget(QtWidgets.QLabel("Any changes to keyboard shortcuts will not take effect until you quit and re-open the application.")) + layout.addWidget( + QtWidgets.QLabel( + "Any changes to keyboard shortcuts will not take effect until you quit and re-open the application." + ) + ) layout.addWidget(self.make_buttons_widget()) self.setLayout(layout) - + def make_buttons_widget(self): - buttons = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel) + buttons = QtWidgets.QDialogButtonBox( + QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel + ) buttons.accepted.connect(self.accept) buttons.rejected.connect(self.reject) return buttons @@ -71,46 +78,74 @@ def make_column_widget(self, shortcuts): def dict_cut(d, a, b): return dict(list(d.items())[a:b]) + class Shortcuts: _shortcuts = None - _names = ("new", "open", "save", "save as", "close", - "add videos", "next video", "prev video", - "goto frame", "mark frame", "goto marked", - "add instance", "delete instance", "delete track", - "transpose", "select next", "clear selection", - "goto next labeled", "goto prev labeled", "goto next user", - "goto next suggestion", "goto prev suggestion", - "goto next track spawn", - "show labels", "show edges", "show trails", - "color predicted", "fit", "learning", - "export clip", "delete clip", "delete area") + _names = ( + "new", + "open", + "save", + "save as", + "close", + "add videos", + "next video", + "prev video", + "goto frame", + "mark frame", + "goto marked", + "add instance", + "delete instance", + "delete track", + "transpose", + "select next", + "clear selection", + "goto next labeled", + "goto prev labeled", + "goto next user", + "goto next suggestion", + "goto prev suggestion", + "goto next track spawn", + "show labels", + "show edges", + "show trails", + "color predicted", + "fit", + "learning", + "export clip", + "delete clip", + "delete area", + ) def __init__(self): - shortcut_yaml = resource_filename(Requirement.parse("sleap"), "sleap/config/shortcuts.yaml") - with open(shortcut_yaml, 'r') as f: + shortcut_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/shortcuts.yaml" + ) + with open(shortcut_yaml, "r") as f: shortcuts = yaml.load(f, Loader=yaml.SafeLoader) - + for action in shortcuts: key_string = shortcuts.get(action, None) key_string = "" if key_string is None else key_string - + try: shortcuts[action] = eval(key_string) except: shortcuts[action] = QKeySequence.fromString(key_string) - + self._shortcuts = shortcuts def save(self): - shortcut_yaml = resource_filename(Requirement.parse("sleap"), "sleap/config/shortcuts.yaml") - with open(shortcut_yaml, 'w') as f: + shortcut_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/shortcuts.yaml" + ) + with open(shortcut_yaml, "w") as f: yaml.dump(self._shortcuts, f) def __getitem__(self, idx): if isinstance(idx, slice): # dict with names and values - return {self._names[i]:self[i] for i in range(*idx.indices(len(self)))} + return {self._names[i]: self[i] for i in range(*idx.indices(len(self)))} elif isinstance(idx, int): # value idx = self._names[idx] @@ -131,8 +166,9 @@ def __setitem__(self, idx, val): def __len__(self): return len(self._names) + if __name__ == "__main__": app = QtWidgets.QApplication() win = ShortcutDialog() win.show() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 46765c50e..407755fc2 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -5,7 +5,15 @@ from PySide2.QtWidgets import QApplication, QWidget, QLayout, QAbstractSlider from PySide2.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem from PySide2.QtWidgets import QSizePolicy, QLabel, QGraphicsRectItem -from PySide2.QtGui import QPainter, QPen, QBrush, QColor, QKeyEvent, QPolygonF, QPainterPath +from PySide2.QtGui import ( + QPainter, + QPen, + QBrush, + QColor, + QKeyEvent, + QPolygonF, + QPainterPath, +) from PySide2.QtCore import Qt, Signal, QRect, QRectF, QPointF from sleap.gui.overlays.tracks import TrackColorManager @@ -15,21 +23,19 @@ import numpy as np from typing import Dict, Optional, Union + @attr.s(auto_attribs=True, cmp=False) class SliderMark: type: str val: float - end_val: float=None - row: int=None - track: 'Track'=None - _color: Union[tuple,str]="black" + end_val: float = None + row: int = None + track: "Track" = None + _color: Union[tuple, str] = "black" @property def color(self): - colors = dict(simple="black", - filled="blue", - open="blue", - predicted="red") + colors = dict(simple="black", filled="blue", open="blue", predicted="red") if self.type in colors: return colors[self.type] @@ -55,6 +61,7 @@ def filled(self): else: return True + class VideoSlider(QGraphicsView): """Drop-in replacement for QSlider with additional features. @@ -78,10 +85,18 @@ class VideoSlider(QGraphicsView): selectionChanged = Signal(int, int) updatedTracks = Signal() - def __init__(self, orientation=-1, min=0, max=100, val=0, - marks=None, tracks=0, - color_manager=None, - *args, **kwargs): + def __init__( + self, + orientation=-1, + min=0, + max=100, + val=0, + marks=None, + tracks=0, + color_manager=None, + *args, + **kwargs + ): super(VideoSlider, self).__init__(*args, **kwargs) self.scene = QGraphicsScene() @@ -90,7 +105,7 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # ScrollBarAsNeeded + self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # ScrollBarAsNeeded self._color_manager = color_manager @@ -100,7 +115,7 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self._min_height = 19 + self._header_height # Add border rect - outline_rect = QRect(0, 0, 200, self._min_height-3) + outline_rect = QRect(0, 0, 200, self._min_height - 3) self.outlineBox = self.scene.addRect(outline_rect) self.outlineBox.setPen(QPen(QColor("black"))) @@ -114,7 +129,7 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self.handle.setBrush(QColor(128, 128, 128, 128)) # Add (hidden) rect to highlight selection - self.select_box = self.scene.addRect(QRect(0, 1, 0, outline_rect.height()-2)) + self.select_box = self.scene.addRect(QRect(0, 1, 0, outline_rect.height() - 2)) self.select_box.setPen(QPen(QColor(80, 80, 255))) self.select_box.setBrush(QColor(80, 80, 255, 128)) self.select_box.hide() @@ -128,7 +143,7 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self.setValue(val) self.setMarks(marks) - pen = QPen(QColor(80, 80, 255), .5) + pen = QPen(QColor(80, 80, 255), 0.5) pen.setCosmetic(True) self.poly = self.scene.addPath(QPainterPath(), pen, self.select_box.brush()) self.headerSeries = dict() @@ -162,7 +177,15 @@ def setTracksFromLabels(self, labels, video): for track in labels.tracks: if track in track_occupancy and not track_occupancy[track].is_empty: for occupancy_range in track_occupancy[track].list: - slider_marks.append(SliderMark("track", val=occupancy_range[0], end_val=occupancy_range[1], row=track_row, color=self._color_manager.get_color(track))) + slider_marks.append( + SliderMark( + "track", + val=occupancy_range[0], + end_val=occupancy_range[1], + row=track_row, + color=self._color_manager.get_color(track), + ) + ) track_row += 1 # Add marks without track @@ -184,10 +207,10 @@ def setTracksFromLabels(self, labels, video): mark_type = "open" slider_marks.append(SliderMark(mark_type, val=frame_idx)) - self.setTracks(track_row) # total number of tracks to show + self.setTracks(track_row) # total number of tracks to show self.setMarks(slider_marks) - def setHeaderSeries(self, series:Optional[Dict[int,float]] = None): + def setHeaderSeries(self, series: Optional[Dict[int, float]] = None): """Show header graph with specified series. Args: @@ -250,23 +273,23 @@ def _toPos(self, val, center=False): """Convert value to x position on slider.""" x = val x -= self._val_min - x /= max(1, self._val_max-self._val_min) + x /= max(1, self._val_max - self._val_min) x *= self._sliderWidth() if center: - x += self.handle.rect().width()/2. + x += self.handle.rect().width() / 2.0 return x def _toVal(self, x, center=False): """Convert x position to value.""" val = x val /= self._sliderWidth() - val *= max(1, self._val_max-self._val_min) + val *= max(1, self._val_max - self._val_min) val += self._val_min val = round(val) return val def _sliderWidth(self): - return self.outlineBox.rect().width()-self.handle.rect().width() + return self.outlineBox.rect().width() - self.handle.rect().width() def value(self): """Get value of slider.""" @@ -314,7 +337,7 @@ def endSelection(self, val, update=False): val: value of endpoint """ # If we want to update endpoint and there's already one, remove it - if update and len(self._selection)%2==0: + if update and len(self._selection) % 2 == 0: self._selection.pop() # Add the selection endpoint self._selection.append(val) @@ -334,7 +357,7 @@ def hasSelection(self) -> bool: def getSelection(self): """Return start and end value of current selection endpoints.""" a, b = 0, 0 - if len(self._selection)%2 == 0 and len(self._selection) > 0: + if len(self._selection) % 2 == 0 and len(self._selection) > 0: a, b = self._selection[-2:] start = min(a, b) end = max(a, b) @@ -351,8 +374,9 @@ def drawSelection(self, a, b): end = max(a, b) start_pos = self._toPos(start, center=True) end_pos = self._toPos(end, center=True) - selection_rect = QRect(start_pos, 1, - end_pos-start_pos, self.outlineBox.rect().height()-2) + selection_rect = QRect( + start_pos, 1, end_pos - start_pos, self.outlineBox.rect().height() - 2 + ) self.select_box.setRect(selection_rect) self.select_box.show() @@ -368,7 +392,7 @@ def moveSelectionAnchor(self, x, y): x = min(x, self.outlineBox.rect().width()) anchor_val = self._toVal(x, center=True) - if len(self._selection)%2 == 0: + if len(self._selection) % 2 == 0: self.startSelection(anchor_val) self.drawSelection(anchor_val, self._selection[-1]) @@ -390,8 +414,8 @@ def clearMarks(self): if hasattr(self, "_mark_items"): for item in self._mark_items.values(): self.scene.removeItem(item) - self._marks = set() # holds mark position - self._mark_items = dict() # holds visual Qt object for plotting mark + self._marks = set() # holds mark position + self._mark_items = dict() # holds visual Qt object for plotting mark def setMarks(self, marks): """Set all marked values for the slider. @@ -421,8 +445,10 @@ def addMark(self, new_mark, update=True): new_mark: value to mark """ # check if mark is within slider range - if new_mark.val > self._val_max: return - if new_mark.val < self._val_min: return + if new_mark.val > self._val_max: + return + if new_mark.val < self._val_min: + return self._marks.add(new_mark) @@ -436,19 +462,19 @@ def addMark(self, new_mark, update=True): height = 1 else: v_offset = v_top_pad - height = self.outlineBox.rect().height()-(v_offset+v_bottom_pad) + height = self.outlineBox.rect().height() - (v_offset + v_bottom_pad) width = 2 if new_mark.type in ("open", "filled") else 0 color = new_mark.QColor - pen = QPen(color, .5) + pen = QPen(color, 0.5) pen.setCosmetic(True) brush = QBrush(color) if new_mark.filled else QBrush() - line = self.scene.addRect(-width//2, v_offset, width, height, - pen, brush) + line = self.scene.addRect(-width // 2, v_offset, width, height, pen, brush) self._mark_items[new_mark] = line - if update: self.updatePos() + if update: + self.updatePos() def _mark_val(self, mark): return mark.val @@ -482,36 +508,38 @@ def drawHeader(self): self.poly.setPath(QPainterPath()) return - step = max(self.headerSeries.keys())//int(self._sliderWidth()) + step = max(self.headerSeries.keys()) // int(self._sliderWidth()) step = max(step, 1) - count = max(self.headerSeries.keys())//step*step + count = max(self.headerSeries.keys()) // step * step sampled = np.full((count), 0.0) for key, val in self.headerSeries.items(): if key < count: sampled[key] = val - sampled = np.max(sampled.reshape(count//step,step), axis=1) - series = {i*step:sampled[i] for i in range(count//step)} + sampled = np.max(sampled.reshape(count // step, step), axis=1) + series = {i * step: sampled[i] for i in range(count // step)} series_min = np.min(sampled) - 1 series_max = np.max(sampled) - series_scale = (self._header_height-5)/(series_max - series_min) + series_scale = (self._header_height - 5) / (series_max - series_min) def toYPos(val): - return self._header_height-((val-series_min)*series_scale) + return self._header_height - ((val - series_min) * series_scale) - step_chart = False # use steps rather than smooth line + step_chart = False # use steps rather than smooth line points = [] points.append((self._toPos(0, center=True), toYPos(series_min))) for idx, val in series.items(): points.append((self._toPos(idx, center=True), toYPos(val))) if step_chart: - points.append((self._toPos(idx+step, center=True), toYPos(val))) - points.append((self._toPos(max(series.keys()) + 1, center=True), toYPos(series_min))) + points.append((self._toPos(idx + step, center=True), toYPos(val))) + points.append( + (self._toPos(max(series.keys()) + 1, center=True), toYPos(series_min)) + ) # Convert to list of QPointF objects - points = list(itertools.starmap(QPointF,points)) + points = list(itertools.starmap(QPointF, points)) self.poly.setPath(self._pointsToPath(points)) def moveHandle(self, x, y): @@ -523,20 +551,21 @@ def moveHandle(self, x, y): x: x position of mouse y: y position of mouse """ - x -= self.handle.rect().width()/2. + x -= self.handle.rect().width() / 2.0 x = max(x, 0) - x = min(x, self.outlineBox.rect().width()-self.handle.rect().width()) + x = min(x, self.outlineBox.rect().width() - self.handle.rect().width()) val = self._toVal(x) # snap to nearby mark within handle mark_vals = [self._mark_val(mark) for mark in self._marks] - handle_left = self._toVal(x - self.handle.rect().width()/2) - handle_right = self._toVal(x + self.handle.rect().width()/2) - marks_in_handle = [mark for mark in mark_vals - if handle_left < mark < handle_right] + handle_left = self._toVal(x - self.handle.rect().width() / 2) + handle_right = self._toVal(x + self.handle.rect().width() / 2) + marks_in_handle = [ + mark for mark in mark_vals if handle_left < mark < handle_right + ] if marks_in_handle: - marks_in_handle.sort(key=lambda m: (abs(m-val), m>val)) + marks_in_handle.sort(key=lambda m: (abs(m - val), m > val)) val = marks_in_handle[0] old = self.value() @@ -557,8 +586,9 @@ def resizeEvent(self, event=None): handle_rect = self.handle.rect() select_box_rect = self.select_box.rect() - outline_rect.setHeight(height-3) - if event is not None: outline_rect.setWidth(event.size().width()-1) + outline_rect.setHeight(height - 3) + if event is not None: + outline_rect.setWidth(event.size().width() - 1) self.outlineBox.setRect(outline_rect) handle_rect.setTop(self._handleTop()) @@ -580,7 +610,9 @@ def _handleHeight(self, outline_rect=None): outline_rect = self.outlineBox.rect() handle_bottom_offset = 1 - handle_height = outline_rect.height() - (self._handleTop()+handle_bottom_offset) + handle_height = outline_rect.height() - ( + self._handleTop() + handle_bottom_offset + ) return handle_height def mousePressEvent(self, event): @@ -592,9 +624,11 @@ def mousePressEvent(self, event): scenePos = self.mapToScene(event.pos()) # Do nothing if not enabled - if not self.enabled(): return + if not self.enabled(): + return # Do nothing if click outside slider area - if not self.outlineBox.rect().contains(scenePos): return + if not self.outlineBox.rect().contains(scenePos): + return move_function = None release_function = None @@ -658,9 +692,11 @@ def paint(self, *args, **kwargs): app = QApplication([]) window = VideoSlider( - min=0, max=20, val=15, - marks=(10,15)#((0,10),(0,15),(1,10),(1,11),(2,12)), tracks=3 - ) + min=0, + max=20, + val=15, + marks=(10, 15), # ((0,10),(0,15),(1,10),(1,11),(2,12)), tracks=3 + ) window.valueChanged.connect(lambda x: print(x)) window.show() diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 0666fc0ce..5e99dd4d2 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -9,13 +9,14 @@ from sleap.io.video import Video + class VideoFrameSuggestions: - rescale=True - rescale_below=512 + rescale = True + rescale_below = 512 @classmethod - def suggest(cls, video:Video, params:dict, labels: 'Labels'=None) -> list: + def suggest(cls, video: Video, params: dict, labels: "Labels" = None) -> list: """ This is the main entry point. @@ -31,13 +32,13 @@ def suggest(cls, video:Video, params:dict, labels: 'Labels'=None) -> list: # map from method param value to corresponding class method method_functions = dict( - strides=cls.strides, - random=cls.random, - pca=cls.pca_cluster, - hog=cls.hog, - brisk=cls.brisk, - proofreading=cls.proofreading - ) + strides=cls.strides, + random=cls.random, + pca=cls.pca_cluster, + hog=cls.hog, + brisk=cls.brisk, + proofreading=cls.proofreading, + ) method = params["method"] if method_functions.get(method, None) is not None: @@ -49,7 +50,7 @@ def suggest(cls, video:Video, params:dict, labels: 'Labels'=None) -> list: @classmethod def strides(cls, video, per_video=20, **kwargs): - suggestions = list(range(0, video.frames, video.frames//per_video)) + suggestions = list(range(0, video.frames, video.frames // per_video)) suggestions = suggestions[:per_video] return suggestions @@ -63,49 +64,57 @@ def random(cls, video, per_video=20, **kwargs): @classmethod def pca_cluster(cls, video, initial_samples, **kwargs): - sample_step = video.frames//initial_samples + sample_step = video.frames // initial_samples feature_stack, frame_idx_map = cls.frame_feature_stack(video, sample_step) result = cls.feature_stack_to_suggestions( - feature_stack, frame_idx_map, **kwargs) + feature_stack, frame_idx_map, **kwargs + ) return result @classmethod def brisk(cls, video, initial_samples, **kwargs): - sample_step = video.frames//initial_samples + sample_step = video.frames // initial_samples feature_stack, frame_idx_map = cls.brisk_feature_stack(video, sample_step) result = cls.feature_stack_to_suggestions( - feature_stack, frame_idx_map, **kwargs) + feature_stack, frame_idx_map, **kwargs + ) return result @classmethod def hog( - cls, video, - clusters=5, - per_cluster=5, - sample_step=5, - pca_components=50, - interleave=True, - **kwargs): + cls, + video, + clusters=5, + per_cluster=5, + sample_step=5, + pca_components=50, + interleave=True, + **kwargs, + ): feature_stack, frame_idx_map = cls.hog_feature_stack(video, sample_step) result = cls.feature_stack_to_suggestions( - feature_stack, frame_idx_map, - clusters=clusters, per_cluster=per_cluster, - pca_components=pca_components, - interleave=interleave, - **kwargs) + feature_stack, + frame_idx_map, + clusters=clusters, + per_cluster=per_cluster, + pca_components=pca_components, + interleave=interleave, + **kwargs, + ) return result @classmethod def proofreading( - cls, video: Video, labels: 'Labels', score_limit, instance_limit, **kwargs): + cls, video: Video, labels: "Labels", score_limit, instance_limit, **kwargs + ): score_limit = float(score_limit) instance_limit = int(instance_limit) @@ -124,7 +133,7 @@ def proofreading( if len(frame_scores) > instance_limit: frame_scores = sorted(frame_scores)[:instance_limit] # Add to matrix - scores[i,:len(frame_scores)] = frame_scores + scores[i, : len(frame_scores)] = frame_scores idxs[i] = lf.frame_idx # Find instances below score of @@ -139,8 +148,8 @@ def proofreading( # These are specific to the suggestion method @classmethod - def frame_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: - sample_count = video.frames//sample_step + def frame_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: + sample_count = video.frames // sample_step factor = cls.get_scale_factor(video) @@ -151,8 +160,8 @@ def frame_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: frame_idx = i * sample_step img = video[frame_idx].squeeze() - multichannel = (video.channels > 1) - img = rescale(img, scale=.5, anti_aliasing=True, multichannel=multichannel) + multichannel = video.channels > 1 + img = rescale(img, scale=0.5, anti_aliasing=True, multichannel=multichannel) flat_img = img.flatten() @@ -163,7 +172,7 @@ def frame_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: return (flat_stack, frame_idx_map) @classmethod - def brisk_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: + def brisk_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: brisk = cv2.BRISK_create() factor = cls.get_scale_factor(video) @@ -186,8 +195,8 @@ def brisk_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: return (feature_stack, frame_idx_map) @classmethod - def hog_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: - sample_count = video.frames//sample_step + def hog_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: + sample_count = video.frames // sample_step hog = cv2.HOGDescriptor() @@ -211,16 +220,14 @@ def hog_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: # These are common for all suggestion methods @staticmethod - def to_frame_idx_list(selected_list:list, frame_idx_map:dict) -> list: + def to_frame_idx_list(selected_list: list, frame_idx_map: dict) -> list: """Convert list of row indexes to list of frame indexes.""" return list(map(lambda x: frame_idx_map[x], selected_list)) @classmethod def feature_stack_to_suggestions( - cls, - feature_stack, frame_idx_map, - return_clusters=False, - **kwargs): + cls, feature_stack, frame_idx_map, return_clusters=False, **kwargs + ): """ Turns a feature stack matrix into a list of suggested frames. @@ -232,27 +239,28 @@ def feature_stack_to_suggestions( """ selected_by_cluster = cls.feature_stack_to_clusters( - feature_stack=feature_stack, - frame_idx_map=frame_idx_map, - **kwargs) + feature_stack=feature_stack, frame_idx_map=frame_idx_map, **kwargs + ) - if return_clusters: return selected_by_cluster + if return_clusters: + return selected_by_cluster selected_list = cls.clusters_to_list( - selected_by_cluster=selected_by_cluster, - **kwargs) + selected_by_cluster=selected_by_cluster, **kwargs + ) return selected_list @classmethod def feature_stack_to_clusters( - cls, - feature_stack, - frame_idx_map, - clusters=5, - per_cluster=5, - pca_components=50, - **kwargs): + cls, + feature_stack, + frame_idx_map, + clusters=5, + per_cluster=5, + pca_components=50, + **kwargs, + ): """ Turns feature stack matrix into list (per cluster) of list of frame indexes. @@ -284,7 +292,7 @@ def feature_stack_to_clusters( selected_by_cluster = [] selected_set = set() for i in range(clusters): - cluster_items, = np.where(row_labels==i) + cluster_items, = np.where(row_labels == i) # convert from row indexes to frame indexes cluster_items = cls.to_frame_idx_list(cluster_items, frame_idx_map) @@ -294,15 +302,19 @@ def feature_stack_to_clusters( cluster_items = list(set(cluster_items) - selected_set) # pick [per_cluster] items from this cluster - samples_from_bin = np.random.choice(cluster_items, min(len(cluster_items), per_cluster), False) + samples_from_bin = np.random.choice( + cluster_items, min(len(cluster_items), per_cluster), False + ) samples_from_bin.sort() selected_by_cluster.append(samples_from_bin) - selected_set = selected_set.union( set(samples_from_bin) ) + selected_set = selected_set.union(set(samples_from_bin)) return selected_by_cluster @classmethod - def clusters_to_list(cls, selected_by_cluster, interleave:bool = True, **kwargs) -> list: + def clusters_to_list( + cls, selected_by_cluster, interleave: bool = True, **kwargs + ) -> list: """ Turns list (per cluster) of lists of frame index into single list of frame indexes. @@ -317,9 +329,11 @@ def clusters_to_list(cls, selected_by_cluster, interleave:bool = True, **kwargs) if interleave: # cycle clusters - all_selected = itertools.chain.from_iterable(itertools.zip_longest(*selected_by_cluster)) + all_selected = itertools.chain.from_iterable( + itertools.zip_longest(*selected_by_cluster) + ) # remove Nones and convert back to list - all_selected = list(filter(lambda x:x is not None, all_selected)) + all_selected = list(filter(lambda x: x is not None, all_selected)) else: all_selected = list(itertools.chain.from_iterable(selected_by_cluster)) all_selected.sort() @@ -336,7 +350,7 @@ def get_scale_factor(cls, video) -> int: if cls.rescale: largest_dim = max(video.height, video.width) factor = 1 - while largest_dim/factor > cls.rescale_below: + while largest_dim / factor > cls.rescale_below: factor += 1 return factor @@ -344,21 +358,22 @@ def get_scale_factor(cls, video) -> int: def resize(cls, img, factor) -> np.ndarray: h, w, _ = img.shape if factor != 1: - return cv2.resize(img, (h//factor, w//factor)) + return cv2.resize(img, (h // factor, w // factor)) else: return img + if __name__ == "__main__": # load some images filename = "tests/data/videos/centered_pair_small.mp4" filename = "files/190605_1509_frame23000_24000.sf.mp4" video = Video.from_filename(filename) - debug=False + debug = False - x = VideoFrameSuggestions.hog(video=video, sample_step=20, - clusters=5, per_cluster=5, - return_clusters=debug) + x = VideoFrameSuggestions.hog( + video=video, sample_step=20, clusters=5, per_cluster=5, return_clusters=debug + ) print(x) if debug: @@ -410,4 +425,4 @@ def resize(cls, img, factor) -> np.ndarray: # print(len(kp)) # print(des.shape) -# print(VideoFrameSuggestions.suggest(video, dict(method="pca"))) \ No newline at end of file +# print(VideoFrameSuggestions.suggest(video, dict(method="pca"))) diff --git a/sleap/gui/training_editor.py b/sleap/gui/training_editor.py index ae6960acd..4e79ee30d 100644 --- a/sleap/gui/training_editor.py +++ b/sleap/gui/training_editor.py @@ -10,18 +10,30 @@ from sleap.io.dataset import Labels from sleap.gui.formbuilder import YamlFormWidget -class TrainingEditor(QtWidgets.QDialog): - def __init__(self, profile_filename: Optional[str]=None, saved_files: list=[], *args, **kwargs): +class TrainingEditor(QtWidgets.QDialog): + def __init__( + self, + profile_filename: Optional[str] = None, + saved_files: list = [], + *args, + **kwargs + ): super(TrainingEditor, self).__init__() - form_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/training_editor.yaml") + form_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/training_editor.yaml" + ) self.form_widgets = dict() - self.form_widgets["model"] = YamlFormWidget(form_yaml, "model", "Network Architecture") - self.form_widgets["datagen"] = YamlFormWidget(form_yaml, "datagen", "Data Generation/Preprocessing") + self.form_widgets["model"] = YamlFormWidget( + form_yaml, "model", "Network Architecture" + ) + self.form_widgets["datagen"] = YamlFormWidget( + form_yaml, "datagen", "Data Generation/Preprocessing" + ) self.form_widgets["trainer"] = YamlFormWidget(form_yaml, "trainer", "Trainer") - self.form_widgets["output"] = YamlFormWidget(form_yaml, "output",) + self.form_widgets["output"] = YamlFormWidget(form_yaml, "output") self.form_widgets["buttons"] = YamlFormWidget(form_yaml, "buttons") self.form_widgets["buttons"].mainAction.connect(self._save_as) @@ -64,7 +76,7 @@ def _layout_widget(layout): widget.setLayout(layout) return widget - def _load_profile(self, profile_filename:str): + def _load_profile(self, profile_filename: str): from sleap.nn.model import ModelOutputType from sleap.nn.training import TrainingJob @@ -83,7 +95,9 @@ def _load_profile(self, profile_filename:str): def _save_as(self): # Show "Save" dialog - save_filename, _ = QtWidgets.QFileDialog.getSaveFileName(self, caption="Save As...", dir=None, filter="Profile JSON (*.json)") + save_filename, _ = QtWidgets.QFileDialog.getSaveFileName( + self, caption="Save As...", dir=None, filter="Profile JSON (*.json)" + ) if len(save_filename): from sleap.nn.model import Model, ModelOutputType @@ -92,35 +106,47 @@ def _save_as(self): # Construct Model model_data = self.form_widgets["model"].get_form_data() - arch = dict(LeapCNN=leap.LeapCNN, - StackedHourglass=hourglass.StackedHourglass, - UNet=unet.UNet, - StackedUNet=unet.StackedUNet, - )[model_data["arch"]] - - output_type = dict(confmaps=ModelOutputType.CONFIDENCE_MAP, - pafs=ModelOutputType.PART_AFFINITY_FIELD, - centroids=ModelOutputType.CENTROIDS - )[model_data["output_type"]] - - backbone_kwargs = {key:val for key, val in model_data.items() - if key in attr.fields_dict(arch).keys()} + arch = dict( + LeapCNN=leap.LeapCNN, + StackedHourglass=hourglass.StackedHourglass, + UNet=unet.UNet, + StackedUNet=unet.StackedUNet, + )[model_data["arch"]] + + output_type = dict( + confmaps=ModelOutputType.CONFIDENCE_MAP, + pafs=ModelOutputType.PART_AFFINITY_FIELD, + centroids=ModelOutputType.CENTROIDS, + )[model_data["output_type"]] + + backbone_kwargs = { + key: val + for key, val in model_data.items() + if key in attr.fields_dict(arch).keys() + } model = Model(output_type=output_type, backbone=arch(**backbone_kwargs)) # Construct Trainer - trainer_data = {**self.form_widgets["datagen"].get_form_data(), - **self.form_widgets["output"].get_form_data(), - **self.form_widgets["trainer"].get_form_data(), - } - - trainer_kwargs = {key:val for key, val in trainer_data.items() - if key in attr.fields_dict(Trainer).keys()} + trainer_data = { + **self.form_widgets["datagen"].get_form_data(), + **self.form_widgets["output"].get_form_data(), + **self.form_widgets["trainer"].get_form_data(), + } + + trainer_kwargs = { + key: val + for key, val in trainer_data.items() + if key in attr.fields_dict(Trainer).keys() + } trainer = Trainer(**trainer_kwargs) # Construct TrainingJob - training_job_kwargs = {key:val for key, val in trainer_data.items() - if key in attr.fields_dict(TrainingJob).keys()} + training_job_kwargs = { + key: val + for key, val in trainer_data.items() + if key in attr.fields_dict(TrainingJob).keys() + } training_job = TrainingJob(model, trainer, **training_job_kwargs) # Write the file @@ -132,6 +158,7 @@ def _save_as(self): self.close() + if __name__ == "__main__": import sys @@ -142,4 +169,4 @@ def _save_as(self): app = QtWidgets.QApplication([]) win = TrainingEditor(profile_filename) win.show() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 4e01d1d84..3876b0f67 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -31,8 +31,14 @@ from typing import Callable, Union from PySide2.QtWidgets import QGraphicsItem, QGraphicsObject + # The PySide2.QtWidgets.QGraphicsObject class provides a base class for all graphics items that require signals, slots and properties. -from PySide2.QtWidgets import QGraphicsEllipseItem, QGraphicsLineItem, QGraphicsTextItem, QGraphicsRectItem +from PySide2.QtWidgets import ( + QGraphicsEllipseItem, + QGraphicsLineItem, + QGraphicsTextItem, + QGraphicsRectItem, +) from sleap.skeleton import Skeleton from sleap.instance import Instance, Point @@ -142,14 +148,14 @@ def addInstance(self, instance, **kwargs): # Check if instance is an Instance (or subclass of Instance) if issubclass(type(instance), Instance): instance = QtInstance(instance=instance, **kwargs) - if type(instance) != QtInstance: return + if type(instance) != QtInstance: + return self.view.scene.addItem(instance) # connect signal from instance instance.changedData.connect(self.changedData) - # connect signal so we can adjust QtNodeLabel positions after zoom self.view.updatedViewer.connect(instance.updatePoints) @@ -237,8 +243,9 @@ def zoomToFit(self): if not zoom_rect.size().isEmpty(): self.view.zoomToRect(zoom_rect) - def onSequenceSelect(self, seq_len: int, on_success: Callable, - on_each = None, on_failure = None): + def onSequenceSelect( + self, seq_len: int, on_success: Callable, on_each=None, on_failure=None + ): """ Collect a sequence of instances (through user selection) and call `on_success`. If the user cancels (by unselecting without new selection), call `on_failure`. @@ -258,11 +265,13 @@ def onSequenceSelect(self, seq_len: int, on_success: Callable, indexes.append(self.view.getSelection()) # Define function that will be called when user selects another instance - def handle_selection(seq_len=seq_len, - indexes=indexes, - on_success=on_success, - on_each=on_each, - on_failure=on_failure): + def handle_selection( + seq_len=seq_len, + indexes=indexes, + on_success=on_success, + on_each=on_each, + on_failure=on_failure, + ): # Get the index of the currently selected instance new_idx = self.view.getSelection() # If something is selected, add it to the list @@ -296,6 +305,7 @@ def _signal_once(signal, callback): def call_once(*args): signal.disconnect(call_once) callback(*args) + signal.connect(call_once) def onPointSelection(self, callback: Callable): @@ -343,9 +353,9 @@ def keyPressEvent(self, event: QKeyEvent): self.view.nextSelection() elif event.key() < 128 and chr(event.key()).isnumeric(): # decrement by 1 since instances are 0-indexed - self.view.selectInstance(int(chr(event.key()))-1) + self.view.selectInstance(int(chr(event.key())) - 1) else: - event.ignore() # Kicks the event up to parent + event.ignore() # Kicks the event up to parent # print(event.key()) # If user is holding down shift and action resulted in moving to another frame @@ -358,6 +368,7 @@ def keyPressEvent(self, event: QKeyEvent): # Set endpoint to frame after action self.seekbar.endSelection(self.frame_idx, update=True) + class GraphicsView(QGraphicsView): """ QGraphicsView used by QtVideoPlayer. @@ -440,7 +451,9 @@ def setImage(self, image): # pixmap = QPixmap.fromImage(image) pixmap = QPixmap(image) else: - raise RuntimeError("ImageViewer.setImage: Argument must be a QImage or QPixmap.") + raise RuntimeError( + "ImageViewer.setImage: Argument must be a QImage or QPixmap." + ) if self.hasImage(): self._pixmapHandle.setPixmap(pixmap) else: @@ -470,13 +483,19 @@ def instances(self): Order in list should match the order in which instances were added to scene. """ - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and not item.predicted] + return [ + item + for item in self.scene.items(Qt.SortOrder.AscendingOrder) + if type(item) == QtInstance and not item.predicted + ] @property def selectable_instances(self): - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and item.selectable] + return [ + item + for item in self.scene.items(Qt.SortOrder.AscendingOrder) + if type(item) == QtInstance and item.selectable + ] @property def predicted_instances(self): @@ -485,8 +504,11 @@ def predicted_instances(self): Order in list should match the order in which instances were added to scene. """ - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and item.predicted] + return [ + item + for item in self.scene.items(Qt.SortOrder.AscendingOrder) + if type(item) == QtInstance and item.predicted + ] @property def all_instances(self): @@ -495,8 +517,11 @@ def all_instances(self): Order in list should match the order in which instances were added to scene. """ - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance] + return [ + item + for item in self.scene.items(Qt.SortOrder.AscendingOrder) + if type(item) == QtInstance + ] def clearSelection(self, signal=True): """ Clear instance skeleton selection. @@ -504,19 +529,21 @@ def clearSelection(self, signal=True): for instance in self.all_instances: instance.selected = False # signal that the selection has changed (so we can update visual display) - if signal: self.updatedSelection.emit() + if signal: + self.updatedSelection.emit() def nextSelection(self): """ Select next instance (or first, if none currently selected). """ instances = self.selectable_instances - if len(instances) == 0: return - select_inst = instances[0] # default to selecting first instance + if len(instances) == 0: + return + select_inst = instances[0] # default to selecting first instance select_idx = 0 for idx, instance in enumerate(instances): if instance.selected: instance.selected = False - select_idx = (idx+1)%len(instances) + select_idx = (idx + 1) % len(instances) select_inst = instances[select_idx] break select_inst.selected = True @@ -535,17 +562,19 @@ def selectInstance(self, select: Union[Instance, int], signal=True): self.clearSelection(signal=False) for idx, instance in enumerate(self.all_instances): - instance.selected = (select == idx or select == instance.instance) + instance.selected = select == idx or select == instance.instance # signal that the selection has changed (so we can update visual display) - if signal: self.updatedSelection.emit() + if signal: + self.updatedSelection.emit() def getSelection(self): """ Returns the index of the currently selected instance. If no instance selected, returns None. """ instances = self.all_instances - if len(instances) == 0: return None + if len(instances) == 0: + return None for idx, instance in enumerate(instances): if instance.selected: return idx @@ -555,7 +584,8 @@ def getSelectionInstance(self): If no instance selected, returns None. """ instances = self.all_instances - if len(instances) == 0: return None + if len(instances) == 0: + return None for idx, instance in enumerate(instances): if instance.selected: return instance.instance @@ -597,7 +627,7 @@ def mouseReleaseEvent(self, event): QGraphicsView.mouseReleaseEvent(self, event) scenePos = self.mapToScene(event.pos()) # check if mouse moved during click - has_moved = (event.pos() != self._down_pos) + has_moved = event.pos() != self._down_pos if event.button() == Qt.LeftButton: if self.click_mode == "": @@ -605,12 +635,17 @@ def mouseReleaseEvent(self, event): if not has_moved: # When just a tap, see if there's an item underneath to select clicked = self.scene.items(scenePos, Qt.IntersectsItemBoundingRect) - clicked_instances = [item for item in clicked - if type(item) == QtInstance and item.selectable] + clicked_instances = [ + item + for item in clicked + if type(item) == QtInstance and item.selectable + ] # We only handle single instance selection so pick at most one from list - clicked_instance = clicked_instances[0] if len(clicked_instances) else None + clicked_instance = ( + clicked_instances[0] if len(clicked_instances) else None + ) for idx, instance in enumerate(self.selectable_instances): - instance.selected = (instance == clicked_instance) + instance.selected = instance == clicked_instance # If we want to allow selection of multiple instances, do this: # instance.selected = (instance in clicked) self.updatedSelection.emit() @@ -620,10 +655,11 @@ def mouseReleaseEvent(self, event): selection_rect = self.scene.selectionArea().boundingRect() self.areaSelected.emit( - selection_rect.left(), - selection_rect.top(), - selection_rect.right(), - selection_rect.bottom()) + selection_rect.left(), + selection_rect.top(), + selection_rect.right(), + selection_rect.bottom(), + ) elif self.click_mode == "point": selection_point = scenePos self.pointSelected.emit(scenePos.x(), scenePos.y()) @@ -638,7 +674,7 @@ def mouseReleaseEvent(self, event): elif event.button() == Qt.RightButton: if self.canZoom: zoom_rect = self.scene.selectionArea().boundingRect() - self.scene.setSelectionArea(QPainterPath()) # clear selection + self.scene.setSelectionArea(QPainterPath()) # clear selection self.zoomToRect(zoom_rect) self.setDragMode(QGraphicsView.NoDrag) self.rightMouseButtonReleased.emit(scenePos.x(), scenePos.y()) @@ -656,10 +692,11 @@ def zoomToRect(self, zoom_rect: QRectF): relative: Controls whether rect is relative to current zoom. """ - if zoom_rect.isNull(): return + if zoom_rect.isNull(): + return - scale_h = self.scene.height()/zoom_rect.height() - scale_w = self.scene.width()/zoom_rect.width() + scale_h = self.scene.height() / zoom_rect.height() + scale_w = self.scene.width() / zoom_rect.width() scale = min(scale_h, scale_w) self.zoomFactor = scale @@ -725,10 +762,10 @@ def wheelEvent(self, event): pass def keyPressEvent(self, event): - event.ignore() # Kicks the event up to parent + event.ignore() # Kicks the event up to parent def keyReleaseEvent(self, event): - event.ignore() # Kicks the event up to parent + event.ignore() # Kicks the event up to parent class QtNodeLabel(QGraphicsTextItem): @@ -776,20 +813,24 @@ def adjustPos(self, *args, **kwargs): if len(node.edges): edge_angles = sorted([edge.angle_to(node) for edge in node.edges]) - edge_angles.append(edge_angles[0] + math.pi*2) + edge_angles.append(edge_angles[0] + math.pi * 2) # Calculate size and bisector for each arc between adjacent edges - edge_arcs = [(edge_angles[i+1]-edge_angles[i], - edge_angles[i+1]/2+edge_angles[i]/2) - for i in range(len(edge_angles)-1)] + edge_arcs = [ + ( + edge_angles[i + 1] - edge_angles[i], + edge_angles[i + 1] / 2 + edge_angles[i] / 2, + ) + for i in range(len(edge_angles) - 1) + ] max_arc = sorted(edge_arcs)[-1] - shift_angle = max_arc[1] # this is the angle of the bisector - shift_angle %= 2*math.pi + shift_angle = max_arc[1] # this is the angle of the bisector + shift_angle %= 2 * math.pi # Use the _shift_factor to control how the label is positioned # relative to the node. # Shift factor of -1 means we shift label up/left by its height/width. - self._shift_factor_x = (math.cos(shift_angle)*.6) -.5 - self._shift_factor_y = (math.sin(shift_angle)*.6) -.5 + self._shift_factor_x = (math.cos(shift_angle) * 0.6) - 0.5 + self._shift_factor_y = (math.sin(shift_angle) * 0.6) - 0.5 # Since item doesn't scale when view is transformed (i.e., zoom) # we need to calculate bounding size in view manually. @@ -803,8 +844,10 @@ def adjustPos(self, *args, **kwargs): height = height / view.viewportTransform().m11() width = width / view.viewportTransform().m22() - self.setPos(self._anchor_x + width*self._shift_factor_x, - self._anchor_y + height*self._shift_factor_y) + self.setPos( + self._anchor_x + width * self._shift_factor_x, + self._anchor_y + height * self._shift_factor_y, + ) # Now apply these changes to the visual display self.adjustStyle() @@ -813,7 +856,9 @@ def adjustStyle(self): """ Update visual display of the label and its node. """ - complete_color = QColor(80, 194, 159) if self.node.point.complete else QColor(232, 45, 32) + complete_color = ( + QColor(80, 194, 159) if self.node.point.complete else QColor(232, 45, 32) + ) if self.predicted: self._base_font.setBold(False) @@ -827,13 +872,13 @@ def adjustStyle(self): elif self.node.point.complete: self._base_font.setBold(True) self.setFont(self._base_font) - self.setDefaultTextColor(complete_color) # greenish + self.setDefaultTextColor(complete_color) # greenish # FIXME: Adjust style of node here as well? # self.node.setBrush(complete_color) else: self._base_font.setBold(False) self.setFont(self._base_font) - self.setDefaultTextColor(complete_color) # redish + self.setDefaultTextColor(complete_color) # redish def boundingRect(self): """ Method required by Qt. @@ -881,9 +926,21 @@ class QtNode(QGraphicsEllipseItem): color: Color of the visual node item. callbacks: List of functions to call after we update to the `Point`. """ - def __init__(self, parent, point:Point, radius:float, color:list, node_name:str = None, - predicted=False, color_predicted=False, show_non_visible=True, - callbacks = None, *args, **kwargs): + + def __init__( + self, + parent, + point: Point, + radius: float, + color: list, + node_name: str = None, + predicted=False, + color_predicted=False, + show_non_visible=True, + callbacks=None, + *args, + **kwargs, + ): self._parent = parent self.point = point self.radius = radius @@ -896,7 +953,15 @@ def __init__(self, parent, point:Point, radius:float, color:list, node_name:str self.callbacks = [] if callbacks is None else callbacks self.dragParent = False - super(QtNode, self).__init__(-self.radius, -self.radius, self.radius*2, self.radius*2, parent=parent, *args, **kwargs) + super(QtNode, self).__init__( + -self.radius, + -self.radius, + self.radius * 2, + self.radius * 2, + parent=parent, + *args, + **kwargs, + ) if node_name is not None: self.setToolTip(node_name) @@ -920,7 +985,9 @@ def __init__(self, parent, point:Point, radius:float, color:list, node_name:str self.setFlag(QGraphicsItem.ItemIsMovable) self.pen_default = QPen(col_line, 1) - self.pen_default.setCosmetic(True) # https://stackoverflow.com/questions/13120486/adjusting-qpen-thickness-when-scaling-qgraphicsview + self.pen_default.setCosmetic( + True + ) # https://stackoverflow.com/questions/13120486/adjusting-qpen-thickness-when-scaling-qgraphicsview self.pen_missing = QPen(col_line, 1) self.pen_missing.setCosmetic(True) self.brush = QBrush(QColor(*self.color, a=128)) @@ -951,12 +1018,12 @@ def updatePoint(self, user_change=True): self.setPen(self.pen_default) self.setBrush(self.brush) else: - radius = self.radius / 2. + radius = self.radius / 2.0 self.setPen(self.pen_missing) self.setBrush(self.brush_missing) if not self.show_non_visible: self.hide() - self.setRect(-radius, -radius, radius*2, radius*2) + self.setRect(-radius, -radius, radius * 2, radius * 2) for edge in self.edges: edge.updateEdge(self) @@ -967,13 +1034,15 @@ def updatePoint(self, user_change=True): self.calls() # Emit event if we're updating from a user change - if user_change: self._parent.changedData.emit(self._parent.instance) + if user_change: + self._parent.changedData.emit(self._parent.instance) def mousePressEvent(self, event): """ Custom event handler for mouse press. """ # Do nothing if node is from predicted instance - if self.parentObject().predicted: return + if self.parentObject().predicted: + return self.setCursor(Qt.ArrowCursor) @@ -1005,17 +1074,19 @@ def mousePressEvent(self, event): def mouseMoveEvent(self, event): """ Custom event handler for mouse move. """ - #print(event) + # print(event) if self.dragParent: self.parentObject().mouseMoveEvent(event) else: super(QtNode, self).mouseMoveEvent(event) - self.updatePoint(user_change=False) # don't count change until mouse release + self.updatePoint( + user_change=False + ) # don't count change until mouse release def mouseReleaseEvent(self, event): """ Custom event handler for mouse release. """ - #print(event) + # print(event) self.unsetCursor() if self.dragParent: self.parentObject().mouseReleaseEvent(event) @@ -1040,6 +1111,7 @@ def mouseDoubleClickEvent(self, event): view = scene.views()[0] view.instanceDoubleClicked.emit(self.parentObject().instance) + class QtEdge(QGraphicsLineItem): """ QGraphicsLineItem to handle display of edge between skeleton instance nodes. @@ -1048,14 +1120,30 @@ class QtEdge(QGraphicsLineItem): src: The `QtNode` source node for the edge. dst: The `QtNode` destination node for the edge. """ - def __init__(self, parent, src:QtNode, dst:QtNode, color, - show_non_visible=True, - *args, **kwargs): + + def __init__( + self, + parent, + src: QtNode, + dst: QtNode, + color, + show_non_visible=True, + *args, + **kwargs, + ): self.src = src self.dst = dst self.show_non_visible = show_non_visible - super(QtEdge, self).__init__(self.src.point.x, self.src.point.y, self.dst.point.x, self.dst.point.y, parent=parent, *args, **kwargs) + super(QtEdge, self).__init__( + self.src.point.x, + self.src.point.y, + self.dst.point.x, + self.dst.point.y, + parent=parent, + *args, + **kwargs, + ) pen = QPen(QColor(*color), 1) pen.setCosmetic(True) @@ -1103,7 +1191,7 @@ def updateEdge(self, node): if self.src.point.visible and self.dst.point.visible: self.full_opacity = 1 else: - self.full_opacity = .5 if self.show_non_visible else 0 + self.full_opacity = 0.5 if self.show_non_visible else 0 self.setOpacity(self.full_opacity) if node == self.src: @@ -1134,11 +1222,18 @@ class QtInstance(QGraphicsObject): changedData = Signal(Instance) - def __init__(self, skeleton:Skeleton = None, instance: Instance = None, - predicted=False, color_predicted=False, - color=(0, 114, 189), markerRadius=4, - show_non_visible=True, - *args, **kwargs): + def __init__( + self, + skeleton: Skeleton = None, + instance: Instance = None, + predicted=False, + color_predicted=False, + color=(0, 114, 189), + markerRadius=4, + show_non_visible=True, + *args, + **kwargs, + ): super(QtInstance, self).__init__(*args, **kwargs) self.skeleton = skeleton if instance is None else instance.skeleton self.instance = instance @@ -1156,8 +1251,8 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, self.labels_shown = True self._selected = False self._bounding_rect = QRectF() - #self.setFlag(QGraphicsItem.ItemIsMovable) - #self.setFlag(QGraphicsItem.ItemIsSelectable) + # self.setFlag(QGraphicsItem.ItemIsMovable) + # self.setFlag(QGraphicsItem.ItemIsSelectable) if self.predicted: self.setZValue(0) @@ -1183,15 +1278,23 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, track_name = "[none]" instance_label_text += f"Track: {track_name}" if hasattr(self.instance, "score"): - instance_label_text += f"
Prediction Score: {round(self.instance.score, 2)}" + instance_label_text += ( + f"
Prediction Score: {round(self.instance.score, 2)}" + ) self.track_label.setHtml(instance_label_text) # Add nodes for (node, point) in self.instance.nodes_points: - node_item = QtNode(parent=self, point=point, node_name=node.name, - predicted=self.predicted, color_predicted=self.color_predicted, - color=self.color, radius=self.markerRadius, - show_non_visible=self.show_non_visible) + node_item = QtNode( + parent=self, + point=point, + node_name=node.name, + predicted=self.predicted, + color_predicted=self.color_predicted, + color=self.color, + radius=self.markerRadius, + show_non_visible=self.show_non_visible, + ) self.nodes[node.name] = node_item @@ -1199,8 +1302,13 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, for (src, dst) in self.skeleton.edge_names: # Make sure that both nodes are present in this instance before drawing edge if src in self.nodes and dst in self.nodes: - edge_item = QtEdge(parent=self, src=self.nodes[src], dst=self.nodes[dst], - color=self.color, show_non_visible=self.show_non_visible) + edge_item = QtEdge( + parent=self, + src=self.nodes[src], + dst=self.nodes[dst], + color=self.color, + show_non_visible=self.show_non_visible, + ) self.nodes[src].edges.append(edge_item) self.nodes[dst].edges.append(edge_item) self.edges.append(edge_item) @@ -1220,7 +1328,7 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, # Update size of box so it includes all the nodes/edges self.updateBox() - def updatePoints(self, complete:bool = False, user_change:bool = False): + def updatePoints(self, complete: bool = False, user_change: bool = False): """ Updates data and display for all points in skeleton. @@ -1239,7 +1347,8 @@ def updatePoints(self, complete:bool = False, user_change:bool = False): node_item.point.x = node_item.scenePos().x() node_item.point.y = node_item.scenePos().y() node_item.setPos(node_item.point.x, node_item.point.y) - if complete: node_item.point.complete = True + if complete: + node_item.point.complete = True # Wait to run callbacks until all nodes are updated # Otherwise the label positions aren't correct since # they depend on the edge vectors to old node positions. @@ -1255,13 +1364,18 @@ def updatePoints(self, complete:bool = False, user_change:bool = False): # Update box for instance selection self.updateBox() # Emit event if we're updating from a user change - if user_change: self.changedData.emit(self.instance) + if user_change: + self.changedData.emit(self.instance) def getPointsBoundingRect(self): """Returns a rect which contains all the nodes in the skeleton.""" rect = None for item in self.edges: - rect = item.boundingRect() if rect is None else rect.united(item.boundingRect()) + rect = ( + item.boundingRect() + if rect is None + else rect.united(item.boundingRect()) + ) return rect def updateBox(self, *args, **kwargs): @@ -1273,7 +1387,7 @@ def updateBox(self, *args, **kwargs): select this instance. """ # Only show box if instance is selected - op = .7 if self._selected else 0 + op = 0.7 if self._selected else 0 self.box.setOpacity(op) # Update the position for the box rect = self.getPointsBoundingRect() @@ -1289,7 +1403,7 @@ def selected(self): return self._selected @selected.setter - def selected(self, selected:bool): + def selected(self, selected: bool): self._selected = selected # Update the selection box for this skeleton instance self.updateBox() @@ -1316,7 +1430,7 @@ def toggleEdges(self): """ self.showEdges(not self.edges_shown) - def showEdges(self, show = True): + def showEdges(self, show=True): """ Draws/hides the edges for this skeleton instance. @@ -1338,8 +1452,8 @@ def paint(self, painter, option, widget=None): """ pass -class QtTextWithBackground(QGraphicsTextItem): +class QtTextWithBackground(QGraphicsTextItem): def __init__(self, *args, **kwargs): super(QtTextWithBackground, self).__init__(*args, **kwargs) self.setFlag(QGraphicsItem.ItemIgnoresTransformations) @@ -1354,24 +1468,30 @@ def paint(self, painter, option, *args, **kwargs): """ text_color = self.defaultTextColor() brush = painter.brush() - background_color = "white" if text_color.lightnessF() < .4 else "black" - background_color = QColor(background_color, a=.5) + background_color = "white" if text_color.lightnessF() < 0.4 else "black" + background_color = QColor(background_color, a=0.5) painter.setBrush(QBrush(background_color)) painter.drawRect(self.boundingRect()) painter.setBrush(brush) super(QtTextWithBackground, self).paint(painter, option, *args, **kwargs) + def video_demo(labels, standalone=False): video = labels.videos[0] - if standalone: app = QApplication([]) + if standalone: + app = QApplication([]) window = QtVideoPlayer(video=video) - window.changedPlot.connect(lambda vp, idx, select_idx: plot_instances(vp.view.scene, idx, labels, video)) + window.changedPlot.connect( + lambda vp, idx, select_idx: plot_instances(vp.view.scene, idx, labels, video) + ) window.show() window.plot() - if standalone: app.exec_() + if standalone: + app.exec_() + def plot_instances(scene, frame_idx, labels, video=None, fixed=True): from sleap.gui.overlays.tracks import TrackColorManager @@ -1380,7 +1500,8 @@ def plot_instances(scene, frame_idx, labels, video=None, fixed=True): color_manager = TrackColorManager(labels=labels) lfs = labels.find(video, frame_idx) - if not lfs: return + if not lfs: + return labeled_frame = lfs[0] @@ -1394,15 +1515,18 @@ def plot_instances(scene, frame_idx, labels, video=None, fixed=True): count_no_track += 1 # Plot instance - inst = QtInstance(instance=instance, - color=color_manager.get_color(pseudo_track), - predicted=fixed, - color_predicted=True, - show_non_visible=False) + inst = QtInstance( + instance=instance, + color=color_manager.get_color(pseudo_track), + predicted=fixed, + color_predicted=True, + show_non_visible=False, + ) inst.showLabels(False) scene.addItem(inst) inst.updatePoints() + if __name__ == "__main__": import argparse @@ -1413,4 +1537,4 @@ def plot_instances(scene, frame_idx, labels, video=None, fixed=True): args = parser.parse_args() labels = Labels.load_json(args.data_path) - video_demo(labels, standalone=True) \ No newline at end of file + video_demo(labels, standalone=True) diff --git a/sleap/info/labels.py b/sleap/info/labels.py index 20c169a6f..5d5b8c7eb 100644 --- a/sleap/info/labels.py +++ b/sleap/info/labels.py @@ -38,4 +38,4 @@ print(f" tracks: {len(tracks)}") print(f" max instances in frame: {concurrent_count}") - print(f"Total user labeled frames: {total_user_frames}") \ No newline at end of file + print(f"Total user labeled frames: {total_user_frames}") diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index ca155076c..46969895f 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -7,12 +7,13 @@ from sleap.instance import Instance, PredictedInstance from sleap.io.dataset import Labels + def matched_instance_distances( - labels_gt: Labels, - labels_pr: Labels, - match_lists_function: Callable, - frame_range: Optional[range]=None) -> Tuple[ - List[int], np.ndarray, np.ndarray, np.ndarray]: + labels_gt: Labels, + labels_pr: Labels, + match_lists_function: Callable, + frame_range: Optional[range] = None, +) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: """ Distances between ground truth and predicted nodes over a set of frames. @@ -62,7 +63,7 @@ def matched_instance_distances( points_gt.append(sorted_gt) points_pr.append(sorted_pr) - frame_idxs.extend([frame_idx]*len(sorted_gt)) + frame_idxs.extend([frame_idx] * len(sorted_gt)) # Convert arrays to numpy matrixes # instances * nodes * (x,y) @@ -75,26 +76,32 @@ def matched_instance_distances( return frame_idxs, D, points_gt, points_pr + def match_instance_lists( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]], - cost_function: Callable) -> Tuple[ - List[Union[Instance, PredictedInstance]], - List[Union[Instance, PredictedInstance]]]: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], + cost_function: Callable, +) -> Tuple[ + List[Union[Instance, PredictedInstance]], List[Union[Instance, PredictedInstance]] +]: """Sorts two lists of Instances to find best overall correspondence for a given cost function (e.g., total distance between points).""" - pairwise_distance_matrix = calculate_pairwise_cost(instances_a, instances_b, cost_function) + pairwise_distance_matrix = calculate_pairwise_cost( + instances_a, instances_b, cost_function + ) match_a, match_b = linear_sum_assignment(pairwise_distance_matrix) sorted_a = list(map(lambda idx: instances_a[idx], match_a)) sorted_b = list(map(lambda idx: instances_b[idx], match_b)) return sorted_a, sorted_b + def calculate_pairwise_cost( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]], - cost_function: Callable) -> np.ndarray: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], + cost_function: Callable, +) -> np.ndarray: """Calculate (a * b) matrix of pairwise costs using cost function.""" matrix_size = (len(instances_a), len(instances_b)) @@ -114,12 +121,14 @@ def calculate_pairwise_cost( pairwise_cost_matrix[idx_a, idx_b] = cost return pairwise_cost_matrix + def match_instance_lists_nodewise( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]], - thresh: float=5) -> Tuple[ - List[Union[Instance, PredictedInstance]], - List[Union[Instance, PredictedInstance]]]: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], + thresh: float = 5, +) -> Tuple[ + List[Union[Instance, PredictedInstance]], List[Union[Instance, PredictedInstance]] +]: """For each node for each instance in the first list, pairs it with the closest corresponding node from *any* instance in the second list.""" @@ -141,8 +150,8 @@ def match_instance_lists_nodewise( for node_idx in range(node_count): # Make sure there's some prediction for this node - if any(~np.isnan(dist_array[:,node_idx])): - best_idx = np.nanargmin(dist_array[:,node_idx]) + if any(~np.isnan(dist_array[:, node_idx])): + best_idx = np.nanargmin(dist_array[:, node_idx]) # Ignore closest point if distance is beyond threshold if dist_array[best_idx, node_idx] <= thresh: @@ -153,9 +162,11 @@ def match_instance_lists_nodewise( return instances_a, best_points_array + def point_dist( - inst_a: Union[Instance, PredictedInstance], - inst_b: Union[Instance, PredictedInstance]) -> np.ndarray: + inst_a: Union[Instance, PredictedInstance], + inst_b: Union[Instance, PredictedInstance], +) -> np.ndarray: """Given two instances, returns array of distances for corresponding nodes.""" points_a = inst_a.points_array @@ -163,8 +174,11 @@ def point_dist( point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist -def nodeless_point_dist(inst_a: Union[Instance, PredictedInstance], - inst_b: Union[Instance, PredictedInstance]) -> np.ndarray: + +def nodeless_point_dist( + inst_a: Union[Instance, PredictedInstance], + inst_b: Union[Instance, PredictedInstance], +) -> np.ndarray: """Given two instances, returns array of distances for closest points ignoring node identities.""" @@ -185,15 +199,17 @@ def nodeless_point_dist(inst_a: Union[Instance, PredictedInstance], match_a, match_b = linear_sum_assignment(pairwise_distance_matrix) # Sort points by this match and calculate overall distance - sorted_points_a = points_a[match_a,:] - sorted_points_b = points_b[match_b,:] + sorted_points_a = points_a[match_a, :] + sorted_points_b = points_b[match_b, :] point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist + def compare_instance_lists( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]]) -> np.ndarray: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], +) -> np.ndarray: """Given two lists of corresponding Instances, returns (instances * nodes) matrix of distances between corresponding nodes.""" @@ -203,29 +219,38 @@ def compare_instance_lists( return np.stack(paired_points_array_distances) -def list_points_array(instances: List[Union[Instance, PredictedInstance]]) -> np.ndarray: + +def list_points_array( + instances: List[Union[Instance, PredictedInstance]] +) -> np.ndarray: """Given list of Instances, returns (instances * nodes * 2) matrix.""" points_arrays = list(map(lambda inst: inst.points_array, instances)) return np.stack(points_arrays) -def point_match_count(dist_array: np.ndarray, thresh: float=5) -> int: + +def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are <= threshold.""" return np.sum(dist_array[~np.isnan(dist_array)] <= thresh) -def point_nonmatch_count(dist_array: np.ndarray, thresh: float=5) -> int: + +def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are not <= threshold.""" return dist_array.shape[0] - point_match_count(dist_array, thresh) + def foo(labels_gt, labels_pr, frame_idx=1092): list_a = labels_gt.find(labels_gt.videos[0], frame_idx=frame_idx)[0].instances list_b = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx)[0].instances match_instance_lists_nodewise(list_a, list_b) + if __name__ == "__main__": labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") - labels_pr = Labels.load_json("tests/data/json_format_v2/centered_pair_predictions.json") + labels_pr = Labels.load_json( + "tests/data/json_format_v2/centered_pair_predictions.json" + ) # OPTION 1 @@ -241,7 +266,9 @@ def foo(labels_gt, labels_pr, frame_idx=1092): # where "match" means the points are within some threshold distance. # Note that each sorted list will be as long as the shorted input list. - instwise_matching_func = lambda gt_list, pr_list: match_instance_lists(gt_list, pr_list, point_nonmatch_count) + instwise_matching_func = lambda gt_list, pr_list: match_instance_lists( + gt_list, pr_list, point_nonmatch_count + ) # PICK THE FUNCTION @@ -249,7 +276,9 @@ def foo(labels_gt, labels_pr, frame_idx=1092): # inst_matching_func = instwise_matching_func # Calculate distances - frame_idxs, D, points_gt, points_pr = matched_instance_distances(labels_gt, labels_pr, inst_matching_func) + frame_idxs, D, points_gt, points_pr = matched_instance_distances( + labels_gt, labels_pr, inst_matching_func + ) # Show mean difference for each node node_names = labels_gt.skeletons[0].node_names diff --git a/sleap/info/summary.py b/sleap/info/summary.py index 7e3106441..9f25b2305 100644 --- a/sleap/info/summary.py +++ b/sleap/info/summary.py @@ -4,7 +4,7 @@ @attr.s(auto_attribs=True) class Summary: - labels: 'Labels' + labels: "Labels" def get_point_count_series(self, video): series = dict() @@ -20,7 +20,12 @@ def get_point_score_series(self, video, reduction="sum"): series = dict() for lf in self.labels.find(video): - val = reduce_funct(point.score for inst in lf for point in inst.points if hasattr(inst, "score")) + val = reduce_funct( + point.score + for inst in lf + for point in inst.points + if hasattr(inst, "score") + ) series[lf.frame_idx] = val return series @@ -44,7 +49,7 @@ def get_point_displacement_series(self, video, reduction="sum"): val = self._calculate_frame_velocity(lf, last_lf, reduce_funct) last_lf = lf if not np.isnan(val): - series[lf.frame_idx] = val #len(lf.instances) + series[lf.frame_idx] = val # len(lf.instances) return series @staticmethod diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index b46b4ec2c..2232823a1 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -31,9 +31,14 @@ parser = argparse.ArgumentParser() parser.add_argument("data_path", help="Path to labels json file") - parser.add_argument('--all-frames', dest='all_frames', action='store_const', - const=True, default=False, - help='include all frames without predictions') + parser.add_argument( + "--all-frames", + dest="all_frames", + action="store_const", + const=True, + default=False, + help="include all frames without predictions", + ) args = parser.parse_args() video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) @@ -48,7 +53,9 @@ first_frame_idx = 0 if args.all_frames else frame_idxs[0] - frame_count = frame_idxs[-1] - first_frame_idx + 1 # count should include unlabeled frames + frame_count = ( + frame_idxs[-1] - first_frame_idx + 1 + ) # count should include unlabeled frames # Desired MATLAB format: # "track_occupancy" tracks * frames @@ -56,7 +63,9 @@ # "track_names" tracks occupancy_matrix = np.zeros((track_count, frame_count), dtype=np.uint8) - prediction_matrix = np.full((frame_count, node_count, 2, track_count), np.nan, dtype=float) + prediction_matrix = np.full( + (frame_count, node_count, 2, track_count), np.nan, dtype=float + ) for lf, inst in [(lf, inst) for lf in labels for inst in lf.instances]: frame_i = lf.frame_idx - first_frame_idx @@ -74,7 +83,9 @@ print(f"ignoring {np.sum(~occupied_track_mask)} empty tracks") occupancy_matrix = occupancy_matrix[occupied_track_mask] prediction_matrix = prediction_matrix[..., occupied_track_mask] - track_names = [track_names[i] for i in range(len(track_names)) if occupied_track_mask[i]] + track_names = [ + track_names[i] for i in range(len(track_names)) if occupied_track_mask[i] + ] print(f"track_occupancy: {occupancy_matrix.shape}") print(f"tracks: {prediction_matrix.shape}") @@ -86,10 +97,16 @@ # We have to transpose the arrays since MATLAB expects column-major ds = f.create_dataset("track_names", data=track_names) ds = f.create_dataset( - "track_occupancy", data=np.transpose(occupancy_matrix), - compression="gzip", compression_opts=9) + "track_occupancy", + data=np.transpose(occupancy_matrix), + compression="gzip", + compression_opts=9, + ) ds = f.create_dataset( - "tracks", data=np.transpose(prediction_matrix), - compression="gzip", compression_opts=9) + "tracks", + data=np.transpose(prediction_matrix), + compression="gzip", + compression_opts=9, + ) print(f"Saved as {output_filename}") diff --git a/sleap/instance.py b/sleap/instance.py index 6ffc15ea8..62c0ff0e2 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -39,14 +39,15 @@ class Point(np.record): # Define the dtype from the point class attributes plus some # additional fields we will use to relate point to instances and # nodes. - dtype = np.dtype( - [('x', 'f8'), - ('y', 'f8'), - ('visible', '?'), - ('complete', '?')]) + dtype = np.dtype([("x", "f8"), ("y", "f8"), ("visible", "?"), ("complete", "?")]) - def __new__(cls, x: float = math.nan, y: float = math.nan, - visible: bool = True, complete: bool = False): + def __new__( + cls, + x: float = math.nan, + y: float = math.nan, + visible: bool = True, + complete: bool = False, + ): # HACK: This is a crazy way to instantiate at new Point but I can't figure # out how recarray does it. So I just use it to make matrix of size 1 and @@ -78,9 +79,7 @@ def isnan(self): # This turns PredictedPoint into an attrs class. Defines comparators for # us and generaly makes it behave better. Crazy that this works! -Point = attr.s(these={name: attr.ib() - for name in Point.dtype.names}, - init=False)(Point) +Point = attr.s(these={name: attr.ib() for name in Point.dtype.names}, init=False)(Point) class PredictedPoint(Point): @@ -100,15 +99,17 @@ class PredictedPoint(Point): # additional fields we will use to relate point to instances and # nodes. dtype = np.dtype( - [('x', 'f8'), - ('y', 'f8'), - ('visible', '?'), - ('complete', '?'), - ('score', 'f8')]) - - def __new__(cls, x: float = math.nan, y: float = math.nan, - visible: bool = True, complete: bool = False, - score: float = 0.0): + [("x", "f8"), ("y", "f8"), ("visible", "?"), ("complete", "?"), ("score", "f8")] + ) + + def __new__( + cls, + x: float = math.nan, + y: float = math.nan, + visible: bool = True, + complete: bool = False, + score: float = 0.0, + ): # HACK: This is a crazy way to instantiate at new Point but I can't figure # out how recarray does it. So I just use it to make matrix of size 1 and @@ -138,14 +139,14 @@ def from_point(cls, point: Point, score: float = 0.0): Returns: A scored point based on the point passed in. """ - return cls(**{**Point.asdict(point), 'score': score}) + return cls(**{**Point.asdict(point), "score": score}) # This turns PredictedPoint into an attrs class. Defines comparators for # us and generaly makes it behave better. Crazy that this works! -PredictedPoint = attr.s(these={name: attr.ib() - for name in PredictedPoint.dtype.names}, - init=False)(PredictedPoint) +PredictedPoint = attr.s( + these={name: attr.ib() for name in PredictedPoint.dtype.names}, init=False +)(PredictedPoint) class PointArray(np.recarray): @@ -156,9 +157,19 @@ class PointArray(np.recarray): _record_type = Point - def __new__(subtype, shape, buf=None, offset=0, strides=None, - formats=None, names=None, titles=None, - byteorder=None, aligned=False, order='C'): + def __new__( + subtype, + shape, + buf=None, + offset=0, + strides=None, + formats=None, + names=None, + titles=None, + byteorder=None, + aligned=False, + order="C", + ): dtype = subtype._record_type.dtype @@ -168,11 +179,19 @@ def __new__(subtype, shape, buf=None, offset=0, strides=None, descr = np.format_parser(formats, names, titles, aligned, byteorder)._descr if buf is None: - self = np.ndarray.__new__(subtype, shape, (subtype._record_type, descr), order=order) + self = np.ndarray.__new__( + subtype, shape, (subtype._record_type, descr), order=order + ) else: - self = np.ndarray.__new__(subtype, shape, (subtype._record_type, descr), - buffer=buf, offset=offset, - strides=strides, order=order) + self = np.ndarray.__new__( + subtype, + shape, + (subtype._record_type, descr), + buffer=buf, + offset=offset, + strides=strides, + order=order, + ) return self def __array_finalize__(self, obj): @@ -206,7 +225,7 @@ def __getitem__(self, indx): if isinstance(obj, np.ndarray): if obj.dtype.fields: obj = obj.view(type(self)) - #if issubclass(obj.dtype.type, numpy.void): + # if issubclass(obj.dtype.type, numpy.void): # return obj.view(dtype=(self.dtype.type, obj.dtype)) return obj else: @@ -216,7 +235,7 @@ def __getitem__(self, indx): return obj @classmethod - def from_array(cls, a: 'PointArray'): + def from_array(cls, a: "PointArray"): """ Convert a PointArray to a new PointArray (or child class, i.e., PredictedPointArray), @@ -235,15 +254,17 @@ def from_array(cls, a: 'PointArray'): return v + class PredictedPointArray(PointArray): """ PredictedPointArray is analogous to PointArray except for predicted points. """ + _record_type = PredictedPoint @classmethod - def to_array(cls, a: 'PredictedPointArray'): + def to_array(cls, a: "PredictedPointArray"): """ Convert a PredictedPointArray to a normal PointArray. @@ -272,10 +293,11 @@ class Track: spawned_on: The frame of the video that this track was spawned on. name: A name given to this track for identifying purposes. """ + spawned_on: int = attr.ib(converter=int) name: str = attr.ib(default="", converter=str) - def matches(self, other: 'Track'): + def matches(self, other: "Track"): """ Check if two tracks match by value. @@ -293,6 +315,7 @@ def matches(self, other: 'Track'): # attributes _frame and _point_array_cache after init. These are private variables # that are created in post init so they are not serialized. + @attr.s(cmp=False, slots=True) class Instance: """ @@ -311,10 +334,10 @@ class Instance: skeleton: Skeleton = attr.ib() track: Track = attr.ib(default=None) - from_predicted: Optional['PredictedInstance'] = attr.ib(default=None) + from_predicted: Optional["PredictedInstance"] = attr.ib(default=None) _points: PointArray = attr.ib(default=None) _nodes: List = attr.ib(default=None) - frame: Union['LabeledFrame', None] = attr.ib(default=None) + frame: Union["LabeledFrame", None] = attr.ib(default=None) # The underlying Point array type that this instances point array should be. _point_array_type = PointArray @@ -322,7 +345,9 @@ class Instance: @from_predicted.validator def _validate_from_predicted_(self, attribute, from_predicted): if from_predicted is not None and type(from_predicted) != PredictedInstance: - raise TypeError(f"Instance.from_predicted type must be PredictedInstance (not {type(from_predicted)})") + raise TypeError( + f"Instance.from_predicted type must be PredictedInstance (not {type(from_predicted)})" + ) @_points.validator def _validate_all_points(self, attribute, points): @@ -340,10 +365,14 @@ def _validate_all_points(self, attribute, points): if is_string_dict: for node_name in points.keys(): if not self.skeleton.has_node(node_name): - raise KeyError(f"There is no node named {node_name} in {self.skeleton}") + raise KeyError( + f"There is no node named {node_name} in {self.skeleton}" + ) elif isinstance(points, PointArray): if len(points) != len(self.skeleton.nodes): - raise ValueError("PointArray does not have the same number of rows as skeleton nodes.") + raise ValueError( + "PointArray does not have the same number of rows as skeleton nodes." + ) def __attrs_post_init__(self): @@ -385,14 +414,18 @@ def _points_dict_to_array(points, parray, skeleton): points = {skeleton.find_node(name): point for name, point in points.items()} if not is_string_dict and not is_node_dict: - raise ValueError("points dictionary must be keyed by either strings " + - "(node names) or Nodes.") + raise ValueError( + "points dictionary must be keyed by either strings " + + "(node names) or Nodes." + ) # Get rid of the points dict and replace with equivalent point array. for node, point in points.items(): # Convert PredictedPoint to Point if Instance if type(parray) == PointArray and type(point) == PredictedPoint: - point = Point(x=point.x, y=point.y, visible=point.visible, complete=point.complete) + point = Point( + x=point.x, y=point.y, visible=point.visible, complete=point.complete + ) try: parray[skeleton.node_to_index(node)] = point # parray[skeleton.node_to_index(node.name)] = point @@ -435,7 +468,9 @@ def __getitem__(self, node): node = self._node_to_index(node) return self._points[node] except ValueError: - raise KeyError(f"The underlying skeleton ({self.skeleton}) has no node '{node}'") + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) def __contains__(self, node): """ @@ -463,10 +498,14 @@ def __setitem__(self, node, value): # Make sure node and value, if either are lists, are of compatible size if type(node) is not list and type(value) is list and len(value) != 1: - raise IndexError("Node list for indexing must be same length and value list.") + raise IndexError( + "Node list for indexing must be same length and value list." + ) if type(node) is list and type(value) is not list and len(node) != 1: - raise IndexError("Node list for indexing must be same length and value list.") + raise IndexError( + "Node list for indexing must be same length and value list." + ) # If we are dealing with lists, do multiple assignment recursively, this should be ok because # skeletons and instances are small. @@ -478,7 +517,9 @@ def __setitem__(self, node, value): node_idx = self._node_to_index(node) self._points[node_idx] = value except ValueError: - raise KeyError(f"The underlying skeleton ({self.skeleton}) has no node '{node}'") + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) def __delitem__(self, node): """ Delete node key and points associated with that node. """ @@ -487,7 +528,9 @@ def __delitem__(self, node): self._points[node_idx].x = math.nan self._points[node_idx].y = math.nan except ValueError: - raise KeyError(f"The underlying skeleton ({self.skeleton}) has no node '{node}'") + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) def matches(self, other): """ @@ -530,8 +573,11 @@ def nodes(self): """ self.fix_array() - return tuple(self._nodes[i] for i, point in enumerate(self._points) - if not point.isnan() and self._nodes[i] in self.skeleton.nodes) + return tuple( + self._nodes[i] + for i, point in enumerate(self._points) + if not point.isnan() and self._nodes[i] in self.skeleton.nodes + ) @property def nodes_points(self): @@ -574,9 +620,9 @@ def fix_array(self): self._points = new_array self._nodes = self.skeleton.nodes - def get_points_array(self, copy: bool = True, - invisible_as_nan: bool = False, - full: bool = False) -> np.ndarray: + def get_points_array( + self, copy: bool = True, invisible_as_nan: bool = False, full: bool = False + ) -> np.ndarray: """ Return the instance's points in array form. @@ -600,9 +646,9 @@ def get_points_array(self, copy: bool = True, return self._points if not copy and not invisible_as_nan: - return self._points[['x', 'y']] + return self._points[["x", "y"]] else: - parray = structured_to_unstructured(self._points[['x', 'y']]) + parray = structured_to_unstructured(self._points[["x", "y"]]) if invisible_as_nan: parray[~self._points.visible] = math.nan @@ -644,6 +690,7 @@ class PredictedInstance(Instance): Args: score: The instance level prediction score. """ + score: float = attr.ib(default=0.0, converter=float) # The underlying Point array type that this instances point array should be. @@ -670,9 +717,13 @@ def from_instance(cls, instance: Instance, score): Returns: A PredictedInstance for the given Instance. """ - kw_args = attr.asdict(instance, recurse=False, filter=lambda attr, value: attr.name not in ("_points", "_nodes")) - kw_args['points'] = PredictedPointArray.from_array(instance._points) - kw_args['score'] = score + kw_args = attr.asdict( + instance, + recurse=False, + filter=lambda attr, value: attr.name not in ("_points", "_nodes"), + ) + kw_args["points"] = PredictedPointArray.from_array(instance._points) + kw_args["score"] = score return cls(**kw_args) @@ -695,15 +746,18 @@ def make_instance_cattr(): converter.register_unstructure_hook(PointArray, lambda x: None) converter.register_unstructure_hook(PredictedPointArray, lambda x: None) + def unstructure_instance(x: Instance): # Unstructure everything but the points array, nodes, and frame attribute - d = {field.name: converter.unstructure(x.__getattribute__(field.name)) - for field in attr.fields(x.__class__) - if field.name not in ['_points', '_nodes', 'frame']} + d = { + field.name: converter.unstructure(x.__getattribute__(field.name)) + for field in attr.fields(x.__class__) + if field.name not in ["_points", "_nodes", "frame"] + } # Replace the point array with a dict - d['_points'] = converter.unstructure({k: v for k, v in x.nodes_points}) + d["_points"] = converter.unstructure({k: v for k, v in x.nodes_points}) return d @@ -713,7 +767,7 @@ def unstructure_instance(x: Instance): ## STRUCTURE HOOKS def structure_points(x, type): - if 'score' in x.keys(): + if "score" in x.keys(): return cattr.structure(x, PredictedPoint) else: return cattr.structure(x, Point) @@ -723,7 +777,7 @@ def structure_points(x, type): def structure_instances_list(x, type): inst_list = [] for inst_data in x: - if 'score' in inst_data.keys(): + if "score" in inst_data.keys(): inst = converter.structure(inst_data, PredictedInstance) else: inst = converter.structure(inst_data, Instance) @@ -731,11 +785,14 @@ def structure_instances_list(x, type): return inst_list - converter.register_structure_hook(Union[List[Instance], List[PredictedInstance]], - structure_instances_list) + converter.register_structure_hook( + Union[List[Instance], List[PredictedInstance]], structure_instances_list + ) - converter.register_structure_hook(ForwardRef('PredictedInstance'), - lambda x, type: converter.structure(x, PredictedInstance)) + converter.register_structure_hook( + ForwardRef("PredictedInstance"), + lambda x, type: converter.structure(x, PredictedInstance), + ) # We can register structure hooks for point arrays that do nothing # because Instance can have a dict of points passed to it in place of @@ -743,7 +800,7 @@ def structure_instances_list(x, type): def structure_point_array(x, t): if x: point1 = x[list(x.keys())[0]] - if 'score' in point1.keys(): + if "score" in point1.keys(): return converter.structure(x, Dict[Node, PredictedPoint]) else: return converter.structure(x, Dict[Node, Point]) @@ -760,7 +817,9 @@ def structure_point_array(x, t): class LabeledFrame: video: Video = attr.ib() frame_idx: int = attr.ib(converter=int) - _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib(default=attr.Factory(list)) + _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib( + default=attr.Factory(list) + ) def __attrs_post_init__(self): @@ -801,7 +860,7 @@ def find(self, track=-1, user=False): instances = self.instances if user: instances = list(filter(lambda inst: type(inst) == Instance, instances)) - if track != -1: # use -1 since we want to accept None as possible value + if track != -1: # use -1 since we want to accept None as possible value instances = list(filter(lambda inst: inst.track == track, instances)) return instances @@ -837,7 +896,9 @@ def instances(self, instances: List[Instance]): @property def user_instances(self): - return [inst for inst in self._instances if not isinstance(inst, PredictedInstance)] + return [ + inst for inst in self._instances if not isinstance(inst, PredictedInstance) + ] @property def predicted_instances(self): @@ -845,7 +906,7 @@ def predicted_instances(self): @property def has_user_instances(self): - return (len(self.user_instances) > 0) + return len(self.user_instances) > 0 @property def unused_predictions(self): @@ -853,22 +914,30 @@ def unused_predictions(self): any_tracks = [inst.track for inst in self._instances if inst.track is not None] if len(any_tracks): # use tracks to determine which predicted instances have been used - used_tracks = [inst.track for inst in self._instances - if type(inst) == Instance and inst.track is not None - ] - unused_predictions = [inst for inst in self._instances - if inst.track not in used_tracks - and type(inst) == PredictedInstance - ] + used_tracks = [ + inst.track + for inst in self._instances + if type(inst) == Instance and inst.track is not None + ] + unused_predictions = [ + inst + for inst in self._instances + if inst.track not in used_tracks and type(inst) == PredictedInstance + ] else: # use from_predicted to determine which predicted instances have been used # TODO: should we always do this instead of using tracks? - used_instances = [inst.from_predicted for inst in self._instances - if inst.from_predicted is not None] - unused_predictions = [inst for inst in self._instances - if type(inst) == PredictedInstance - and inst not in used_instances] + used_instances = [ + inst.from_predicted + for inst in self._instances + if inst.from_predicted is not None + ] + unused_predictions = [ + inst + for inst in self._instances + if type(inst) == PredictedInstance and inst not in used_instances + ] return unused_predictions @@ -879,9 +948,16 @@ def instances_to_show(self): predicted instances for which there's a corresponding regular instance. """ unused_predictions = self.unused_predictions - inst_to_show = [inst for inst in self._instances - if type(inst) == Instance or inst in unused_predictions] - inst_to_show.sort(key=lambda inst: inst.track.spawned_on if inst.track is not None else math.inf) + inst_to_show = [ + inst + for inst in self._instances + if type(inst) == Instance or inst in unused_predictions + ] + inst_to_show.sort( + key=lambda inst: inst.track.spawned_on + if inst.track is not None + else math.inf + ) return inst_to_show @staticmethod @@ -912,14 +988,15 @@ def merge_frames(labeled_frames, video, remove_redundant=True): # note first lf with this frame_idx frames_found[lf.frame_idx] = idx # remove labeled frames with no instances - labeled_frames = list(filter(lambda lf: len(lf.instances), - labeled_frames)) + labeled_frames = list(filter(lambda lf: len(lf.instances), labeled_frames)) if redundant_count: print(f"skipped {redundant_count} redundant instances") return labeled_frames @classmethod - def complex_merge_between(cls, base_labels: 'Labels', new_frames: List['LabeledFrame']): + def complex_merge_between( + cls, base_labels: "Labels", new_frames: List["LabeledFrame"] + ): """Merge new_frames into base_labels cleanly when possible, return conflicts if any. @@ -944,8 +1021,9 @@ def complex_merge_between(cls, base_labels: 'Labels', new_frames: List['LabeledF base_labels.labeled_frames.append(new_frame) merged_instances = new_frame.instances else: - merged_instances, extra_base_frame, extra_new_frame = \ - cls.complex_frame_merge(base_lfs[0], new_frame) + merged_instances, extra_base_frame, extra_new_frame = cls.complex_frame_merge( + base_lfs[0], new_frame + ) if extra_base_frame: extra_base.append(extra_base_frame) if extra_new_frame: @@ -980,8 +1058,12 @@ def complex_frame_merge(cls, base_frame, new_frame): conflict = False if extra_base_instances and extra_new_instances: - base_predictions = list(filter(lambda inst: hasattr(inst, "score"), extra_base_instances)) - new_predictions = list(filter(lambda inst: hasattr(inst, "score"), extra_new_instances)) + base_predictions = list( + filter(lambda inst: hasattr(inst, "score"), extra_base_instances) + ) + new_predictions = list( + filter(lambda inst: hasattr(inst, "score"), extra_new_instances) + ) base_has_nonpred = len(extra_base_instances) - len(base_predictions) new_has_nonpred = len(extra_new_instances) - len(new_predictions) @@ -1008,14 +1090,24 @@ def complex_frame_merge(cls, base_frame, new_frame): extra_new_instances = [] # Construct frames to hold any conflicting instances - extra_base = cls( - video=base_frame.video, - frame_idx=base_frame.frame_idx, - instances=extra_base_instances) if extra_base_instances else None - - extra_new = cls( - video=new_frame.video, - frame_idx=new_frame.frame_idx, - instances=extra_new_instances) if extra_new_instances else None - - return merged_instances, extra_base, extra_new \ No newline at end of file + extra_base = ( + cls( + video=base_frame.video, + frame_idx=base_frame.frame_idx, + instances=extra_base_instances, + ) + if extra_base_instances + else None + ) + + extra_new = ( + cls( + video=new_frame.video, + frame_idx=new_frame.frame_idx, + instances=extra_new_instances, + ) + if extra_new_instances + else None + ) + + return merged_instances, extra_base, extra_new diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index d5feb88f8..f2e3665c8 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -35,9 +35,17 @@ import pandas as pd from sleap.skeleton import Skeleton, Node -from sleap.instance import Instance, Point, LabeledFrame, \ - Track, PredictedPoint, PredictedInstance, \ - make_instance_cattr, PointArray, PredictedPointArray +from sleap.instance import ( + Instance, + Point, + LabeledFrame, + Track, + PredictedPoint, + PredictedInstance, + make_instance_cattr, + PointArray, + PredictedPointArray, +) from sleap.rangelist import RangeList from sleap.io.video import Video from sleap.util import uniquify, weak_filename_match @@ -49,6 +57,7 @@ def json_loads(json_str: str): except: return json.loads(json_str) + def json_dumps(d: Dict, filename: str = None): """ A simple wrapper around the JSON encoder we are using. @@ -61,14 +70,16 @@ def json_dumps(d: Dict, filename: str = None): None """ import codecs + encoder = rapidjson if filename: - with open(filename, 'w') as f: + with open(filename, "w") as f: encoder.dump(d, f, ensure_ascii=False) else: return encoder.dumps(d) + """ The version number to put in the Labels JSON format. """ @@ -135,41 +146,54 @@ def _update_from_labels(self, merge=False): # Ditto for skeletons if merge or len(self.skeletons) == 0: - self.skeletons = list(set(self.skeletons).union( - {instance.skeleton - for label in self.labels - for instance in label.instances})) + self.skeletons = list( + set(self.skeletons).union( + { + instance.skeleton + for label in self.labels + for instance in label.instances + } + ) + ) # Ditto for nodes if merge or len(self.nodes) == 0: - self.nodes = list(set(self.nodes).union({node for skeleton in self.skeletons for node in skeleton.nodes})) + self.nodes = list( + set(self.nodes).union( + {node for skeleton in self.skeletons for node in skeleton.nodes} + ) + ) # Ditto for tracks, a pattern is emerging here if merge or len(self.tracks) == 0: # Get tracks from any Instances or PredictedInstances - other_tracks = {instance.track - for frame in self.labels - for instance in frame.instances - if instance.track} + other_tracks = { + instance.track + for frame in self.labels + for instance in frame.instances + if instance.track + } # Add tracks from any PredictedInstance referenced by instance # This fixes things when there's a referenced PredictionInstance # which is no longer in the frame. other_tracks = other_tracks.union( - {instance.from_predicted.track - for frame in self.labels - for instance in frame.instances - if instance.from_predicted and instance.from_predicted.track}) + { + instance.from_predicted.track + for frame in self.labels + for instance in frame.instances + if instance.from_predicted and instance.from_predicted.track + } + ) # Get list of other tracks not already in track list new_tracks = list(other_tracks - set(self.tracks)) # Sort the new tracks by spawned on and then name - new_tracks.sort(key=lambda t:(t.spawned_on, t.name)) + new_tracks.sort(key=lambda t: (t.spawned_on, t.name)) self.tracks.extend(new_tracks) - def _update_lookup_cache(self): # Data structures for caching self._lf_by_video = dict() @@ -177,7 +201,9 @@ def _update_lookup_cache(self): self._track_occupancy = dict() for video in self.videos: self._lf_by_video[video] = [lf for lf in self.labels if lf.video == video] - self._frame_idx_map[video] = {lf.frame_idx: lf for lf in self._lf_by_video[video]} + self._frame_idx_map[video] = { + lf.frame_idx: lf for lf in self._lf_by_video[video] + } self._track_occupancy[video] = self._make_track_occupany(video) # Below are convenience methods for working with Labels as list. @@ -206,7 +232,12 @@ def __contains__(self, item): return item in self.skeletons elif isinstance(item, Node): return item in self.nodes - elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Video) and isinstance(item[1], int): + elif ( + isinstance(item, tuple) + and len(item) == 2 + and isinstance(item[0], Video) + and isinstance(item[1], int) + ): return self.find_first(*item) is not None def __getitem__(self, key): @@ -218,7 +249,12 @@ def __getitem__(self, key): raise KeyError("Video not found in labels.") return self.find(video=key) - elif isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], Video) and isinstance(key[1], int): + elif ( + isinstance(key, tuple) + and len(key) == 2 + and isinstance(key[0], Video) + and isinstance(key[1], int) + ): if key[0] not in self.videos: raise KeyError("Video not found in labels.") @@ -285,7 +321,12 @@ def remove(self, value: LabeledFrame): self._lf_by_video[value.video].remove(value) del self._frame_idx_map[value.video][value.frame_idx] - def find(self, video: Video, frame_idx: Union[int, range] = None, return_new: bool=False) -> List[LabeledFrame]: + def find( + self, + video: Video, + frame_idx: Union[int, range] = None, + return_new: bool = False, + ) -> List[LabeledFrame]: """ Search for labeled frames given video and/or frame index. Args: @@ -297,19 +338,28 @@ def find(self, video: Video, frame_idx: Union[int, range] = None, return_new: bo List of `LabeledFrame`s that match the criteria. Empty if no matches found. """ - null_result = [LabeledFrame(video=video, frame_idx=frame_idx)] if return_new else [] + null_result = ( + [LabeledFrame(video=video, frame_idx=frame_idx)] if return_new else [] + ) if frame_idx is not None: - if video not in self._frame_idx_map: return null_result + if video not in self._frame_idx_map: + return null_result if type(frame_idx) == range: - return [self._frame_idx_map[video][idx] for idx in frame_idx if idx in self._frame_idx_map[video]] + return [ + self._frame_idx_map[video][idx] + for idx in frame_idx + if idx in self._frame_idx_map[video] + ] - if frame_idx not in self._frame_idx_map[video]: return null_result + if frame_idx not in self._frame_idx_map[video]: + return null_result return [self._frame_idx_map[video][frame_idx]] else: - if video not in self._lf_by_video: return null_result + if video not in self._lf_by_video: + return null_result return self._lf_by_video[video] def frames(self, video: Video, from_frame_idx: int = -1, reverse=False): @@ -317,16 +367,21 @@ def frames(self, video: Video, from_frame_idx: int = -1, reverse=False): Iterator over all frames in a video, starting with first frame after specified frame_idx (or first frame in video if none specified). """ - if video not in self._frame_idx_map: return None + if video not in self._frame_idx_map: + return None # Get sorted list of frame indexes for this video frame_idxs = sorted(self._frame_idx_map[video].keys()) # Find the next frame index after (before) the specified frame if not reverse: - next_frame_idx = min(filter(lambda x: x > from_frame_idx, frame_idxs), default=frame_idxs[0]) + next_frame_idx = min( + filter(lambda x: x > from_frame_idx, frame_idxs), default=frame_idxs[0] + ) else: - next_frame_idx = max(filter(lambda x: x < from_frame_idx, frame_idxs), default=frame_idxs[-1]) + next_frame_idx = max( + filter(lambda x: x < from_frame_idx, frame_idxs), default=frame_idxs[-1] + ) cut_list_idx = frame_idxs.index(next_frame_idx) # Shift list of frame indices to start with specified frame @@ -349,7 +404,9 @@ def find_first(self, video: Video, frame_idx: int = None) -> LabeledFrame: if video in self.videos: for label in self.labels: - if label.video == video and (frame_idx is None or (label.frame_idx == frame_idx)): + if label.video == video and ( + frame_idx is None or (label.frame_idx == frame_idx) + ): return label def find_last(self, video: Video, frame_idx: int = None) -> LabeledFrame: @@ -365,7 +422,9 @@ def find_last(self, video: Video, frame_idx: int = None) -> LabeledFrame: if video in self.videos: for label in reversed(self.labels): - if label.video == video and (frame_idx is None or (label.frame_idx == frame_idx)): + if label.video == video and ( + frame_idx is None or (label.frame_idx == frame_idx) + ): return label @property @@ -373,7 +432,11 @@ def user_labeled_frames(self): return [lf for lf in self.labeled_frames if lf.has_user_instances] def get_video_user_labeled_frames(self, video: Video) -> List[LabeledFrame]: - return [lf for lf in self.labeled_frames if lf.has_user_instances and lf.video == video] + return [ + lf + for lf in self.labeled_frames + if lf.has_user_instances and lf.video == video + ] # Methods for instances @@ -381,10 +444,11 @@ def instance_count(self, video: Video, frame_idx: int) -> int: count = 0 labeled_frame = self.find_first(video, frame_idx) if labeled_frame is not None: - count = len([inst for inst in labeled_frame.instances if type(inst)==Instance]) + count = len( + [inst for inst in labeled_frame.instances if type(inst) == Instance] + ) return count - @property def all_instances(self): return list(self.instances()) @@ -421,17 +485,30 @@ def add_track(self, video: Video, track: Track): self.tracks.append(track) self._track_occupancy[video][track] = RangeList() - def track_set_instance(self, frame: LabeledFrame, instance: Instance, new_track: Track): - self.track_swap(frame.video, new_track, instance.track, (frame.frame_idx, frame.frame_idx+1)) + def track_set_instance( + self, frame: LabeledFrame, instance: Instance, new_track: Track + ): + self.track_swap( + frame.video, + new_track, + instance.track, + (frame.frame_idx, frame.frame_idx + 1), + ) if instance.track is None: self._track_remove_instance(frame, instance) instance.track = new_track - def track_swap(self, video: Video, new_track: Track, old_track: Track, frame_range: tuple): + def track_swap( + self, video: Video, new_track: Track, old_track: Track, frame_range: tuple + ): # Get ranges in track occupancy cache - _, within_old, _ = self._track_occupancy[video][old_track].cut_range(frame_range) - _, within_new, _ = self._track_occupancy[video][new_track].cut_range(frame_range) + _, within_old, _ = self._track_occupancy[video][old_track].cut_range( + frame_range + ) + _, within_new, _ = self._track_occupancy[video][new_track].cut_range( + frame_range + ) if old_track is not None: # Instances that didn't already have track can't be handled here. @@ -459,11 +536,19 @@ def track_swap(self, video: Video, new_track: Track, old_track: Track, frame_ran instance.track = old_track def _track_remove_instance(self, frame: LabeledFrame, instance: Instance): - if instance.track not in self._track_occupancy[frame.video]: return + if instance.track not in self._track_occupancy[frame.video]: + return # If this is only instance in track in frame, then remove frame from track. - if len(list(filter(lambda inst: inst.track == instance.track, frame.instances))) == 1: - self._track_occupancy[frame.video][instance.track].remove((frame.frame_idx, frame.frame_idx+1)) + if ( + len( + list(filter(lambda inst: inst.track == instance.track, frame.instances)) + ) + == 1 + ): + self._track_occupancy[frame.video][instance.track].remove( + (frame.frame_idx, frame.frame_idx + 1) + ) def remove_instance(self, frame: LabeledFrame, instance: Instance): self._track_remove_instance(frame, instance) @@ -474,8 +559,11 @@ def add_instance(self, frame: LabeledFrame, instance: Instance): self._track_occupancy[frame.video] = dict() # Ensure that there isn't already an Instance with this track - tracks_in_frame = [inst.track for inst in frame - if type(inst) == Instance and inst.track is not None] + tracks_in_frame = [ + inst.track + for inst in frame + if type(inst) == Instance and inst.track is not None + ] if instance.track in tracks_in_frame: instance.track = None @@ -483,7 +571,9 @@ def add_instance(self, frame: LabeledFrame, instance: Instance): if instance.track not in self._track_occupancy[frame.video]: self._track_occupancy[frame.video][instance.track] = RangeList() - self._track_occupancy[frame.video][instance.track].insert((frame.frame_idx, frame.frame_idx+1)) + self._track_occupancy[frame.video][instance.track].insert( + (frame.frame_idx, frame.frame_idx + 1) + ) frame.instances.append(instance) def _make_track_occupany(self, video): @@ -499,7 +589,9 @@ def _make_track_occupany(self, video): tracks[instance.track].add(frame_idx) return tracks - def find_track_occupancy(self, video: Video, track: Union[Track, int], frame_range=None) -> List[Tuple[LabeledFrame, Instance]]: + def find_track_occupancy( + self, video: Video, track: Union[Track, int], frame_range=None + ) -> List[Tuple[LabeledFrame, Instance]]: """Get instances for a given track. Args: @@ -518,25 +610,29 @@ def does_track_match(inst, tr, labeled_frame): match = False if type(tr) == Track and inst.track is tr: match = True - elif (type(tr) == int and labeled_frame.instances.index(inst) == tr - and inst.track is None): + elif ( + type(tr) == int + and labeled_frame.instances.index(inst) == tr + and inst.track is None + ): match = True return match - track_frame_inst = [(lf, instance) - for lf in self.find(video) - for instance in lf.instances - if does_track_match(instance, track, lf) - and (frame_range is None or lf.frame_idx in frame_range)] + track_frame_inst = [ + (lf, instance) + for lf in self.find(video) + for instance in lf.instances + if does_track_match(instance, track, lf) + and (frame_range is None or lf.frame_idx in frame_range) + ] return track_frame_inst - def find_track_instances(self, *args, **kwargs) -> List[Instance]: return [inst for lf, inst in self.find_track_occupancy(*args, **kwargs)] # Methods for suggestions - def get_video_suggestions(self, video:Video) -> list: + def get_video_suggestions(self, video: Video) -> list: """ Returns the list of suggested frames for the specified video or suggestions for all videos (if no video specified). @@ -545,46 +641,62 @@ def get_video_suggestions(self, video:Video) -> list: def get_suggestions(self) -> list: """Return all suggestions as a list of (video, frame) tuples.""" - suggestion_list = [(video, frame_idx) + suggestion_list = [ + (video, frame_idx) for video in self.videos for frame_idx in self.get_video_suggestions(video) - ] + ] return suggestion_list def get_next_suggestion(self, video, frame_idx, seek_direction=1) -> list: """Returns a (video, frame_idx) tuple.""" # make sure we have valid seek_direction - if seek_direction not in (-1, 1): return (None, None) + if seek_direction not in (-1, 1): + return (None, None) # make sure the video belongs to this Labels object - if video not in self.videos: return (None, None) + if video not in self.videos: + return (None, None) all_suggestions = self.get_suggestions() # If we're currently on a suggestion, then follow order of list if (video, frame_idx) in all_suggestions: suggestion_idx = all_suggestions.index((video, frame_idx)) - new_idx = (suggestion_idx+seek_direction)%len(all_suggestions) + new_idx = (suggestion_idx + seek_direction) % len(all_suggestions) video, frame_suggestion = all_suggestions[new_idx] # Otherwise, find the prev/next suggestion sorted by frame order else: # look for next (or previous) suggestion in current video if seek_direction == 1: - frame_suggestion = min((i for i in self.get_video_suggestions(video) if i > frame_idx), default=None) + frame_suggestion = min( + (i for i in self.get_video_suggestions(video) if i > frame_idx), + default=None, + ) else: - frame_suggestion = max((i for i in self.get_video_suggestions(video) if i < frame_idx), default=None) - if frame_suggestion is not None: return (video, frame_suggestion) + frame_suggestion = max( + (i for i in self.get_video_suggestions(video) if i < frame_idx), + default=None, + ) + if frame_suggestion is not None: + return (video, frame_suggestion) # if we didn't find suggestion in current video, # then we want earliest frame in next video with suggestions - next_video_idx = (self.videos.index(video) + seek_direction) % len(self.videos) + next_video_idx = (self.videos.index(video) + seek_direction) % len( + self.videos + ) video = self.videos[next_video_idx] if seek_direction == 1: - frame_suggestion = min((i for i in self.get_video_suggestions(video)), default=None) + frame_suggestion = min( + (i for i in self.get_video_suggestions(video)), default=None + ) else: - frame_suggestion = max((i for i in self.get_video_suggestions(video)), default=None) + frame_suggestion = max( + (i for i in self.get_video_suggestions(video)), default=None + ) return (video, frame_suggestion) - def set_suggestions(self, suggestions:Dict[Video, list]): + def set_suggestions(self, suggestions: Dict[Video, list]): """Sets the suggested frames.""" self.suggestions = suggestions @@ -639,7 +751,7 @@ def remove_video(self, video: Video): # Methods for negative anchors - def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): + def add_negative_anchor(self, video: Video, frame_idx: int, where: tuple): """Adds a location for a negative training sample. Args: @@ -651,7 +763,7 @@ def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): self.negative_anchors[video] = [] self.negative_anchors[video].append((frame_idx, *where)) - def remove_negative_anchors(self, video:Video, frame_idx: int): + def remove_negative_anchors(self, video: Video, frame_idx: int): """Removes negative training samples for given video and frame. Args: @@ -660,16 +772,21 @@ def remove_negative_anchors(self, video:Video, frame_idx: int): Returns: None """ - if video not in self.negative_anchors: return + if video not in self.negative_anchors: + return - anchors = [(idx, x, y) - for idx, x, y in self.negative_anchors[video] - if idx != frame_idx] + anchors = [ + (idx, x, y) + for idx, x, y in self.negative_anchors[video] + if idx != frame_idx + ] self.negative_anchors[video] = anchors # Methods for saving/loading - def extend_from(self, new_frames: Union['Labels',List[LabeledFrame]], unify:bool=False): + def extend_from( + self, new_frames: Union["Labels", List[LabeledFrame]], unify: bool = False + ): """ Merge data from another Labels object or list of LabeledFrames into self. @@ -681,11 +798,14 @@ def extend_from(self, new_frames: Union['Labels',List[LabeledFrame]], unify:bool bool, True if we added frames, False otherwise """ # allow either Labels or list of LabeledFrames - if isinstance(new_frames, Labels): new_frames = new_frames.labeled_frames + if isinstance(new_frames, Labels): + new_frames = new_frames.labeled_frames # return if this isn't non-empty list of labeled frames - if not isinstance(new_frames, list) or len(new_frames) == 0: return False - if not isinstance(new_frames[0], LabeledFrame): return False + if not isinstance(new_frames, list) or len(new_frames) == 0: + return False + if not isinstance(new_frames[0], LabeledFrame): + return False # If unify, we want to replace objects in the frames with # corresponding objects from the current labels. @@ -708,7 +828,9 @@ def extend_from(self, new_frames: Union['Labels',List[LabeledFrame]], unify:bool return True @classmethod - def complex_merge_between(cls, base_labels: 'Labels', new_labels: 'Labels', unify:bool = True) -> tuple: + def complex_merge_between( + cls, base_labels: "Labels", new_labels: "Labels", unify: bool = True + ) -> tuple: """ Merge frames and other data that can be merged cleanly, and return frames that conflict. @@ -745,10 +867,9 @@ def complex_merge_between(cls, base_labels: 'Labels', new_labels: 'Labels', unif new_labels = cls.from_json(new_json, match_to=base_labels) # Merge anything that can be merged cleanly and get conflicts - merged, extra_base, extra_new = \ - LabeledFrame.complex_merge_between( - base_labels=base_labels, - new_frames=new_labels.labeled_frames) + merged, extra_base, extra_new = LabeledFrame.complex_merge_between( + base_labels=base_labels, new_frames=new_labels.labeled_frames + ) # For clean merge, finish merge now by cleaning up base object if not extra_base and not extra_new: @@ -759,27 +880,31 @@ def complex_merge_between(cls, base_labels: 'Labels', new_labels: 'Labels', unif # Merge suggestions and negative anchors cls.merge_container_dicts(base_labels.suggestions, new_labels.suggestions) - cls.merge_container_dicts(base_labels.negative_anchors, new_labels.negative_anchors) + cls.merge_container_dicts( + base_labels.negative_anchors, new_labels.negative_anchors + ) return merged, extra_base, extra_new -# @classmethod -# def merge_predictions_by_score(cls, extra_base: List[LabeledFrame], extra_new: List[LabeledFrame]): -# """ -# Remove all predictions from input lists, return list with only -# the merged predictions. -# -# Args: -# extra_base: list of `LabeledFrame`s -# extra_new: list of `LabeledFrame`s -# Conflicting frames should have same index in both lists. -# Returns: -# list of `LabeledFrame`s with merged predictions -# """ -# pass + # @classmethod + # def merge_predictions_by_score(cls, extra_base: List[LabeledFrame], extra_new: List[LabeledFrame]): + # """ + # Remove all predictions from input lists, return list with only + # the merged predictions. + # + # Args: + # extra_base: list of `LabeledFrame`s + # extra_new: list of `LabeledFrame`s + # Conflicting frames should have same index in both lists. + # Returns: + # list of `LabeledFrame`s with merged predictions + # """ + # pass @staticmethod - def finish_complex_merge(base_labels: 'Labels', resolved_frames: List[LabeledFrame]): + def finish_complex_merge( + base_labels: "Labels", resolved_frames: List[LabeledFrame] + ): """ Finish conflicted merge from complex_merge_between. @@ -824,7 +949,9 @@ def merge_matching_frames(self, video=None): for vid in {lf.video for lf in self.labeled_frames}: self.merge_matching_frames(video=vid) else: - self.labeled_frames = LabeledFrame.merge_frames(self.labeled_frames, video=video) + self.labeled_frames = LabeledFrame.merge_frames( + self.labeled_frames, video=video + ) def to_dict(self, skip_labels: bool = False): """ @@ -851,17 +978,27 @@ def to_dict(self, skip_labels: bool = False): # FIXME: Update list of nodes # We shouldn't have to do this here, but for some reason we're missing nodes # which are in the skeleton but don't have points (in the first instance?). - self.nodes = list(set(self.nodes).union({node for skeleton in self.skeletons for node in skeleton.nodes})) + self.nodes = list( + set(self.nodes).union( + {node for skeleton in self.skeletons for node in skeleton.nodes} + ) + ) # Register some unstructure hooks since we don't want complete deserialization # of video and skeleton objects present in the labels. We will serialize these # as references to the above constructed lists to limit redundant data in the # json label_cattr = make_instance_cattr() - label_cattr.register_unstructure_hook(Skeleton, lambda x: str(self.skeletons.index(x))) - label_cattr.register_unstructure_hook(Video, lambda x: str(self.videos.index(x))) + label_cattr.register_unstructure_hook( + Skeleton, lambda x: str(self.skeletons.index(x)) + ) + label_cattr.register_unstructure_hook( + Video, lambda x: str(self.videos.index(x)) + ) label_cattr.register_unstructure_hook(Node, lambda x: str(self.nodes.index(x))) - label_cattr.register_unstructure_hook(Track, lambda x: str(self.tracks.index(x))) + label_cattr.register_unstructure_hook( + Track, lambda x: str(self.tracks.index(x)) + ) # Make a converter for the top level skeletons list. idx_to_node = {i: self.nodes[i] for i in range(len(self.nodes))} @@ -870,17 +1007,17 @@ def to_dict(self, skip_labels: bool = False): # Serialize the skeletons, videos, and labels dicts = { - 'version': LABELS_JSON_FILE_VERSION, - 'skeletons': skeleton_cattr.unstructure(self.skeletons), - 'nodes': cattr.unstructure(self.nodes), - 'videos': Video.cattr().unstructure(self.videos), - 'tracks': cattr.unstructure(self.tracks), - 'suggestions': label_cattr.unstructure(self.suggestions), - 'negative_anchors': label_cattr.unstructure(self.negative_anchors) - } + "version": LABELS_JSON_FILE_VERSION, + "skeletons": skeleton_cattr.unstructure(self.skeletons), + "nodes": cattr.unstructure(self.nodes), + "videos": Video.cattr().unstructure(self.videos), + "tracks": cattr.unstructure(self.tracks), + "suggestions": label_cattr.unstructure(self.suggestions), + "negative_anchors": label_cattr.unstructure(self.negative_anchors), + } if not skip_labels: - dicts['labels'] = label_cattr.unstructure(self.labeled_frames) + dicts["labels"] = label_cattr.unstructure(self.labeled_frames) return dicts @@ -897,10 +1034,13 @@ def to_json(self): return json_dumps(self.to_dict()) @staticmethod - def save_json(labels: 'Labels', filename: str, - compress: bool = False, - save_frame_data: bool = False, - frame_data_format: str = 'png'): + def save_json( + labels: "Labels", + filename: str, + compress: bool = False, + save_frame_data: bool = False, + frame_data_format: str = "png", + ): """ Save a Labels instance to a JSON format. @@ -947,20 +1087,26 @@ def save_json(labels: 'Labels', filename: str, # Create a set of new Video objects with imgstore backends. One for each # of the videos. We will only include the labeled frames though. We will # then replace each video with this new video - new_videos = labels.save_frame_data_imgstore(output_dir=tmp_dir, format=frame_data_format) + new_videos = labels.save_frame_data_imgstore( + output_dir=tmp_dir, format=frame_data_format + ) # Make video paths relative for vid in new_videos: tmp_path = vid.filename # Get the parent dir of the YAML file. # Use "/" since this works on Windows and posix - img_store_dir = os.path.basename(os.path.split(tmp_path)[0]) + "/" + os.path.basename(tmp_path) + img_store_dir = ( + os.path.basename(os.path.split(tmp_path)[0]) + + "/" + + os.path.basename(tmp_path) + ) # Change to relative path vid.backend.filename = img_store_dir # Convert to a dict, not JSON yet, because we need to patch up the videos d = labels.to_dict() - d['videos'] = Video.cattr().unstructure(new_videos) + d["videos"] = Video.cattr().unstructure(new_videos) else: d = labels.to_dict() @@ -976,14 +1122,16 @@ def save_json(labels: 'Labels', filename: str, json_dumps(d, full_out_filename) # Create the archive - shutil.make_archive(base_name=filename, root_dir=tmp_dir, format='zip') + shutil.make_archive(base_name=filename, root_dir=tmp_dir, format="zip") # If the user doesn't want to compress, then just write the json to the filename else: json_dumps(d, filename) @classmethod - def from_json(cls, data: Union[str, dict], match_to: Optional['Labels'] = None) -> 'Labels': + def from_json( + cls, data: Union[str, dict], match_to: Optional["Labels"] = None + ) -> "Labels": # Parse the json string if needed. if type(data) is str: @@ -991,16 +1139,20 @@ def from_json(cls, data: Union[str, dict], match_to: Optional['Labels'] = None) else: dicts = data - dicts['tracks'] = dicts.get('tracks', []) # don't break if json doesn't include tracks + dicts["tracks"] = dicts.get( + "tracks", [] + ) # don't break if json doesn't include tracks # First, deserialize the skeletons, videos, and nodes lists. # The labels reference these so we will need them while deserializing. - nodes = cattr.structure(dicts['nodes'], List[Node]) + nodes = cattr.structure(dicts["nodes"], List[Node]) - idx_to_node = {i:nodes[i] for i in range(len(nodes))} - skeletons = Skeleton.make_cattr(idx_to_node).structure(dicts['skeletons'], List[Skeleton]) - videos = Video.cattr().structure(dicts['videos'], List[Video]) - tracks = cattr.structure(dicts['tracks'], List[Track]) + idx_to_node = {i: nodes[i] for i in range(len(nodes))} + skeletons = Skeleton.make_cattr(idx_to_node).structure( + dicts["skeletons"], List[Skeleton] + ) + videos = Video.cattr().structure(dicts["videos"], List[Video]) + tracks = cattr.structure(dicts["tracks"], List[Track]) # if we're given a Labels object to match, use its objects when they match if match_to is not None: @@ -1017,49 +1169,67 @@ def from_json(cls, data: Union[str, dict], match_to: Optional['Labels'] = None) for idx, vid in enumerate(videos): for old_vid in match_to.videos: # compare last three parts of path - if vid.filename == old_vid.filename or weak_filename_match(vid.filename, old_vid.filename): + if vid.filename == old_vid.filename or weak_filename_match( + vid.filename, old_vid.filename + ): # use video from match videos[idx] = old_vid break if "suggestions" in dicts: suggestions_cattr = cattr.Converter() - suggestions_cattr.register_structure_hook(Video, lambda x,type: videos[int(x)]) - suggestions = suggestions_cattr.structure(dicts['suggestions'], Dict[Video, List]) + suggestions_cattr.register_structure_hook( + Video, lambda x, type: videos[int(x)] + ) + suggestions = suggestions_cattr.structure( + dicts["suggestions"], Dict[Video, List] + ) else: suggestions = dict() if "negative_anchors" in dicts: negative_anchors_cattr = cattr.Converter() - negative_anchors_cattr.register_structure_hook(Video, lambda x,type: videos[int(x)]) - negative_anchors = negative_anchors_cattr.structure(dicts['negative_anchors'], Dict[Video, List]) + negative_anchors_cattr.register_structure_hook( + Video, lambda x, type: videos[int(x)] + ) + negative_anchors = negative_anchors_cattr.structure( + dicts["negative_anchors"], Dict[Video, List] + ) else: negative_anchors = dict() # If there is actual labels data, get it. - if 'labels' in dicts: + if "labels" in dicts: label_cattr = make_instance_cattr() - label_cattr.register_structure_hook(Skeleton, lambda x,type: skeletons[int(x)]) - label_cattr.register_structure_hook(Video, lambda x,type: videos[int(x)]) - label_cattr.register_structure_hook(Node, lambda x,type: x if isinstance(x,Node) else nodes[int(x)]) - label_cattr.register_structure_hook(Track, lambda x, type: None if x is None else tracks[int(x)]) - - labels = label_cattr.structure(dicts['labels'], List[LabeledFrame]) + label_cattr.register_structure_hook( + Skeleton, lambda x, type: skeletons[int(x)] + ) + label_cattr.register_structure_hook(Video, lambda x, type: videos[int(x)]) + label_cattr.register_structure_hook( + Node, lambda x, type: x if isinstance(x, Node) else nodes[int(x)] + ) + label_cattr.register_structure_hook( + Track, lambda x, type: None if x is None else tracks[int(x)] + ) + + labels = label_cattr.structure(dicts["labels"], List[LabeledFrame]) else: labels = [] - return cls(labeled_frames=labels, - videos=videos, - skeletons=skeletons, - nodes=nodes, - suggestions=suggestions, - negative_anchors=negative_anchors, - tracks=tracks) + return cls( + labeled_frames=labels, + videos=videos, + skeletons=skeletons, + nodes=nodes, + suggestions=suggestions, + negative_anchors=negative_anchors, + tracks=tracks, + ) @classmethod - def load_json(cls, filename: str, - video_callback=None, - match_to: Optional['Labels'] = None): + def load_json( + cls, filename: str, video_callback=None, match_to: Optional["Labels"] = None + ): tmp_dir = None @@ -1068,8 +1238,10 @@ def load_json(cls, filename: str, # Make a tmpdir, located in the directory that the file exists, to unzip # its contents. - tmp_dir = os.path.join(os.path.dirname(filename), - f"tmp_{os.getpid()}_{os.path.basename(filename)}") + tmp_dir = os.path.join( + os.path.dirname(filename), + f"tmp_{os.getpid()}_{os.path.basename(filename)}", + ) if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir, ignore_errors=True) try: @@ -1077,7 +1249,7 @@ def load_json(cls, filename: str, except FileExistsError: pass - #tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(filename)) + # tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(filename)) try: @@ -1090,10 +1262,16 @@ def load_json(cls, filename: str, # We can now open the JSON file, save the zip file and # replace file with the first JSON file we find in the archive. - json_files = [os.path.join(tmp_dir, file) for file in os.listdir(tmp_dir) if file.endswith(".json")] + json_files = [ + os.path.join(tmp_dir, file) + for file in os.listdir(tmp_dir) + if file.endswith(".json") + ] if len(json_files) == 0: - raise ValueError(f"No JSON file found inside {filename}. Are you sure this is a valid sLEAP dataset.") + raise ValueError( + f"No JSON file found inside {filename}. Are you sure this is a valid sLEAP dataset." + ) filename = json_files[0] @@ -1103,7 +1281,7 @@ def load_json(cls, filename: str, raise # Open and parse the JSON in filename - with open(filename, 'r') as file: + with open(filename, "r") as file: # FIXME: Peek into the json to see if there is version string. # We do this to tell apart old JSON data from leap_dev vs the @@ -1119,7 +1297,9 @@ def load_json(cls, filename: str, # Replace local video paths (for imagestore) if tmp_dir: for vid in dicts["videos"]: - vid["backend"]["filename"] = os.path.join(tmp_dir, vid["backend"]["filename"]) + vid["backend"]["filename"] = os.path.join( + tmp_dir, vid["backend"]["filename"] + ) # Use the callback if given to handle missing videos if callable(video_callback): @@ -1145,7 +1325,7 @@ def load_json(cls, filename: str, except Exception as ex: # Ok, we give up, where the hell are these videos! - raise # Re-raise. + raise # Re-raise. finally: os.chdir(cwd) # Make sure to change back if we have problems. @@ -1155,9 +1335,12 @@ def load_json(cls, filename: str, return load_labels_json_old(data_path=filename, parsed_json=dicts) @staticmethod - def save_hdf5(labels: 'Labels', filename: str, - append: bool = False, - save_frame_data: bool = False): + def save_hdf5( + labels: "Labels", + filename: str, + append: bool = False, + save_frame_data: bool = False, + ): """ Serialize the labels dataset to an HDF5 file. @@ -1177,7 +1360,9 @@ def save_hdf5(labels: 'Labels', filename: str, # FIXME: Need to implement this. if save_frame_data: - raise NotImplementedError('Saving frame data is not implemented yet with HDF5 Labels datasets.') + raise NotImplementedError( + "Saving frame data is not implemented yet with HDF5 Labels datasets." + ) # Delete the file if it exists, we want to start from scratch since # h5py truncates the file which seems to not actually delete data @@ -1188,16 +1373,18 @@ def save_hdf5(labels: 'Labels', filename: str, # Serialize all the meta-data to JSON. d = labels.to_dict(skip_labels=True) - with h5.File(filename, 'a') as f: + with h5.File(filename, "a") as f: # Add all the JSON metadata - meta_group = f.require_group('metadata') + meta_group = f.require_group("metadata") # If we are appending and there already exists JSON metadata - if append and 'json' in meta_group.attrs: + if append and "json" in meta_group.attrs: # Otherwise, we need to read the JSON and append to the lists - old_labels = Labels.from_json(meta_group.attrs['json'].tostring().decode()) + old_labels = Labels.from_json( + meta_group.attrs["json"].tostring().decode() + ) # A function to join to list but only include new non-dupe entries # from the right hand list. @@ -1234,50 +1421,67 @@ def append_unique(old, new): d = labels.to_dict(skip_labels=True) # Output the dict to JSON - meta_group.attrs['json'] = np.string_(json_dumps(d)) + meta_group.attrs["json"] = np.string_(json_dumps(d)) # FIXME: We can probably construct these from attrs fields # We will store Instances and PredcitedInstances in the same # table. instance_type=0 or Instance and instance_type=1 for # PredictedInstance, score will be ignored for Instances. - instance_dtype = np.dtype([('instance_id', 'i8'), - ('instance_type', 'u1'), - ('frame_id', 'u8'), - ('skeleton', 'u4'), - ('track', 'i4'), - ('from_predicted', 'i8'), - ('score', 'f4'), - ('point_id_start', 'u8'), - ('point_id_end', 'u8')]) - frame_dtype = np.dtype([('frame_id', 'u8'), - ('video', 'u4'), - ('frame_idx', 'u8'), - ('instance_id_start', 'u8'), - ('instance_id_end', 'u8')]) + instance_dtype = np.dtype( + [ + ("instance_id", "i8"), + ("instance_type", "u1"), + ("frame_id", "u8"), + ("skeleton", "u4"), + ("track", "i4"), + ("from_predicted", "i8"), + ("score", "f4"), + ("point_id_start", "u8"), + ("point_id_end", "u8"), + ] + ) + frame_dtype = np.dtype( + [ + ("frame_id", "u8"), + ("video", "u4"), + ("frame_idx", "u8"), + ("instance_id_start", "u8"), + ("instance_id_end", "u8"), + ] + ) num_instances = len(labels.all_instances) max_skeleton_size = max([len(s.nodes) for s in labels.skeletons]) # Initialize data arrays for serialization points = np.zeros(num_instances * max_skeleton_size, dtype=Point.dtype) - pred_points = np.zeros(num_instances * max_skeleton_size, dtype=PredictedPoint.dtype) + pred_points = np.zeros( + num_instances * max_skeleton_size, dtype=PredictedPoint.dtype + ) instances = np.zeros(num_instances, dtype=instance_dtype) frames = np.zeros(len(labels), dtype=frame_dtype) # Pre compute some structures to make serialization faster - skeleton_to_idx = {skeleton: labels.skeletons.index(skeleton) for skeleton in labels.skeletons} - track_to_idx = {track: labels.tracks.index(track) for track in labels.tracks} + skeleton_to_idx = { + skeleton: labels.skeletons.index(skeleton) + for skeleton in labels.skeletons + } + track_to_idx = { + track: labels.tracks.index(track) for track in labels.tracks + } track_to_idx[None] = -1 - video_to_idx = {video: labels.videos.index(video) for video in labels.videos} + video_to_idx = { + video: labels.videos.index(video) for video in labels.videos + } instance_type_to_idx = {Instance: 0, PredictedInstance: 1} # If we are appending, we need look inside to see what frame, instance, and point # ids we need to start from. This gives us offsets to use. - if append and 'points' in f: - point_id_offset = f['points'].shape[0] - pred_point_id_offset = f['pred_points'].shape[0] - instance_id_offset = f['instances'][-1]['instance_id'] + 1 - frame_id_offset = int(f['frames'][-1]['frame_id']) + 1 + if append and "points" in f: + point_id_offset = f["points"].shape[0] + pred_point_id_offset = f["pred_points"].shape[0] + instance_id_offset = f["instances"][-1]["instance_id"] + 1 + frame_id_offset = int(f["frames"][-1]["frame_id"]) + 1 else: point_id_offset = 0 pred_point_id_offset = 0 @@ -1291,8 +1495,13 @@ def append_unique(old, new): all_from_predicted = [] from_predicted_id = 0 for frame_id, label in enumerate(labels): - frames[frame_id] = (frame_id+frame_id_offset, video_to_idx[label.video], label.frame_idx, - instance_id+instance_id_offset, instance_id+instance_id_offset+len(label.instances)) + frames[frame_id] = ( + frame_id + frame_id_offset, + video_to_idx[label.video], + label.frame_idx, + instance_id + instance_id_offset, + instance_id + instance_id_offset + len(label.instances), + ) for instance in label.instances: parray = instance.get_points_array(copy=False, full=True) instance_type = type(instance) @@ -1312,22 +1521,27 @@ def append_unique(old, new): from_predicted_id = from_predicted_id + 1 # Copy all the data - instances[instance_id] = (instance_id+instance_id_offset, - instance_type_to_idx[instance_type], - frame_id, - skeleton_to_idx[instance.skeleton], - track_to_idx[instance.track], - -1, - score, - pid, pid + len(parray)) + instances[instance_id] = ( + instance_id + instance_id_offset, + instance_type_to_idx[instance_type], + frame_id, + skeleton_to_idx[instance.skeleton], + track_to_idx[instance.track], + -1, + score, + pid, + pid + len(parray), + ) # If these are predicted points, copy them to the predicted point array # otherwise, use the normal point array if type(parray) is PredictedPointArray: - pred_points[pred_point_id:pred_point_id + len(parray)] = parray + pred_points[ + pred_point_id : pred_point_id + len(parray) + ] = parray pred_point_id = pred_point_id + len(parray) else: - points[point_id:point_id + len(parray)] = parray + points[point_id : point_id + len(parray)] = parray point_id = point_id + len(parray) instance_id = instance_id + 1 @@ -1338,31 +1552,48 @@ def append_unique(old, new): pred_points = pred_points[0:pred_point_id] # Create datasets if we need to - if append and 'points' in f: - f['points'].resize((f["points"].shape[0] + points.shape[0]), axis = 0) - f['points'][-points.shape[0]:] = points - f['pred_points'].resize((f["pred_points"].shape[0] + pred_points.shape[0]), axis=0) - f['pred_points'][-pred_points.shape[0]:] = pred_points - f['instances'].resize((f["instances"].shape[0] + instances.shape[0]), axis=0) - f['instances'][-instances.shape[0]:] = instances - f['frames'].resize((f["frames"].shape[0] + frames.shape[0]), axis=0) - f['frames'][-frames.shape[0]:] = frames + if append and "points" in f: + f["points"].resize((f["points"].shape[0] + points.shape[0]), axis=0) + f["points"][-points.shape[0] :] = points + f["pred_points"].resize( + (f["pred_points"].shape[0] + pred_points.shape[0]), axis=0 + ) + f["pred_points"][-pred_points.shape[0] :] = pred_points + f["instances"].resize( + (f["instances"].shape[0] + instances.shape[0]), axis=0 + ) + f["instances"][-instances.shape[0] :] = instances + f["frames"].resize((f["frames"].shape[0] + frames.shape[0]), axis=0) + f["frames"][-frames.shape[0] :] = frames else: - f.create_dataset("points", data=points, maxshape=(None,), dtype=Point.dtype) - f.create_dataset("pred_points", data=pred_points, maxshape=(None,), dtype=PredictedPoint.dtype) - f.create_dataset("instances", data=instances, maxshape=(None,), dtype=instance_dtype) - f.create_dataset("frames", data=frames, maxshape=(None,), dtype=frame_dtype) + f.create_dataset( + "points", data=points, maxshape=(None,), dtype=Point.dtype + ) + f.create_dataset( + "pred_points", + data=pred_points, + maxshape=(None,), + dtype=PredictedPoint.dtype, + ) + f.create_dataset( + "instances", data=instances, maxshape=(None,), dtype=instance_dtype + ) + f.create_dataset( + "frames", data=frames, maxshape=(None,), dtype=frame_dtype + ) @classmethod - def load_hdf5(cls, filename: str, - video_callback=None, - match_to: Optional['Labels'] = None): + def load_hdf5( + cls, filename: str, video_callback=None, match_to: Optional["Labels"] = None + ): - with h5.File(filename, 'r') as f: + with h5.File(filename, "r") as f: # Extract the Labels JSON metadata and create Labels object with just # this metadata. - dicts = json_loads(f.require_group('metadata').attrs['json'].tostring().decode()) + dicts = json_loads( + f.require_group("metadata").attrs["json"].tostring().decode() + ) # Use the callback if given to handle missing videos if callable(video_callback): @@ -1370,16 +1601,18 @@ def load_hdf5(cls, filename: str, labels = cls.from_json(dicts, match_to=match_to) - frames_dset = f['frames'][:] - instances_dset = f['instances'][:] - points_dset = f['points'][:] - pred_points_dset = f['pred_points'][:] + frames_dset = f["frames"][:] + instances_dset = f["instances"][:] + points_dset = f["points"][:] + pred_points_dset = f["pred_points"][:] # Rather than instantiate a bunch of Point\PredictedPoint objects, we will # use inplace numpy recarrays. This will save a lot of time and memory # when reading things in. points = PointArray(buf=points_dset, shape=len(points_dset)) - pred_points = PredictedPointArray(buf=pred_points_dset, shape=len(pred_points_dset)) + pred_points = PredictedPointArray( + buf=pred_points_dset, shape=len(pred_points_dset) + ) # Extend the tracks list with a None track. We will signify this with a -1 in the # data which will map to last element of tracks @@ -1389,23 +1622,35 @@ def load_hdf5(cls, filename: str, # Create the instances instances = [] for i in instances_dset: - track = tracks[i['track']] - skeleton = labels.skeletons[i['skeleton']] - - if i['instance_type'] == 0: # Instance - instance = Instance(skeleton=skeleton, track=track, - points=points[i['point_id_start']:i['point_id_end']]) - else: # PredictedInstance - instance = PredictedInstance(skeleton=skeleton, track=track, - points=pred_points[i['point_id_start']:i['point_id_end']], - score=i['score']) + track = tracks[i["track"]] + skeleton = labels.skeletons[i["skeleton"]] + + if i["instance_type"] == 0: # Instance + instance = Instance( + skeleton=skeleton, + track=track, + points=points[i["point_id_start"] : i["point_id_end"]], + ) + else: # PredictedInstance + instance = PredictedInstance( + skeleton=skeleton, + track=track, + points=pred_points[i["point_id_start"] : i["point_id_end"]], + score=i["score"], + ) instances.append(instance) # Create the labeled frames - frames = [LabeledFrame(video=labels.videos[frame['video']], - frame_idx=frame['frame_idx'], - instances=instances[frame['instance_id_start']:frame['instance_id_end']]) - for i, frame in enumerate(frames_dset)] + frames = [ + LabeledFrame( + video=labels.videos[frame["video"]], + frame_idx=frame["frame_idx"], + instances=instances[ + frame["instance_id_start"] : frame["instance_id_end"] + ], + ) + for i, frame in enumerate(frames_dset) + ] labels.labeled_frames = frames @@ -1430,18 +1675,19 @@ def load_file(cls, filename: str, *args, **kwargs): raise ValueError(f"Cannot detect filetype for {filename}") @classmethod - def save_file(cls, labels: 'Labels', filename: str, *args, **kwargs): + def save_file(cls, labels: "Labels", filename: str, *args, **kwargs): """Save file, detecting format from filename.""" if filename.endswith((".json", ".zip")): compress = filename.endswith(".zip") - cls.save_json(labels = labels, filename = filename, - compress = compress) + cls.save_json(labels=labels, filename=filename, compress=compress) elif filename.endswith(".h5"): - cls.save_hdf5(labels = labels, filename = filename) + cls.save_hdf5(labels=labels, filename=filename) else: raise ValueError(f"Cannot detect filetype for {filename}") - def save_frame_data_imgstore(self, output_dir: str = './', format: str = 'png', all_labels: bool = False): + def save_frame_data_imgstore( + self, output_dir: str = "./", format: str = "png", all_labels: bool = False + ): """ Write all labeled frames from all videos to a collection of imgstore datasets. This only writes frames that have been labeled. Videos without any labeled frames @@ -1460,14 +1706,18 @@ def save_frame_data_imgstore(self, output_dir: str = './', format: str = 'png', # For each label imgstore_vids = [] for v_idx, v in enumerate(self.videos): - frame_nums = [lf.frame_idx for lf in self.labeled_frames - if v == lf.video - and (all_labels or lf.has_user_instances)] + frame_nums = [ + lf.frame_idx + for lf in self.labeled_frames + if v == lf.video and (all_labels or lf.has_user_instances) + ] # Join with "/" instead of os.path.join() since we want # path to work on Windows and Posix systems - frames_filename = output_dir + f'/frame_data_vid{v_idx}' - vid = v.to_imgstore(path=frames_filename, frame_numbers=frame_nums, format=format) + frames_filename = output_dir + f"/frame_data_vid{v_idx}" + vid = v.to_imgstore( + path=frames_filename, frame_numbers=frame_nums, format=format + ) # Close the video for now vid.close() @@ -1476,7 +1726,6 @@ def save_frame_data_imgstore(self, output_dir: str = './', format: str = 'png', return imgstore_vids - @staticmethod def _unwrap_mat_scalar(a): if a.shape == (1,): @@ -1499,11 +1748,13 @@ def load_mat(cls, filename): # If the video file isn't found, try in the same dir as the mat file if not os.path.exists(box_path): file_dir = os.path.dirname(filename) - box_path_name = box_path.split("\\")[-1] # assume windows path + box_path_name = box_path.split("\\")[-1] # assume windows path box_path = os.path.join(file_dir, box_path_name) if os.path.exists(box_path): - vid = Video.from_hdf5(dataset="box", filename=box_path, input_format="channels_first") + vid = Video.from_hdf5( + dataset="box", filename=box_path, input_format="channels_first" + ) else: vid = None @@ -1513,12 +1764,12 @@ def load_mat(cls, filename): edges_ = mat_contents["skeleton"]["edges"] points_ = mat_contents["positions"] - edges_ = edges_ - 1 # convert matlab 1-indexing to python 0-indexing + edges_ = edges_ - 1 # convert matlab 1-indexing to python 0-indexing nodes = Labels._unwrap_mat_array(nodes_) edges = Labels._unwrap_mat_array(edges_) - nodes = list(map(str, nodes)) # convert np._str to str + nodes = list(map(str, nodes)) # convert np._str to str sk = Skeleton(name=filename) sk.add_nodes(nodes) @@ -1529,14 +1780,14 @@ def load_mat(cls, filename): node_count, _, frame_count = points_.shape for i in range(frame_count): - new_inst = Instance(skeleton = sk) + new_inst = Instance(skeleton=sk) for node_idx, node in enumerate(nodes): x = points_[node_idx][0][i] y = points_[node_idx][1][i] new_inst[node] = Point(x, y) if len(new_inst.points): new_frame = LabeledFrame(video=vid, frame_idx=i) - new_frame.instances = new_inst, + new_frame.instances = (new_inst,) labeled_frames.append(new_frame) labels = cls(labeled_frames=labeled_frames, videos=[vid], skeletons=[sk]) @@ -1572,7 +1823,7 @@ def load_deeplabcut_csv(cls, filename): # x2 = config['x2'] # y2 = config['y2'] - data = pd.read_csv(filename, header=[1,2]) + data = pd.read_csv(filename, header=[1, 2]) # Create the skeleton from the list of nodes in the csv file # Note that DeepLabCut doesn't have edges, so these will have to be added by user later @@ -1585,7 +1836,7 @@ def load_deeplabcut_csv(cls, filename): # This may not be ideal for large projects, since we're reading in # each image and then writing it out in a new directory. - img_files = data.ix[:,0] # get list of all images + img_files = data.ix[:, 0] # get list of all images # the image filenames in the csv may not match where the user has them # so we'll change the directory to match where the user has the csv @@ -1612,7 +1863,7 @@ def fix_img_path(img_dir, img_filename): # get points for each node instance_points = dict() for node in node_names: - x, y = data[(node, 'x')][i], data[(node, 'y')][i] + x, y = data[(node, "x")][i], data[(node, "y")][i] instance_points[node] = Point(x, y) # create instance with points (we can assume there's only one instance per frame) instance = Instance(skeleton=skeleton, points=instance_points) @@ -1625,6 +1876,7 @@ def fix_img_path(img_dir, img_filename): @classmethod def make_video_callback(cls, search_paths=None): search_paths = search_paths or [] + def video_callback(video_list, new_paths=search_paths): # Check each video for video_item in video_list: @@ -1653,16 +1905,20 @@ def video_callback(video_list, new_paths=search_paths): video_item["backend"]["filename"] = check_path is_found = True break + return video_callback @classmethod def make_gui_video_callback(cls, search_paths): search_paths = search_paths or [] + def gui_video_callback(video_list, new_paths=search_paths): import os from PySide2.QtWidgets import QFileDialog, QMessageBox - has_shown_prompt = False # have we already alerted user about missing files? + has_shown_prompt = ( + False + ) # have we already alerted user about missing files? basename_list = [] @@ -1696,30 +1952,42 @@ def gui_video_callback(video_list, new_paths=search_paths): break # if we found this file, then move on to the next file - if is_found: continue + if is_found: + continue # Since we couldn't find the file on our own, prompt the user. print(f"Unable to find: {current_filename}") - QMessageBox(text=f"We're unable to locate one or more video files for this project. Please locate {current_filename}.").exec_() + QMessageBox( + text=f"We're unable to locate one or more video files for this project. Please locate {current_filename}." + ).exec_() has_shown_prompt = True current_root, current_ext = os.path.splitext(current_basename) caption = f"Please locate {current_basename}..." - filters = [f"{current_root} file (*{current_ext})", "Any File (*.*)"] + filters = [ + f"{current_root} file (*{current_ext})", + "Any File (*.*)", + ] dir = None if len(new_paths) == 0 else new_paths[-1] - new_filename, _ = QFileDialog.getOpenFileName(None, dir=dir, caption=caption, filter=";;".join(filters)) + new_filename, _ = QFileDialog.getOpenFileName( + None, dir=dir, caption=caption, filter=";;".join(filters) + ) # if we got an answer, then update filename for video if len(new_filename): video_item["backend"]["filename"] = new_filename # keep track of the directory chosen by user new_paths.append(os.path.dirname(new_filename)) basename_list.append(current_basename) + return gui_video_callback -def load_labels_json_old(data_path: str, parsed_json: dict = None, - adjust_matlab_indexing: bool = True, - fix_rel_paths: bool = True) -> Labels: +def load_labels_json_old( + data_path: str, + parsed_json: dict = None, + adjust_matlab_indexing: bool = True, + fix_rel_paths: bool = True, +) -> Labels: """ Simple utitlity code to load data from Talmo's old JSON format into newer Labels object. @@ -1765,7 +2033,10 @@ def load_labels_json_old(data_path: str, parsed_json: dict = None, if adjust_matlab_indexing: edges = np.array(edges) - 1 for (src_idx, dst_idx) in edges: - skeleton.add_edge(data["skeleton"]["nodeNames"][src_idx], data["skeleton"]["nodeNames"][dst_idx]) + skeleton.add_edge( + data["skeleton"]["nodeNames"][src_idx], + data["skeleton"]["nodeNames"][dst_idx], + ) if fix_rel_paths: for i, row in videos.iterrows(): @@ -1781,13 +2052,17 @@ def load_labels_json_old(data_path: str, parsed_json: dict = None, if videos.at[i, "format"] == "media": vid = Video.from_media(videos.at[i, "filepath"]) else: - vid = Video.from_hdf5(filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"]) + vid = Video.from_hdf5( + filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"] + ) video_objects[videos.at[i, "id"]] = vid # A function to get all the instances for a particular video frame def get_frame_instances(video_id, frame_idx): - is_in_frame = (points["videoId"] == video_id) & (points["frameIdx"] == frame_idx) + is_in_frame = (points["videoId"] == video_id) & ( + points["frameIdx"] == frame_idx + ) if not is_in_frame.any(): return [] @@ -1795,8 +2070,12 @@ def get_frame_instances(video_id, frame_idx): frame_instance_ids = np.unique(points["instanceId"][is_in_frame]) for i, instance_id in enumerate(frame_instance_ids): is_instance = is_in_frame & (points["instanceId"] == instance_id) - instance_points = {data["skeleton"]["nodeNames"][n]: Point(x, y, visible=v) for x, y, n, v in - zip(*[points[k][is_instance] for k in ["x", "y", "node", "visible"]])} + instance_points = { + data["skeleton"]["nodeNames"][n]: Point(x, y, visible=v) + for x, y, n, v in zip( + *[points[k][is_instance] for k in ["x", "y", "node", "visible"]] + ) + } instance = Instance(skeleton=skeleton, points=instance_points) instances.append(instance) @@ -1804,12 +2083,20 @@ def get_frame_instances(video_id, frame_idx): return instances # Get the unique labeled frames and construct a list of LabeledFrame objects for them. - frame_keys = list({(videoId, frameIdx) for videoId, frameIdx in zip(points["videoId"], points["frameIdx"])}) + frame_keys = list( + { + (videoId, frameIdx) + for videoId, frameIdx in zip(points["videoId"], points["frameIdx"]) + } + ) frame_keys.sort() labels = [] for videoId, frameIdx in frame_keys: - label = LabeledFrame(video=video_objects[videoId], frame_idx=frameIdx, - instances = get_frame_instances(videoId, frameIdx)) + label = LabeledFrame( + video=video_objects[videoId], + frame_idx=frameIdx, + instances=get_frame_instances(videoId, frameIdx), + ) labels.append(label) return Labels(labels) diff --git a/sleap/io/legacy.py b/sleap/io/legacy.py index 4f6c36293..8ddfe7fad 100644 --- a/sleap/io/legacy.py +++ b/sleap/io/legacy.py @@ -11,10 +11,13 @@ from ..nn.tracking import Track + def load_predicted_labels_json_old( - data_path: str, parsed_json: dict = None, - adjust_matlab_indexing: bool = True, - fix_rel_paths: bool = True) -> Labels: + data_path: str, + parsed_json: dict = None, + adjust_matlab_indexing: bool = True, + fix_rel_paths: bool = True, +) -> Labels: """ Simple utitlity code to load data from Talmo's old JSON format into newer Labels object. This loads the prediced instances @@ -53,7 +56,10 @@ def load_predicted_labels_json_old( if adjust_matlab_indexing: edges = np.array(edges) - 1 for (src_idx, dst_idx) in edges: - skeleton.add_edge(data["skeleton"]["nodeNames"][src_idx], data["skeleton"]["nodeNames"][dst_idx]) + skeleton.add_edge( + data["skeleton"]["nodeNames"][src_idx], + data["skeleton"]["nodeNames"][dst_idx], + ) if fix_rel_paths: for i, row in videos.iterrows(): @@ -69,22 +75,32 @@ def load_predicted_labels_json_old( if videos.at[i, "format"] == "media": vid = Video.from_media(videos.at[i, "filepath"]) else: - vid = Video.from_hdf5(filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"]) + vid = Video.from_hdf5( + filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"] + ) video_objects[videos.at[i, "id"]] = vid - track_ids = predicted_instances['trackId'].values + track_ids = predicted_instances["trackId"].values unique_track_ids = np.unique(track_ids) - spawned_on = {track_id: predicted_instances.loc[predicted_instances['trackId'] == track_id]['frameIdx'].values[0] - for track_id in unique_track_ids} - tracks = {i: Track(name=str(i), spawned_on=spawned_on[i]) - for i in np.unique(predicted_instances['trackId'].values).tolist()} + spawned_on = { + track_id: predicted_instances.loc[predicted_instances["trackId"] == track_id][ + "frameIdx" + ].values[0] + for track_id in unique_track_ids + } + tracks = { + i: Track(name=str(i), spawned_on=spawned_on[i]) + for i in np.unique(predicted_instances["trackId"].values).tolist() + } # A function to get all the instances for a particular video frame def get_frame_predicted_instances(video_id, frame_idx): points = predicted_points - is_in_frame = (points["videoId"] == video_id) & (points["frameIdx"] == frame_idx) + is_in_frame = (points["videoId"] == video_id) & ( + points["frameIdx"] == frame_idx + ) if not is_in_frame.any(): return [] @@ -92,28 +108,54 @@ def get_frame_predicted_instances(video_id, frame_idx): frame_instance_ids = np.unique(points["instanceId"][is_in_frame]) for i, instance_id in enumerate(frame_instance_ids): is_instance = is_in_frame & (points["instanceId"] == instance_id) - track_id = predicted_instances.loc[predicted_instances['id'] == instance_id]['trackId'].values[0] - match_score = predicted_instances.loc[predicted_instances['id'] == instance_id]['matching_score'].values[0] - track_score = predicted_instances.loc[predicted_instances['id'] == instance_id]['tracking_score'].values[0] - instance_points = {data["skeleton"]["nodeNames"][n]: PredictedPoint(x, y, visible=v, score=confidence) - for x, y, n, v, confidence in - zip(*[points[k][is_instance] for k in ["x", "y", "node", "visible", "confidence"]])} - - instance = PredictedInstance(skeleton=skeleton, - points=instance_points, - track=tracks[track_id], - score=match_score) + track_id = predicted_instances.loc[ + predicted_instances["id"] == instance_id + ]["trackId"].values[0] + match_score = predicted_instances.loc[ + predicted_instances["id"] == instance_id + ]["matching_score"].values[0] + track_score = predicted_instances.loc[ + predicted_instances["id"] == instance_id + ]["tracking_score"].values[0] + instance_points = { + data["skeleton"]["nodeNames"][n]: PredictedPoint( + x, y, visible=v, score=confidence + ) + for x, y, n, v, confidence in zip( + *[ + points[k][is_instance] + for k in ["x", "y", "node", "visible", "confidence"] + ] + ) + } + + instance = PredictedInstance( + skeleton=skeleton, + points=instance_points, + track=tracks[track_id], + score=match_score, + ) instances.append(instance) return instances # Get the unique labeled frames and construct a list of LabeledFrame objects for them. - frame_keys = list({(videoId, frameIdx) for videoId, frameIdx in zip(predicted_points["videoId"], predicted_points["frameIdx"])}) + frame_keys = list( + { + (videoId, frameIdx) + for videoId, frameIdx in zip( + predicted_points["videoId"], predicted_points["frameIdx"] + ) + } + ) frame_keys.sort() labels = [] for videoId, frameIdx in frame_keys: - label = LabeledFrame(video=video_objects[videoId], frame_idx=frameIdx, - instances = get_frame_predicted_instances(videoId, frameIdx)) + label = LabeledFrame( + video=video_objects[videoId], + frame_idx=frameIdx, + instances=get_frame_predicted_instances(videoId, frameIdx), + ) labels.append(label) - return Labels(labels) \ No newline at end of file + return Labels(labels) diff --git a/sleap/io/video.py b/sleap/io/video.py index de56c034f..ef5096be3 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) + @attr.s(auto_attribs=True, cmp=False) class HDF5Video: """ @@ -49,7 +50,9 @@ def __attrs_post_init__(self): try: self.__file_h5 = h5.File(self.filename, "r") except OSError as ex: - raise FileNotFoundError(f"Could not find HDF5 file {self.filename}") from ex + raise FileNotFoundError( + f"Could not find HDF5 file {self.filename}" + ) from ex else: self.__file_h5 = None @@ -58,12 +61,15 @@ def __attrs_post_init__(self): self.__dataset_h5 = self.dataset self.__file_h5 = self.__dataset_h5.file self.dataset = self.__dataset_h5.name - elif (self.dataset is not None) and isinstance(self.dataset, str) and (self.__file_h5 is not None): + elif ( + (self.dataset is not None) + and isinstance(self.dataset, str) + and (self.__file_h5 is not None) + ): self.__dataset_h5 = self.__file_h5[self.dataset] else: self.__dataset_h5 = None - @input_format.validator def check(self, attribute, value): if value not in ["channels_first", "channels_last"]: @@ -88,10 +94,12 @@ def matches(self, other): Returns: True if attributes match, False otherwise """ - return self.filename == other.filename and \ - self.dataset == other.dataset and \ - self.convert_range == other.convert_range and \ - self.input_format == other.input_format + return ( + self.filename == other.filename + and self.dataset == other.dataset + and self.convert_range == other.convert_range + and self.input_format == other.input_format + ) # The properties and methods below complete our contract with the # higher level Video interface. @@ -116,7 +124,7 @@ def height(self): def dtype(self): return self.__dataset_h5.dtype - def get_frame(self, idx):# -> np.ndarray: + def get_frame(self, idx): # -> np.ndarray: """ Get a frame from the underlying HDF5 video data. @@ -131,7 +139,7 @@ def get_frame(self, idx):# -> np.ndarray: if self.input_format == "channels_first": frame = np.transpose(frame, (2, 1, 0)) - if self.convert_range and np.max(frame) <= 1.: + if self.convert_range and np.max(frame) <= 1.0: frame = (frame * 255).astype(int) return frame @@ -149,6 +157,7 @@ class MediaVideo: grayscale: Whether the video is grayscale or not. "auto" means detect based on first frame. """ + filename: str = attr.ib() # grayscale: bool = attr.ib(default=None, converter=bool) grayscale: bool = attr.ib() @@ -167,7 +176,9 @@ def __reader(self): # Load if not already loaded if self._reader_ is None: if not os.path.isfile(self.filename): - raise FileNotFoundError(f"Could not find filename video filename named {self.filename}") + raise FileNotFoundError( + f"Could not find filename video filename named {self.filename}" + ) # Try and open the file either locally in current directory or with full path self._reader_ = cv2.VideoCapture(self.filename) @@ -175,7 +186,9 @@ def __reader(self): # If the user specified None for grayscale bool, figure it out based on the # the first frame of data. if self._detect_grayscale is True: - self.grayscale = bool(np.alltrue(self.__test_frame[..., 0] == self.__test_frame[..., -1])) + self.grayscale = bool( + np.alltrue(self.__test_frame[..., 0] == self.__test_frame[..., -1]) + ) # Return cached reader return self._reader_ @@ -200,10 +213,11 @@ def matches(self, other): Returns: True if attributes match, False otherwise """ - return self.filename == other.filename and \ - self.grayscale == other.grayscale and \ - self.bgr == other.bgr - + return ( + self.filename == other.filename + and self.grayscale == other.grayscale + and self.bgr == other.bgr + ) @property def fps(self): @@ -249,10 +263,10 @@ def get_frame(self, idx, grayscale=None): grayscale = self.grayscale if grayscale: - frame = frame[...,0][...,None] + frame = frame[..., 0][..., None] if self.bgr: - frame = frame[...,::-1] + frame = frame[..., ::-1] return frame @@ -267,6 +281,7 @@ class NumpyVideo: * numpy data shape: (frames, width, height, channels) """ + filename: attr.ib() def __attrs_post_init__(self): @@ -284,7 +299,9 @@ def __attrs_post_init__(self): try: self.__data = np.load(self.filename) except OSError as ex: - raise FileNotFoundError(f"Could not find filename {self.filename}") from ex + raise FileNotFoundError( + f"Could not find filename {self.filename}" + ) from ex else: self.__data = None @@ -352,9 +369,9 @@ def __attrs_post_init__(self): # If the filename does not contain metadata.yaml, append it to the filename # assuming that this is a directory that contains the imgstore. - if 'metadata.yaml' not in self.filename: + if "metadata.yaml" not in self.filename: # Use "/" since this works on Windows and posix - self.filename = self.filename + '/metadata.yaml' + self.filename = self.filename + "/metadata.yaml" # Make relative path into absolute, ImgStores don't work properly it seems # without full paths if we change working directories. Video.fixup_path will @@ -376,7 +393,10 @@ def matches(self, other): Returns: True if attributes match, False otherwise """ - return self.filename == other.filename and self.index_by_original == other.index_by_original + return ( + self.filename == other.filename + and self.index_by_original == other.index_by_original + ) @property def __store(self): @@ -437,8 +457,9 @@ def get_frame(self, frame_number) -> np.ndarray: if self.index_by_original: img, (frame_number, frame_timestamp) = self.__store.get_image(frame_number) else: - img, (frame_number, frame_timestamp) = self.__store.get_image(frame_number=None, - frame_index=frame_number) + img, (frame_number, frame_timestamp) = self.__store.get_image( + frame_number=None, frame_index=frame_number + ) # If the frame has one channel, add a singleton channel as it seems other # video implementations do this. @@ -572,7 +593,7 @@ def get_frames(self, idxs: Union[int, Iterable[int]]) -> np.ndarray: The requested video frames with shape (len(idxs), width, height, channels) """ if np.isscalar(idxs): - idxs = [idxs,] + idxs = [idxs] return np.stack([self.get_frame(idx) for idx in idxs], axis=0) def __getitem__(self, idxs): @@ -582,10 +603,13 @@ def __getitem__(self, idxs): return self.get_frames(idxs) @classmethod - def from_hdf5(cls, dataset: Union[str, h5.Dataset], - filename: Union[str, h5.File] = None, - input_format: str = "channels_last", - convert_range: bool = True): + def from_hdf5( + cls, + dataset: Union[str, h5.Dataset], + filename: Union[str, h5.File] = None, + input_format: str = "channels_last", + convert_range: bool = True, + ): """ Create an instance of a video object from an HDF5 file and dataset. This is a helper method that invokes the HDF5Video backend. @@ -602,11 +626,11 @@ def from_hdf5(cls, dataset: Union[str, h5.Dataset], """ filename = Video.fixup_path(filename) backend = HDF5Video( - filename=filename, - dataset=dataset, - input_format=input_format, - convert_range=convert_range - ) + filename=filename, + dataset=dataset, + input_format=input_format, + convert_range=convert_range, + ) return cls(backend=backend) @classmethod @@ -670,7 +694,9 @@ def from_filename(cls, filename: str, *args, **kwargs): raise ValueError("Could not detect backend for specified filename.") @classmethod - def imgstore_from_filenames(cls, filenames: list, output_filename: str, *args, **kwargs): + def imgstore_from_filenames( + cls, filenames: list, output_filename: str, *args, **kwargs + ): """Create an imagestore from a list of image files. Args: @@ -686,9 +712,9 @@ def imgstore_from_filenames(cls, filenames: list, output_filename: str, *args, * img_shape = first_img.shape # create the imagestore - store = imgstore.new_for_format('png', - mode='w', basedir=output_filename, - imgshape=img_shape) + store = imgstore.new_for_format( + "png", mode="w", basedir=output_filename, imgshape=img_shape + ) # read each frame and write it to the imagestore # unfortunately imgstore doesn't let us just add the file @@ -703,12 +729,15 @@ def imgstore_from_filenames(cls, filenames: list, output_filename: str, *args, * @classmethod def to_numpy(cls, frame_data: np.array, file_name: str): - np.save(file_name, frame_data, 'w') + np.save(file_name, frame_data, "w") - def to_imgstore(self, path, - frame_numbers: List[int] = None, - format: str = "png", - index_by_original: bool = True): + def to_imgstore( + self, + path, + frame_numbers: List[int] = None, + format: str = "png", + index_by_original: bool = True, + ): """ Read frames from an arbitrary video backend and store them in a loopbio imgstore. This should facilitate conversion of any video to a loopbio imgstore. @@ -750,28 +779,36 @@ def to_imgstore(self, path, # new_backend = self.backend.copy_to(path) # return self.__class__(backend=new_backend) - store = imgstore.new_for_format(format, - mode='w', basedir=path, - imgshape=(self.shape[1], self.shape[2], self.shape[3]), - chunksize=1000) + store = imgstore.new_for_format( + format, + mode="w", + basedir=path, + imgshape=(self.shape[1], self.shape[2], self.shape[3]), + chunksize=1000, + ) # Write the JSON for the original video object to the metadata # of the imgstore for posterity store.add_extra_data(source_sleap_video_obj=Video.cattr().unstructure(self)) import time + for frame_num in frame_numbers: store.add_image(self.get_frame(frame_num), frame_num, time.time()) # If there are no frames to save for this video, add a dummy frame # since we can't save an empty imgstore. if len(frame_numbers) == 0: - store.add_image(np.zeros((self.shape[1], self.shape[2], self.shape[3])), 0, time.time()) + store.add_image( + np.zeros((self.shape[1], self.shape[2], self.shape[3])), 0, time.time() + ) store.close() # Return an ImgStoreVideo object referencing this new imgstore. - return self.__class__(backend=ImgStoreVideo(filename=path, index_by_original=index_by_original)) + return self.__class__( + backend=ImgStoreVideo(filename=path, index_by_original=index_by_original) + ) @staticmethod def cattr(): @@ -785,10 +822,10 @@ def cattr(): # When we are structuring video backends, try to fixup the video file paths # in case they are coming from a different computer or the file has been moved. def fixup_video(x, cl): - if 'filename' in x: - x['filename'] = Video.fixup_path(x['filename']) - if 'file' in x: - x['file'] = Video.fixup_path(x['file']) + if "filename" in x: + x["filename"] = Video.fixup_path(x["filename"]) + if "file" in x: + x["file"] = Video.fixup_path(x["file"]) return cl(**x) @@ -833,7 +870,7 @@ def fixup_path(path, raise_error=False) -> str: # Special case: this is an ImgStore path! We cant use # basename because it will strip the directory name off - elif path.endswith('metadata.yaml'): + elif path.endswith("metadata.yaml"): # Get the parent dir of the YAML file. img_store_dir = os.path.basename(os.path.split(path)[0]) @@ -846,4 +883,3 @@ def fixup_path(path, raise_error=False) -> str: else: logger.warning(f"Cannot find a video file: {path}") return path - diff --git a/sleap/io/visuals.py b/sleap/io/visuals.py index 1949ed58a..58f4091fc 100644 --- a/sleap/io/visuals.py +++ b/sleap/io/visuals.py @@ -13,11 +13,13 @@ from threading import Thread import logging + logger = logging.getLogger(__name__) # Object that signals shutdown _sentinel = object() + def reader(out_q: Queue, video: Video, frames: List[int]): """Read frame images from video and send them into queue. @@ -32,7 +34,7 @@ def reader(out_q: Queue, video: Video, frames: List[int]): total_count = len(frames) chunk_size = 64 - chunk_count = math.ceil(total_count/chunk_size) + chunk_count = math.ceil(total_count / chunk_size) logger.info(f"Chunks: {chunk_count}, chunk size: {chunk_size}") @@ -50,7 +52,7 @@ def reader(out_q: Queue, video: Video, frames: List[int]): video_frame_images = video[frames_idx_chunk] elapsed = clock() - t0 - fps = len(frames_idx_chunk)/elapsed + fps = len(frames_idx_chunk) / elapsed logger.debug(f"reading chunk {i} in {elapsed} s = {fps} fps") i += 1 @@ -59,6 +61,7 @@ def reader(out_q: Queue, video: Video, frames: List[int]): # send _sentinal object into queue to signal that we're done out_q.put(_sentinel) + def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): """Annotate frame images (draw instances). @@ -89,14 +92,15 @@ def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): imgs = [] for i, frame_idx in enumerate(frames_idx_chunk): img = get_frame_image( - video_frame=video_frame_images[i], - video_idx=video_idx, - frame_idx=frame_idx, - labels=labels) + video_frame=video_frame_images[i], + video_idx=video_idx, + frame_idx=frame_idx, + labels=labels, + ) imgs.append(img) elapsed = clock() - t0 - fps = len(imgs)/elapsed + fps = len(imgs) / elapsed logger.debug(f"drawing chunk {chunk_i} in {elapsed} s = {fps} fps") chunk_i += 1 out_q.put(imgs) @@ -104,8 +108,8 @@ def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): # send _sentinal object into queue to signal that we're done out_q.put(_sentinel) -def writer(in_q: Queue, progress_queue: Queue, - filename: str, fps: int, img_w_h: tuple): + +def writer(in_q: Queue, progress_queue: Queue, filename: str, fps: int, img_w_h: tuple): """Write annotated images to video. Args: @@ -123,7 +127,7 @@ def writer(in_q: Queue, progress_queue: Queue, cv2.setNumThreads(usable_cpu_count()) - fourcc = cv2.VideoWriter_fourcc(*'MJPG') + fourcc = cv2.VideoWriter_fourcc(*"MJPG") out = cv2.VideoWriter(filename, fourcc, fps, img_w_h) start_time = clock() @@ -143,7 +147,7 @@ def writer(in_q: Queue, progress_queue: Queue, out.write(img) elapsed = clock() - t0 - fps = len(data)/elapsed + fps = len(data) / elapsed logger.debug(f"writing chunk {i} in {elapsed} s = {fps} fps") i += 1 @@ -155,13 +159,15 @@ def writer(in_q: Queue, progress_queue: Queue, # send (-1, time) to signal done progress_queue.put((-1, total_elapsed)) + def save_labeled_video( - filename: str, - labels: Labels, - video: Video, - frames: List[int], - fps: int=15, - gui_progress: bool=False): + filename: str, + labels: Labels, + video: Video, + frames: List[int], + fps: int = 15, + gui_progress: bool = False, +): """Function to generate and save video with annotations.""" output_size = (video.height, video.width) @@ -173,12 +179,14 @@ def save_labeled_video( q2 = Queue() progress_queue = Queue() - thread_read = Thread(target=reader, args=(q1, video, frames,)) - thread_mark = Thread(target=marker, args=(q1, q2, labels, labels.videos.index(video))) - thread_write = Thread(target=writer, args=( - q2, progress_queue, filename, - fps, (video.width, video.height), - )) + thread_read = Thread(target=reader, args=(q1, video, frames)) + thread_mark = Thread( + target=marker, args=(q1, q2, labels, labels.videos.index(video)) + ) + thread_write = Thread( + target=writer, + args=(q2, progress_queue, filename, fps, (video.width, video.height)), + ) thread_read.start() thread_mark.start() @@ -189,9 +197,8 @@ def save_labeled_video( from PySide2 import QtWidgets, QtCore progress_win = QtWidgets.QProgressDialog( - f"Generating video with {len(frames)} frames...", - "Cancel", - 0, len(frames)) + f"Generating video with {len(frames)} frames...", "Cancel", 0, len(frames) + ) progress_win.setMinimumWidth(300) progress_win.setWindowModality(QtCore.Qt.WindowModal) @@ -201,19 +208,22 @@ def save_labeled_video( break if progress_win is not None and progress_win.wasCanceled(): break - fps = frames_complete/elapsed + fps = frames_complete / elapsed remaining_frames = len(frames) - frames_complete - remaining_time = remaining_frames/fps + remaining_time = remaining_frames / fps if gui_progress: progress_win.setValue(frames_complete) else: - print(f"Finished {frames_complete} frames in {elapsed} s, fps = {fps}, approx {remaining_time} s remaining") + print( + f"Finished {frames_complete} frames in {elapsed} s, fps = {fps}, approx {remaining_time} s remaining" + ) elapsed = clock() - t0 - fps = len(frames)/elapsed + fps = len(frames) / elapsed print(f"Done in {elapsed} s, fps = {fps}.") + def img_to_cv(img): # Convert RGB to BGR for OpenCV if img.shape[-1] == 3: @@ -223,27 +233,31 @@ def img_to_cv(img): img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) return img + def get_frame_image(video_frame, video_idx, frame_idx, labels): img = img_to_cv(video_frame) plot_instances_cv(img, video_idx, frame_idx, labels) return img + def _point_int_tuple(point): return int(point.x), int(point.y) + def plot_instances_cv(img, video_idx, frame_idx, labels): - cmap = ([ - [0, 114, 189], - [217, 83, 25], - [237, 177, 32], - [126, 47, 142], - [119, 172, 48], - [77, 190, 238], - [162, 20, 47], - ]) + cmap = [ + [0, 114, 189], + [217, 83, 25], + [237, 177, 32], + [126, 47, 142], + [119, 172, 48], + [77, 190, 238], + [162, 20, 47], + ] lfs = labels.find(labels.videos[video_idx], frame_idx) - if len(lfs) == 0: return + if len(lfs) == 0: + return count_no_track = 0 for i, instance in enumerate(lfs[0].instances_to_show): @@ -254,10 +268,11 @@ def plot_instances_cv(img, video_idx, frame_idx, labels): track_idx = len(labels.tracks) + count_no_track count_no_track += 1 - inst_color = cmap[track_idx%len(cmap)] + inst_color = cmap[track_idx % len(cmap)] plot_instance_cv(img, instance, inst_color) + def plot_instance_cv(img, instance, color, marker_radius=4): # RGB -> BGR for cv2 @@ -266,23 +281,31 @@ def plot_instance_cv(img, instance, color, marker_radius=4): for (node, point) in instance.nodes_points: # plot node at point if point.visible and not point.isnan(): - cv2.circle(img, - _point_int_tuple(point), - marker_radius, - cv_color, - lineType=cv2.LINE_AA) + cv2.circle( + img, + _point_int_tuple(point), + marker_radius, + cv_color, + lineType=cv2.LINE_AA, + ) for (src, dst) in instance.skeleton.edges: # Make sure that both nodes are present in this instance before drawing edge if src in instance and dst in instance: - if instance[src].visible and instance[dst].visible \ - and not instance[src].isnan() and not instance[dst].isnan(): + if ( + instance[src].visible + and instance[dst].visible + and not instance[src].isnan() + and not instance[dst].isnan() + ): cv2.line( - img, - _point_int_tuple(instance[src]), - _point_int_tuple(instance[dst]), - cv_color, - lineType=cv2.LINE_AA) + img, + _point_int_tuple(instance[src]), + _point_int_tuple(instance[dst]), + cv_color, + lineType=cv2.LINE_AA, + ) + if __name__ == "__main__": @@ -291,13 +314,21 @@ def plot_instance_cv(img, instance, color, marker_radius=4): parser = argparse.ArgumentParser() parser.add_argument("data_path", help="Path to labels json file") - parser.add_argument('-o', '--output', type=str, default=None, - help='The output filename for the video') - parser.add_argument('-f', '--fps', type=int, default=15, - help='Frames per second') - parser.add_argument('--frames', type=frame_list, default="", - help='list of frames to predict. Either comma separated list (e.g. 1,2,3) or ' - 'a range separated by hyphen (e.g. 1-3). (default is entire video)') + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="The output filename for the video", + ) + parser.add_argument("-f", "--fps", type=int, default=15, help="Frames per second") + parser.add_argument( + "--frames", + type=frame_list, + default="", + help="list of frames to predict. Either comma separated list (e.g. 1,2,3) or " + "a range separated by hyphen (e.g. 1-3). (default is entire video)", + ) args = parser.parse_args() video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) @@ -312,10 +343,12 @@ def plot_instance_cv(img, instance, color, marker_radius=4): filename = args.output or args.data_path + ".avi" - save_labeled_video(filename=filename, - labels=labels, - video=labels.videos[0], - frames=frames, - fps=args.fps) + save_labeled_video( + filename=filename, + labels=labels, + video=labels.videos[0], + frames=frames, + fps=args.fps, + ) print(f"Video saved as: {filename}") diff --git a/sleap/nn/architectures/__init__.py b/sleap/nn/architectures/__init__.py index 86c13db10..8b653eb77 100644 --- a/sleap/nn/architectures/__init__.py +++ b/sleap/nn/architectures/__init__.py @@ -9,4 +9,6 @@ available_arch_names = [arch.__name__ for arch in available_archs] BackboneType = TypeVar("BackboneType", *available_archs) -__all__ = ["available_archs", "available_arch_names", "BackboneType"] + [arch.__name__ for arch in available_archs] +__all__ = ["available_archs", "available_arch_names", "BackboneType"] + [ + arch.__name__ for arch in available_archs +] diff --git a/sleap/nn/architectures/common.py b/sleap/nn/architectures/common.py index 61c747332..fb0804ebd 100644 --- a/sleap/nn/architectures/common.py +++ b/sleap/nn/architectures/common.py @@ -4,6 +4,7 @@ from keras.layers import Conv2D, BatchNormalization, Add + def expand_to_n(x, n): """Expands an object `x` to `n` elements if scalar. @@ -18,15 +19,14 @@ def expand_to_n(x, n): """ if not isinstance(x, (collections.Sequence, np.ndarray)): - x = [x,] - + x = [x] + if np.size(x) == 1: x = np.tile(x, n) elif np.size(x) != n: raise ValueError("Variable to expand must be scalar.") - - return x + return x def conv(num_filters, kernel_size=(3, 3), activation="relu", **kwargs): @@ -41,7 +41,14 @@ def conv(num_filters, kernel_size=(3, 3), activation="relu", **kwargs): Returns: keras.layers.Conv2D instance built with presets """ - return Conv2D(num_filters, kernel_size=kernel_size, activation=activation, padding="same", **kwargs) + return Conv2D( + num_filters, + kernel_size=kernel_size, + activation=activation, + padding="same", + **kwargs + ) + def conv1(num_filters, **kwargs): """Convenience presets for 1x1 Conv2D. @@ -55,6 +62,7 @@ def conv1(num_filters, **kwargs): """ return conv(num_filters, kernel_size=(1, 1), **kwargs) + def conv3(num_filters, **kwargs): """Convenience presets for 3x3 Conv2D. @@ -67,6 +75,7 @@ def conv3(num_filters, **kwargs): """ return conv(num_filters, kernel_size=(3, 3), **kwargs) + def residual_block(x_in, num_filters=None, batch_norm=True): """Residual bottleneck block. @@ -99,26 +108,31 @@ def residual_block(x_in, num_filters=None, batch_norm=True): # Default to output the same number of channels as input if num_filters is None: num_filters = x_in.shape[-1] - + # Number of output channels must be divisible by 2 if num_filters % 2 != 0: - raise ValueError("Number of output filters must be divisible by 2 in residual blocks.") - + raise ValueError( + "Number of output filters must be divisible by 2 in residual blocks." + ) + # If number of input and output channels are different, add a 1x1 conv to use as the # identity tensor to which we add the residual at the end x_identity = x_in if x_in.shape[-1] != num_filters: x_identity = conv1(num_filters)(x_in) - if batch_norm: x_identity = BatchNormalization()(x_identity) - + if batch_norm: + x_identity = BatchNormalization()(x_identity) + # Bottleneck: 1x1 -> 3x3 -> 1x1 -> Add residual to identity x = conv1(num_filters // 2)(x_in) - if batch_norm: x = BatchNormalization()(x) + if batch_norm: + x = BatchNormalization()(x) x = conv3(num_filters // 2)(x) - if batch_norm: x = BatchNormalization()(x) + if batch_norm: + x = BatchNormalization()(x) x = conv1(num_filters)(x) - if batch_norm: x = BatchNormalization()(x) + if batch_norm: + x = BatchNormalization()(x) x_out = Add()([x_identity, x]) return x_out - diff --git a/sleap/nn/architectures/densenet.py b/sleap/nn/architectures/densenet.py index 826b27263..97d840ac0 100644 --- a/sleap/nn/architectures/densenet.py +++ b/sleap/nn/architectures/densenet.py @@ -14,6 +14,7 @@ from keras import backend, layers, models import keras.utils as keras_utils + def dense_block(x, blocks, name): """A dense block. # Arguments @@ -24,7 +25,7 @@ def dense_block(x, blocks, name): output tensor for the block. """ for i in range(blocks): - x = conv_block(x, 32, name=name + '_block' + str(i + 1)) + x = conv_block(x, 32, name=name + "_block" + str(i + 1)) return x @@ -37,14 +38,16 @@ def transition_block(x, reduction, name): # Returns output tensor for the block. """ - bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 - x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, - name=name + '_bn')(x) - x = layers.Activation('relu', name=name + '_relu')(x) - x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1, - use_bias=False, - name=name + '_conv')(x) - x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x) + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_bn")(x) + x = layers.Activation("relu", name=name + "_relu")(x) + x = layers.Conv2D( + int(backend.int_shape(x)[bn_axis] * reduction), + 1, + use_bias=False, + name=name + "_conv", + )(x) + x = layers.AveragePooling2D(2, strides=2, name=name + "_pool")(x) return x @@ -57,30 +60,24 @@ def conv_block(x, growth_rate, name): # Returns Output tensor for the block. """ - bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 - x1 = layers.BatchNormalization(axis=bn_axis, - epsilon=1.001e-5, - name=name + '_0_bn')(x) - x1 = layers.Activation('relu', name=name + '_0_relu')(x1) - x1 = layers.Conv2D(4 * growth_rate, 1, - use_bias=False, - name=name + '_1_conv')(x1) - x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, - name=name + '_1_bn')(x1) - x1 = layers.Activation('relu', name=name + '_1_relu')(x1) - x1 = layers.Conv2D(growth_rate, 3, - padding='same', - use_bias=False, - name=name + '_2_conv')(x1) - x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1]) + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_0_bn")( + x + ) + x1 = layers.Activation("relu", name=name + "_0_relu")(x1) + x1 = layers.Conv2D(4 * growth_rate, 1, use_bias=False, name=name + "_1_conv")(x1) + x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_1_bn")( + x1 + ) + x1 = layers.Activation("relu", name=name + "_1_relu")(x1) + x1 = layers.Conv2D( + growth_rate, 3, padding="same", use_bias=False, name=name + "_2_conv" + )(x1) + x = layers.Concatenate(axis=bn_axis, name=name + "_concat")([x, x1]) return x -def DenseNet(blocks, - output_channels, - input_tensor=None, - input_shape=None, - **kwargs): +def DenseNet(blocks, output_channels, input_tensor=None, input_shape=None, **kwargs): """Instantiates the DenseNet architecture. Optionally loads weights pre-trained on ImageNet. Note that the data format convention used by the model is @@ -123,7 +120,6 @@ def DenseNet(blocks, or invalid input shape. """ - if input_tensor is None: img_input = layers.Input(shape=input_shape) else: @@ -132,29 +128,29 @@ def DenseNet(blocks, else: img_input = input_tensor - bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input) - x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x) - x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x) - x = layers.Activation('relu', name='conv1/relu')(x) + x = layers.Conv2D(64, 7, strides=2, use_bias=False, name="conv1/conv")(x) + x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name="conv1/bn")(x) + x = layers.Activation("relu", name="conv1/relu")(x) x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x) - x = layers.MaxPooling2D(3, strides=2, name='pool1')(x) + x = layers.MaxPooling2D(3, strides=2, name="pool1")(x) - x = dense_block(x, blocks[0], name='conv2') - x = transition_block(x, 0.5, name='pool2') - x = dense_block(x, blocks[1], name='conv3') - x = transition_block(x, 0.5, name='pool3') - x = dense_block(x, blocks[2], name='conv4') - x = transition_block(x, 0.5, name='pool4') - x = dense_block(x, blocks[3], name='conv5') + x = dense_block(x, blocks[0], name="conv2") + x = transition_block(x, 0.5, name="pool2") + x = dense_block(x, blocks[1], name="conv3") + x = transition_block(x, 0.5, name="pool3") + x = dense_block(x, blocks[2], name="conv4") + x = transition_block(x, 0.5, name="pool4") + x = dense_block(x, blocks[3], name="conv5") - x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, name='bn')(x) - x = layers.Activation('relu', name='relu')(x) + x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name="bn")(x) + x = layers.Activation("relu", name="relu")(x) - x = layers.Conv2D(filters=output_channels, kernel_size=(3, 3), padding="same", name="output")(x) + x = layers.Conv2D( + filters=output_channels, kernel_size=(3, 3), padding="same", name="output" + )(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. @@ -165,13 +161,13 @@ def DenseNet(blocks, # Create model. if blocks == [6, 12, 24, 16]: - model = models.Model(inputs, x, name='densenet121') + model = models.Model(inputs, x, name="densenet121") elif blocks == [6, 12, 32, 32]: - model = models.Model(inputs, x, name='densenet169') + model = models.Model(inputs, x, name="densenet169") elif blocks == [6, 12, 48, 32]: - model = models.Model(inputs, x, name='densenet201') + model = models.Model(inputs, x, name="densenet201") else: - model = models.Model(inputs, x, name='densenet') + model = models.Model(inputs, x, name="densenet") return model @@ -215,4 +211,4 @@ def DenseNet(blocks, # include_top, weights, # input_tensor, input_shape, # pooling, classes, -# **kwargs) \ No newline at end of file +# **kwargs) diff --git a/sleap/nn/architectures/hourglass.py b/sleap/nn/architectures/hourglass.py index 70881fac8..f520c8ec3 100644 --- a/sleap/nn/architectures/hourglass.py +++ b/sleap/nn/architectures/hourglass.py @@ -1,7 +1,21 @@ import attr -from sleap.nn.architectures.common import residual_block, expand_to_n, conv, conv1, conv3 -from keras.layers import Conv2D, BatchNormalization, Add, MaxPool2D, UpSampling2D, Concatenate, Conv2DTranspose +from sleap.nn.architectures.common import ( + residual_block, + expand_to_n, + conv, + conv1, + conv3, +) +from keras.layers import ( + Conv2D, + BatchNormalization, + Add, + MaxPool2D, + UpSampling2D, + Concatenate, + Conv2DTranspose, +) @attr.s(auto_attribs=True) @@ -36,6 +50,7 @@ class StackedHourglass: initial_stride: Stride of first convolution to use for reducing input resolution. """ + num_stacks: int = 3 num_filters: int = 32 depth: int = 3 @@ -45,7 +60,6 @@ class StackedHourglass: interp: str = "bilinear" initial_stride: int = 1 - def output(self, x_in, num_output_channels): """ Generate a tensorflow graph for the backbone and return the output tensor. @@ -62,7 +76,15 @@ def output(self, x_in, num_output_channels): return stacked_hourglass(x_in, num_output_channels, **attr.asdict(self)) -def hourglass_block(x_in, num_output_channels, num_filters, depth=3, batch_norm=True, upsampling_layers=True, interp="bilinear"): +def hourglass_block( + x_in, + num_output_channels, + num_filters, + depth=3, + batch_norm=True, + upsampling_layers=True, + interp="bilinear", +): """Creates a single hourglass block. This function builds an hourglass block from residual blocks and max pooling. @@ -96,15 +118,21 @@ def hourglass_block(x_in, num_output_channels, num_filters, depth=3, batch_norm= x_out: tf.Tensor of the output of the block of the same width and height as the input with `num_output_channels` channels. """ - + # Check if input tensor has the right number of channels if x_in.shape[-1] != num_filters: - raise ValueError("Input tensor must have the same number of channels as the intermediate output of the hourglass (%d)." % num_filters) - + raise ValueError( + "Input tensor must have the same number of channels as the intermediate output of the hourglass (%d)." + % num_filters + ) + # Check if input tensor has the right height/width for pooling given depth - if x_in.shape[-2] % (2**depth) != 0 or x_in.shape[-2] % (2**depth) != 0: - raise ValueError("Input tensor must have width and height dimensions divisible by %d." % (2**depth)) - + if x_in.shape[-2] % (2 ** depth) != 0 or x_in.shape[-2] % (2 ** depth) != 0: + raise ValueError( + "Input tensor must have width and height dimensions divisible by %d." + % (2 ** depth) + ) + # Down x = x_in blocks_down = [] @@ -112,41 +140,59 @@ def hourglass_block(x_in, num_output_channels, num_filters, depth=3, batch_norm= x = residual_block(x, num_filters, batch_norm) blocks_down.append(x) x = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) - + x = residual_block(x, num_filters, batch_norm) - + # Middle x_identity = residual_block(x, num_filters, batch_norm) x = residual_block(x, num_filters, batch_norm) x = residual_block(x, num_filters, batch_norm) x = residual_block(x, num_filters, batch_norm) x = Add()([x_identity, x]) - + # Up for x_down in blocks_down[::-1]: x_down = residual_block(x_down, num_filters, batch_norm) if upsampling_layers: - x = UpSampling2D(size=(2,2), interpolation=interp)(x) + x = UpSampling2D(size=(2, 2), interpolation=interp)(x) else: - x = Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x) + x = Conv2DTranspose( + num_filters, + kernel_size=3, + strides=2, + padding="same", + activation="relu", + kernel_initializer="glorot_normal", + )(x) x = Add()([x_down, x]) x = residual_block(x, num_filters, batch_norm) - + # Head x = conv1(num_filters)(x) - if batch_norm: x = BatchNormalization()(x) - + if batch_norm: + x = BatchNormalization()(x) + x_out = conv1(num_output_channels, activation="linear")(x) - + x = conv1(num_filters, activation="linear")(x) x_ = conv1(num_filters, activation="linear")(x_out) x = Add()([x_in, x, x_]) - + return x, x_out -def stacked_hourglass(x_in, num_output_channels, num_stacks=3, num_filters=32, depth=3, batch_norm=True, - intermediate_inputs=True, upsampling_layers=True, interp="bilinear", initial_stride=1): +def stacked_hourglass( + x_in, + num_output_channels, + num_stacks=3, + num_filters=32, + depth=3, + batch_norm=True, + intermediate_inputs=True, + upsampling_layers=True, + interp="bilinear", + initial_stride=1, +): """Stacked hourglass block. This function builds and connects multiple hourglass blocks. See `hourglass` for @@ -181,7 +227,6 @@ def stacked_hourglass(x_in, num_output_channels, num_stacks=3, num_filters=32, d as the input with `num_output_channels` channels. """ - # Expand block-specific parameters if scalars provided num_filters = expand_to_n(num_filters, num_stacks) depth = expand_to_n(depth, num_stacks) @@ -195,12 +240,12 @@ def stacked_hourglass(x_in, num_output_channels, num_stacks=3, num_filters=32, d # Batchnorm after the intial down sampling if batch_norm[0]: x = BatchNormalization()(x) - + # Make sure first block gets the right number of channels # x = x_in if x.shape[-1] != num_filters[0]: x = residual_block(x, num_filters[0], batch_norm[0]) - + # Create individual hourglasses and collect intermediate outputs x_in = x x_outs = [] @@ -209,9 +254,15 @@ def stacked_hourglass(x_in, num_output_channels, num_stacks=3, num_filters=32, d x = Concatenate()([x, x_in]) x = residual_block(x, num_filters[i], batch_norm[i]) - x, x_out = hourglass_block(x, num_output_channels, num_filters[i], - depth=depth[i], batch_norm=batch_norm[i], upsampling_layers=upsampling_layers[i], interp=interp[i]) + x, x_out = hourglass_block( + x, + num_output_channels, + num_filters[i], + depth=depth[i], + batch_norm=batch_norm[i], + upsampling_layers=upsampling_layers[i], + interp=interp[i], + ) x_outs.append(x_out) - - return x_outs + return x_outs diff --git a/sleap/nn/architectures/leap.py b/sleap/nn/architectures/leap.py index b943347af..51b418343 100644 --- a/sleap/nn/architectures/leap.py +++ b/sleap/nn/architectures/leap.py @@ -48,7 +48,15 @@ def output(self, x_in, num_output_channels): return leap_cnn(x_in, num_output_channels, **attr.asdict(self)) -def leap_cnn(x_in, num_output_channels, down_blocks=3, up_blocks=3, upsampling_layers=True, num_filters=64, interp="bilinear"): +def leap_cnn( + x_in, + num_output_channels, + down_blocks=3, + up_blocks=3, + upsampling_layers=True, + num_filters=64, + interp="bilinear", +): """LEAP CNN block. Implementation generalized from original paper (`Pereira et al., 2019 @@ -74,25 +82,50 @@ def leap_cnn(x_in, num_output_channels, down_blocks=3, up_blocks=3, upsampling_l """ # Check if input tensor has the right height/width for pooling given depth - if x_in.shape[-2] % (2**down_blocks) != 0 or x_in.shape[-2] % (2**down_blocks) != 0: - raise ValueError("Input tensor must have width and height dimensions divisible by %d." % (2**down_blocks)) + if ( + x_in.shape[-2] % (2 ** down_blocks) != 0 + or x_in.shape[-2] % (2 ** down_blocks) != 0 + ): + raise ValueError( + "Input tensor must have width and height dimensions divisible by %d." + % (2 ** down_blocks) + ) x = x_in for i in range(down_blocks): - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) x = MaxPool2D(pool_size=2, strides=2, padding="same")(x) for i in range(up_blocks, 0, -1): if upsampling_layers: x = UpSampling2D(interpolation=interp)(x) else: - x = Conv2DTranspose(num_filters * (2 ** i), kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - - x = Conv2D(num_output_channels, kernel_size=3, padding="same", activation="linear")(x) + x = Conv2DTranspose( + num_filters * (2 ** i), + kernel_size=3, + strides=2, + padding="same", + activation="relu", + kernel_initializer="glorot_normal", + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + + x = Conv2D(num_output_channels, kernel_size=3, padding="same", activation="linear")( + x + ) return x diff --git a/sleap/nn/architectures/resnet.py b/sleap/nn/architectures/resnet.py index cef678071..4589fa1f3 100644 --- a/sleap/nn/architectures/resnet.py +++ b/sleap/nn/architectures/resnet.py @@ -4,9 +4,10 @@ import attr + @attr.s(auto_attribs=True) class ResNet50: - """ResNet50 pretrained backbone. + """ResNet50 pretrained backbone. Args: x_in: Input 4-D tf.Tensor or instantiated layer. @@ -21,14 +22,14 @@ class ResNet50: False, random weights are used for initialization. """ - upsampling_layers: bool = True - interp: str = "bilinear" - up_blocks: int = 5 - refine_conv_up: bool = False - pretrained: bool = True + upsampling_layers: bool = True + interp: str = "bilinear" + up_blocks: int = 5 + refine_conv_up: bool = False + pretrained: bool = True - def output(self, x_in, num_output_channels): - """ + def output(self, x_in, num_output_channels): + """ Generate a tensorflow graph for the backbone and return the output tensor. Args: @@ -41,59 +42,71 @@ def output(self, x_in, num_output_channels): Returns: x_out: tf.Tensor of the output of the block of with `num_output_channels` channels. """ - return resnet50(x_in, num_output_channels, **attr.asdict(self)) + return resnet50(x_in, num_output_channels, **attr.asdict(self)) - @property - def down_blocks(self): - """Returns the number of downsampling steps in the model.""" + @property + def down_blocks(self): + """Returns the number of downsampling steps in the model.""" - # This is a fixed constant for ResNet50. - return 5 - + # This is a fixed constant for ResNet50. + return 5 - @property - def output_scale(self): - """Returns relative scaling factor of this backbone.""" + @property + def output_scale(self): + """Returns relative scaling factor of this backbone.""" - return (1 / (2 ** (self.down_blocks - self.up_blocks))) + return 1 / (2 ** (self.down_blocks - self.up_blocks)) def preprocess_input(X): - """Rescale input to [-1, 1] and tile if not RGB.""" - X = (X * 2) - 1 - - if tf.shape(X)[-1] != 3: - X = tf.tile(X, [1, 1, 1, 3]) - - return X - - -def resnet50(x_in, num_output_channels, up_blocks=5, upsampling_layers=True, - interp="bilinear", refine_conv_up=False, pretrained=True): - """Build ResNet50 backbone.""" - - # Input should be rescaled from [0, 1] to [-1, 1] and needs to be 3 channels (RGB) - x = keras.layers.Lambda(preprocess_input)(x_in) - - # Automatically downloads weights - resnet_model = applications.ResNet50( - include_top=False, - input_shape=(int(x_in.shape[-3]), int(x_in.shape[-2]), 3), - weights="imagenet" if pretrained else None, + """Rescale input to [-1, 1] and tile if not RGB.""" + X = (X * 2) - 1 + + if tf.shape(X)[-1] != 3: + X = tf.tile(X, [1, 1, 1, 3]) + + return X + + +def resnet50( + x_in, + num_output_channels, + up_blocks=5, + upsampling_layers=True, + interp="bilinear", + refine_conv_up=False, + pretrained=True, +): + """Build ResNet50 backbone.""" + + # Input should be rescaled from [0, 1] to [-1, 1] and needs to be 3 channels (RGB) + x = keras.layers.Lambda(preprocess_input)(x_in) + + # Automatically downloads weights + resnet_model = applications.ResNet50( + include_top=False, + input_shape=(int(x_in.shape[-3]), int(x_in.shape[-2]), 3), + weights="imagenet" if pretrained else None, ) - # Output size is reduced by factor of 32 (2 ** 5) - x = resnet_model(x) + # Output size is reduced by factor of 32 (2 ** 5) + x = resnet_model(x) - for i in range(up_blocks): - if upsampling_layers: - x = keras.layers.UpSampling2D(size=(2, 2), interpolation=interp)(x) - else: - x = keras.layers.Conv2DTranspose(2 ** (8 - i), kernel_size=3, strides=2, padding="same", kernel_initializer="glorot_normal")(x) + for i in range(up_blocks): + if upsampling_layers: + x = keras.layers.UpSampling2D(size=(2, 2), interpolation=interp)(x) + else: + x = keras.layers.Conv2DTranspose( + 2 ** (8 - i), + kernel_size=3, + strides=2, + padding="same", + kernel_initializer="glorot_normal", + )(x) - if refine_conv_up: - x = keras.layers.Conv2D(2 ** (8 - i), kernel_size=1, padding="same")(x) + if refine_conv_up: + x = keras.layers.Conv2D(2 ** (8 - i), kernel_size=1, padding="same")(x) - x = keras.layers.Conv2D(num_output_channels, (3, 3), padding="same")(x) + x = keras.layers.Conv2D(num_output_channels, (3, 3), padding="same")(x) - return x + return x diff --git a/sleap/nn/architectures/unet.py b/sleap/nn/architectures/unet.py index 6df5e02e2..75e37b4cd 100644 --- a/sleap/nn/architectures/unet.py +++ b/sleap/nn/architectures/unet.py @@ -29,6 +29,7 @@ class UNet: interp: Method to use for interpolation when upsampling smaller features. """ + down_blocks: int = 3 up_blocks: int = 3 convs_per_depth: int = 2 @@ -105,8 +106,17 @@ def output(self, x_in, num_output_channels): return stacked_unet(x_in, num_output_channels, **attr.asdict(self)) -def unet(x_in, num_output_channels, down_blocks=3, up_blocks=3, convs_per_depth=2, num_filters=16, - kernel_size=5, upsampling_layers=True, interp="bilinear"): +def unet( + x_in, + num_output_channels, + down_blocks=3, + up_blocks=3, + convs_per_depth=2, + num_filters=16, + kernel_size=5, + upsampling_layers=True, + interp="bilinear", +): """U-net block. Implementation based off of `CARE @@ -137,50 +147,73 @@ def unet(x_in, num_output_channels, down_blocks=3, up_blocks=3, convs_per_depth= """ # Check if input tensor has the right height/width for pooling given depth - if x_in.shape[-2] % (2**down_blocks) != 0 or x_in.shape[-2] % (2**down_blocks) != 0: - raise ValueError("Input tensor must have width and height dimensions divisible by %d." % (2**down_blocks)) + if ( + x_in.shape[-2] % (2 ** down_blocks) != 0 + or x_in.shape[-2] % (2 ** down_blocks) != 0 + ): + raise ValueError( + "Input tensor must have width and height dimensions divisible by %d." + % (2 ** down_blocks) + ) # Ensure we have a tuple in case scalar provided kernel_size = expand_to_n(kernel_size, 2) # Input tensor x = x_in - + # Downsampling skip_layers = [] for n in range(down_blocks): for i in range(convs_per_depth): x = conv(num_filters * 2 ** n, kernel_size=kernel_size)(x) skip_layers.append(x) - x = MaxPool2D(pool_size=(2,2))(x) + x = MaxPool2D(pool_size=(2, 2))(x) # Middle for i in range(convs_per_depth - 1): x = conv(num_filters * 2 ** down_blocks, kernel_size=kernel_size)(x) - x = conv(num_filters * 2 ** max(0, down_blocks-1), kernel_size=kernel_size)(x) + x = conv(num_filters * 2 ** max(0, down_blocks - 1), kernel_size=kernel_size)(x) # Upsampling (with skips) - for n in range(down_blocks-1, down_blocks-up_blocks-1, -1): + for n in range(down_blocks - 1, down_blocks - up_blocks - 1, -1): if upsampling_layers: - x = UpSampling2D(size=(2,2), interpolation=interp)(x) + x = UpSampling2D(size=(2, 2), interpolation=interp)(x) else: - x = Conv2DTranspose(num_filters * 2 ** n, kernel_size=kernel_size, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x) + x = Conv2DTranspose( + num_filters * 2 ** n, + kernel_size=kernel_size, + strides=2, + padding="same", + activation="relu", + kernel_initializer="glorot_normal", + )(x) x = Concatenate(axis=-1)([x, skip_layers[n]]) - + for i in range(convs_per_depth - 1): x = conv(num_filters * 2 ** n, kernel_size=kernel_size)(x) - x = conv(num_filters * 2 ** max(0, n-1), kernel_size=kernel_size)(x) - + x = conv(num_filters * 2 ** max(0, n - 1), kernel_size=kernel_size)(x) + # Final layer x_out = conv(num_output_channels, activation="linear")(x) return x_out -def stacked_unet(x_in, num_output_channels, num_stacks=3, depth=3, convs_per_depth=2, num_filters=16, kernel_size=5, - upsampling_layers=True, intermediate_inputs=True, interp="bilinear"): +def stacked_unet( + x_in, + num_output_channels, + num_stacks=3, + depth=3, + convs_per_depth=2, + num_filters=16, + kernel_size=5, + upsampling_layers=True, + intermediate_inputs=True, + interp="bilinear", +): """Stacked U-net block. See `unet` for more specifics on the implementation. @@ -226,11 +259,17 @@ def stacked_unet(x_in, num_output_channels, num_stacks=3, depth=3, convs_per_dep if i > 0 and intermediate_inputs: x = Concatenate()([x, x_in]) - x_out = unet(x, num_output_channels, depth=depth[i], convs_per_depth=convs_per_depth[i], - num_filters=num_filters[i], kernel_size=kernel_size[i], - upsampling_layers=upsampling_layers[i], interp=interp[i]) + x_out = unet( + x, + num_output_channels, + depth=depth[i], + convs_per_depth=convs_per_depth[i], + num_filters=num_filters[i], + kernel_size=kernel_size[i], + upsampling_layers=upsampling_layers[i], + interp=interp[i], + ) x_outs.append(x_out) x = x_out - - return x_outs + return x_outs diff --git a/sleap/nn/augmentation.py b/sleap/nn/augmentation.py index 04654fabc..dfe097a3d 100644 --- a/sleap/nn/augmentation.py +++ b/sleap/nn/augmentation.py @@ -57,7 +57,9 @@ def __attrs_post_init__(self): # Setup batching all_idx = np.arange(self.num_samples) - self.batches = np.array_split(all_idx, np.ceil(self.num_samples / self.batch_size)) + self.batches = np.array_split( + all_idx, np.ceil(self.num_samples / self.batch_size) + ) # Initial shuffling if self.shuffle_initially: @@ -67,10 +69,16 @@ def __attrs_post_init__(self): # TODO: translation? self.aug_stack = [] if self.rotation is not None: - self.rotation = self.rotation if isinstance(self.rotation, tuple) else (-self.rotation, self.rotation) + self.rotation = ( + self.rotation + if isinstance(self.rotation, tuple) + else (-self.rotation, self.rotation) + ) if self.scale is not None and self.scale[0] != self.scale[1]: self.scale = (min(self.scale), max(self.scale)) - self.aug_stack.append(imgaug.augmenters.Affine(rotate=self.rotation, scale=self.scale)) + self.aug_stack.append( + imgaug.augmenters.Affine(rotate=self.rotation, scale=self.scale) + ) else: self.aug_stack.append(imgaug.augmenters.Affine(rotate=self.rotation)) @@ -110,11 +118,13 @@ def shuffle(self, batches_only=False): # Re-batch after shuffling all_idx = np.arange(self.num_samples) np.random.shuffle(all_idx) - self.batches = np.array_split(all_idx, np.ceil(self.num_samples / self.batch_size)) - + self.batches = np.array_split( + all_idx, np.ceil(self.num_samples / self.batch_size) + ) + def __len__(self): return len(self.batches) - + def __getitem__(self, batch_idx): aug_det = self.aug.to_deterministic() idx = self.batches[batch_idx] @@ -137,18 +147,18 @@ def __getitem__(self, batch_idx): # Combine each list of point arrays (per frame) to single KeypointsOnImage # points: frames -> instances -> point_array frames_in_batch = [self.points[i] for i in idx] - points_per_instance_per_frame = [[pa.shape[0] for pa in frame] for frame in frames_in_batch] + points_per_instance_per_frame = [ + [pa.shape[0] for pa in frame] for frame in frames_in_batch + ] koi_in_frame = [] for i, frame in enumerate(frames_in_batch): if len(frame): koi = imgaug.augmentables.kps.KeypointsOnImage.from_xy_array( - np.concatenate(frame), - shape=X[i].shape) + np.concatenate(frame), shape=X[i].shape + ) else: - koi = imgaug.augmentables.kps.KeypointsOnImage( - [], - shape=X[i].shape) + koi = imgaug.augmentables.kps.KeypointsOnImage([], shape=X[i].shape) koi_in_frame.append(koi) # Augment KeypointsOnImage @@ -165,7 +175,7 @@ def __getitem__(self, batch_idx): frame_point_arrays = [] offset = 0 for point_count in points_per_instance_per_frame[i]: - inst_points = frame[offset:offset+point_count] + inst_points = frame[offset : offset + point_count] frame_point_arrays.append(inst_points) offset += point_count split_points.append(frame_point_arrays) @@ -192,18 +202,22 @@ def make_cattr(X=None, Y=None, Points=None): # parameters. aug_cattr.register_unstructure_hook( Augmenter, - lambda x: - attr.asdict(x, - filter=attr.filters.exclude( - attr.fields(Augmenter).X, - attr.fields(Augmenter).Y, - attr.fields(Augmenter).Points))) + lambda x: attr.asdict( + x, + filter=attr.filters.exclude( + attr.fields(Augmenter).X, + attr.fields(Augmenter).Y, + attr.fields(Augmenter).Points, + ), + ), + ) # We the user needs to unstructure, what images, outputs, and points should # they use. We didn't serialize these, just the parameters. if X is not None: - aug_cattr.register_structure_hook(Augmenter, - lambda x: Augmenter(X=X,Y=Y,Points=Points, **x)) + aug_cattr.register_structure_hook( + Augmenter, lambda x: Augmenter(X=X, Y=Y, Points=Points, **x) + ) return aug_cattr @@ -211,22 +225,24 @@ def make_cattr(X=None, Y=None, Points=None): def demo_augmentation(): from sleap.io.dataset import Labels from sleap.nn.datagen import generate_training_data - from sleap.nn.datagen import generate_confmaps_from_points, generate_pafs_from_points + from sleap.nn.datagen import ( + generate_confmaps_from_points, + generate_pafs_from_points, + ) data_path = "tests/data/json_format_v1/centered_pair.json" labels = Labels.load_json(data_path) # Generate raw training data skeleton = labels.skeletons[0] - imgs, points = generate_training_data(labels, params = dict( - scale = 1, - instance_crop = True, - min_crop_size = 0, - negative_samples = 0)) + imgs, points = generate_training_data( + labels, + params=dict(scale=1, instance_crop=True, min_crop_size=0, negative_samples=0), + ) shape = (imgs.shape[1], imgs.shape[2]) def datagen_from_points(points): -# return generate_pafs_from_points(points, skeleton, shape) + # return generate_pafs_from_points(points, skeleton, shape) return generate_confmaps_from_points(points, skeleton, shape) # Augment @@ -239,12 +255,13 @@ def datagen_from_points(points): from PySide2.QtWidgets import QApplication # Visualize augmented training data - vid = Video.from_numpy(imgs*255) + vid = Video.from_numpy(imgs * 255) app = QApplication([]) demo_confmaps(aug_out, vid) -# demo_pafs(aug_out, vid) + # demo_pafs(aug_out, vid) app.exec_() + def demo_bad_augmentation(): from sleap.io.dataset import Labels from sleap.nn.datagen import generate_images, generate_confidence_maps @@ -262,7 +279,7 @@ def demo_bad_augmentation(): confmaps = generate_confidence_maps(labels) # Augment - aug = Augmenter(X=imgs, Y=confmaps, scale=(.5, 2)) + aug = Augmenter(X=imgs, Y=confmaps, scale=(0.5, 2)) imgs, confmaps = aug[0] from sleap.io.video import Video @@ -271,11 +288,12 @@ def demo_bad_augmentation(): from PySide2.QtWidgets import QApplication # Visualize augmented training data - vid = Video.from_numpy(imgs*255) + vid = Video.from_numpy(imgs * 255) app = QApplication([]) demo_confmaps(confmaps, vid) app.exec_() + if __name__ == "__main__": demo_augmentation() diff --git a/sleap/nn/datagen.py b/sleap/nn/datagen.py index cdb38202c..165280949 100644 --- a/sleap/nn/datagen.py +++ b/sleap/nn/datagen.py @@ -10,6 +10,7 @@ from sleap.io.dataset import Labels + def generate_training_data(labels, params): """ Generate imgs (ndarray) and points (list) to use for training. @@ -33,53 +34,72 @@ def generate_training_data(labels, params): resize_hack = not params["instance_crop"] - imgs = generate_images(labels, params["scale"], - frame_limit=params.get("frame_limit", None), - resize_hack=resize_hack) + imgs = generate_images( + labels, + params["scale"], + frame_limit=params.get("frame_limit", None), + resize_hack=resize_hack, + ) - points = generate_points(labels, params["scale"], - frame_limit=params.get("frame_limit", None)) + points = generate_points( + labels, params["scale"], frame_limit=params.get("frame_limit", None) + ) if params["instance_crop"]: # Crop and include any *random* negative samples imgs, points = instance_crops( - imgs, points, - min_crop_size = params["min_crop_size"], - negative_samples = params["negative_samples"]) + imgs, + points, + min_crop_size=params["min_crop_size"], + negative_samples=params["negative_samples"], + ) # Include any *specific* negative samples imgs, points = add_negative_anchor_crops( - labels, - imgs, points, - scale=params["scale"]) + labels, imgs, points, scale=params["scale"] + ) return imgs, points -def generate_images(labels:Labels, scale: float=1.0, - resize_hack: bool=True, frame_limit: int=None) -> np.ndarray: + +def generate_images( + labels: Labels, + scale: float = 1.0, + resize_hack: bool = True, + frame_limit: int = None, +) -> np.ndarray: """ Generate a ndarray of the image data for any user labeled frames. Wrapper that calls generate_images_from_list() with list of all frames that were labeled by user. """ - frame_list = [(lf.video, lf.frame_idx) - for lf in labels.user_labeled_frames[:frame_limit]] + frame_list = [ + (lf.video, lf.frame_idx) for lf in labels.user_labeled_frames[:frame_limit] + ] return generate_images_from_list(labels, frame_list, scale, resize_hack) -def generate_points(labels:Labels, scale: float=1.0, frame_limit: int=None) -> list: + +def generate_points( + labels: Labels, scale: float = 1.0, frame_limit: int = None +) -> list: """Generates point data for instances for any user labeled frames. Wrapper that calls generate_points_from_list() with list of all frames that were labeled by user. """ - frame_list = [(lf.video, lf.frame_idx) - for lf in labels.user_labeled_frames[:frame_limit]] + frame_list = [ + (lf.video, lf.frame_idx) for lf in labels.user_labeled_frames[:frame_limit] + ] return generate_points_from_list(labels, frame_list, scale) + def generate_images_from_list( - labels:Labels, frame_list: List[Tuple], - scale: float=1.0, resize_hack: bool=True) -> np.ndarray: + labels: Labels, + frame_list: List[Tuple], + scale: float = 1.0, + resize_hack: bool = True, +) -> np.ndarray: """ Generate a ndarray of the image data for given list of frames @@ -98,11 +118,11 @@ def generate_images_from_list( # rescale by factor y, x, c = img.shape if scale != 1.0 or resize_hack: - y_scaled, x_scaled = int(y//(1/scale)), int(x//(1/scale)) + y_scaled, x_scaled = int(y // (1 / scale)), int(x // (1 / scale)) # FIXME: hack to resize image so dimensions are divisible by 8 if resize_hack: - y_scaled, x_scaled = y_scaled//8*8, x_scaled//8*8 + y_scaled, x_scaled = y_scaled // 8 * 8, x_scaled // 8 * 8 if (x, y) != (x_scaled, y_scaled): # resize image @@ -122,7 +142,10 @@ def generate_images_from_list( return imgs -def generate_points_from_list(labels:Labels, frame_list: List[Tuple], scale: float=1.0) -> list: + +def generate_points_from_list( + labels: Labels, frame_list: List[Tuple], scale: float = 1.0 +) -> list: """Generates point data for instances in specified frames. Output is in the format expected by @@ -139,22 +162,28 @@ def generate_points_from_list(labels:Labels, frame_list: List[Tuple], scale: flo a list (each frame) of lists (each instance) of ndarrays (of points) i.e., frames -> instances -> point_array """ + def lf_points_from_singleton(lf_singleton): - if len(lf_singleton) == 0: return [] + if len(lf_singleton) == 0: + return [] lf = lf_singleton[0] - points = [inst.points_array*scale - for inst in lf.user_instances] + points = [inst.points_array * scale for inst in lf.user_instances] return points lfs = [labels.find(video, frame_idx) for (video, frame_idx) in frame_list] return list(map(lf_points_from_singleton, lfs)) -def generate_confmaps_from_points(frames_inst_points, - skeleton: Optional['Skeleton'], - shape, - node_count: Optional[int] = None, - sigma:float=5.0, scale:float=1.0, output_size=None) -> np.ndarray: + +def generate_confmaps_from_points( + frames_inst_points, + skeleton: Optional["Skeleton"], + shape, + node_count: Optional[int] = None, + sigma: float = 5.0, + scale: float = 1.0, + output_size=None, +) -> np.ndarray: """ Generates confmaps for set of frames. This is used to generate confmaps on the fly during training, @@ -175,15 +204,16 @@ def generate_confmaps_from_points(frames_inst_points, full_size = shape if output_size is None: - output_size = (shape[0] // (1/scale), shape[1] // (1/scale)) + output_size = (shape[0] // (1 / scale), shape[1] // (1 / scale)) output_size = tuple(map(int, output_size)) - ball = _get_conf_ball(output_size, sigma*scale) + ball = _get_conf_ball(output_size, sigma * scale) num_frames = len(frames_inst_points) - confmaps = np.zeros((num_frames, output_size[0], output_size[1], node_count), - dtype="float32") + confmaps = np.zeros( + (num_frames, output_size[0], output_size[1], node_count), dtype="float32" + ) for frame_idx, points_arrays in enumerate(frames_inst_points): for inst_points in points_arrays: @@ -191,12 +221,21 @@ def generate_confmaps_from_points(frames_inst_points, if not np.isnan(np.sum(inst_points[node_idx])): x = inst_points[node_idx][0] * scale y = inst_points[node_idx][1] * scale - _raster_ball(arr=confmaps[frame_idx], ball=ball, c=node_idx, x=x, y=y) + _raster_ball( + arr=confmaps[frame_idx], ball=ball, c=node_idx, x=x, y=y + ) return confmaps -def generate_pafs_from_points(frames_inst_points, skeleton, shape, - sigma:float=5.0, scale:float=1.0, output_size=None) -> np.ndarray: + +def generate_pafs_from_points( + frames_inst_points, + skeleton, + shape, + sigma: float = 5.0, + scale: float = 1.0, + output_size=None, +) -> np.ndarray: """ Generates pafs for set of frames. This is used to generate pafs on the fly during training, @@ -212,7 +251,7 @@ def generate_pafs_from_points(frames_inst_points, skeleton, shape, """ full_size = shape if output_size is None: - output_size = (shape[0] // (1/scale), shape[1] // (1/scale)) + output_size = (shape[0] // (1 / scale), shape[1] // (1 / scale)) # TODO: throw warning for truncation errors full_size = tuple(map(int, full_size)) @@ -221,8 +260,9 @@ def generate_pafs_from_points(frames_inst_points, skeleton, shape, num_frames = len(frames_inst_points) num_channels = len(skeleton.edges) * 2 - pafs = np.zeros((num_frames, output_size[0], output_size[1], num_channels), - dtype="float32") + pafs = np.zeros( + (num_frames, output_size[0], output_size[1], num_channels), dtype="float32" + ) for frame_idx, points_arrays in enumerate(frames_inst_points): for inst_points in points_arrays: for c, (src_node, dst_node) in enumerate(skeleton.edges): @@ -236,19 +276,23 @@ def generate_pafs_from_points(frames_inst_points, skeleton, shape, return pafs + def _get_conf_ball(output_size, sigma): # Pre-allocate coordinate grid xv = np.linspace(0, output_size[1] - 1, output_size[1], dtype="float32") yv = np.linspace(0, output_size[0] - 1, output_size[0], dtype="float32") XX, YY = np.meshgrid(xv, yv) - x, y = output_size[1]//2, output_size[0]//2 + x, y = output_size[1] // 2, output_size[0] // 2 ball_full = np.exp(-((YY - y) ** 2 + (XX - x) ** 2) / (2 * sigma ** 2)) - window_size = int(sigma*4) - ball_window = ball_full[y-window_size:y+window_size, x-window_size:x+window_size] + window_size = int(sigma * 4) + ball_window = ball_full[ + y - window_size : y + window_size, x - window_size : x + window_size + ] return ball_window + def _raster_ball(arr, ball, c, x, y): x, y = int(x), int(y) ball_h, ball_w = ball.shape @@ -257,8 +301,8 @@ def _raster_ball(arr, ball, c, x, y): ball_slice_y = slice(0, ball_h) ball_slice_x = slice(0, ball_w) - arr_slice_y = slice(y-ball_h//2, y+ball_h//2) - arr_slice_x = slice(x-ball_w//2, x+ball_w//2) + arr_slice_y = slice(y - ball_h // 2, y + ball_h // 2) + arr_slice_x = slice(x - ball_w // 2, x + ball_w // 2) # crop ball if it would be out of array bounds # i.e., it's close to edge @@ -275,24 +319,26 @@ def _raster_ball(arr, ball, c, x, y): if arr_slice_y.stop > out_h: cut = arr_slice_y.stop - out_h arr_slice_y = slice(arr_slice_y.start, out_h) - ball_slice_y = slice(0, ball_h-cut) + ball_slice_y = slice(0, ball_h - cut) if arr_slice_x.stop > out_w: cut = arr_slice_x.stop - out_w arr_slice_x = slice(arr_slice_x.start, out_w) - ball_slice_x = slice(0, ball_w-cut) + ball_slice_x = slice(0, ball_w - cut) - if ball_slice_x.stop <= ball_slice_x.start \ - or ball_slice_y.stop <= ball_slice_y.start: + if ( + ball_slice_x.stop <= ball_slice_x.start + or ball_slice_y.stop <= ball_slice_y.start + ): return # impose ball on array arr[arr_slice_y, arr_slice_x, c] = np.maximum( - arr[arr_slice_y, arr_slice_x, c], - ball[ball_slice_y, ball_slice_x] - ) + arr[arr_slice_y, arr_slice_x, c], ball[ball_slice_y, ball_slice_x] + ) + -def generate_confidence_maps(labels:Labels, sigma=5.0, scale=1): +def generate_confidence_maps(labels: Labels, sigma=5.0, scale=1): """Wrapper for generate_confmaps_from_points which takes labels instead of points.""" # TODO: multi-skeleton support @@ -306,16 +352,19 @@ def generate_confidence_maps(labels:Labels, sigma=5.0, scale=1): return confmaps + def _raster_pafs(arr, c, x0, y0, x1, y1, sigma): # skip if any nan - if np.isnan(np.sum((x0, y0, x1, y1))): return + if np.isnan(np.sum((x0, y0, x1, y1))): + return delta_x, delta_y = x1 - x0, y1 - y0 - edge_len = (delta_x ** 2 + delta_y ** 2) ** .5 + edge_len = (delta_x ** 2 + delta_y ** 2) ** 0.5 # skip if no distance between nodes - if edge_len == 0.0: return + if edge_len == 0.0: + return edge_x = delta_x / edge_len edge_y = delta_y / edge_len @@ -330,6 +379,7 @@ def _raster_pafs(arr, c, x0, y0, x1, y1, sigma): yy = perp_y0, perp_y0 + delta_y, perp_y1 + delta_y, perp_y1 from skimage.draw import polygon, polygon_perimeter + points_y, points_x = polygon(yy, xx, (arr.shape[0], arr.shape[1])) perim_y, perim_x = polygon_perimeter(yy, xx, shape=(arr.shape[0], arr.shape[1])) @@ -341,7 +391,8 @@ def _raster_pafs(arr, c, x0, y0, x1, y1, sigma): arr[y, x, c] = edge_x arr[y, x, c + 1] = edge_y -def generate_pafs(labels: Labels, sigma:float=5.0, scale:float=1.0) -> np.ndarray: + +def generate_pafs(labels: Labels, sigma: float = 5.0, scale: float = 1.0) -> np.ndarray: """Wrapper for generate_pafs_from_points which takes labels instead of points.""" # TODO: multi-skeleton support @@ -355,6 +406,7 @@ def generate_pafs(labels: Labels, sigma:float=5.0, scale:float=1.0) -> np.ndarra return pafs + def point_array_bounding_box(point_array: np.ndarray) -> tuple: """Returns (x0, y0, x1, y1) for box that bounds point_array.""" x0 = np.nanmin(point_array[:, 0]) @@ -363,6 +415,7 @@ def point_array_bounding_box(point_array: np.ndarray) -> tuple: y1 = np.nanmax(point_array[:, 1]) return x0, y0, x1, y1 + def pad_rect_to(x0: int, y0: int, x1: int, y1: int, pad_to: tuple, within: tuple): """Grow (x0, y0, x1, y1) so it's as large as pad_to but stays inside within. @@ -381,14 +434,14 @@ def pad_rect_to(x0: int, y0: int, x1: int, y1: int, pad_to: tuple, within: tuple * 0 <= (x1-x0) <= within w """ pad_to_y, pad_to_x = pad_to - x_margin = pad_to_x - (x1-x0) - y_margin = pad_to_y - (y1-y0) + x_margin = pad_to_x - (x1 - x0) + y_margin = pad_to_y - (y1 - y0) # initial values - x0 -= x_margin//2 - x1 += x_margin-x_margin//2 - y0 -= y_margin//2 - y1 += y_margin-y_margin//2 + x0 -= x_margin // 2 + x1 += x_margin - x_margin // 2 + y0 -= y_margin // 2 + y1 += y_margin - y_margin // 2 # adjust to stay inside within within_y, within_x = within @@ -397,36 +450,40 @@ def pad_rect_to(x0: int, y0: int, x1: int, y1: int, pad_to: tuple, within: tuple x1 = min(within_x, pad_to_x) if x1 > within_x: x1 = within_x - x0 = max(0, within_x-pad_to_x) + x0 = max(0, within_x - pad_to_x) if y0 < 0: y0 = 0 y1 = min(within_y, pad_to_y) if y1 > within_y: y1 = within_y - y0 = max(0, within_y-pad_to_y) + y0 = max(0, within_y - pad_to_y) return x0, y0, x1, y1 + def generate_centroid_points(points: list) -> list: """Takes the points for each instance and replaces it with a single centroid point.""" - centroids = [[_centroid(*point_array_bounding_box(point_array)) - for point_array in frame] for frame in points] + centroids = [ + [_centroid(*point_array_bounding_box(point_array)) for point_array in frame] + for frame in points + ] return centroids + def _to_np_point(x, y) -> np.ndarray: a = np.array((x, y)) return np.expand_dims(a, axis=0) + def _centroid(x0, y0, x1, y1) -> np.ndarray: - return _to_np_point(x = x0+(x1-x0)/2, y = y0+(y1-y0)/2) + return _to_np_point(x=x0 + (x1 - x0) / 2, y=y0 + (y1 - y0) / 2) + def instance_crops( - imgs: np.ndarray, - points: list, - min_crop_size: int=0, - negative_samples: int=0) -> Tuple[np.ndarray, List]: + imgs: np.ndarray, points: list, min_crop_size: int = 0, negative_samples: int = 0 +) -> Tuple[np.ndarray, List]: """ Take imgs, points and return imgs, points cropped around instances. @@ -455,29 +512,38 @@ def instance_crops( # Add bounding boxes for *random* negative samples if negative_samples > 0: - neg_img_idxs, neg_bbs = get_random_negative_samples(img_idxs, bbs, img_shape, negative_samples) + neg_img_idxs, neg_bbs = get_random_negative_samples( + img_idxs, bbs, img_shape, negative_samples + ) neg_imgs, neg_points = _crop_and_transform(imgs, points, neg_img_idxs, neg_bbs) - crop_imgs, crop_points = _extend_imgs_points(crop_imgs, crop_points, neg_imgs, neg_points) + crop_imgs, crop_points = _extend_imgs_points( + crop_imgs, crop_points, neg_imgs, neg_points + ) return crop_imgs, crop_points + def _crop_and_transform(imgs, points, img_idxs, bbs): crop_imgs = _crop(imgs, img_idxs, bbs) crop_points = _transform_crop_points(points, img_idxs, bbs) return crop_imgs, crop_points + def _extend_imgs_points(imgs_a, points_a, imgs_b, points_b): imgs = np.concatenate((imgs_a, imgs_b)) points = points_a + points_b return imgs, points + def _pad_bbs_to_min(bbs, min_crop_size, img_shape): padded_bbs = _pad_bbs( - bbs = bbs, - box_shape = _bb_pad_shape(bbs, min_crop_size, img_shape), - img_shape = img_shape) + bbs=bbs, + box_shape=_bb_pad_shape(bbs, min_crop_size, img_shape), + img_shape=img_shape, + ) return padded_bbs + def _bb_pad_shape(bbs, min_crop_size, img_shape): """ Given a list of bounding boxes, finds the square size which will be: @@ -499,8 +565,8 @@ def _bb_pad_shape(bbs, min_crop_size, img_shape): max_width = max((x1 - x0 for (x0, y0, x1, y1) in bbs)) max_dim = max(max_height, max_width) max_dim = max(max_dim, min_crop_size) - max_dim += 20 # pad - box_side = ceil(max_dim/64)*64 # round up to nearest multiple of 64 + max_dim += 20 # pad + box_side = ceil(max_dim / 64) * 64 # round up to nearest multiple of 64 # TODO: make sure we have valid box_size @@ -509,6 +575,7 @@ def _bb_pad_shape(bbs, min_crop_size, img_shape): return box_shape + def _transform_crop_points(points, img_idxs, bbs): """Takes points on the original images and returns points in bounding boxes. @@ -533,14 +600,19 @@ def _transform_crop_points(points, img_idxs, bbs): crop_points = list(map(lambda i: points[i], img_idxs)) # translate points to location w/in cropped image - crop_points = [_translate_points_array(points_array, bbs[i][0], bbs[i][1]) - for i, points_array in enumerate(crop_points)] + crop_points = [ + _translate_points_array(points_array, bbs[i][0], bbs[i][1]) + for i, points_array in enumerate(crop_points) + ] return crop_points + def _translate_points_array(points_array, x, y): - if len(points_array) == 0: return points_array - return points_array - np.asarray([x,y]) + if len(points_array) == 0: + return points_array + return points_array - np.asarray([x, y]) + def merge_boxes(box_a, box_b): """Return a box that contains both boxes.""" @@ -553,10 +625,12 @@ def merge_boxes(box_a, box_b): return (c_x1, c_y1, c_x2, c_y2) + def merge_boxes_with_overlap(boxes): """Return a list of boxes after merging any overlapping boxes.""" - if len(boxes) < 2: return boxes + if len(boxes) < 2: + return boxes first_box = boxes[0] other_boxes = boxes[1:] @@ -575,6 +649,7 @@ def merge_boxes_with_overlap(boxes): return [first_box] + other_boxes + def merge_boxes_with_overlap_and_padding(boxes, pad_factor_box, within): """ Returns a list of boxes after merging any overlapping boxes @@ -589,12 +664,15 @@ def merge_boxes_with_overlap_and_padding(boxes, pad_factor_box, within): if len(merged_boxes) == len(boxes): return merged_boxes else: - return merge_boxes_with_overlap_and_padding(merged_boxes, pad_factor_box, within) + return merge_boxes_with_overlap_and_padding( + merged_boxes, pad_factor_box, within + ) + def pad_box_to_multiple(box, pad_factor_box, within): - box_h = box[3] - box[1] # difference in y - box_w = box[2] - box[0] # difference in x + box_h = box[3] - box[1] # difference in y + box_w = box[2] - box[0] # difference in x pad_h, pad_w = pad_factor_box @@ -606,7 +684,8 @@ def pad_box_to_multiple(box, pad_factor_box, within): multiple = max(multiple_h, multiple_w) # Return padded box - return pad_rect_to(*box, (pad_h*multiple, pad_w*multiple), within) + return pad_rect_to(*box, (pad_h * multiple, pad_w * multiple), within) + def bounding_box_nms(boxes, scores, iou_threshold): """ @@ -639,10 +718,10 @@ def bounding_box_nms(boxes, scores, iou_threshold): pick = [] # grab the coordinates of the bounding boxes - x1 = boxes[:,0] - y1 = boxes[:,1] - x2 = boxes[:,2] - y2 = boxes[:,3] + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] # compute the area of the bounding boxes, sort by scores area = (x2 - x1 + 1) * (y2 - y1 + 1) @@ -673,16 +752,17 @@ def bounding_box_nms(boxes, scores, iou_threshold): overlap = (w * h) / area[idxs[:last]] # delete all indexes from the index list that have - idxs = np.delete(idxs, np.concatenate(([last], - np.where(overlap > iou_threshold)[0]))) + idxs = np.delete( + idxs, np.concatenate(([last], np.where(overlap > iou_threshold)[0])) + ) # return the list of picked boxes return pick + def negative_anchor_crops( - labels: Labels, - negative_anchors: Dict['Video', Dict[int, Tuple]], - scale, crop_size) -> Tuple[np.ndarray, List]: + labels: Labels, negative_anchors: Dict["Video", Dict[int, Tuple]], scale, crop_size +) -> Tuple[np.ndarray, List]: """ Returns crops around *specific* negative samples from Labels object. @@ -698,16 +778,17 @@ def negative_anchor_crops( # negative_anchors[video]: (frame_idx, x, y) for center of crop - neg_anchor_tuples = [(video, frame_idx, x, y) - for video in negative_anchors - for (frame_idx, x, y) in negative_anchors[video]] + neg_anchor_tuples = [ + (video, frame_idx, x, y) + for video in negative_anchors + for (frame_idx, x, y) in negative_anchors[video] + ] - if len(neg_anchor_tuples) == 0: return None, None + if len(neg_anchor_tuples) == 0: + return None, None - frame_list = [(video, frame_idx) - for (video, frame_idx, x, y) in neg_anchor_tuples] - anchors = [[_to_np_point(x,y)] - for (video, frame_idx, x, y) in neg_anchor_tuples] + frame_list = [(video, frame_idx) for (video, frame_idx, x, y) in neg_anchor_tuples] + anchors = [[_to_np_point(x, y)] for (video, frame_idx, x, y) in neg_anchor_tuples] imgs = generate_images_from_list(labels, frame_list, scale) points = generate_points_from_list(labels, frame_list, scale) @@ -723,39 +804,49 @@ def negative_anchor_crops( return crop_imgs, crop_points -def add_negative_anchor_crops(labels: Labels, imgs: np.ndarray, points: list, scale: float) -> Tuple[np.ndarray, List]: + +def add_negative_anchor_crops( + labels: Labels, imgs: np.ndarray, points: list, scale: float +) -> Tuple[np.ndarray, List]: """Wrapper to build and append negative anchor crops.""" # Include any *specific* negative samples neg_imgs, neg_points = negative_anchor_crops( - labels, - labels.negative_anchors, - scale=scale, - crop_size=imgs.shape[1]) + labels, labels.negative_anchors, scale=scale, crop_size=imgs.shape[1] + ) if neg_imgs is not None: imgs, points = _extend_imgs_points(imgs, points, neg_imgs, neg_points) return imgs, points + def get_random_negative_samples(img_idxs, bbs, img_shape, negative_samples): - if len(bbs) == 0: return + if len(bbs) == 0: + return frame_count = len({frame for frame in img_idxs}) - box_side = bbs[0][2] - bbs[0][0] # x1 - x0 for the first bb + box_side = bbs[0][2] - bbs[0][0] # x1 - x0 for the first bb neg_sample_list = [] # Collect negative samples (and some extras) - for _ in range(max(int(negative_samples*1.5), negative_samples+10)): + for _ in range(max(int(negative_samples * 1.5), negative_samples + 10)): # find negative sample # pick a random image sample_img_idx = random.randrange(frame_count) # pick a random box within image - x, y = random.randrange(img_shape[1] - box_side), random.randrange(img_shape[0] - box_side) - sample_bb = (x, y, x+box_side, y+box_side) + x, y = ( + random.randrange(img_shape[1] - box_side), + random.randrange(img_shape[0] - box_side), + ) + sample_bb = (x, y, x + box_side, y + box_side) - frame_bbs = [bbs[i] for i, frame in enumerate(img_idxs) if frame == sample_img_idx] - area_covered = sum(map(lambda bb: box_overlap_area(sample_bb, bb), frame_bbs))/(box_side**2) + frame_bbs = [ + bbs[i] for i, frame in enumerate(img_idxs) if frame == sample_img_idx + ] + area_covered = sum( + map(lambda bb: box_overlap_area(sample_bb, bb), frame_bbs) + ) / (box_side ** 2) # append negative sample to lists neg_sample_list.append((area_covered, sample_img_idx, sample_bb)) @@ -766,9 +857,14 @@ def get_random_negative_samples(img_idxs, bbs, img_shape, negative_samples): return neg_img_idxs[:negative_samples], neg_bbs[:negative_samples] + def _bbs_from_points(points): # List of bounding box for every instance - bbs = [point_array_bounding_box(point_array) for frame in points for point_array in frame] + bbs = [ + point_array_bounding_box(point_array) + for frame in points + for point_array in frame + ] bbs = [(int(x0), int(y0), int(x1), int(y1)) for (x0, y0, x1, y1) in bbs] # List to map bb to its img frame idx @@ -776,6 +872,7 @@ def _bbs_from_points(points): return bbs, img_idxs + def box_overlap_area(box_a, box_b): # determine the (x, y)-coordinates of the intersection rectangle xA = max(box_a[0], box_b[0]) @@ -788,16 +885,22 @@ def box_overlap_area(box_a, box_b): return inter_area + def _pad_bbs(bbs, box_shape, img_shape): return list(map(lambda bb: pad_rect_to(*bb, box_shape, img_shape), bbs)) + def _crop(imgs, img_idxs, bbs) -> np.ndarray: - imgs = [imgs[img_idxs[i], bb[1]:bb[3], bb[0]:bb[2]] for i, bb in enumerate(bbs)] # imgs[frame_idx, y0:y1, x0:x1] + imgs = [ + imgs[img_idxs[i], bb[1] : bb[3], bb[0] : bb[2]] for i, bb in enumerate(bbs) + ] # imgs[frame_idx, y0:y1, x0:x1] imgs = np.stack(imgs, axis=0) return imgs -def fullsize_points_from_crop(idx: int, point_array: np.ndarray, - bbs: list, img_idxs: list): + +def fullsize_points_from_crop( + idx: int, point_array: np.ndarray, bbs: list, img_idxs: list +): """ Map point within crop back to original image frames. @@ -811,13 +914,14 @@ def fullsize_points_from_crop(idx: int, point_array: np.ndarray, """ bb = bbs[idx] - top_left_point = ((bb[0], bb[1]),) # for (x, y) column vector + top_left_point = ((bb[0], bb[1]),) # for (x, y) column vector point_array += np.array(top_left_point) frame_idx = img_idxs[idx] return frame_idx, point_array + def demo_datagen_time(): data_path = "tests/data/json_format_v2/centered_pair_predictions.json" @@ -828,7 +932,10 @@ def demo_datagen_time(): timing_reps = 1 import timeit - t = timeit.timeit("generate_confidence_maps(labels)", number=timing_reps, globals=globals()) + + t = timeit.timeit( + "generate_confidence_maps(labels)", number=timing_reps, globals=globals() + ) t /= timing_reps print(f"confmaps time: {t} = {t/count} s/frame for {count} frames") @@ -836,10 +943,13 @@ def demo_datagen_time(): t /= timing_reps print(f"pafs time: {t} = {t/count} s/frame for {count} frames") + def demo_datagen(): import os - data_path = "C:/Users/tdp/OneDrive/code/sandbox/leap_wt_gold_pilot/centered_pair.json" + data_path = ( + "C:/Users/tdp/OneDrive/code/sandbox/leap_wt_gold_pilot/centered_pair.json" + ) if not os.path.exists(data_path): data_path = "tests/data/json_format_v1/centered_pair.json" # data_path = "tests/data/json_format_v2/minimal_instance.json" @@ -852,12 +962,11 @@ def demo_datagen(): scale = 1 imgs, points = generate_training_data( - labels = labels, - params = dict( - scale = scale, - instance_crop = True, - min_crop_size = 0, - negative_samples = 0)) + labels=labels, + params=dict( + scale=scale, instance_crop=True, min_crop_size=0, negative_samples=0 + ), + ) print("--imgs--") print(imgs.shape) @@ -890,7 +999,9 @@ def demo_datagen(): skeleton = labels.skeletons[0] img_shape = (imgs.shape[1], imgs.shape[2]) - confmaps = generate_confmaps_from_points(points, skeleton, img_shape, scale=.5, sigma=5.0*scale) + confmaps = generate_confmaps_from_points( + points, skeleton, img_shape, scale=0.5, sigma=5.0 * scale + ) print("--confmaps--") print(confmaps.shape) print(confmaps.dtype) @@ -898,7 +1009,9 @@ def demo_datagen(): demo_confmaps(confmaps, vid) - pafs = generate_pafs_from_points(points, skeleton, img_shape, scale=.5, sigma=5.0*scale) + pafs = generate_pafs_from_points( + points, skeleton, img_shape, scale=0.5, sigma=5.0 * scale + ) print("--pafs--") print(pafs.shape) print(pafs.dtype) @@ -908,5 +1021,6 @@ def demo_datagen(): app.exec_() + if __name__ == "__main__": - demo_datagen() \ No newline at end of file + demo_datagen() diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 9cc6ab14e..37aaa54bb 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4,6 +4,7 @@ import os import json import logging + logger = logging.getLogger(__name__) import numpy as np @@ -34,7 +35,12 @@ from sleap.nn.datagen import merge_boxes_with_overlap_and_padding from sleap.nn.peakfinding import find_all_peaks, find_all_single_peaks from sleap.nn.peakfinding_tf import peak_tf_inference -from sleap.nn.peakmatching import match_single_peaks_all, match_peaks_paf, match_peaks_paf_par, instances_nms +from sleap.nn.peakmatching import ( + match_single_peaks_all, + match_peaks_paf, + match_peaks_paf_par, + instances_nms, +) from sleap.nn.util import batch, batch_count, save_visual_outputs @@ -54,7 +60,6 @@ def skeleton(self) -> Skeleton: return self.job.model.skeletons[0] - @property def output_type(self) -> ModelOutputType: """Returns the output type of this model.""" @@ -86,7 +91,9 @@ def output_relative_scale(self) -> float: return self.job.model.output_scale - def compute_output_shape(self, input_shape: Tuple[int], relative=True) -> Tuple[int]: + def compute_output_shape( + self, input_shape: Tuple[int], relative=True + ) -> Tuple[int]: """Returns the output tensor shape for a given input shape. Args: @@ -106,11 +113,11 @@ def compute_output_shape(self, input_shape: Tuple[int], relative=True) -> Tuple[ output_shape = ( int(input_shape[0] * scaling_factor), int(input_shape[1] * scaling_factor), - self.output_channels) + self.output_channels, + ) return output_shape - - + def load_model(self, model_path: Text = None) -> keras.Model: """Loads a saved model from disk and caches it. @@ -125,17 +132,16 @@ def load_model(self, model_path: Text = None) -> keras.Model: if not model_path: # Try the best model first. - model_path = os.path.join(self.job.save_dir, - self.job.best_model_filename) + model_path = os.path.join(self.job.save_dir, self.job.best_model_filename) # Try the final model if that didn't exist. if not os.path.exists(model_path): - model_path = os.path.join(self.job.save_dir, - self.job.final_model_filename) + model_path = os.path.join( + self.job.save_dir, self.job.final_model_filename + ) # Load from disk. - keras_model = keras.models.load_model(model_path, - custom_objects={"tf": tf}) + keras_model = keras.models.load_model(model_path, custom_objects={"tf": tf}) logger.info("Loaded model: " + model_path) # Store the loaded model path for reference. @@ -151,16 +157,14 @@ def load_model(self, model_path: Text = None) -> keras.Model: # Create input node with undetermined height/width. input_tensor = keras.layers.Input((None, None, self.input_channels)) keras_model = keras.Model( - inputs=input_tensor, - outputs=keras_model(input_tensor)) - + inputs=input_tensor, outputs=keras_model(input_tensor) + ) # Save the modified and loaded model. self._keras_model = keras_model return self.keras_model - @property def keras_model(self) -> keras.Model: """Returns the underlying Keras model, loading it if necessary.""" @@ -170,23 +174,25 @@ def keras_model(self) -> keras.Model: return self._keras_model - @property def model_path(self) -> Text: """Returns the path to the loaded model.""" if not self._model_path: - raise AttributeError("No model loaded. Call inference_model.load_model() first.") + raise AttributeError( + "No model loaded. Call inference_model.load_model() first." + ) return self._model_path - @property def trained_input_shape(self) -> Tuple[int]: """Returns the shape of the model when it was loaded.""" if not self._trained_input_shape: - raise AttributeError("No model loaded. Call inference_model.load_model() first.") + raise AttributeError( + "No model loaded. Call inference_model.load_model() first." + ) return self._trained_input_shape @@ -194,11 +200,12 @@ def trained_input_shape(self) -> Tuple[int]: def output_channels(self) -> int: """Returns the number of output channels of the model.""" if not self._trained_input_shape: - raise AttributeError("No model loaded. Call inference_model.load_model() first.") + raise AttributeError( + "No model loaded. Call inference_model.load_model() first." + ) return self._output_channels - @property def input_channels(self) -> int: """Returns the number of channels expected for the input data.""" @@ -206,14 +213,12 @@ def input_channels(self) -> int: # TODO: Multi-output support return self.trained_input_shape[-1] - @property def is_grayscale(self) -> bool: """Returns True if the model expects grayscale images.""" return self.input_channels == 1 - @property def down_blocks(self): """Returns the number of pooling steps applied during the model. @@ -223,12 +228,13 @@ def down_blocks(self): # TODO: Replace this with an explicit calculation that takes stride sizes into account. return self.job.model.down_blocks - - - def predict(self, X: Union[np.ndarray, List[np.ndarray]], + + def predict( + self, + X: Union[np.ndarray, List[np.ndarray]], batch_size: int = 32, - normalize: bool = True - ) -> Union[np.ndarray, List[np.ndarray]]: + normalize: bool = True, + ) -> Union[np.ndarray, List[np.ndarray]]: """Runs inference on the input data. This is a simple wrapper around the keras model predict function. @@ -248,11 +254,11 @@ def predict(self, X: Union[np.ndarray, List[np.ndarray]], # TODO: Store normalization scheme in the model metadata. if isinstance(X, np.ndarray): if X.dtype == np.dtype("uint8"): - X = X.astype("float32") / 255. + X = X.astype("float32") / 255.0 elif isinstance(X, list): for i in range(len(X)): if X[i].dtype == np.dtype("uint8"): - X[i] = X[i].astype("float32") / 255. + X[i] = X[i].astype("float32") / 255.0 return self.keras_model.predict(X, batch_size=batch_size) @@ -301,15 +307,17 @@ class Predictor: """ training_jobs: Dict[ModelOutputType, TrainingJob] = None - inference_models: Dict[ModelOutputType, InferenceModel] = attr.ib(default=attr.Factory(dict)) + inference_models: Dict[ModelOutputType, InferenceModel] = attr.ib( + default=attr.Factory(dict) + ) skeleton: Skeleton = None inference_batch_size: int = 2 read_chunk_size: int = 256 - save_frequency: int = 100 # chunks + save_frequency: int = 100 # chunks nms_min_thresh = 0.3 nms_kernel_size: int = 9 - nms_sigma: float = 3. + nms_sigma: float = 3.0 min_score_to_node_ratio: float = 0.2 min_score_midpts: float = 0.05 min_score_integral: float = 0.6 @@ -337,11 +345,12 @@ def __attrs_post_init__(self): self.inference_models[model_output_type] = InferenceModel(job=training_job) self.inference_models[model_output_type].load_model() - - def predict(self, - input_video: Union[dict, Video], - frames: Optional[List[int]] = None, - is_async: bool = False) -> List[LabeledFrame]: + def predict( + self, + input_video: Union[dict, Video], + frames: Optional[List[int]] = None, + is_async: bool = False, + ) -> List[LabeledFrame]: """Run the entire inference pipeline on an input video. Args: @@ -396,7 +405,6 @@ def predict(self, logger.info(" Frames: %d" % len(frames)) logger.info(" Frame shape (H x W): %d x %d" % (vid.height, vid.width)) - # Initialize tracking if self.with_tracking: tracker = FlowShiftTracker(window=self.flow_window, verbosity=0) @@ -449,7 +457,8 @@ def predict(self, # Use centroid predictions to get subchunks of crops. subchunks_to_process = self.centroid_crop_inference( - imgs_full, frames_idx) + imgs_full, frames_idx + ) else: # Create transform object @@ -485,19 +494,19 @@ def predict(self, logger.warning("No PAF model! Running in SINGLE INSTANCE mode.") subchunk_lfs = self.single_instance_inference( - subchunk_imgs_full, - subchunk_transform, - vid) + subchunk_imgs_full, subchunk_transform, vid + ) else: # Pipeline for predicting multiple animals in a frame # This uses confidence maps and part affinity fields subchunk_lfs = self.multi_instance_inference( - subchunk_imgs_full, - subchunk_transform, - vid) + subchunk_imgs_full, subchunk_transform, vid + ) - logger.info(f" Subchunk frames with instances found: {len(subchunk_lfs)}") + logger.info( + f" Subchunk frames with instances found: {len(subchunk_lfs)}" + ) subchunk_results.append(subchunk_lfs) @@ -519,9 +528,13 @@ def predict(self, predicted_frames_chunk = [] for subchunk_frames in subchunk_results: predicted_frames_chunk.extend(subchunk_frames) - predicted_frames_chunk = LabeledFrame.merge_frames(predicted_frames_chunk, video=vid) + predicted_frames_chunk = LabeledFrame.merge_frames( + predicted_frames_chunk, video=vid + ) - logger.info(f" Instances found on {len(predicted_frames_chunk)} out of {len(imgs_full)} frames.") + logger.info( + f" Instances found on {len(predicted_frames_chunk)} out of {len(imgs_full)} frames." + ) if len(predicted_frames_chunk): @@ -545,18 +558,25 @@ def predict(self, labels = Labels(labeled_frames=predicted_frames) if self.output_path is not None: if self.output_path.endswith("json"): - Labels.save_json(labels, filename=self.output_path, compress=True) + Labels.save_json( + labels, filename=self.output_path, compress=True + ) else: Labels.save_hdf5(labels, filename=self.output_path) - logger.info(" Saved to: %s [%.1fs]" % (self.output_path, time() - t0)) + logger.info( + " Saved to: %s [%.1fs]" % (self.output_path, time() - t0) + ) elapsed = time() - t0_chunk total_elapsed = time() - t0_start fps = len(predicted_frames) / total_elapsed frames_left = len(frames) - len(predicted_frames) eta = (frames_left / fps) if fps > 0 else 0 - logger.info(" Finished chunk [%.1fs / %.1f FPS / ETA: %.1f min]" % (elapsed, fps, eta / 60)) + logger.info( + " Finished chunk [%.1fs / %.1f FPS / ETA: %.1f min]" + % (elapsed, fps, eta / 60) + ) sys.stdout.flush() @@ -602,13 +622,13 @@ def predict_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: return result - - def centroid_crop_inference(self, - imgs: np.ndarray, - frames_idx: List[int], - box_size: int=None, - do_merge: bool=True) \ - -> List[Tuple[np.ndarray, DataTransform]]: + def centroid_crop_inference( + self, + imgs: np.ndarray, + frames_idx: List[int], + box_size: int = None, + do_merge: bool = True, + ) -> List[Tuple[np.ndarray, DataTransform]]: """ Takes stack of images and runs centroid inference to get crops. @@ -630,7 +650,10 @@ def centroid_crop_inference(self, # TODO: Replace this calculation when model-specific divisibility calculation implemented. divisor = 2 ** centroid_model.down_blocks - crop_within = ((imgs.shape[1] // divisor) * divisor, (imgs.shape[2] // divisor) * divisor) + crop_within = ( + (imgs.shape[1] // divisor) * divisor, + (imgs.shape[2] // divisor) * divisor, + ) logger.info(f" crop_within: {crop_within}") # Create transform @@ -638,28 +661,38 @@ def centroid_crop_inference(self, # and will also let us map the points on the scaled image to # points on the original images so we can crop original images. centroid_transform = DataTransform() - target_shape = (int(imgs.shape[1] * centroid_model.input_scale), int(imgs.shape[2] * centroid_model.input_scale)) + target_shape = ( + int(imgs.shape[1] * centroid_model.input_scale), + int(imgs.shape[2] * centroid_model.input_scale), + ) # Scale to match input size of trained centroid model. centroid_imgs_scaled = centroid_transform.scale_to( - imgs=imgs, target_size=target_shape) + imgs=imgs, target_size=target_shape + ) # Predict centroid confidence maps, then find peaks. t0 = time() - centroid_confmaps = centroid_model.predict(centroid_imgs_scaled, - batch_size=self.inference_batch_size) + centroid_confmaps = centroid_model.predict( + centroid_imgs_scaled, batch_size=self.inference_batch_size + ) - peaks, peak_vals = find_all_peaks(centroid_confmaps, - min_thresh=self.nms_min_thresh, sigma=self.nms_sigma) + peaks, peak_vals = find_all_peaks( + centroid_confmaps, min_thresh=self.nms_min_thresh, sigma=self.nms_sigma + ) elapsed = time() - t0 total_peaks = sum([len(frame_peaks[0]) for frame_peaks in peaks]) - logger.info(f" Found {total_peaks} centroid peaks ({total_peaks / len(peaks):.2f} centroids/frame) [{elapsed:.2f}s].") + logger.info( + f" Found {total_peaks} centroid peaks ({total_peaks / len(peaks):.2f} centroids/frame) [{elapsed:.2f}s]." + ) if box_size is None: # Get training bounding box size to determine (min) centroid crop size. # TODO: fix this to use a stored value or move this logic elsewhere - crop_size = int(max(cm_model.trained_input_shape[1:3]) // cm_model.input_scale) + crop_size = int( + max(cm_model.trained_input_shape[1:3]) // cm_model.input_scale + ) bb_half = crop_size // 2 # bb_half = (crop_size + self.crop_padding) // 2 else: @@ -682,18 +715,29 @@ def centroid_crop_inference(self, for peak_i in range(frame_peaks[0].shape[0]): # Rescale peak back onto full-sized image - peak_x = int(frame_peaks[0][peak_i][0] / centroid_model.output_scale) - peak_y = int(frame_peaks[0][peak_i][1] / centroid_model.output_scale) - - boxes.append((peak_x - bb_half, peak_y - bb_half, - peak_x + bb_half, peak_y + bb_half)) + peak_x = int( + frame_peaks[0][peak_i][0] / centroid_model.output_scale + ) + peak_y = int( + frame_peaks[0][peak_i][1] / centroid_model.output_scale + ) + + boxes.append( + ( + peak_x - bb_half, + peak_y - bb_half, + peak_x + bb_half, + peak_y + bb_half, + ) + ) if do_merge: # Merge overlapping boxes and pad to multiple of crop size merged_boxes = merge_boxes_with_overlap_and_padding( - boxes=boxes, - pad_factor_box=(self.crop_growth, self.crop_growth), - within=crop_within) + boxes=boxes, + pad_factor_box=(self.crop_growth, self.crop_growth), + within=crop_within, + ) else: # Just return the boxes centered around each centroid. @@ -751,14 +795,15 @@ def centroid_crop_inference(self, # Add subchunk subchunks.append((imgs_cropped, transform)) - logger.info(f" Subchunk for size {crop_size} has {len(imgs_cropped)} crops.") + logger.info( + f" Subchunk for size {crop_size} has {len(imgs_cropped)} crops." + ) else: logger.info(" No centroids found so done with this chunk.") return subchunks - def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame]: """Run the single instance pipeline for a stack of images. @@ -776,8 +821,11 @@ def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame # Scale to match input size of trained model. # Images are expected to be at full resolution, but may be cropped. - assert(transform.scale == 1.0) - target_shape = (int(imgs.shape[1] * cm_model.input_scale), int(imgs.shape[2] * cm_model.input_scale)) + assert transform.scale == 1.0 + target_shape = ( + int(imgs.shape[1] * cm_model.input_scale), + int(imgs.shape[2] * cm_model.input_scale), + ) imgs_scaled = transform.scale_to(imgs=imgs, target_size=target_shape) # TODO: Adjust for divisibility @@ -787,31 +835,33 @@ def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame # Run inference. t0 = time() confmaps = cm_model.predict(imgs_scaled, batch_size=self.inference_batch_size) - logger.info( " Inferred confmaps [%.1fs]" % (time() - t0)) + logger.info(" Inferred confmaps [%.1fs]" % (time() - t0)) logger.info(f" confmaps: shape={confmaps.shape}, ptp={np.ptp(confmaps)}") t0 = time() # TODO: Move this to GPU and add subpixel refinement. # Use single highest peak in channel corresponding node - points_arrays = find_all_single_peaks(confmaps, - min_thresh=self.nms_min_thresh) + points_arrays = find_all_single_peaks(confmaps, min_thresh=self.nms_min_thresh) # Adjust for multi-scale such that the points are at the scale of the transform. points_arrays = [pts / cm_model.output_relative_scale for pts in points_arrays] # Create labeled frames and predicted instances from the points. predicted_frames_chunk = match_single_peaks_all( - points_arrays=points_arrays, - skeleton=cm_model.skeleton, - transform=transform, - video=video) + points_arrays=points_arrays, + skeleton=cm_model.skeleton, + transform=transform, + video=video, + ) logger.info(" Used highest peaks to create instances [%.1fs]" % (time() - t0)) # Save confmaps if self.output_path is not None and self.save_confmaps_pafs: - raise NotImplementedError("Not saving confmaps/pafs because feature currently not working.") + raise NotImplementedError( + "Not saving confmaps/pafs because feature currently not working." + ) # Disable save_confmaps_pafs since not currently working. # The problem is that we can't put data for different crop sizes # all into a single h5 datasource. It's now possible to view live @@ -822,7 +872,6 @@ def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame return predicted_frames_chunk - def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame]: """Run the multi-instance inference pipeline for a stack of images. @@ -844,15 +893,20 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] # Scale to match input resolution of model. # Images are expected to be at full resolution, but may be cropped. - assert(transform.scale == 1.0) - cm_target_shape = (int(imgs.shape[1] * cm_model.input_scale), int(imgs.shape[2] * cm_model.input_scale)) + assert transform.scale == 1.0 + cm_target_shape = ( + int(imgs.shape[1] * cm_model.input_scale), + int(imgs.shape[2] * cm_model.input_scale), + ) imgs_scaled = transform.scale_to(imgs=imgs, target_size=cm_target_shape) if imgs_scaled.dtype == np.dtype("uint8"): # TODO: Unify normalization. - imgs_scaled = imgs_scaled.astype("float32") / 255. - + imgs_scaled = imgs_scaled.astype("float32") / 255.0 + # TODO: Unfuck this whole workflow if self.gpu_peak_finding: - confmaps_shape = cm_model.compute_output_shape((imgs_scaled.shape[1], imgs_scaled.shape[2])) + confmaps_shape = cm_model.compute_output_shape( + (imgs_scaled.shape[1], imgs_scaled.shape[2]) + ) peaks, peak_vals, confmaps = peak_tf_inference( model=cm_model.keras_model, confmaps_shape=confmaps_shape, @@ -863,12 +917,16 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] upsample_factor=int(self.supersample_factor / cm_model.output_scale), win_size=self.supersample_window_size, return_confmaps=self.save_confmaps_pafs, - batch_size=self.inference_batch_size - ) + batch_size=self.inference_batch_size, + ) else: - confmaps = cm_model.predict(imgs_scaled, batch_size=self.inference_batch_size) - peaks, peak_vals = find_all_peaks(confmaps, min_thresh=self.nms_min_thresh, sigma=self.nms_sigma) + confmaps = cm_model.predict( + imgs_scaled, batch_size=self.inference_batch_size + ) + peaks, peak_vals = find_all_peaks( + confmaps, min_thresh=self.nms_min_thresh, sigma=self.nms_sigma + ) # # Undo just the scaling so we're back to full resolution, but possibly cropped. for t in range(len(peaks)): # frames @@ -880,14 +938,27 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] transform.scale = 1.0 elapsed = time() - t0 - total_peaks = sum([len(channel_peaks) for frame_peaks in peaks for channel_peaks in frame_peaks]) - logger.info(f" Found {total_peaks} peaks ({total_peaks / len(imgs):.2f} peaks/frame) [{elapsed:.2f}s].") + total_peaks = sum( + [ + len(channel_peaks) + for frame_peaks in peaks + for channel_peaks in frame_peaks + ] + ) + logger.info( + f" Found {total_peaks} peaks ({total_peaks / len(imgs):.2f} peaks/frame) [{elapsed:.2f}s]." + ) # logger.info(f" peaks: {peaks}") # Scale to match input resolution of model. # Images are expected to be at full resolution, but may be cropped. - paf_target_shape = (int(imgs.shape[1] * paf_model.input_scale), int(imgs.shape[2] * paf_model.input_scale)) - if (imgs_scaled.shape[1] == paf_target_shape[0]) and (imgs_scaled.shape[2] == paf_target_shape[1]): + paf_target_shape = ( + int(imgs.shape[1] * paf_model.input_scale), + int(imgs.shape[2] * paf_model.input_scale), + ) + if (imgs_scaled.shape[1] == paf_target_shape[0]) and ( + imgs_scaled.shape[2] == paf_target_shape[1] + ): # No need to scale again if we're already there, so just adjust the stored scale transform.scale = paf_model.input_scale @@ -898,7 +969,7 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] # Infer pafs t0 = time() pafs = paf_model.predict(imgs_scaled, batch_size=self.inference_batch_size) - logger.info( " Inferred PAFs [%.1fs]" % (time() - t0)) + logger.info(" Inferred PAFs [%.1fs]" % (time() - t0)) logger.info(f" pafs: shape={pafs.shape}, ptp={np.ptp(pafs)}") # Adjust points to the paf output scale so we can invert later (should not incur loss of precision) @@ -910,23 +981,34 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] # Determine whether to use serial or parallel version of peak-finding # Use the serial version is we're already running in a thread pool - match_peaks_function = match_peaks_paf_par if not self.is_async else match_peaks_paf + match_peaks_function = ( + match_peaks_paf_par if not self.is_async else match_peaks_paf + ) # Match peaks via PAFs t0 = time() predicted_frames_chunk = match_peaks_function( - peaks, peak_vals, pafs, paf_model.skeleton, - transform=transform, video=video, + peaks, + peak_vals, + pafs, + paf_model.skeleton, + transform=transform, + video=video, min_score_to_node_ratio=self.min_score_to_node_ratio, min_score_midpts=self.min_score_midpts, min_score_integral=self.min_score_integral, add_last_edge=self.add_last_edge, single_per_crop=self.single_per_crop, - pool=self.pool) + pool=self.pool, + ) - total_instances = sum([len(labeled_frame) for labeled_frame in predicted_frames_chunk]) + total_instances = sum( + [len(labeled_frame) for labeled_frame in predicted_frames_chunk] + ) logger.info(" Matched peaks via PAFs [%.1fs]" % (time() - t0)) - logger.info(f" Found {total_instances} instances ({total_instances / len(imgs):.2f} instances/frame)") + logger.info( + f" Found {total_instances} instances ({total_instances / len(imgs):.2f} instances/frame)" + ) # Remove overlapping predicted instances if self.overlapping_instances_nms: @@ -935,12 +1017,16 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] n = len(lf.instances) instances_nms(lf.instances) if len(lf.instances) < n: - logger.info(f" Removed {n-len(lf.instances)} overlapping instance(s) from frame {lf.frame_idx}") + logger.info( + f" Removed {n-len(lf.instances)} overlapping instance(s) from frame {lf.frame_idx}" + ) logger.info(" Instance NMS [%.1fs]" % (clock() - t0)) # Save confmaps and pafs if self.output_path is not None and self.save_confmaps_pafs: - raise NotImplementedError("Not saving confmaps/pafs because feature currently not working.") + raise NotImplementedError( + "Not saving confmaps/pafs because feature currently not working." + ) # Disable save_confmaps_pafs since not currently working. # The problem is that we can't put data for different crop sizes # all into a single h5 datasource. It's now possible to view live @@ -954,7 +1040,6 @@ def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame] def main(): - def frame_list(frame_str: str): # Handle ranges of frames. Must be of the form "1-200" @@ -962,35 +1047,73 @@ def frame_list(frame_str: str): min_max = frame_str.split("-") min_frame = int(min_max[0]) max_frame = int(min_max[1]) - return list(range(min_frame, max_frame+1)) + return list(range(min_frame, max_frame + 1)) return [int(x) for x in frame_str.split(",")] if len(frame_str) else None parser = argparse.ArgumentParser() parser.add_argument("data_path", help="Path to video file") - parser.add_argument("-m", "--model", dest='models', action='append', - help="Path to saved model (confmaps, pafs, ...) JSON. " - "Multiple models can be specified, each preceded by " - "--model. Confmap and PAF models are required.", - required=True) - parser.add_argument("--resize-input", dest="resize_input", action="store_const", - const=True, default=False, - help="resize the input layer to image size (default False)") - parser.add_argument("--with-tracking", dest="with_tracking", action="store_const", - const=True, default=False, - help="just visualize predicted confmaps/pafs (default False)") - parser.add_argument("--frames", type=frame_list, default="", - help="list of frames to predict. Either comma separated list (e.g. 1,2,3) or " - "a range separated by hyphen (e.g. 1-3). (default is entire video)") - parser.add_argument("-o", "--output", type=str, default=None, - help="The output filename to use for the predicted data.") - parser.add_argument("--out_format", choices=["hdf5", "json"], help="The format to use for" - " the output file. Either hdf5 or json. hdf5 is the default.", - default="hdf5") - parser.add_argument("--save-confmaps-pafs", dest="save_confmaps_pafs", action="store_const", - const=True, default=False, - help="Whether to save the confidence maps or pafs") - parser.add_argument("-v", "--verbose", help="Increase logging output verbosity.", action="store_true") + parser.add_argument( + "-m", + "--model", + dest="models", + action="append", + help="Path to saved model (confmaps, pafs, ...) JSON. " + "Multiple models can be specified, each preceded by " + "--model. Confmap and PAF models are required.", + required=True, + ) + parser.add_argument( + "--resize-input", + dest="resize_input", + action="store_const", + const=True, + default=False, + help="resize the input layer to image size (default False)", + ) + parser.add_argument( + "--with-tracking", + dest="with_tracking", + action="store_const", + const=True, + default=False, + help="just visualize predicted confmaps/pafs (default False)", + ) + parser.add_argument( + "--frames", + type=frame_list, + default="", + help="list of frames to predict. Either comma separated list (e.g. 1,2,3) or " + "a range separated by hyphen (e.g. 1-3). (default is entire video)", + ) + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="The output filename to use for the predicted data.", + ) + parser.add_argument( + "--out_format", + choices=["hdf5", "json"], + help="The format to use for" + " the output file. Either hdf5 or json. hdf5 is the default.", + default="hdf5", + ) + parser.add_argument( + "--save-confmaps-pafs", + dest="save_confmaps_pafs", + action="store_const", + const=True, + default=False, + help="Whether to save the confidence maps or pafs", + ) + parser.add_argument( + "-v", + "--verbose", + help="Increase logging output verbosity.", + action="store_true", + ) args = parser.parse_args() @@ -1027,14 +1150,16 @@ def frame_list(frame_str: str): img_shape = None # Create a predictor to do the work. - predictor = Predictor(training_jobs=sleap_models, + predictor = Predictor( + training_jobs=sleap_models, output_path=save_path, save_confmaps_pafs=args.save_confmaps_pafs, - with_tracking=args.with_tracking) + with_tracking=args.with_tracking, + ) # Run the inference pipeline return predictor.predict(input_video=data_path, frames=frames) if __name__ == "__main__": - main() + main() diff --git a/sleap/nn/model.py b/sleap/nn/model.py index e9d7b1480..f89826934 100644 --- a/sleap/nn/model.py +++ b/sleap/nn/model.py @@ -30,6 +30,7 @@ class ModelOutputType(Enum): by Cao et al. """ + CONFIDENCE_MAP = 0 PART_AFFINITY_FIELD = 1 CENTROIDS = 2 @@ -43,7 +44,9 @@ def __str__(self): return "centroids" else: # This shouldn't ever happen I don't think. - raise NotImplementedError(f"__str__ not implemented for ModelOutputType={self}") + raise NotImplementedError( + f"__str__ not implemented for ModelOutputType={self}" + ) @attr.s(auto_attribs=True) @@ -66,6 +69,7 @@ class Model: not set this value. """ + output_type: ModelOutputType backbone: BackboneType skeletons: Union[None, List[Skeleton]] = None @@ -74,12 +78,16 @@ class Model: def __attrs_post_init__(self): if not isinstance(self.backbone, tuple(available_archs)): - raise ValueError(f"backbone ({self.backbone}) is not " - f"in available architectures ({available_archs})") + raise ValueError( + f"backbone ({self.backbone}) is not " + f"in available architectures ({available_archs})" + ) - if not hasattr(self.backbone, 'output'): - raise ValueError(f"backbone ({self.backbone}) has now output method! " - f"Not a valid backbone architecture!") + if not hasattr(self.backbone, "output"): + raise ValueError( + f"backbone ({self.backbone}) has now output method! " + f"Not a valid backbone architecture!" + ) if self.backbone_name is None: self.backbone_name = self.backbone.__class__.__name__ @@ -108,9 +116,10 @@ def output(self, input_tensor, num_output_channels=None): elif self.output_type == ModelOutputType.PART_AFFINITY_FIELD: num_outputs_channels = len(self.skeleton[0].edges) * 2 else: - raise ValueError("Model.skeletons has not been set. " - "Cannot infer num output channels.") - + raise ValueError( + "Model.skeletons has not been set. " + "Cannot infer num output channels." + ) return self.backbone.output(input_tensor, num_output_channels) @@ -140,7 +149,7 @@ def down_blocks(self): else: return 0 - + @property def output_scale(self): """Calculates output scale relative to input.""" @@ -148,16 +157,17 @@ def output_scale(self): if hasattr(self.backbone, "output_scale"): return self.backbone.output_scale - elif hasattr(self.backbone, "down_blocks") and hasattr(self.backbone, "up_blocks"): + elif hasattr(self.backbone, "down_blocks") and hasattr( + self.backbone, "up_blocks" + ): asym = self.backbone.down_blocks - self.backbone.up_blocks - return (1 / (2 ** asym)) + return 1 / (2 ** asym) elif hasattr(self.backbone, "initial_stride"): - return (1 / self.backbone.initial_stride) + return 1 / self.backbone.initial_stride else: return 1 - @staticmethod def _structure_model(model_dict, cls): @@ -185,7 +195,8 @@ class to use. arch_idx = available_arch_names.index(model_dict["backbone_name"]) backbone_cls = available_archs[arch_idx] - return Model(backbone=backbone_cls(**model_dict["backbone"]), - output_type=ModelOutputType(model_dict["output_type"]), - skeletons=model_dict["skeletons"] - ) + return Model( + backbone=backbone_cls(**model_dict["backbone"]), + output_type=ModelOutputType(model_dict["output_type"]), + skeletons=model_dict["skeletons"], + ) diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index 360b2f0ee..3ffd8f49a 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -4,10 +4,12 @@ import zmq import jsonpickle import logging + logger = logging.getLogger(__name__) from PySide2 import QtCore, QtWidgets, QtGui, QtCharts + class LossViewer(QtWidgets.QMainWindow): def __init__(self, zmq_context=None, show_controller=True, parent=None): super(LossViewer, self).__init__(parent) @@ -16,7 +18,7 @@ def __init__(self, zmq_context=None, show_controller=True, parent=None): self.stop_button = None self.redraw_batch_interval = 40 - self.batches_to_show = 200 # -1 to show all + self.batches_to_show = 200 # -1 to show all self.ignore_outliers = False self.log_scale = True @@ -57,7 +59,7 @@ def reset(self, what=""): for s in self.series: self.series[s].pen().setColor(self.color[s]) - self.series["batch"].setMarkerSize(8.) + self.series["batch"].setMarkerSize(8.0) self.chart.addSeries(self.series["batch"]) self.chart.addSeries(self.series["epoch_loss"]) @@ -120,7 +122,9 @@ def reset(self, what=""): self.batch_options = "200,1000,5000,All".split(",") for opt in self.batch_options: field.addItem(opt) - field.currentIndexChanged.connect(lambda x: self.set_batches_to_show(self.batch_options[x])) + field.currentIndexChanged.connect( + lambda x: self.set_batches_to_show(self.batch_options[x]) + ) control_layout.addWidget(field) control_layout.addStretch(1) @@ -129,7 +133,6 @@ def reset(self, what=""): self.stop_button.clicked.connect(self.stop) control_layout.addWidget(self.stop_button) - widget = QtWidgets.QWidget() widget.setLayout(control_layout) layout.addWidget(widget) @@ -185,7 +188,7 @@ def update_y_axis(self): def setup_zmq(self, zmq_context): # Progress monitoring - self.ctx_given = (zmq_context is not None) + self.ctx_given = zmq_context is not None self.ctx = zmq.Context() if zmq_context is None else zmq_context self.sub = self.ctx.socket(zmq.SUB) self.sub.subscribe("") @@ -204,7 +207,7 @@ def stop(self): if self.zmq_ctrl is not None: # send command to stop training logger.info("Sending command to stop training") - self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop",))) + self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) if self.stop_button is not None: self.stop_button.setText("Stopping...") self.stop_button.setEnabled(False) @@ -222,7 +225,10 @@ def add_datapoint(self, x, y, which="batch"): if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: xs, ys = self.X, self.Y else: - xs, ys = self.X[-self.batches_to_show:], self.Y[-self.batches_to_show:] + xs, ys = ( + self.X[-self.batches_to_show :], + self.Y[-self.batches_to_show :], + ) points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] self.series["batch"].replace(points) @@ -234,12 +240,12 @@ def add_datapoint(self, x, y, which="batch"): if self.ignore_outliers: dy = np.ptp(ys) * 0.02 # Set Y scale to exclude outliers - q1, q3 = np.quantile(ys, (.25, .75)) - iqr = q3-q1 # interquartile range + q1, q3 = np.quantile(ys, (0.25, 0.75)) + iqr = q3 - q1 # interquartile range low = q1 - iqr * 1.5 high = q3 + iqr * 1.5 - low = max(low, min(ys) - dy) # keep within range of data + low = max(low, min(ys) - dy) # keep within range of data high = min(high, max(ys) + dy) else: # Set Y scale to show all points @@ -248,7 +254,7 @@ def add_datapoint(self, x, y, which="batch"): high = max(ys) + dy if self.log_scale: - low = max(low, 1e-5) # for log scale, low cannot be 0 + low = max(low, 1e-5) # for log scale, low cannot be 0 self.chart.axisY().setRange(low, high) @@ -290,16 +296,28 @@ def check_messages(self, timeout=10): self.epoch = msg["epoch"] elif msg["event"] == "epoch_end": self.epoch_size = max(self.epoch_size, self.last_batch_number + 1) - self.add_datapoint((self.epoch+1)*self.epoch_size, msg["logs"]["loss"], "epoch_loss") + self.add_datapoint( + (self.epoch + 1) * self.epoch_size, + msg["logs"]["loss"], + "epoch_loss", + ) if "val_loss" in msg["logs"].keys(): self.last_epoch_val_loss = msg["logs"]["val_loss"] - self.add_datapoint((self.epoch+1)*self.epoch_size, msg["logs"]["val_loss"], "val_loss") + self.add_datapoint( + (self.epoch + 1) * self.epoch_size, + msg["logs"]["val_loss"], + "val_loss", + ) elif msg["event"] == "batch_end": self.last_batch_number = msg["logs"]["batch"] - self.add_datapoint((self.epoch * self.epoch_size) + msg["logs"]["batch"], msg["logs"]["loss"]) + self.add_datapoint( + (self.epoch * self.epoch_size) + msg["logs"]["batch"], + msg["logs"]["loss"], + ) self.update_runtime() + if __name__ == "__main__": app = QtWidgets.QApplication([]) win = LossViewer() @@ -307,11 +325,11 @@ def check_messages(self, timeout=10): def test_point(x=[0]): x[0] += 1 - i = x[0]+1 - win.add_datapoint(i, i%30+1) + i = x[0] + 1 + win.add_datapoint(i, i % 30 + 1) t = QtCore.QTimer() t.timeout.connect(test_point) t.start(0) - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/nn/peakfinding.py b/sleap/nn/peakfinding.py index d9f600d7f..603910cec 100644 --- a/sleap/nn/peakfinding.py +++ b/sleap/nn/peakfinding.py @@ -1,6 +1,7 @@ import cv2 import numpy as np + def impeaksnms_cv(I, min_thresh=0.3, sigma=3, return_val=True): """ Find peaks via non-maximum suppresion using OpenCV. """ @@ -10,12 +11,10 @@ def impeaksnms_cv(I, min_thresh=0.3, sigma=3, return_val=True): # Blur if sigma is not None: - I = cv2.GaussianBlur(I, (9,9), sigma) + I = cv2.GaussianBlur(I, (9, 9), sigma) # Maximum filter - kernel = np.array([[1,1,1], - [1,0,1], - [1,1,1]]).astype("uint8") + kernel = np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]]).astype("uint8") m = cv2.dilate(I, kernel) # Convert to points @@ -24,7 +23,7 @@ def impeaksnms_cv(I, min_thresh=0.3, sigma=3, return_val=True): # Return if return_val: - vals = np.array([I[pt[1],pt[0]] for pt in pts]) + vals = np.array([I[pt[1], pt[0]] for pt in pts]) return pts.astype("float32"), vals else: return pts.astype("float32") @@ -38,7 +37,9 @@ def find_all_peaks(confmaps, min_thresh=0.3, sigma=3): peaks_i = [] peak_vals_i = [] for i in range(confmap.shape[-1]): - peak, val = impeaksnms_cv(confmap[...,i], min_thresh=min_thresh, sigma=sigma, return_val=True) + peak, val = impeaksnms_cv( + confmap[..., i], min_thresh=min_thresh, sigma=sigma, return_val=True + ) peaks_i.append(peak) peak_vals_i.append(val) peaks.append(peaks_i) @@ -46,6 +47,7 @@ def find_all_peaks(confmaps, min_thresh=0.3, sigma=3): return peaks, peak_vals + def find_all_single_peaks(confmaps, min_thresh=0.3): """ Finds single peak for each frame/channel in a stack of conf maps. @@ -57,13 +59,17 @@ def find_all_single_peaks(confmaps, min_thresh=0.3): all_point_arrays = [] for confmap in confmaps: - peaks_vals = [image_single_peak(confmap[...,i], min_thresh) for i in range(confmap.shape[-1])] + peaks_vals = [ + image_single_peak(confmap[..., i], min_thresh) + for i in range(confmap.shape[-1]) + ] peaks_vals = [(*point, val) for point, val in peaks_vals] points_array = np.stack(peaks_vals, axis=0) all_point_arrays.append(points_array) return all_point_arrays + def image_single_peak(I, min_thresh): peak = np.unravel_index(I.argmax(), I.shape) val = I[peak] @@ -74,4 +80,4 @@ def image_single_peak(I, min_thresh): else: y, x = peak - return (x, y), val \ No newline at end of file + return (x, y), val diff --git a/sleap/nn/peakfinding_tf.py b/sleap/nn/peakfinding_tf.py index a857ea425..5c7a1d93e 100644 --- a/sleap/nn/peakfinding_tf.py +++ b/sleap/nn/peakfinding_tf.py @@ -11,6 +11,7 @@ from sleap.nn.util import batch + def find_maxima_tf(x): col_max = tf.reduce_max(x, axis=1) @@ -24,7 +25,8 @@ def find_maxima_tf(x): maxima = tf.concat([rows, cols], -1) # max_val = tf.reduce_max(col_max, axis=1) # should match tf.reduce_max(x, axis=[1,2]) - return maxima #, max_val + return maxima # , max_val + def impeaksnms_tf(I, min_thresh=0.3): @@ -32,9 +34,7 @@ def impeaksnms_tf(I, min_thresh=0.3): # less than min_thresh are set to 0. It = tf.cast(I > min_thresh, I.dtype) * I - kernel = np.array([[0, 0, 0], - [0, -1, 0], - [0, 0, 0]])[..., None] + kernel = np.array([[0, 0, 0], [0, -1, 0], [0, 0, 0]])[..., None] # kernel = np.array([[1, 1, 1], # [1, 0, 1], # [1, 1, 1]])[..., None] @@ -46,12 +46,20 @@ def impeaksnms_tf(I, min_thresh=0.3): return inds, peak_vals -def find_peaks_tf(confmaps, confmaps_shape, min_thresh=0.3, upsample_factor: int = 1, win_size: int = 5): +def find_peaks_tf( + confmaps, + confmaps_shape, + min_thresh=0.3, + upsample_factor: int = 1, + win_size: int = 5, +): # n, h, w, c = confmaps.get_shape().as_list() h, w, c = confmaps_shape - unrolled_confmaps = tf.reshape(tf.transpose(confmaps, perm=[0, 3, 1, 2]), [-1, h, w, 1]) # (nc, h, w, 1) + unrolled_confmaps = tf.reshape( + tf.transpose(confmaps, perm=[0, 3, 1, 2]), [-1, h, w, 1] + ) # (nc, h, w, 1) peak_inds, peak_vals = impeaksnms_tf(unrolled_confmaps, min_thresh=min_thresh) channel_sample_ind, y, x, _ = tf.split(peak_inds, 4, axis=1) @@ -71,21 +79,24 @@ def find_peaks_tf(confmaps, confmaps_shape, min_thresh=0.3, upsample_factor: int # Get the boxes coordinates centered on the peaks, normalized to image # coordinates box_ind = tf.squeeze(tf.cast(channel_sample_ind, tf.int32)) - top_left = (tf.cast(peaks[:, 1:3], tf.float32) + - tf.constant([-offset, -offset], dtype="float32")) / (h - 1.0) - bottom_right = (tf.cast(peaks[:, 1:3], tf.float32) + tf.constant([offset, offset], dtype="float32")) / (w - 1.0) + top_left = ( + tf.cast(peaks[:, 1:3], tf.float32) + + tf.constant([-offset, -offset], dtype="float32") + ) / (h - 1.0) + bottom_right = ( + tf.cast(peaks[:, 1:3], tf.float32) + + tf.constant([offset, offset], dtype="float32") + ) / (w - 1.0) boxes = tf.concat([top_left, bottom_right], axis=1) small_windows = tf.image.crop_and_resize( - unrolled_confmaps, - boxes, - box_ind, - crop_size=[win_size, win_size]) + unrolled_confmaps, boxes, box_ind, crop_size=[win_size, win_size] + ) # Upsample cropped windows windows = tf.image.resize_bicubic( - small_windows, - [upsample_factor * win_size, upsample_factor * win_size]) + small_windows, [upsample_factor * win_size, upsample_factor * win_size] + ) windows = tf.squeeze(windows) @@ -93,34 +104,37 @@ def find_peaks_tf(confmaps, confmaps_shape, min_thresh=0.3, upsample_factor: int windows_peaks = find_maxima_tf(windows) # [row_ind, col_ind] ==> (nc, 2) # Adjust back to resolution before upsampling - windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast(upsample_factor, tf.float32) + windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast( + upsample_factor, tf.float32 + ) # Convert to offsets relative to the original peaks (center of cropped windows) windows_offsets = windows_peaks - tf.cast(offset, tf.float32) # (nc, 2) - windows_offsets = tf.pad(windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0) # (nc, 4) + windows_offsets = tf.pad( + windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0 + ) # (nc, 4) # Apply offsets peaks = tf.cast(peaks, tf.float32) + windows_offsets return peaks, peak_vals + # Blurring: # Ref: https://stackoverflow.com/questions/52012657/how-to-make-a-2d-gaussian-filter-in-tensorflow -def gaussian_kernel(size: int, - mean: float, - std: float, - ): +def gaussian_kernel(size: int, mean: float, std: float): """Makes 2D gaussian Kernel for convolution.""" d = tf.distributions.Normal(mean, std) - vals = d.prob(tf.range(start = -size, limit = size + 1, dtype = tf.float32)) - gauss_kernel = tf.einsum("i,j->ij", - vals, - vals) + vals = d.prob(tf.range(start=-size, limit=size + 1, dtype=tf.float32)) + gauss_kernel = tf.einsum("i,j->ij", vals, vals) return gauss_kernel / tf.reduce_sum(gauss_kernel) -def peak_tf_inference(model, data, + +def peak_tf_inference( + model, + data, confmaps_shape: Tuple[int], min_thresh: float = 0.3, gaussian_size: int = 9, @@ -128,14 +142,15 @@ def peak_tf_inference(model, data, upsample_factor: int = 1, return_confmaps: bool = False, batch_size: int = 4, - win_size: int = 7): + win_size: int = 7, +): sess = keras.backend.get_session() # TODO: Unfuck this. confmaps = model.outputs[-1] h, w, c = confmaps_shape - + if gaussian_size > 0 and gaussian_sigma > 0: # Make Gaussian Kernel with desired specs. @@ -149,14 +164,22 @@ def peak_tf_inference(model, data, pointwise_filter = tf.eye(c, batch_shape=[1, 1]) # Convolve. - confmaps = tf.nn.separable_conv2d(confmaps, gauss_kernel, pointwise_filter, - strides=[1, 1, 1, 1], padding="SAME") + confmaps = tf.nn.separable_conv2d( + confmaps, + gauss_kernel, + pointwise_filter, + strides=[1, 1, 1, 1], + padding="SAME", + ) - # Setup peak finding computations. - peaks, peak_vals = find_peaks_tf(confmaps, - confmaps_shape=confmaps_shape, min_thresh=min_thresh, - upsample_factor=upsample_factor, win_size=win_size) + peaks, peak_vals = find_peaks_tf( + confmaps, + confmaps_shape=confmaps_shape, + min_thresh=min_thresh, + upsample_factor=upsample_factor, + win_size=win_size, + ) # We definitely want to capture the peaks in the output # We will map the tensorflow outputs onto a dict to return @@ -166,7 +189,10 @@ def peak_tf_inference(model, data, outputs_dict["confmaps"] = confmaps # Convert dict to list of keys and list of tensors (to evaluate) - outputs_keys, output_tensors = list(outputs_dict.keys()), list(outputs_dict.values()) + outputs_keys, output_tensors = ( + list(outputs_dict.keys()), + list(outputs_dict.values()), + ) # Run the graph and retrieve output arrays. peaks_arr = [] @@ -204,16 +230,23 @@ def peak_tf_inference(model, data, # Use indices to convert matrices to lists of lists # (this matches the format of cpu-based peak-finding) - peak_list, peak_val_list = split_matrices_by_double_index(sample_channel_ind, peak_points, peak_vals_arr, - n_samples=len(data), n_channels=c) + peak_list, peak_val_list = split_matrices_by_double_index( + sample_channel_ind, + peak_points, + peak_vals_arr, + n_samples=len(data), + n_channels=c, + ) return peak_list, peak_val_list, confmaps + def split_matrices_by_double_index(idxs, *data_list, n_samples=None, n_channels=None): """Convert data matrices to lists of lists expected by other functions.""" # Return empty array if there are no idxs - if len(idxs) == 0: return [], [] + if len(idxs) == 0: + return [], [] # Determine the list length for major and minor indices if n_samples is None: @@ -251,4 +284,4 @@ def split_matrices_by_double_index(idxs, *data_list, n_samples=None, n_channels= for data_matrix_idx in range(data_matrix_count): r[data_matrix_idx].append(major[data_matrix_idx]) - return r \ No newline at end of file + return r diff --git a/sleap/nn/peakmatching.py b/sleap/nn/peakmatching.py index 9487cc32a..228bd8ee1 100644 --- a/sleap/nn/peakmatching.py +++ b/sleap/nn/peakmatching.py @@ -4,6 +4,7 @@ from sleap.instance import LabeledFrame, PredictedPoint, PredictedInstance from sleap.info.metrics import calculate_pairwise_cost + def match_single_peaks_frame(points_array, skeleton, transform, img_idx): """ Make instance from points array returned by single peak finding. @@ -13,10 +14,11 @@ def match_single_peaks_frame(points_array, skeleton, transform, img_idx): Returns: PredictedInstance, or None if no points. """ - if points_array.shape[0] == 0: return None + if points_array.shape[0] == 0: + return None # apply inverse transform to points - points_array[...,0:2] = transform.invert(img_idx, points_array[...,0:2]) + points_array[..., 0:2] = transform.invert(img_idx, points_array[..., 0:2]) pts = dict() for i, node in enumerate(skeleton.nodes): @@ -29,11 +31,14 @@ def match_single_peaks_frame(points_array, skeleton, transform, img_idx): matched_instance = None if len(pts) > 0: # FIXME: how should we calculate score for instance? - inst_score = np.sum(points_array[...,2]) / len(pts) - matched_instance = PredictedInstance(skeleton=skeleton, points=pts, score=inst_score) + inst_score = np.sum(points_array[..., 2]) / len(pts) + matched_instance = PredictedInstance( + skeleton=skeleton, points=pts, score=inst_score + ) return matched_instance + def match_single_peaks_all(points_arrays, skeleton, video, transform): """ Make labeled frames for the results of single peak finding. @@ -52,6 +57,7 @@ def match_single_peaks_all(points_arrays, skeleton, video, transform): predicted_frames.append(new_lf) return predicted_frames + def improfile(I, p0, p1, max_points=None): """ Returns values of the image I evaluated along the line formed @@ -73,7 +79,7 @@ def improfile(I, p0, p1, max_points=None): I = np.squeeze(I) # Find number of points to extract - n = np.sqrt((p0[0] - p1[0])**2 + (p0[1] - p1[1])**2) + n = np.sqrt((p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2) n = max(n, 1) if max_points is not None: n = min(n, max_points) @@ -84,15 +90,23 @@ def improfile(I, p0, p1, max_points=None): y = np.round(np.linspace(p0[1], p1[1], n)).astype("int32") # Extract values and concatenate into vector - vals = np.stack([I[yi,xi] for xi, yi in zip(x,y)]) + vals = np.stack([I[yi, xi] for xi, yi in zip(x, y)]) return vals -def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx, - min_score_to_node_ratio=0.4, - min_score_midpts=0.05, - min_score_integral=0.8, - add_last_edge=False, - single_per_crop=True): + +def match_peaks_frame( + peaks_t, + peak_vals_t, + pafs_t, + skeleton, + transform, + img_idx, + min_score_to_node_ratio=0.4, + min_score_midpts=0.05, + min_score_integral=0.8, + add_last_edge=False, + single_per_crop=True, +): """ Matches single frame """ @@ -116,8 +130,8 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx for k, edge in enumerate(skeleton.edge_names): src_node_idx = skeleton.node_to_index(edge[0]) dst_node_idx = skeleton.node_to_index(edge[1]) - paf_x = pafs_t[...,2*k] - paf_y = pafs_t[...,2*k+1] + paf_x = pafs_t[..., 2 * k] + paf_y = pafs_t[..., 2 * k + 1] # Make sure matrix has rows for these nodes if len(peaks_t) <= src_node_idx or len(peaks_t) <= dst_node_idx: @@ -156,16 +170,23 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx # Compute score score_midpts = vec_x * vec[0] + vec_y * vec[1] - score_with_dist_prior = np.mean(score_midpts) + min(0.5 * paf_x.shape[0] / norm - 1, 0) + score_with_dist_prior = np.mean(score_midpts) + min( + 0.5 * paf_x.shape[0] / norm - 1, 0 + ) score_integral = np.mean(score_midpts > min_score_midpts) - if score_with_dist_prior > 0 and score_integral > min_score_integral: + if ( + score_with_dist_prior > 0 + and score_integral > min_score_integral + ): connection_candidates.append([i, j, score_with_dist_prior]) # Sort candidates for current edge by descending score - connection_candidates = sorted(connection_candidates, key=lambda x: x[2], reverse=True) + connection_candidates = sorted( + connection_candidates, key=lambda x: x[2], reverse=True + ) # Add to list of candidates for next step - connection = np.zeros((0,5)) # src_id, dst_id, paf_score, i, j + connection = np.zeros((0, 5)) # src_id, dst_id, paf_score, i, j for candidate in connection_candidates: i, j, score = candidate # Add to connections if node is not already included @@ -180,20 +201,27 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx connection_all.append(connection) # Greedy matching of each edge candidate set - subset = -1 * np.ones((0, len(skeleton.nodes)+2)) # ids, overall score, number of parts - candidate = np.array([y for x in peaks_t for y in x]) # flattened set of all points - candidate_scores = np.array([y for x in peak_vals_t for y in x]) # flattened set of all peak scores + subset = -1 * np.ones( + (0, len(skeleton.nodes) + 2) + ) # ids, overall score, number of parts + candidate = np.array([y for x in peaks_t for y in x]) # flattened set of all points + candidate_scores = np.array( + [y for x in peak_vals_t for y in x] + ) # flattened set of all peak scores for k, edge in enumerate(skeleton.edge_names): # No matches for this edge if k in special_k: continue # Get IDs for current connection - partAs = connection_all[k][:,0] - partBs = connection_all[k][:,1] + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] # Get edge - indexA, indexB = (skeleton.node_to_index(edge[0]), skeleton.node_to_index(edge[1])) + indexA, indexB = ( + skeleton.node_to_index(edge[0]), + skeleton.node_to_index(edge[1]), + ) # Loop through all candidates for current edge for i in range(len(connection_all[k])): @@ -209,18 +237,24 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx # One of the two candidate points found in matched subset if found == 1: j = subset_idx[0] - if subset[j][indexB] != partBs[i]: # did we already assign this part? - subset[j][indexB] = partBs[i] # assign part - subset[j][-1] += 1 # increment instance part counter - subset[j][-2] += candidate_scores[int(partBs[i])] + connection_all[k][i][2] # add peak + edge score + if subset[j][indexB] != partBs[i]: # did we already assign this part? + subset[j][indexB] = partBs[i] # assign part + subset[j][-1] += 1 # increment instance part counter + subset[j][-2] += ( + candidate_scores[int(partBs[i])] + connection_all[k][i][2] + ) # add peak + edge score # Both candidate points found in matched subset elif found == 2: - j1, j2 = subset_idx # get indices in matched subset - membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] # count number of instances per body parts + j1, j2 = subset_idx # get indices in matched subset + membership = ( + (subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int) + )[ + :-2 + ] # count number of instances per body parts # All body parts are disjoint, merge them if np.all(membership < 2): - subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][:-2] += subset[j2][:-2] + 1 subset[j1][-2:] += subset[j2][-2:] subset[j1][-2] += connection_all[k][i][2] subset = np.delete(subset, j2, axis=0) @@ -229,19 +263,25 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx else: subset[j1][indexB] = partBs[i] subset[j1][-1] += 1 - subset[j1][-2] += candidate_scores[partBs[i].astype(int)] + connection_all[k][i][2] + subset[j1][-2] += ( + candidate_scores[partBs[i].astype(int)] + + connection_all[k][i][2] + ) # Neither point found, create a new subset (if not the last edge) - elif found == 0 and (add_last_edge or (k < (len(skeleton.edges)-1))): - row = -1 * np.ones(len(skeleton.nodes)+2) - row[indexA] = partAs[i] # ID - row[indexB] = partBs[i] # ID - row[-1] = 2 # initial count - row[-2] = sum(candidate_scores[connection_all[k][i, :2].astype(int)]) + connection_all[k][i][2] # score - subset = np.vstack([subset, row]) # add to matched subset + elif found == 0 and (add_last_edge or (k < (len(skeleton.edges) - 1))): + row = -1 * np.ones(len(skeleton.nodes) + 2) + row[indexA] = partAs[i] # ID + row[indexB] = partBs[i] # ID + row[-1] = 2 # initial count + row[-2] = ( + sum(candidate_scores[connection_all[k][i, :2].astype(int)]) + + connection_all[k][i][2] + ) # score + subset = np.vstack([subset, row]) # add to matched subset # Filter small instances - score_to_node_ratio = subset[:,-2] / subset[:,-1] + score_to_node_ratio = subset[:, -2] / subset[:, -1] subset = subset[score_to_node_ratio > min_score_to_node_ratio, :] # Apply inverse transform to points to return to full resolution, uncropped image coordinates @@ -257,28 +297,31 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx for i, node_name in enumerate(skeleton.node_names): if match[i] >= 0: match_idx = int(match[i]) - pt = PredictedPoint(x=candidate[match_idx, 0], y=candidate[match_idx, 1], - score=candidate_scores[match_idx]) + pt = PredictedPoint( + x=candidate[match_idx, 0], + y=candidate[match_idx, 1], + score=candidate_scores[match_idx], + ) pts[node_name] = pt if len(pts): - matched_instances_t.append(PredictedInstance(skeleton=skeleton, - points=pts, - score=match[-2])) + matched_instances_t.append( + PredictedInstance(skeleton=skeleton, points=pts, score=match[-2]) + ) # For centroid crop just return instance closest to centroid # if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: - # crop_centroid = np.array(((transform.crop_size//2, transform.crop_size//2),)) # center of crop box - # crop_centroid = transform.invert(img_idx, crop_centroid) # relative to original image + # crop_centroid = np.array(((transform.crop_size//2, transform.crop_size//2),)) # center of crop box + # crop_centroid = transform.invert(img_idx, crop_centroid) # relative to original image - # # sort by distance from crop centroid - # matched_instances_t.sort(key=lambda inst: np.linalg.norm(inst.centroid - crop_centroid)) + # # sort by distance from crop centroid + # matched_instances_t.sort(key=lambda inst: np.linalg.norm(inst.centroid - crop_centroid)) - # # logger.debug(f"SINGLE_INSTANCE_PER_CROP: crop has {len(matched_instances_t)} instances, filter to 1.") + # # logger.debug(f"SINGLE_INSTANCE_PER_CROP: crop has {len(matched_instances_t)} instances, filter to 1.") - # # just use closest - # matched_instances_t = matched_instances_t[0:1] + # # just use closest + # matched_instances_t = matched_instances_t[0:1] if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: # Just keep highest scoring instance @@ -286,55 +329,91 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx return matched_instances_t -def match_peaks_paf(peaks, peak_vals, pafs, skeleton, - video, transform, - min_score_to_node_ratio=0.4, min_score_midpts=0.05, - min_score_integral=0.8, add_last_edge=False, single_per_crop=True, - **kwargs): + +def match_peaks_paf( + peaks, + peak_vals, + pafs, + skeleton, + video, + transform, + min_score_to_node_ratio=0.4, + min_score_midpts=0.05, + min_score_integral=0.8, + add_last_edge=False, + single_per_crop=True, + **kwargs +): """ Computes PAF-based peak matching via greedy assignment """ # Process each frame predicted_frames = [] - for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate(zip(peaks, peak_vals, pafs)): - instances = match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, - transform, img_idx, - min_score_to_node_ratio=min_score_to_node_ratio, - min_score_midpts=min_score_midpts, - min_score_integral=min_score_integral, - add_last_edge=add_last_edge, - single_per_crop=single_per_crop) + for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate( + zip(peaks, peak_vals, pafs) + ): + instances = match_peaks_frame( + peaks_t, + peak_vals_t, + pafs_t, + skeleton, + transform, + img_idx, + min_score_to_node_ratio=min_score_to_node_ratio, + min_score_midpts=min_score_midpts, + min_score_integral=min_score_integral, + add_last_edge=add_last_edge, + single_per_crop=single_per_crop, + ) frame_idx = transform.get_frame_idxs(img_idx) - predicted_frames.append(LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)) + predicted_frames.append( + LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + ) # Combine LabeledFrame objects for the same video frame predicted_frames = LabeledFrame.merge_frames(predicted_frames, video=video) return predicted_frames -def match_peaks_paf_par(peaks, peak_vals, pafs, skeleton, - video, transform, - min_score_to_node_ratio=0.4, - min_score_midpts=0.05, - min_score_integral=0.8, - add_last_edge=False, - single_per_crop=True, - pool=None, **kwargs): + +def match_peaks_paf_par( + peaks, + peak_vals, + pafs, + skeleton, + video, + transform, + min_score_to_node_ratio=0.4, + min_score_midpts=0.05, + min_score_integral=0.8, + add_last_edge=False, + single_per_crop=True, + pool=None, + **kwargs +): """ Parallel version of PAF peak matching """ if pool is None: import multiprocessing + pool = multiprocessing.Pool() futures = [] - for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate(zip(peaks, peak_vals, pafs)): - future = pool.apply_async(match_peaks_frame, - [peaks_t, peak_vals_t, pafs_t, skeleton], - dict(transform=transform, img_idx=img_idx, - min_score_to_node_ratio=min_score_to_node_ratio, - min_score_midpts=min_score_midpts, - min_score_integral=min_score_integral, - add_last_edge=add_last_edge, - single_per_crop=single_per_crop,)) + for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate( + zip(peaks, peak_vals, pafs) + ): + future = pool.apply_async( + match_peaks_frame, + [peaks_t, peak_vals_t, pafs_t, skeleton], + dict( + transform=transform, + img_idx=img_idx, + min_score_to_node_ratio=min_score_to_node_ratio, + min_score_midpts=min_score_midpts, + min_score_integral=min_score_integral, + add_last_edge=add_last_edge, + single_per_crop=single_per_crop, + ), + ) futures.append(future) predicted_frames = [] @@ -347,32 +426,46 @@ def match_peaks_paf_par(peaks, peak_vals, pafs, skeleton, # an expensive operation. for i in range(len(instances)): points = {node.name: point for node, point in instances[i].nodes_points} - instances[i] = PredictedInstance(skeleton=skeleton, points=points, score=instances[i].score) + instances[i] = PredictedInstance( + skeleton=skeleton, points=points, score=instances[i].score + ) - predicted_frames.append(LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)) + predicted_frames.append( + LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + ) # Combine LabeledFrame objects for the same video frame predicted_frames = LabeledFrame.merge_frames(predicted_frames, video=video) return predicted_frames -def instances_nms(instances: List[PredictedInstance], thresh: float=4) -> List[PredictedInstance]: + +def instances_nms( + instances: List[PredictedInstance], thresh: float = 4 +) -> List[PredictedInstance]: """Remove overlapping instances from list.""" - if len(instances) <= 1: return + if len(instances) <= 1: + return # Look for overlapping instances - overlap_matrix = calculate_pairwise_cost(instances, instances, - cost_function = lambda x: np.nan if all(np.isnan(x)) else np.nanmean(x)) + overlap_matrix = calculate_pairwise_cost( + instances, + instances, + cost_function=lambda x: np.nan if all(np.isnan(x)) else np.nanmean(x), + ) # Set diagonals over threshold since an instance doesn't overlap with itself - np.fill_diagonal(overlap_matrix, thresh+1) - overlap_matrix[np.isnan(overlap_matrix)] = thresh+1 + np.fill_diagonal(overlap_matrix, thresh + 1) + overlap_matrix[np.isnan(overlap_matrix)] = thresh + 1 instances_to_remove = [] def sort_funct(inst_idx): # sort by number of points in instance, then by prediction score (desc) - return (len(instances[inst_idx].nodes), -getattr(instances[inst_idx], "score", 0)) + return ( + len(instances[inst_idx].nodes), + -getattr(instances[inst_idx], "score", 0), + ) while np.nanmin(overlap_matrix) < thresh: # Find the pair of instances with greatest overlap @@ -384,8 +477,8 @@ def sort_funct(inst_idx): keep_idx = idxs[-1] # Remove this instance from overlap matrix - overlap_matrix[pick_idx, :] = thresh+1 - overlap_matrix[:, pick_idx] = thresh+1 + overlap_matrix[pick_idx, :] = thresh + 1 + overlap_matrix[:, pick_idx] = thresh + 1 # Add to list of instances that we'll remove. # We'll remove these later so list index doesn't change now. @@ -394,4 +487,4 @@ def sort_funct(inst_idx): # Remove selected instances from list # Note that we're modifying the original list in place for inst in instances_to_remove: - instances.remove(inst) \ No newline at end of file + instances.remove(inst) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index aac286de2..413aaa573 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -28,13 +28,14 @@ class ShiftedInstance: Args: parent: The Instance that this optical flow shifted instance is derived from. """ - parent: Union[Instance, 'ShiftedInstance'] = attr.ib() + + parent: Union[Instance, "ShiftedInstance"] = attr.ib() frame: Union[LabeledFrame, None] = attr.ib() points: np.ndarray = attr.ib() @property @functools.lru_cache() - def source(self) -> 'Instance': + def source(self) -> "Instance": """ Recursively discover root instance to a chain of flow shifted instances. @@ -79,6 +80,7 @@ def get_points_array(self, *args, **kwargs): """ return self.points + @attr.s(slots=True) class Tracks: instances: Dict[int, list] = attr.ib(default=attr.Factory(dict)) @@ -91,13 +93,21 @@ def get_frame_instances(self, frame_idx: int, max_shift=None): # Filter if max_shift is not None: - instances = [instance for instance in instances if isinstance(instance, Instance) or ( - isinstance(instance, ShiftedInstance) and ( - (frame_idx - instance.source.frame_idx) <= max_shift))] + instances = [ + instance + for instance in instances + if isinstance(instance, Instance) + or ( + isinstance(instance, ShiftedInstance) + and ((frame_idx - instance.source.frame_idx) <= max_shift) + ) + ] return instances - def add_instance(self, instance: Union[Instance, 'ShiftedInstance'], frame_index: int): + def add_instance( + self, instance: Union[Instance, "ShiftedInstance"], frame_index: int + ): frame_instances = self.instances.get(frame_index, []) frame_instances.append(instance) self.instances[frame_index] = frame_instances @@ -113,11 +123,17 @@ def get_last_known(self, curr_frame_index: int = None, max_shift: int = None): return list(self.last_known_instance.values()) else: if max_shift is None: - return [i for i in self.last_known_instance.values() - if i.track == curr_frame_index] + return [ + i + for i in self.last_known_instance.values() + if i.track == curr_frame_index + ] else: - return [i for i in self.last_known_instance.values() - if (curr_frame_index-i.frame_idx) < max_shift] + return [ + i + for i in self.last_known_instance.values() + if (curr_frame_index - i.frame_idx) < max_shift + ] def update_track_last_known(self, frame: LabeledFrame, max_shift: int = None): for i in frame.instances: @@ -126,9 +142,11 @@ def update_track_last_known(self, frame: LabeledFrame, max_shift: int = None): # Remove tracks from the dict that have exceeded the max_shift horizon if max_shift is not None: - del_tracks = [track - for track, instance in self.last_known_instance.items() - if (frame.frame_idx-instance.frame_idx) > max_shift] + del_tracks = [ + track + for track, instance in self.last_known_instance.items() + if (frame.frame_idx - instance.frame_idx) > max_shift + ] for key in del_tracks: del self.last_known_instance[key] @@ -153,7 +171,7 @@ class FlowShiftTracker: """ window: int = 10 - of_win_size: Tuple = (21,21) + of_win_size: Tuple = (21, 21) of_max_level: int = 3 of_max_count: int = 30 of_epsilon: float = 0.01 @@ -168,7 +186,7 @@ def __attrs_post_init__(self): def _fix_img(self, img: np.ndarray): # Drop single channel dimension and convert to uint8 in [0, 255] range - curr_img = (np.squeeze(img)*255).astype(np.uint8) + curr_img = (np.squeeze(img) * 255).astype(np.uint8) np.clip(curr_img, 0, 255) # If we still have 3 dimensions the image is color, need to convert @@ -178,9 +196,7 @@ def _fix_img(self, img: np.ndarray): return curr_img - def process(self, - imgs: np.ndarray, - labeled_frames: List[LabeledFrame]): + def process(self, imgs: np.ndarray, labeled_frames: List[LabeledFrame]): """ Flow shift track a batch of frames with matched instances for each frame represented as a list of LabeledFrame's. @@ -215,7 +231,9 @@ def process(self, # known instance for each track. Do this for the last frame and # skip on the first frame. if img_idx > 0: - self.tracks.update_track_last_known(labeled_frames[img_idx-1], max_shift=None) + self.tracks.update_track_last_known( + labeled_frames[img_idx - 1], max_shift=None + ) # Copy the actual frame index for this labeled frame, we will # use this a lot. @@ -232,30 +250,45 @@ def process(self, instance.track = Track(spawned_on=t, name=f"{i}") self.tracks.add_instance(instance, frame_index=t) - logger.debug(f"[t = {t}] Created {len(self.tracks.tracks)} initial tracks") + logger.debug( + f"[t = {t}] Created {len(self.tracks.tracks)} initial tracks" + ) self.last_img = self._fix_img(imgs[img_idx].copy()) continue # Get all points in reference frame - instances_ref = self.tracks.get_frame_instances(self.last_frame_index, max_shift=self.window - 1) + instances_ref = self.tracks.get_frame_instances( + self.last_frame_index, max_shift=self.window - 1 + ) pts_ref = [instance.get_points_array() for instance in instances_ref] - tmp = min([instance.frame_idx for instance in instances_ref] + - [instance.source.frame_idx for instance in instances_ref - if isinstance(instance, ShiftedInstance)]) + tmp = min( + [instance.frame_idx for instance in instances_ref] + + [ + instance.source.frame_idx + for instance in instances_ref + if isinstance(instance, ShiftedInstance) + ] + ) logger.debug(f"[t = {t}] Using {len(instances_ref)} refs back to t = {tmp}") curr_img = self._fix_img(imgs[img_idx].copy()) - pts_fs, status, err = \ - cv2.calcOpticalFlowPyrLK(self.last_img, curr_img, - (np.concatenate(pts_ref, axis=0)).astype("float32"), - None, winSize=self.of_win_size, - maxLevel=self.of_max_level, - criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, - self.of_max_count, self.of_epsilon)) + pts_fs, status, err = cv2.calcOpticalFlowPyrLK( + self.last_img, + curr_img, + (np.concatenate(pts_ref, axis=0)).astype("float32"), + None, + winSize=self.of_win_size, + maxLevel=self.of_max_level, + criteria=( + cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, + self.of_max_count, + self.of_epsilon, + ), + ) self.last_img = curr_img # Split by instance @@ -266,22 +299,30 @@ def process(self, err = np.split(err, sections, axis=0) # Store shifted instances with metadata - shifted_instances = [ShiftedInstance(parent=ref, points=pts, frame=frame) - for ref, pts, found in zip(instances_ref, pts_fs, status) - if np.sum(found) > 0] + shifted_instances = [ + ShiftedInstance(parent=ref, points=pts, frame=frame) + for ref, pts, found in zip(instances_ref, pts_fs, status) + if np.sum(found) > 0 + ] # Get the track present in the shifted instances shifted_tracks = list({instance.track for instance in shifted_instances}) - last_known = self.tracks.get_last_known(curr_frame_index=t, max_shift=self.window) + last_known = self.tracks.get_last_known( + curr_frame_index=t, max_shift=self.window + ) alive_tracks = {i.track for i in last_known} # If we didn't get any shifted instances from the reference frame, use the last # know positions for each track. if len(shifted_instances) == 0: - logger.debug(f"[t = {t}] Optical flow failed, using last known positions for each track.") + logger.debug( + f"[t = {t}] Optical flow failed, using last known positions for each track." + ) shifted_instances = self.tracks.get_last_known() - shifted_tracks = list({instance.track for instance in shifted_instances}) + shifted_tracks = list( + {instance.track for instance in shifted_instances} + ) else: # We might have got some shifted instances, but make sure we aren't missing any # tracks @@ -298,21 +339,35 @@ def process(self, continue # Reduce distances by track - unassigned_pts = np.stack(instances_pts, axis=0) # instances x nodes x 2 - logger.debug(f"[t = {t}] Flow shift matching {len(unassigned_pts)} " - f"instances to {len(shifted_tracks)} ref tracks") + unassigned_pts = np.stack(instances_pts, axis=0) # instances x nodes x 2 + logger.debug( + f"[t = {t}] Flow shift matching {len(unassigned_pts)} " + f"instances to {len(shifted_tracks)} ref tracks" + ) cost_matrix = np.full((len(unassigned_pts), len(shifted_tracks)), np.nan) for i, track in enumerate(shifted_tracks): # Get shifted points for current track - track_pts = np.stack([instance.get_points_array() - for instance in shifted_instances - if instance.track == track], axis=0) # track_instances x nodes x 2 + track_pts = np.stack( + [ + instance.get_points_array() + for instance in shifted_instances + if instance.track == track + ], + axis=0, + ) # track_instances x nodes x 2 # Compute pairwise distances between points - distances = np.sqrt(np.sum((np.expand_dims(unassigned_pts / img_scale, axis=1) - - np.expand_dims(track_pts, axis=0)) ** 2, - axis=-1)) # unassigned_instances x track_instances x nodes + distances = np.sqrt( + np.sum( + ( + np.expand_dims(unassigned_pts / img_scale, axis=1) + - np.expand_dims(track_pts, axis=0) + ) + ** 2, + axis=-1, + ) + ) # unassigned_instances x track_instances x nodes # Reduce over nodes and instances distances = -np.nansum(np.exp(-distances), axis=(1, 2)) @@ -328,18 +383,23 @@ def process(self, frame.instances[i].track = shifted_tracks[j] self.tracks.add_instance(frame.instances[i], frame_index=t) - logger.debug(f"[t = {t}] Assigned instance {i} to existing track " - f"{shifted_tracks[j].name} (cost = {cost_matrix[i,j]})") + logger.debug( + f"[t = {t}] Assigned instance {i} to existing track " + f"{shifted_tracks[j].name} (cost = {cost_matrix[i,j]})" + ) # Spawn new tracks for unassigned instances for i, pts in enumerate(unassigned_pts): - if i in assigned_ind: continue + if i in assigned_ind: + continue instance = frame.instances[i] instance.track = Track(spawned_on=t, name=f"{len(self.tracks.tracks)}") self.tracks.add_instance(instance, frame_index=t) - logger.debug(f"[t = {t}] Assigned remaining instance {i} to newly " - f"spawned track {instance.track.name} " - f"(best cost = {cost_matrix[i,:].min()})") + logger.debug( + f"[t = {t}] Assigned remaining instance {i} to newly " + f"spawned track {instance.track.name} " + f"(best cost = {cost_matrix[i,:].min()})" + ) # Update the last know data structures for the last frame. self.tracks.update_track_last_known(labeled_frames[img_idx - 1], max_shift=None) @@ -353,9 +413,11 @@ def occupancy(self): occ = np.zeros((len(self.tracks.tracks), int(num_frames)), dtype="bool") for t in range(int(num_frames)): instances = self.tracks.get_frame_instances(t) - instances = [instance for instance in instances if isinstance(instance, Instance)] + instances = [ + instance for instance in instances if isinstance(instance, Instance) + ] for instance in instances: - occ[self.tracks.tracks.index(instance.track),t] = True + occ[self.tracks.tracks.index(instance.track), t] = True return occ @@ -370,22 +432,34 @@ def generate_tracks(self): instance_tracks = np.full((num_frames, num_nodes, 2, num_tracks), np.nan) for t in range(num_frames): instances = self.tracks.get_frame_instances(t) - instances = [instance for instance in instances if isinstance(instance, Instance)] + instances = [ + instance for instance in instances if isinstance(instance, Instance) + ] for instance in instances: - instance_tracks[t, :, :, self.tracks.tracks.index(instance.track)] = instance.points + instance_tracks[ + t, :, :, self.tracks.tracks.index(instance.track) + ] = instance.points return instance_tracks def generate_shifted_data(self): """ Generate arrays with all shifted instance data """ - shifted_instances = [y for x in self.tracks.instances.values() - for y in x if isinstance(y, ShiftedInstance)] + shifted_instances = [ + y + for x in self.tracks.instances.values() + for y in x + if isinstance(y, ShiftedInstance) + ] - track_id = np.array([self.tracks.tracks.index(instance.track) for instance in shifted_instances]) + track_id = np.array( + [self.tracks.tracks.index(instance.track) for instance in shifted_instances] + ) frame_idx = np.array([instance.frame_idx for instance in shifted_instances]) - frame_idx_source = np.array([instance.source.frame_idx for instance in shifted_instances]) + frame_idx_source = np.array( + [instance.source.frame_idx for instance in shifted_instances] + ) points = np.stack([instance.points for instance in shifted_instances], axis=0) return track_id, frame_idx, frame_idx_source, points diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 93e58ba0a..a05f1e013 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -2,6 +2,7 @@ import json import logging + logger = logging.getLogger(__name__) import numpy as np @@ -20,8 +21,22 @@ from pathlib import Path, PureWindowsPath from keras import backend as K -from keras.layers import Input, Conv2D, BatchNormalization, Add, MaxPool2D, UpSampling2D, Concatenate -from keras.callbacks import ReduceLROnPlateau, EarlyStopping, TensorBoard, LambdaCallback, ModelCheckpoint +from keras.layers import ( + Input, + Conv2D, + BatchNormalization, + Add, + MaxPool2D, + UpSampling2D, + Concatenate, +) +from keras.callbacks import ( + ReduceLROnPlateau, + EarlyStopping, + TensorBoard, + LambdaCallback, + ModelCheckpoint, +) from sklearn.model_selection import train_test_split @@ -30,7 +45,14 @@ from sleap.nn.augmentation import Augmenter from sleap.nn.model import Model, ModelOutputType from sleap.nn.monitor import LossViewer -from sleap.nn.datagen import generate_training_data, generate_confmaps_from_points, generate_pafs_from_points, generate_images, generate_points, generate_centroid_points +from sleap.nn.datagen import ( + generate_training_data, + generate_confmaps_from_points, + generate_pafs_from_points, + generate_images, + generate_points, + generate_centroid_points, +) @attr.s(auto_attribs=True) @@ -133,15 +155,17 @@ class Trainer: min_crop_size: int = 32 negative_samples: int = 10 - def train(self, - model: Model, - labels: Union[str, Labels, Dict], - run_name: str = None, - save_dir: Union[str, None] = None, - tensorboard_dir: Union[str, None] = None, - control_zmq_port: int = 9000, - progress_report_zmq_port: int = 9001, - multiprocessing_workers: int = 0) -> str: + def train( + self, + model: Model, + labels: Union[str, Labels, Dict], + run_name: str = None, + save_dir: Union[str, None] = None, + tensorboard_dir: Union[str, None] = None, + control_zmq_port: int = 9000, + progress_report_zmq_port: int = 9001, + multiprocessing_workers: int = 0, + ) -> str: """ Train a given model using labels and the Trainer's current hyper-parameter settings. This method executes synchronously, thus it blocks until training is finished. @@ -176,7 +200,6 @@ def train(self, elif type(labels) is dict: labels = Labels.from_json(labels) - # FIXME: We need to handle multiple skeletons. skeleton = labels.skeletons[0] @@ -187,22 +210,24 @@ def train(self, # Generate CENTROID training data if model.output_type == ModelOutputType.CENTROIDS: imgs = generate_images(labels, scale=self.scale) - points = generate_centroid_points( - generate_points(labels, scale=self.scale)) + points = generate_centroid_points(generate_points(labels, scale=self.scale)) # Generate REGULAR training data else: imgs, points = generate_training_data( - labels, - params = dict( - scale = self.scale, - instance_crop = self.instance_crop, - min_crop_size = self.min_crop_size, - negative_samples = self.negative_samples)) + labels, + params=dict( + scale=self.scale, + instance_crop=self.instance_crop, + min_crop_size=self.min_crop_size, + negative_samples=self.negative_samples, + ), + ) # Split data into train/validation - imgs_train, imgs_val, outputs_train, outputs_val = \ - train_test_split(imgs, points, test_size=self.val_size) + imgs_train, imgs_val, outputs_train, outputs_val = train_test_split( + imgs, points, test_size=self.val_size + ) # Free up the original datasets after test and train split. del imgs, points @@ -227,10 +252,17 @@ def train(self, # then the output (confidence maps or part affinity fields) will # be at a different scale than the input (images). input_img_size = (imgs_train.shape[1], imgs_train.shape[2]) - output_img_size = (int(input_img_size[0] * model.output_scale), int(input_img_size[1] * model.output_scale)) + output_img_size = ( + int(input_img_size[0] * model.output_scale), + int(input_img_size[1] * model.output_scale), + ) - logger.info(f"Training set: {imgs_train.shape} -> {output_img_size}, {num_outputs_channels} channels") - logger.info(f"Validation set: {imgs_val.shape} -> {output_img_size}, {num_outputs_channels} channels") + logger.info( + f"Training set: {imgs_train.shape} -> {output_img_size}, {num_outputs_channels} channels" + ) + logger.info( + f"Validation set: {imgs_val.shape} -> {output_img_size}, {num_outputs_channels} channels" + ) # Input layer img_input = Input((img_height, img_width, img_channels)) @@ -244,7 +276,8 @@ def train(self, if img_height % pool_factor != 0 or img_width % pool_factor != 0: logger.warning( f"Image dimensions ({img_height}, {img_width}) are " - f"not divisible by the pooling factor ({pool_factor}).") + f"not divisible by the pooling factor ({pool_factor})." + ) # gap_height = (np.ceil(img_height / pool_factor) * pool_factor) - img_height # gap_width = (np.ceil(img_width / pool_factor) * pool_factor) - img_width @@ -254,24 +287,25 @@ def train(self, # Solution: https://www.tensorflow.org/api_docs/python/tf/pad + Lambda layer + corresponding crop at the end? # Instantiate the backbone, this builds the Tensorflow graph - x_outs = model.output(input_tensor=img_input, num_output_channels=num_outputs_channels) + x_outs = model.output( + input_tensor=img_input, num_output_channels=num_outputs_channels + ) # Create training model by combining the input layer and backbone graph. keras_model = keras.Model(inputs=img_input, outputs=x_outs) # Specify the optimizer. if self.optimizer.lower() == "adam": - _optimizer = keras.optimizers.Adam(lr=self.learning_rate, amsgrad=self.amsgrad) + _optimizer = keras.optimizers.Adam( + lr=self.learning_rate, amsgrad=self.amsgrad + ) elif self.optimizer.lower() == "rmsprop": _optimizer = keras.optimizers.RMSprop(lr=self.learning_rate) else: raise ValueError(f"Unknown optimizer, value = {self.optimizer}!") # Compile the Keras model - keras_model.compile( - optimizer=_optimizer, - loss="mean_squared_error", - ) + keras_model.compile(optimizer=_optimizer, loss="mean_squared_error") logger.info("Params: {:,}".format(keras_model.count_params())) # Default to one loop through dataset per epoch @@ -284,17 +318,34 @@ def train(self, # Setup data generation if model.output_type == ModelOutputType.CONFIDENCE_MAP: + def datagen_function(points): - return generate_confmaps_from_points(points, skeleton, input_img_size, - sigma=self.sigma, scale=model.output_scale) + return generate_confmaps_from_points( + points, + skeleton, + input_img_size, + sigma=self.sigma, + scale=model.output_scale, + ) + elif model.output_type == ModelOutputType.PART_AFFINITY_FIELD: + def datagen_function(points): - return generate_pafs_from_points(points, skeleton, input_img_size, - sigma=self.sigma, scale=model.output_scale) + return generate_pafs_from_points( + points, + skeleton, + input_img_size, + sigma=self.sigma, + scale=model.output_scale, + ) + elif model.output_type == ModelOutputType.CENTROIDS: + def datagen_function(points): - return generate_confmaps_from_points(points, None, input_img_size, - node_count=1, sigma=self.sigma) + return generate_confmaps_from_points( + points, None, input_img_size, node_count=1, sigma=self.sigma + ) + else: datagen_function = None @@ -303,15 +354,23 @@ def datagen_function(points): # Initialize data generator with augmentation train_datagen = Augmenter( - imgs_train, points=outputs_train, - datagen=datagen_function, output_names=keras_model.output_names, - batch_size=self.batch_size, shuffle_initially=self.shuffle_initially, + imgs_train, + points=outputs_train, + datagen=datagen_function, + output_names=keras_model.output_names, + batch_size=self.batch_size, + shuffle_initially=self.shuffle_initially, rotation=self.augment_rotation, - scale=(self.augment_scale_min, self.augment_scale_max)) + scale=(self.augment_scale_min, self.augment_scale_max), + ) - train_run = TrainingJob(model=model, trainer=self, - save_dir=save_dir, run_name=run_name, - labels_filename=labels_file_name) + train_run = TrainingJob( + model=model, + trainer=self, + save_dir=save_dir, + run_name=run_name, + labels_filename=labels_file_name, + ) # Setup saving save_path = None @@ -319,8 +378,10 @@ def datagen_function(points): # Generate run name if run_name is None: timestamp = datetime.now().strftime("%y%m%d_%H%M%S") - train_run.run_name = f"{timestamp}.{str(model.output_type)}." \ - f"{model.name}.n={num_total}" + train_run.run_name = ( + f"{timestamp}.{str(model.output_type)}." + f"{model.name}.n={num_total}" + ) # Build save path save_path = os.path.join(save_dir, train_run.run_name) @@ -328,8 +389,10 @@ def datagen_function(points): # Check if it already exists if os.path.exists(save_path): - logger.warning(f"Save path already exists. " - f"Previous run data may be overwritten!") + logger.warning( + f"Save path already exists. " + f"Previous run data may be overwritten!" + ) # Create run folder os.makedirs(save_path, exist_ok=True) @@ -339,19 +402,25 @@ def datagen_function(points): if len(keras_model.output_names) > 1: monitor_metric_name = "val_" + keras_model.output_names[-1] + "_loss" callbacks = self._setup_callbacks( - train_run, save_path, train_datagen, - tensorboard_dir, control_zmq_port, + train_run, + save_path, + train_datagen, + tensorboard_dir, + control_zmq_port, progress_report_zmq_port, output_type=str(model.output_type), monitor_metric_name=monitor_metric_name, - ) + ) # Train! history = keras_model.fit_generator( train_datagen, steps_per_epoch=steps_per_epoch, epochs=self.num_epochs, - validation_data=(imgs_val, {output_name: outputs_val for output_name in keras_model.output_names}), + validation_data=( + imgs_val, + {output_name: outputs_val for output_name in keras_model.output_names}, + ), callbacks=callbacks, verbose=2, use_multiprocessing=multiprocessing_workers > 0, @@ -361,7 +430,9 @@ def datagen_function(points): # Save once done training if save_path is not None: final_model_path = os.path.join(save_path, "final_model.h5") - keras_model.save(filepath=final_model_path, overwrite=True, include_optimizer=True) + keras_model.save( + filepath=final_model_path, overwrite=True, include_optimizer=True + ) logger.info(f"Saved final model: {final_model_path}") # TODO: save training history @@ -401,12 +472,17 @@ def train_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: return pool, result - def _setup_callbacks(self, train_run: 'TrainingJob', - save_path, train_datagen, - tensorboard_dir, control_zmq_port, - progress_report_zmq_port, - output_type, - monitor_metric_name="val_loss"): + def _setup_callbacks( + self, + train_run: "TrainingJob", + save_path, + train_datagen, + tensorboard_dir, + control_zmq_port, + progress_report_zmq_port, + output_type, + monitor_metric_name="val_loss", + ): """ Setup callbacks for the call to Keras fit_generator. @@ -421,61 +497,98 @@ def _setup_callbacks(self, train_run: 'TrainingJob', if save_path is not None: if self.save_every_epoch: full_path = os.path.join(save_path, "newest_model.h5") - train_run.newest_model_filename = os.path.relpath(full_path, train_run.save_dir) + train_run.newest_model_filename = os.path.relpath( + full_path, train_run.save_dir + ) callbacks.append( - ModelCheckpoint(filepath=full_path, - monitor=monitor_metric_name, save_best_only=False, - save_weights_only=False, period=1)) + ModelCheckpoint( + filepath=full_path, + monitor=monitor_metric_name, + save_best_only=False, + save_weights_only=False, + period=1, + ) + ) if self.save_best_val: full_path = os.path.join(save_path, "best_model.h5") - train_run.best_model_filename = os.path.relpath(full_path, train_run.save_dir) + train_run.best_model_filename = os.path.relpath( + full_path, train_run.save_dir + ) callbacks.append( - ModelCheckpoint(filepath=full_path, - monitor=monitor_metric_name, save_best_only=True, - save_weights_only=False, period=1)) + ModelCheckpoint( + filepath=full_path, + monitor=monitor_metric_name, + save_best_only=True, + save_weights_only=False, + period=1, + ) + ) TrainingJob.save_json(train_run, f"{save_path}.json") # Callbacks: Shuffle after every epoch if self.shuffle_every_epoch: callbacks.append( - LambdaCallback(on_epoch_end=lambda epoch, logs: train_datagen.shuffle())) + LambdaCallback(on_epoch_end=lambda epoch, logs: train_datagen.shuffle()) + ) # Callbacks: LR reduction callbacks.append( - ReduceLROnPlateau(min_delta=self.reduce_lr_min_delta, - factor=self.reduce_lr_factor, - patience=self.reduce_lr_patience, - cooldown=self.reduce_lr_cooldown, - min_lr=self.reduce_lr_min_lr, - monitor=monitor_metric_name, mode="auto", verbose=1, ) + ReduceLROnPlateau( + min_delta=self.reduce_lr_min_delta, + factor=self.reduce_lr_factor, + patience=self.reduce_lr_patience, + cooldown=self.reduce_lr_cooldown, + min_lr=self.reduce_lr_min_lr, + monitor=monitor_metric_name, + mode="auto", + verbose=1, + ) ) # Callbacks: Early stopping callbacks.append( - EarlyStopping(monitor=monitor_metric_name, - min_delta=self.early_stopping_min_delta, - patience=self.early_stopping_patience, verbose=1)) + EarlyStopping( + monitor=monitor_metric_name, + min_delta=self.early_stopping_min_delta, + patience=self.early_stopping_patience, + verbose=1, + ) + ) # Callbacks: Tensorboard if tensorboard_dir is not None: callbacks.append( - TensorBoard(log_dir=f"{tensorboard_dir}/{output_type}{time()}", - batch_size=32, update_freq=150, histogram_freq=0, - write_graph=False, write_grads=False, write_images=False, - embeddings_freq=0, embeddings_layer_names=None, - embeddings_metadata=None, embeddings_data=None)) + TensorBoard( + log_dir=f"{tensorboard_dir}/{output_type}{time()}", + batch_size=32, + update_freq=150, + histogram_freq=0, + write_graph=False, + write_grads=False, + write_images=False, + embeddings_freq=0, + embeddings_layer_names=None, + embeddings_metadata=None, + embeddings_data=None, + ) + ) # Callbacks: ZMQ control if control_zmq_port is not None: callbacks.append( - TrainingControllerZMQ(address="tcp://127.0.0.1", - port=control_zmq_port, - topic="", poll_timeout=10)) + TrainingControllerZMQ( + address="tcp://127.0.0.1", + port=control_zmq_port, + topic="", + poll_timeout=10, + ) + ) # Callbacks: ZMQ progress reporter if progress_report_zmq_port is not None: callbacks.append( - ProgressReporterZMQ(port=progress_report_zmq_port, what=output_type)) + ProgressReporterZMQ(port=progress_report_zmq_port, what=output_type) + ) return callbacks @@ -502,6 +615,7 @@ class TrainingJob: from the final state of training. Set to None if save_dir is None. This model file is not created until training is finished. """ + model: Model trainer: Trainer labels_filename: Union[str, None] = None @@ -512,7 +626,7 @@ class TrainingJob: final_model_filename: Union[str, None] = None @staticmethod - def save_json(training_job: 'TrainingJob', filename: str): + def save_json(training_job: "TrainingJob", filename: str): """ Save a training run to a JSON file. @@ -524,7 +638,7 @@ def save_json(training_job: 'TrainingJob', filename: str): None """ - with open(filename, 'w') as file: + with open(filename, "w") as file: # We have some skeletons to deal with, make sure to setup a Skeleton cattr. my_cattr = Skeleton.make_cattr() @@ -532,7 +646,6 @@ def save_json(training_job: 'TrainingJob', filename: str): json_str = json.dumps(dicts) file.write(json_str) - @classmethod def load_json(cls, filename: str): """ @@ -556,8 +669,9 @@ def load_json(cls, filename: str): if ("model" in dicts) and ("skeletons" in dicts["model"]): if dicts["model"]["skeletons"]: dicts["model"]["skeletons"] = converter.structure( - dicts["model"]["skeletons"], List[Skeleton]) - + dicts["model"]["skeletons"], List[Skeleton] + ) + else: dicts["model"]["skeletons"] = [] @@ -598,7 +712,9 @@ def __init__(self, address="tcp://127.0.0.1", port=9000, topic="", poll_timeout= self.socket = self.context.socket(zmq.SUB) self.socket.subscribe(self.topic) self.socket.connect(self.address) - logger.info(f"Training controller subscribed to: {self.address} (topic: {self.topic})") + logger.info( + f"Training controller subscribed to: {self.address} (topic: {self.topic})" + ) # TODO: catch/throw exception about failure to connect @@ -665,17 +781,26 @@ def on_train_begin(self, logs=None): logs: dict, currently no data is passed to this argument for this method but that may change in the future. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="train_begin", logs=logs))) - + self.socket.send_string( + jsonpickle.encode(dict(what=self.what, event="train_begin", logs=logs)) + ) def on_batch_begin(self, batch, logs=None): """A backwards compatibility alias for `on_train_batch_begin`.""" # self.logger.info("batch_begin") - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="batch_begin", batch=batch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="batch_begin", batch=batch, logs=logs) + ) + ) def on_batch_end(self, batch, logs=None): """A backwards compatibility alias for `on_train_batch_end`.""" - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="batch_end", batch=batch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="batch_end", batch=batch, logs=logs) + ) + ) def on_epoch_begin(self, epoch, logs=None): """Called at the start of an epoch. @@ -686,7 +811,11 @@ def on_epoch_begin(self, epoch, logs=None): logs: dict, currently no data is passed to this argument for this method but that may change in the future. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="epoch_begin", epoch=epoch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="epoch_begin", epoch=epoch, logs=logs) + ) + ) def on_epoch_end(self, epoch, logs=None): """Called at the end of an epoch. @@ -698,7 +827,11 @@ def on_epoch_end(self, epoch, logs=None): validation epoch if validation is performed. Validation result keys are prefixed with `val_`. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="epoch_end", epoch=epoch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="epoch_end", epoch=epoch, logs=logs) + ) + ) def on_train_end(self, logs=None): """Called at the end of training. @@ -707,31 +840,48 @@ def on_train_end(self, logs=None): logs: dict, currently no data is passed to this argument for this method but that may change in the future. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="train_end", logs=logs))) + self.socket.send_string( + jsonpickle.encode(dict(what=self.what, event="train_end", logs=logs)) + ) + def main(): from PySide2 import QtWidgets -# from sleap.nn.architectures.unet import UNet -# model = Model(output_type=ModelOutputType.CONFIDENCE_MAP, -# backbone=UNet(num_filters=16, depth=3, up_blocks=2)) + # from sleap.nn.architectures.unet import UNet + # model = Model(output_type=ModelOutputType.CONFIDENCE_MAP, + # backbone=UNet(num_filters=16, depth=3, up_blocks=2)) from sleap.nn.architectures.leap import LeapCNN - model = Model(output_type=ModelOutputType.PART_AFFINITY_FIELD, - backbone=LeapCNN(down_blocks=3, up_blocks=2, - upsampling_layers=True, num_filters=32, interp="bilinear")) + + model = Model( + output_type=ModelOutputType.PART_AFFINITY_FIELD, + backbone=LeapCNN( + down_blocks=3, + up_blocks=2, + upsampling_layers=True, + num_filters=32, + interp="bilinear", + ), + ) # Setup a Trainer object to train the model above - trainer = Trainer(val_size=0.1, batch_size=4, - num_epochs=10, steps_per_epoch=5, - save_best_val=True, - save_every_epoch=True) + trainer = Trainer( + val_size=0.1, + batch_size=4, + num_epochs=10, + steps_per_epoch=5, + save_best_val=True, + save_every_epoch=True, + ) # Run training asynchronously - pool, result = trainer.train_async(model=model, - labels=Labels.load_json("tests/data/json_format_v1/centered_pair.json"), - save_dir='test_train/', - run_name="training_run_2") + pool, result = trainer.train_async( + model=model, + labels=Labels.load_json("tests/data/json_format_v1/centered_pair.json"), + save_dir="test_train/", + run_name="training_run_2", + ) app = QtWidgets.QApplication() @@ -743,7 +893,7 @@ def main(): while not result.ready(): app.processEvents() - result.wait(.01) + result.wait(0.01) print("Get") train_job_path = result.get() @@ -755,13 +905,21 @@ def main(): # Now lets load the training job we just ran train_job = TrainingJob.load_json(train_job_path) - assert os.path.exists(os.path.join(train_job.save_dir, train_job.newest_model_filename)) - assert os.path.exists(os.path.join(train_job.save_dir, train_job.best_model_filename)) - assert os.path.exists(os.path.join(train_job.save_dir, train_job.final_model_filename)) + assert os.path.exists( + os.path.join(train_job.save_dir, train_job.newest_model_filename) + ) + assert os.path.exists( + os.path.join(train_job.save_dir, train_job.best_model_filename) + ) + assert os.path.exists( + os.path.join(train_job.save_dir, train_job.final_model_filename) + ) import sys + sys.exit(0) + def run(labels_filename: str, job_filename: str): labels = Labels.load_file(labels_filename) @@ -771,12 +929,13 @@ def run(labels_filename: str, job_filename: str): save_dir = os.path.join(os.path.dirname(labels_filename), "models") job.trainer.train( - model=job.model, - labels=labels, - save_dir=save_dir, - control_zmq_port=None, - progress_report_zmq_port=None - ) + model=job.model, + labels=labels, + save_dir=save_dir, + control_zmq_port=None, + progress_report_zmq_port=None, + ) + if __name__ == "__main__": import argparse @@ -790,7 +949,9 @@ def run(labels_filename: str, job_filename: str): job_filename = args.profile_path if not os.path.exists(job_filename): - profile_dir = resource_filename(Requirement.parse("sleap"), "sleap/training_profiles") + profile_dir = resource_filename( + Requirement.parse("sleap"), "sleap/training_profiles" + ) if os.path.exists(os.path.join(profile_dir, job_filename)): job_filename = os.path.join(profile_dir, job_filename) else: @@ -799,7 +960,4 @@ def run(labels_filename: str, job_filename: str): print(f"Training labels file: {args.labels_path}") print(f"Training profile: {job_filename}") - run( - labels_filename=args.labels_path, - job_filename=job_filename) - + run(labels_filename=args.labels_path, job_filename=job_filename) diff --git a/sleap/nn/transform.py b/sleap/nn/transform.py index 0f5cd6ab2..89f71b77e 100644 --- a/sleap/nn/transform.py +++ b/sleap/nn/transform.py @@ -5,6 +5,7 @@ from sleap.nn.datagen import _bbs_from_points, _pad_bbs, _crop + @attr.s(auto_attribs=True, slots=True) class DataTransform: """ @@ -23,7 +24,9 @@ def _init_frame_idxs(self, frame_count): self.frame_idxs = list(range(frame_count)) def get_data_idxs(self, frame_idx): - return [i for i in range(len(self.frame_idxs)) if self.frame_idxs[i] == frame_idx] + return [ + i for i in range(len(self.frame_idxs)) if self.frame_idxs[i] == frame_idx + ] def get_frame_idxs(self, idxs): if type(idxs) == int: @@ -51,7 +54,7 @@ def scale_to(self, imgs, target_size): self._init_frame_idxs(img_count) # update object state (so we can invert) - self.scale = self.scale * (h/img_h) + self.scale = self.scale * (h / img_h) # return the scaled images return self._scale(imgs, target_size) @@ -67,7 +70,7 @@ def invert_scale(self, imgs): """ # determine target size for inverting scale img_count, img_h, img_w, img_channels = imgs.shape - target_size = (img_h * int(1/self.scale), img_w * int(1/self.scale)) + target_size = (img_h * int(1 / self.scale), img_w * int(1 / self.scale)) return self.scale_to(imgs, target_size) @@ -79,12 +82,14 @@ def _scale(self, imgs, target_size): if (img_h, img_w) != target_size: # build ndarray for new size - scaled_imgs = np.zeros((imgs.shape[0], h, w, imgs.shape[3]), dtype=imgs.dtype) + scaled_imgs = np.zeros( + (imgs.shape[0], h, w, imgs.shape[3]), dtype=imgs.dtype + ) for i in range(imgs.shape[0]): # resize using cv2 img = cv2.resize(imgs[i, :, :], (w, h)) - # add back singleton channel (removed by cv2) + # add back singleton channel (removed by cv2) if img_channels == 1: img = img[..., None] else: @@ -99,7 +104,7 @@ def _scale(self, imgs, target_size): return scaled_imgs - def centroid_crop(self, imgs: np.ndarray, centroids: list, crop_size: int=0): + def centroid_crop(self, imgs: np.ndarray, centroids: list, crop_size: int = 0): """ Crop images around centroid points. Updates state of DataTransform object so we can later invert on points. @@ -122,7 +127,7 @@ def centroid_crop(self, imgs: np.ndarray, centroids: list, crop_size: int=0): # Crop images return self.crop(imgs, bbs, idxs) - def crop(self, imgs:np.ndarray, boxes: list, idxs: list) -> np.ndarray: + def crop(self, imgs: np.ndarray, boxes: list, idxs: list) -> np.ndarray: """ Crop images to given boxes. @@ -172,7 +177,7 @@ def invert(self, idx: int, point_array: np.ndarray) -> np.ndarray: # translate point_array using corresponding bounding_box bb = self.bounding_boxes[idx] - top_left_point = ((bb[0], bb[1]),) # for (x, y) row vector + top_left_point = ((bb[0], bb[1]),) # for (x, y) row vector new_point_array += np.array(top_left_point) - return new_point_array \ No newline at end of file + return new_point_array diff --git a/sleap/nn/util.py b/sleap/nn/util.py index 9eb4647a7..b4cfd65ae 100644 --- a/sleap/nn/util.py +++ b/sleap/nn/util.py @@ -4,10 +4,13 @@ def batch_count(data, batch_size): """Return number of batch_size batches into which data can be divided.""" from math import ceil + return ceil(len(data) / batch_size) -def batch(data: Sequence, batch_size: int) -> Generator[Tuple[int, int, Sequence], None, None]: +def batch( + data: Sequence, batch_size: int +) -> Generator[Tuple[int, int, Sequence], None, None]: """Iterate over sequence data in batches. Arguments: @@ -18,10 +21,10 @@ def batch(data: Sequence, batch_size: int) -> Generator[Tuple[int, int, Sequence * batch number (int) * row offset (int) * batch_size number of items from data - """ + """ total_row_count = len(data) for start in range(0, total_row_count, batch_size): - i = start//batch_size + i = start // batch_size end = min(start + batch_size, total_row_count) yield i, start, data[start:end] @@ -36,7 +39,7 @@ def save_visual_outputs(output_path: str, data: dict): # output_path is full path to labels.json, so replace "json" with "h5" viz_output_path = output_path if viz_output_path.endswith(".json"): - viz_output_path = viz_output_path[:-(len(".json"))] + viz_output_path = viz_output_path[: -(len(".json"))] viz_output_path += ".h5" # write file @@ -45,11 +48,15 @@ def save_visual_outputs(output_path: str, data: dict): val = np.array(val) if key in f: f[key].resize(f[key].shape[0] + val.shape[0], axis=0) - f[key][-val.shape[0]:] = val + f[key][-val.shape[0] :] = val else: maxshape = (None, *val.shape[1:]) - f.create_dataset(key, data=val, maxshape=maxshape, - compression="gzip", compression_opts=9) + f.create_dataset( + key, + data=val, + maxshape=maxshape, + compression="gzip", + compression_opts=9, + ) # logger.info(" Saved visual outputs [%.1fs]" % (time() - t0)) - diff --git a/sleap/rangelist.py b/sleap/rangelist.py index 6db9d7e6e..2ccc1175c 100644 --- a/sleap/rangelist.py +++ b/sleap/rangelist.py @@ -7,6 +7,7 @@ from typing import List, Tuple + class RangeList: """ Class for manipulating a list of range intervals. @@ -37,15 +38,16 @@ def is_empty(self): @property def start(self): """Return the start value of range (or None if empty).""" - if self.is_empty: return None + if self.is_empty: + return None return self.list[0][0] def add(self, val, tolerance=0): """Add a single value, merges to last range if contiguous.""" if self.list and self.list[-1][1] + tolerance >= val: - self.list[-1] = (self.list[-1][0], val+1) + self.list[-1] = (self.list[-1][0], val + 1) else: - self.list.append((val, val+1)) + self.list.append((val, val + 1)) def insert(self, new_range: tuple): """Add a new range, merging to adjacent/overlapping ranges as appropriate.""" @@ -72,7 +74,8 @@ def cut(self, cut: int): def cut_range(self, cut: tuple): """Return three lists, everthing before/within/after cut range.""" - if not self.list: return [], [], [] + if not self.list: + return [], [], [] cut = self._as_tuple(cut) a, r = self.cut_(self.list, cut[0]) @@ -83,7 +86,8 @@ def cut_range(self, cut: tuple): @staticmethod def _as_tuple(x): """Return tuple (converting from range if necessary).""" - if isinstance(x, range): return x.start, x.stop + if isinstance(x, range): + return x.start, x.stop return x @staticmethod @@ -120,14 +124,17 @@ def join_(cls, list_list: List[List[Tuple[int]]]): Returns: range list that joins all of the lists in list_list """ - if len(list_list) == 1: return list_list[0] - if len(list_list) == 2: return cls.join_pair_(list_list[0], list_list[1]) + if len(list_list) == 1: + return list_list[0] + if len(list_list) == 2: + return cls.join_pair_(list_list[0], list_list[1]) return cls.join_pair_(list_list[0], cls.join_(list_list[1:])) @staticmethod def join_pair_(list_a: List[Tuple[int]], list_b: List[Tuple[int]]): """Return a single pair of lists that joins two input lists.""" - if not list_a or not list_b: return list_a + list_b + if not list_a or not list_b: + return list_a + list_b last_a = list_a[-1] first_b = list_b[0] diff --git a/sleap/skeleton.py b/sleap/skeleton.py index c2c6db3dd..ab9b5ee28 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -32,6 +32,7 @@ class EdgeType(Enum): * SYMMETRY - these edges represent symmetrical relationships between parts (e.g. left and right arms) """ + BODY = 1 SYMMETRY = 2 @@ -44,10 +45,10 @@ class Node: """ name: str - weight: float = 1. + weight: float = 1.0 @staticmethod - def from_names(name_list: str): + def from_names(name_list: str) -> List["Node"]: """Convert list of node names to list of nodes objects.""" nodes = [] for name in name_list: @@ -55,11 +56,11 @@ def from_names(name_list: str): return nodes @classmethod - def as_node(cls, node): + def as_node(cls, node: Union[str, "Node"]) -> "Node": """Convert given `node` to `Node` object (if not already).""" return node if isinstance(node, cls) else cls(node) - def matches(self, other): + def matches(self, other: "Node") -> bool: """ Check whether all attributes match between two nodes. @@ -99,14 +100,13 @@ def __init__(self, name: str = None): if name is None or not isinstance(name, str) or not name: name = "Skeleton-" + str(next(self._skeleton_idx)) - # Since networkx does not keep edges in the order we insert them we need # to keep track of how many edges have been inserted so we can number them # as they are inserted and sort them by this numbering when the edge list # is returned. self._graph: nx.MultiDiGraph = nx.MultiDiGraph(name=name, num_edges_inserted=0) - def matches(self, other: 'Skeleton'): + def matches(self, other: "Skeleton") -> bool: """ Compare this `Skeleton` to another, ignoring skeleton name and the identities of the `Node` objects in each graph. @@ -117,11 +117,14 @@ def matches(self, other: 'Skeleton'): Returns: True if match, False otherwise. """ + def dict_match(dict1, dict2): return dict1 == dict2 # Check if the graphs are iso-morphic - is_isomorphic = nx.is_isomorphic(self._graph, other._graph, node_match=dict_match) + is_isomorphic = nx.is_isomorphic( + self._graph, other._graph, node_match=dict_match + ) if not is_isomorphic: return False @@ -137,7 +140,11 @@ def dict_match(dict1, dict2): @property def graph(self): - edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.BODY] + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.BODY + ] # TODO: properly induce subgraph for MultiDiGraph # Currently, NetworkX will just return the nodes in the subgraph. # See: https://stackoverflow.com/questions/16150557/networkxcreating-a-subgraph-induced-from-edges @@ -145,11 +152,15 @@ def graph(self): @property def graph_symmetry(self): - edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.SYMMETRY] + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.SYMMETRY + ] return self._graph.edge_subgraph(edges) @staticmethod - def find_unique_nodes(skeletons: List['Skeleton']): + def find_unique_nodes(skeletons: List["Skeleton"]): """ Given list of skeletons, return a list of unique node objects across all skeletons. @@ -173,11 +184,19 @@ def make_cattr(idx_to_node: Dict[int, Node] = None): Returns: A cattr.Converter() instance ready for skeleton serialization and deserialization. """ - node_to_idx = {node:idx for idx,node in idx_to_node.items()} if idx_to_node is not None else None + node_to_idx = ( + {node: idx for idx, node in idx_to_node.items()} + if idx_to_node is not None + else None + ) _cattr = cattr.Converter() - _cattr.register_unstructure_hook(Skeleton, lambda x: Skeleton.to_dict(x, node_to_idx)) - _cattr.register_structure_hook(Skeleton, lambda x, cls: Skeleton.from_dict(x, idx_to_node)) + _cattr.register_unstructure_hook( + Skeleton, lambda x: Skeleton.to_dict(x, node_to_idx) + ) + _cattr.register_structure_hook( + Skeleton, lambda x, cls: Skeleton.from_dict(x, idx_to_node) + ) return _cattr @property @@ -204,13 +223,15 @@ def name(self, name: str): Raises: NotImplementedError """ - raise NotImplementedError("Cannot change Skeleton name, it is immutable since " + - "it is used for hashing. Create a copy of the skeleton " + - "with new name using " + - f"new_skeleton = Skeleton.rename(skeleton, '{name}'))") + raise NotImplementedError( + "Cannot change Skeleton name, it is immutable since " + + "it is used for hashing. Create a copy of the skeleton " + + "with new name using " + + f"new_skeleton = Skeleton.rename(skeleton, '{name}'))" + ) @classmethod - def rename_skeleton(cls, skeleton: 'Skeleton', name: str) -> 'Skeleton': + def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": """ A skeleton object cannot change its name. This property is immutable because it is used to hash skeletons. If you want to rename a Skeleton you must use this classmethod. @@ -254,9 +275,11 @@ def edges(self): Returns: list of (src_node, dst_node) """ - edge_list = [(d['edge_insert_idx'], src, dst) - for src, dst, key, d in self._graph.edges(keys=True, data=True) - if d['type'] == EdgeType.BODY] + edge_list = [ + (d["edge_insert_idx"], src, dst) + for src, dst, key, d in self._graph.edges(keys=True, data=True) + if d["type"] == EdgeType.BODY + ] # We don't want to return the edge list in the order it is stored. We # want to use the insertion order. Sort by the insertion index for each @@ -272,9 +295,11 @@ def edge_names(self): Returns: list of (src_node.name, dst_node.name) """ - edge_list = [(d['edge_insert_idx'], src.name, dst.name) - for src, dst, key, d in self._graph.edges(keys=True, data=True) - if d['type'] == EdgeType.BODY] + edge_list = [ + (d["edge_insert_idx"], src.name, dst.name) + for src, dst, key, d in self._graph.edges(keys=True, data=True) + if d["type"] == EdgeType.BODY + ] # We don't want to return the edge list in the order it is stored. We # want to use the insertion order. Sort by the insertion index for each @@ -290,7 +315,11 @@ def edges_full(self): Returns: list of (src_node, dst_node, key, attributes) """ - return [(src, dst, key, attr) for src, dst, key, attr in self._graph.edges(keys=True, data=True) if attr["type"] == EdgeType.BODY] + return [ + (src, dst, key, attr) + for src, dst, key, attr in self._graph.edges(keys=True, data=True) + if attr["type"] == EdgeType.BODY + ] @property def symmetries(self): @@ -300,7 +329,11 @@ def symmetries(self): list of (node1, node2) """ # Find all symmetric edges - symmetries = [(src, dst) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.SYMMETRY] + symmetries = [ + (src, dst) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.SYMMETRY + ] # Get rid of duplicates symmetries = list(set([tuple(set(e)) for e in symmetries])) return symmetries @@ -315,7 +348,11 @@ def symmetries_full(self): list of (node1, node2, key, attr) """ # Find all symmetric edges - return [(src, dst, key, attr) for src, dst, key, attr in self._graph.edges(keys=True, data=True) if attr["type"] == EdgeType.SYMMETRY] + return [ + (src, dst, key, attr) + for src, dst, key, attr in self._graph.edges(keys=True, data=True) + if attr["type"] == EdgeType.SYMMETRY + ] def node_to_index(self, node: Union[str, Node]): """ @@ -378,7 +415,9 @@ def delete_node(self, name: str): node = self.find_node(name) self._graph.remove_node(node) except nx.NetworkXError: - raise ValueError("The node named ({}) does not exist, cannot remove it.".format(name)) + raise ValueError( + "The node named ({}) does not exist, cannot remove it.".format(name) + ) def find_node(self, name: str): """Find node in skeleton by name of node. @@ -426,17 +465,31 @@ def add_edge(self, source: str, destination: str): destination_node = self.find_node(destination) if source_node is None: - raise ValueError("Skeleton does not have source node named ({})".format(source)) + raise ValueError( + "Skeleton does not have source node named ({})".format(source) + ) if destination_node is None: - raise ValueError("Skeleton does not have destination node named ({})".format(destination)) + raise ValueError( + "Skeleton does not have destination node named ({})".format(destination) + ) if self._graph.has_edge(source_node, destination_node): - raise ValueError("Skeleton already has an edge between ({}) and ({}).".format(source, destination)) - - self._graph.add_edge(source_node, destination_node, type = EdgeType.BODY, - edge_insert_idx = self._graph.graph['num_edges_inserted']) - self._graph.graph['num_edges_inserted'] = self._graph.graph['num_edges_inserted'] + 1 + raise ValueError( + "Skeleton already has an edge between ({}) and ({}).".format( + source, destination + ) + ) + + self._graph.add_edge( + source_node, + destination_node, + type=EdgeType.BODY, + edge_insert_idx=self._graph.graph["num_edges_inserted"], + ) + self._graph.graph["num_edges_inserted"] = ( + self._graph.graph["num_edges_inserted"] + 1 + ) def delete_edge(self, source: str, destination: str): """Delete an edge between two nodes. @@ -461,13 +514,21 @@ def delete_edge(self, source: str, destination: str): destination_node = self.find_node(destination) if source_node is None: - raise ValueError("Skeleton does not have source node named ({})".format(source)) + raise ValueError( + "Skeleton does not have source node named ({})".format(source) + ) if destination_node is None: - raise ValueError("Skeleton does not have destination node named ({})".format(destination)) + raise ValueError( + "Skeleton does not have destination node named ({})".format(destination) + ) if not self._graph.has_edge(source_node, destination_node): - raise ValueError("Skeleton has no edge between ({}) and ({}).".format(source, destination)) + raise ValueError( + "Skeleton has no edge between ({}) and ({}).".format( + source, destination + ) + ) self._graph.remove_edge(source_node, destination_node) @@ -494,10 +555,14 @@ def add_symmetry(self, node1: str, node2: str): raise ValueError("Cannot add symmetry to the same node.") if self.get_symmetry(node1) is not None: - raise ValueError(f"{node1} is already symmetric with {self.get_symmetry(node1)}.") + raise ValueError( + f"{node1} is already symmetric with {self.get_symmetry(node1)}." + ) if self.get_symmetry(node2) is not None: - raise ValueError(f"{node2} is already symmetric with {self.get_symmetry(node2)}.") + raise ValueError( + f"{node2} is already symmetric with {self.get_symmetry(node2)}." + ) self._graph.add_edge(node1_node, node2_node, type=EdgeType.SYMMETRY) self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY) @@ -515,10 +580,19 @@ def delete_symmetry(self, node1: str, node2: str): node1_node = self.find_node(node1) node2_node = self.find_node(node2) - if self.get_symmetry(node1) != node2_node or self.get_symmetry(node2) != node1_node: + if ( + self.get_symmetry(node1) != node2_node + or self.get_symmetry(node2) != node1_node + ): raise ValueError(f"Nodes {node1}, {node2} are not symmetric.") - edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges([node1_node, node2_node], keys=True, data="type") if edge_type == EdgeType.SYMMETRY] + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges( + [node1_node, node2_node], keys=True, data="type" + ) + if edge_type == EdgeType.SYMMETRY + ] self._graph.remove_edges_from(edges) def get_symmetry(self, node: str): @@ -532,7 +606,11 @@ def get_symmetry(self, node: str): """ node_node = self.find_node(node) - symmetry = [dst for src, dst, edge_type in self._graph.edges(node_node, data="type") if edge_type == EdgeType.SYMMETRY] + symmetry = [ + dst + for src, dst, edge_type in self._graph.edges(node_node, data="type") + if edge_type == EdgeType.SYMMETRY + ] if len(symmetry) == 0: return None @@ -658,11 +736,14 @@ def has_edge(self, source_name: str, dest_name: str) -> bool: True is yes, False if no. """ - source_node, destination_node = self.find_node(source_name), self.find_node(dest_name) + source_node, destination_node = ( + self.find_node(source_name), + self.find_node(dest_name), + ) return self._graph.has_edge(source_node, destination_node) @staticmethod - def to_dict(obj: 'Skeleton', node_to_idx: Dict[Node, int] = None): + def to_dict(obj: "Skeleton", node_to_idx: Dict[Node, int] = None): # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. @@ -682,9 +763,11 @@ def to_json(self, node_to_idx: Dict[Node, int] = None) -> str: Returns: A string containing the JSON representation of the Skeleton. """ - jsonpickle.set_encoder_options('simplejson', sort_keys=True, indent=4) + jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) if node_to_idx is not None: - indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx) # map nodes to int + indexed_node_graph = nx.relabel_nodes( + G=self._graph, mapping=node_to_idx + ) # map nodes to int else: indexed_node_graph = self._graph @@ -708,7 +791,7 @@ def save_json(self, filename: str, node_to_idx: Dict[Node, int] = None): json_str = self.to_json(node_to_idx) - with open(filename, 'w') as file: + with open(filename, "w") as file: file.write(json_str) @classmethod @@ -749,7 +832,7 @@ def load_json(cls, filename: str, idx_to_node: Dict[int, Node] = None): """ - with open(filename, 'r') as file: + with open(filename, "r") as file: skeleton = Skeleton.from_json(file.read(), idx_to_node) return skeleton @@ -768,15 +851,16 @@ def load_hdf5(cls, file: Union[str, h5.File], name: str): """ if isinstance(file, str): with h5.File(file) as _file: - skeletons = Skeleton._load_hdf5(_file) # Load all skeletons + skeletons = Skeleton._load_hdf5(_file) # Load all skeletons else: skeletons = Skeleton._load_hdf5(file) return skeletons[name] @classmethod - def load_all_hdf5(cls, file: Union[str, h5.File], - return_dict: bool = False) -> Union[List['Skeleton'], Dict[str, 'Skeleton']]: + def load_all_hdf5( + cls, file: Union[str, h5.File], return_dict: bool = False + ) -> Union[List["Skeleton"], Dict[str, "Skeleton"]]: """ Load all skeletons found in the HDF5 file. @@ -791,7 +875,7 @@ def load_all_hdf5(cls, file: Union[str, h5.File], """ if isinstance(file, str): with h5.File(file) as _file: - skeletons = Skeleton._load_hdf5(_file) # Load all skeletons + skeletons = Skeleton._load_hdf5(_file) # Load all skeletons else: skeletons = Skeleton._load_hdf5(file) @@ -804,7 +888,7 @@ def load_all_hdf5(cls, file: Union[str, h5.File], def _load_hdf5(cls, file: h5.File): skeletons = {} - for name, json_str in file['skeleton'].attrs.items(): + for name, json_str in file["skeleton"].attrs.items(): skeletons[name] = Skeleton.from_json(json_str) return skeletons @@ -817,7 +901,7 @@ def save_hdf5(self, file: Union[str, h5.File]): self._save_hdf5(file) @classmethod - def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List['Skeleton']): + def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List["Skeleton"]): """ Convenience method to save a list of skeletons to HDF5 file. Skeletons are saved as attributes of a /skeleton group in the file. @@ -851,10 +935,10 @@ def _save_hdf5(self, file: h5.File): """ # All skeleton will be put as sub-groups in the skeleton group - if 'skeleton' not in file: - all_sk_group = file.create_group('skeleton', track_order=True) + if "skeleton" not in file: + all_sk_group = file.create_group("skeleton", track_order=True) else: - all_sk_group = file.require_group('skeleton') + all_sk_group = file.require_group("skeleton") # Write the dataset to JSON string, then store it in a string # attribute @@ -879,14 +963,16 @@ def load_mat(cls, filename: str): skel_mat = loadmat(filename) skel_mat["nodes"] = skel_mat["nodes"][0][0] # convert to scalar - skel_mat["edges"] = skel_mat["edges"] - 1 # convert to 0-based indexing + skel_mat["edges"] = skel_mat["edges"] - 1 # convert to 0-based indexing - node_names = skel_mat['nodeNames'] + node_names = skel_mat["nodeNames"] node_names = [str(n[0][0]) for n in node_names] skeleton.add_nodes(node_names) for k in range(len(skel_mat["edges"])): edge = skel_mat["edges"][k] - skeleton.add_edge(source=node_names[edge[0]], destination=node_names[edge[1]]) + skeleton.add_edge( + source=node_names[edge[0]], destination=node_names[edge[1]] + ) return skeleton diff --git a/sleap/util.py b/sleap/util.py index 4ea5dc180..9f689a4f2 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -10,6 +10,9 @@ import attr import psutil +from typing import Hashable, Iterable, List, Optional + + def attr_to_dtype(cls): """Convert classes with basic types to numpy composite dtypes. @@ -23,16 +26,21 @@ def attr_to_dtype(cls): if field.type == str: dtype_list.append((field.name, h5.special_dtype(vlen=str))) elif field.type is None: - raise TypeError(f"numpy dtype for {cls} cannot be constructed because no " + - "type information found. Make sure each field is type annotated.") + raise TypeError( + f"numpy dtype for {cls} cannot be constructed because no " + + "type information found. Make sure each field is type annotated." + ) elif field.type in [str, int, float, bool]: dtype_list.append((field.name, field.type)) else: - raise TypeError(f"numpy dtype for {cls} cannot be constructed because no " + - f"{field.type} is not supported.") + raise TypeError( + f"numpy dtype for {cls} cannot be constructed because no " + + f"{field.type} is not supported." + ) return np.dtype(dtype_list) + def usable_cpu_count() -> int: """Get number of CPUs usable by the current process. @@ -50,6 +58,7 @@ def usable_cpu_count() -> int: result = os.cpu_count() return result + def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): """ Saves dictionary to an HDF5 file, calls itself recursively if items in @@ -60,7 +69,8 @@ def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): h5file: The HDF5 filename object to save the data to. Assume it is open. path: The path to group save the dict under. dic: The dict to save. - + Raises: + ValueError: If type for item in dict cannot be saved. Returns: None """ @@ -74,23 +84,24 @@ def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): items_encoded = [] for it in item: if isinstance(it, str): - items_encoded.append(it.encode('utf8')) + items_encoded.append(it.encode("utf8")) else: items_encoded.append(it) h5file[path + key] = np.asarray(items_encoded) elif isinstance(item, (str)): - h5file[path + key] = item.encode('utf8') + h5file[path + key] = item.encode("utf8") elif isinstance(item, (np.ndarray, np.int64, np.float64, str, bytes, float)): h5file[path + key] = item elif isinstance(item, dict): - save_dict_to_hdf5(h5file, path + key + '/', item) + save_dict_to_hdf5(h5file, path + key + "/", item) elif isinstance(item, int): h5file[path + key] = item else: - raise ValueError('Cannot save %s type'%type(item)) + raise ValueError("Cannot save %s type" % type(item)) + -def frame_list(frame_str: str): +def frame_list(frame_str: str) -> Optional[List[int]]: """Convert 'n-m' string to list of ints. Args: @@ -100,15 +111,16 @@ def frame_list(frame_str: str): """ # Handle ranges of frames. Must be of the form "1-200" - if '-' in frame_str: - min_max = frame_str.split('-') + if "-" in frame_str: + min_max = frame_str.split("-") min_frame = int(min_max[0]) max_frame = int(min_max[1]) - return list(range(min_frame, max_frame+1)) + return list(range(min_frame, max_frame + 1)) return [int(x) for x in frame_str.split(",")] if len(frame_str) else None -def uniquify(seq): + +def uniquify(seq: Iterable[Hashable]) -> List: """ Given a list, return unique elements but preserve order. @@ -126,8 +138,19 @@ def uniquify(seq): # https://twitter.com/raymondh/status/944125570534621185 return list(dict.fromkeys(seq)) -def weak_filename_match(filename_a, filename_b): - """Check if paths probably point to same file.""" + +def weak_filename_match(filename_a: str, filename_b: str) -> bool: + """ + Check if paths probably point to same file. + + Compares the filename and names of two directories up. + + Args: + filename_a: first path to check + filename_b: path to check against first path + Returns: + True if the paths probably match. + """ # convert all path separators to / filename_a = filename_a.replace("\\", "/") filename_b = filename_b.replace("\\", "/") From 980b5220b059d485baad2e6ff9a2ace2080f354e Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 16:55:59 -0400 Subject: [PATCH 117/176] Add env variables for coveralls. --- appveyor.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/appveyor.yml b/appveyor.yml index 70aa5aede..7f1a75c1c 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -10,6 +10,11 @@ environment: BUILD_DIR: "build" + # Variables for coveralls + CI_NAME: "appveyor" + CI_BRANCH: APPVEYOR_REPO_BRANCH + CI_BUILD_NUMBER: APPVEYOR_BUILD_NUMBER + conda_access_token: secure: d+v++uejbVEhIuaJSuFIOA== matrix: From 32c455995bc9861a2f8a4d37fbde70d1062cf8de Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 17:15:14 -0400 Subject: [PATCH 118/176] Use coveralls command line args. --- appveyor.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 7f1a75c1c..4e7a2b4ae 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -9,11 +9,6 @@ environment: secure: T7XuBtHDu85Tk/d1AeyfhW3CVyzaoddTWmR4xsPIdQ3di0R6x8ncWqw3KrYXkWJm BUILD_DIR: "build" - - # Variables for coveralls - CI_NAME: "appveyor" - CI_BRANCH: APPVEYOR_REPO_BRANCH - CI_BUILD_NUMBER: APPVEYOR_BUILD_NUMBER conda_access_token: secure: d+v++uejbVEhIuaJSuFIOA== @@ -76,7 +71,7 @@ test_script: - cmd: pytest --cov=sleap tests/ on_success: - - cmd: coveralls + - cmd: coveralls --commitId $env:APPVEYOR_REPO_COMMIT --commitBranch $env:APPVEYOR_REPO_BRANCH --commitAuthor $env:APPVEYOR_REPO_COMMIT_AUTHOR --commitEmail $env:APPVEYOR_REPO_COMMIT_AUTHOR_EMAIL --commitMessage $env:APPVEYOR_REPO_COMMIT_MESSAGE --jobId $env:APPVEYOR_BUILD_NUMBER --serviceName appveyor # here we are going to override common configuration for: From d06f865b753625e7b1ccf353079fc9a938a377c1 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 19:03:09 -0400 Subject: [PATCH 119/176] coveralls --service arg --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 4e7a2b4ae..488fccb52 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,7 +71,7 @@ test_script: - cmd: pytest --cov=sleap tests/ on_success: - - cmd: coveralls --commitId $env:APPVEYOR_REPO_COMMIT --commitBranch $env:APPVEYOR_REPO_BRANCH --commitAuthor $env:APPVEYOR_REPO_COMMIT_AUTHOR --commitEmail $env:APPVEYOR_REPO_COMMIT_AUTHOR_EMAIL --commitMessage $env:APPVEYOR_REPO_COMMIT_MESSAGE --jobId $env:APPVEYOR_BUILD_NUMBER --serviceName appveyor + - cmd: coveralls --service appveyor # here we are going to override common configuration for: From 53dac36f216873cbda708f111090421bdf6299d8 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 19:21:02 -0400 Subject: [PATCH 120/176] coveralls --service=arg --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 488fccb52..e8004585b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,7 +71,7 @@ test_script: - cmd: pytest --cov=sleap tests/ on_success: - - cmd: coveralls --service appveyor + - cmd: coveralls --service=appveyor # here we are going to override common configuration for: From d24f15a4a4c5e5393faf15a6a08f9ea5e9416bdd Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 19:30:47 -0400 Subject: [PATCH 121/176] Add .coveralls.yml --- .coveralls.yml | 1 + appveyor.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 .coveralls.yml diff --git a/.coveralls.yml b/.coveralls.yml new file mode 100644 index 000000000..d09470693 --- /dev/null +++ b/.coveralls.yml @@ -0,0 +1 @@ +service_name: appveyor \ No newline at end of file diff --git a/appveyor.yml b/appveyor.yml index e8004585b..b5ea78828 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,7 +71,7 @@ test_script: - cmd: pytest --cov=sleap tests/ on_success: - - cmd: coveralls --service=appveyor + - coveralls # here we are going to override common configuration for: From 335e0b6cf4a613e69a633a463fe5c697bc92d7c5 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 21:28:21 -0400 Subject: [PATCH 122/176] Better typing and docstrings. --- sleap/skeleton.py | 287 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 198 insertions(+), 89 deletions(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index ab9b5ee28..c7c96f5bc 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -1,9 +1,9 @@ -"""Implementation of skeleton data structure and API. - -This module implements and API for creating animal skeleton's in LEAP. The goal -is to provide a common interface for defining the parts of the animal, their -connection to each other, and needed meta-data. +""" +Implementation of skeleton data structure and API. +This module implements and API for creating animal skeletons. The goal +is to provide a common interface for defining the parts of the animal, +their connection to each other, and needed meta-data. """ import attr @@ -16,7 +16,7 @@ from enum import Enum from itertools import count -from typing import Iterable, Union, List, Dict +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import networkx as nx from networkx.readwrite import json_graph @@ -65,7 +65,7 @@ def matches(self, other: "Node") -> bool: Check whether all attributes match between two nodes. Args: - other: The node to compare to this one. + other: The `Node` to compare to this one. Returns: True if all attributes match, False otherwise. @@ -74,17 +74,15 @@ def matches(self, other: "Node") -> bool: class Skeleton: - """The main object for representing animal skeletons in LEAP. + """The main object for representing animal skeletons. The skeleton represents the constituent parts of the animal whose pose is being estimated. - """ - - """ - A index variable used to give skeletons a default name that attempts + An index variable used to give skeletons a default name that attempts to be unique across all skeletons. """ + _skeleton_idx = count(0) def __init__(self, name: str = None): @@ -160,29 +158,31 @@ def graph_symmetry(self): return self._graph.edge_subgraph(edges) @staticmethod - def find_unique_nodes(skeletons: List["Skeleton"]): + def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: """ - Given list of skeletons, return a list of unique node objects across all skeletons. + Find all unique nodes from a list of skeletons. Args: skeletons: The list of skeletons. Returns: - A list of unique node objects. + A list of unique `Node` objects. """ return list({node for skeleton in skeletons for node in skeleton.nodes}) @staticmethod - def make_cattr(idx_to_node: Dict[int, Node] = None): + def make_cattr(idx_to_node: Dict[int, Node] = None) -> Callable: """ - Create a cattr.Converter() that registers structure and unstructure hooks for - Skeleton objects that handle serialization of skeletons objects. + Make cattr.Convert() for `Skeleton`. + + Make a cattr.Converter() that registers structure and unstructure + hooks for Skeleton objects to handle serialization of skeletons. Args: idx_to_node: A dict that maps node index to Node objects. Returns: - A cattr.Converter() instance ready for skeleton serialization and deserialization. + A cattr.Converter() instance for skeleton serialization and deserialization. """ node_to_idx = ( {node: idx for idx, node in idx_to_node.items()} @@ -200,7 +200,7 @@ def make_cattr(idx_to_node: Dict[int, Node] = None): return _cattr @property - def name(self): + def name(self) -> str: """Get the name of the skeleton. Returns: @@ -211,8 +211,10 @@ def name(self): @name.setter def name(self, name: str): """ - A skeleton object cannot change its name. This property is immutable because it is - used to hash skeletons. If you want to rename a Skeleton you must use the class + A skeleton object cannot change its name. + + This property is immutable because it is used to hash skeletons. + If you want to rename a Skeleton you must use the class method :code:`rename_skeleton`: >>> new_skeleton = Skeleton.rename_skeleton(skeleton=old_skeleton, name="New Name") @@ -221,7 +223,7 @@ def name(self, name: str): name: The name of the Skeleton. Raises: - NotImplementedError + NotImplementedError: Error is always raised. """ raise NotImplementedError( "Cannot change Skeleton name, it is immutable since " @@ -233,8 +235,10 @@ def name(self, name: str): @classmethod def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": """ - A skeleton object cannot change its name. This property is immutable because it is - used to hash skeletons. If you want to rename a Skeleton you must use this classmethod. + Make copy of skeleton with new name. + + This property is immutable because it is used to hash skeletons. + If you want to rename a Skeleton you must use this class method. >>> new_skeleton = Skeleton.rename_skeleton(skeleton=old_skeleton, name="New Name") @@ -251,7 +255,7 @@ def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": return new_skeleton @property - def nodes(self): + def nodes(self) -> List[Node]: """Get a list of :class:`Node`s. Returns: @@ -260,7 +264,7 @@ def nodes(self): return list(self._graph.nodes) @property - def node_names(self): + def node_names(self) -> List[str]: """Get a list of node names. Returns: @@ -269,7 +273,7 @@ def node_names(self): return [node.name for node in self.nodes] @property - def edges(self): + def edges(self) -> List[Tuple[Node, Node]]: """Get a list of edge tuples. Returns: @@ -289,7 +293,7 @@ def edges(self): return edge_list @property - def edge_names(self): + def edge_names(self) -> List[Tuple[str, str]]: """Get a list of edge name tuples. Returns: @@ -309,7 +313,7 @@ def edge_names(self): return [(src.name, dst.name) for src, dst in self.edges] @property - def edges_full(self): + def edges_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of edge tuples with keys and attributes. Returns: @@ -322,7 +326,7 @@ def edges_full(self): ] @property - def symmetries(self): + def symmetries(self) -> List[Tuple[Node, Node]]: """Get a list of all symmetries without duplicates. Returns: @@ -339,10 +343,11 @@ def symmetries(self): return symmetries @property - def symmetries_full(self): + def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of all symmetries with keys and attributes. - Note: The returned list will contain duplicates (node1, node2) and (node2, node1). + Note: The returned list will contain duplicates (node1, node2) + and (node2, node1). Returns: list of (node1, node2, key, attr) @@ -354,13 +359,16 @@ def symmetries_full(self): if attr["type"] == EdgeType.SYMMETRY ] - def node_to_index(self, node: Union[str, Node]): + def node_to_index(self, node: Union[str, Node]) -> int: """ - Return the index of the node, accepts either a node or string name of a Node. + Return the index of the node, accepts either `Node` or name. Args: node: The name of the node or the Node object. + Raises: + ValueError if node cannot be found in skeleton. + Returns: The index of the node in the graph. """ @@ -374,7 +382,11 @@ def add_node(self, name: str): """Add a node representing an animal part to the skeleton. Args: - name: The name of the node to add to the skeleton. This name must be unique within the skeleton. + name: The name of the node to add to the skeleton. + This name must be unique within the skeleton. + + Raises: + ValueError: If name is not unique. Returns: None @@ -387,7 +399,7 @@ def add_node(self, name: str): self._graph.add_node(Node(name)) - def add_nodes(self, name_list: list): + def add_nodes(self, name_list: List[str]): """ Add a list of nodes representing animal parts to the skeleton. @@ -406,7 +418,10 @@ def delete_node(self, name: str): The method removes a node from the skeleton and any edge that is connected to it. Args: - name: The name of the edge to remove + name: The name of the node to remove + + Raises: + ValueError: If node cannot be found. Returns: None @@ -419,14 +434,14 @@ def delete_node(self, name: str): "The node named ({}) does not exist, cannot remove it.".format(name) ) - def find_node(self, name: str): + def find_node(self, name: str) -> Node: """Find node in skeleton by name of node. Args: name: The name of the :class:`Node` (or a :class:`Node`) Returns: - Node, or None if no match found + `Node`, or None if no match found """ if isinstance(name, Node): name = name.name @@ -447,10 +462,12 @@ def add_edge(self, source: str, destination: str): Args: source: The name of the source node. destination: The name of the destination node. + Raises: + ValueError: If source or destination nodes cannot be found, + or if edge already exists between those nodes. Returns: - None - + None. """ if isinstance(source, Node): source_node = source @@ -498,6 +515,10 @@ def delete_edge(self, source: str, destination: str): source: The name of the source node. destination: The name of the destination node. + Raises: + ValueError: If skeleton does not have either source node, + destination node, or edge between them. + Returns: None """ @@ -542,6 +563,10 @@ def add_symmetry(self, node1: str, node2: str): node1: The name of the first part in the symmetric pair node2: The name of the second part in the symmetric pair + Raises: + ValueError: If node1 and node2 match, or if there is already + a symmetry between them. + Returns: None @@ -574,6 +599,9 @@ def delete_symmetry(self, node1: str, node2: str): node1: The name of the first part in the symmetric pair node2: The name of the second part in the symmetric pair + Raises: + ValueError: If there's no symmetry between node1 and node2. + Returns: None """ @@ -595,12 +623,15 @@ def delete_symmetry(self, node1: str, node2: str): ] self._graph.remove_edges_from(edges) - def get_symmetry(self, node: str): + def get_symmetry(self, node: str) -> Optional[Node]: """ Returns the node symmetric with the specified node. Args: node: The name of the node to query. + Raises: + ValueError: If node has more than one symmetry. + Returns: The symmetric :class:`Node`, None if no symmetry """ @@ -619,7 +650,7 @@ def get_symmetry(self, node: str): else: raise ValueError(f"{node} has more than one symmetry.") - def get_symmetry_name(self, node: str): + def get_symmetry_name(self, node: str) -> Optional[str]: """Returns the name of the node symmetric with the specified node. Args: @@ -638,6 +669,9 @@ def __getitem__(self, node_name: str) -> dict: Args: node_name: The name from which to retrieve data. + Raises: + ValueError: If node cannot be found. + Returns: A dictionary of data associated with this node. @@ -678,7 +712,11 @@ def relabel_nodes(self, mapping: Dict[str, str]): Relabel the nodes of the skeleton. Args: - mapping: A dictionary with the old labels as keys and new labels as values. A partial mapping is allowed. + mapping: A dictionary with the old labels as keys and new + labels as values. A partial mapping is allowed. + + Raises: + ValueError: If node already present with one of the new names. Returns: None @@ -743,22 +781,60 @@ def has_edge(self, source_name: str, dest_name: str) -> bool: return self._graph.has_edge(source_node, destination_node) @staticmethod - def to_dict(obj: "Skeleton", node_to_idx: Dict[Node, int] = None): + def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> Dict: + """ + Convert `Skeleton` to dict; used for saving as JSON. + + Args: + obj: the `Skeleton` + node_to_idx: optional dict which maps `Node` objects + to index in some list. This is used when saving `Labels` + where we want to serialize the `Nodes` outside the + `Skeleton` object. + If given, then we replace each `Node` with specified + index before converting `Skeleton`. Otherwise, we + convert `Node`s with the rest of the `Skeleton`. + Returns: + dict with data from `Skeleton` + """ # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. return json.loads(obj.to_json(node_to_idx)) @classmethod - def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None): + def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": + """ + Create `Skeleton` from dict; used for loading from JSON. + + Args: + d: the `dict` from which to deserialize + node_to_idx: optional dict which maps `Node` objects + to index in some list. This is used when saving `Labels` + where we want to serialize the `Nodes` outside the + `Skeleton` object. + If given, then we can replace the int graph nodes + with appropriate `Node` objects. Otherwise, we'll + leave the nodes as is. + + Returns: + `Skeleton`. + + """ return Skeleton.from_json(json.dumps(d), node_to_idx) - def to_json(self, node_to_idx: Dict[Node, int] = None) -> str: + def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: """ Convert the skeleton to a JSON representation. Args: - node_to_idx (optional): Map for converting `Node` nodes to int + node_to_idx: optional dict which maps `Node` objects + to index in some list. This is used when saving `Labels` + where we want to serialize the `Nodes` outside the + `Skeleton` object. + If given, then we replace each `Node` with specified + index before converting `Skeleton`. Otherwise, we + convert `Node`s with the rest of the `Skeleton`. Returns: A string containing the JSON representation of the Skeleton. @@ -776,18 +852,25 @@ def to_json(self, node_to_idx: Dict[Node, int] = None) -> str: return json_str - def save_json(self, filename: str, node_to_idx: Dict[Node, int] = None): - """Save the skeleton as JSON file. + def save_json(self, filename: str, node_to_idx: Optional[Dict[Node, int]] = None): + """ + Save the skeleton as JSON file. - Output the complete skeleton to a file in JSON format. + Output the complete skeleton to a file in JSON format. - Args: - filename: The filename to save the JSON to. - node_to_idx (optional): Map for converting `Node` nodes to int + Args: + filename: The filename to save the JSON to. + node_to_idx: optional dict which maps `Node` objects + to index in some list. This is used when saving `Labels` + where we want to serialize the `Nodes` outside the + `Skeleton` object. + If given, then we can replace the int graph nodes + with appropriate `Node` objects. Otherwise, we'll + leave the nodes as is. - Returns: - None - """ + Returns: + None + """ json_str = self.to_json(node_to_idx) @@ -795,16 +878,22 @@ def save_json(self, filename: str, node_to_idx: Dict[Node, int] = None): file.write(json_str) @classmethod - def from_json(cls, json_str: str, idx_to_node: Dict[int, Node] = None): + def from_json( + cls, json_str: str, idx_to_node: Dict[int, Node] = None + ) -> "Skeleton": """ - Parse a JSON string containing the Skeleton object and create an instance from it. + Instantiate `Skeleton` from JSON string. Args: json_str: The JSON encoded Skeleton. - idx_to_node (optional): Map for converting int node in json back to corresponding `Node`. + idx_to_node: optional dict which maps an int (indexing a + list of `Node`s) to the already deserialized `Node`. + This should invert `node_to_idx` we used when saving. + If not given, then we'll assume each `Node` was left + in the `Skeleton` when it was saved. Returns: - An instance of the Skeleton object decoded from the JSON. + An instance of the `Skeleton` object decoded from the JSON. """ graph = json_graph.node_link_graph(jsonpickle.decode(json_str)) @@ -818,7 +907,9 @@ def from_json(cls, json_str: str, idx_to_node: Dict[int, Node] = None): return skeleton @classmethod - def load_json(cls, filename: str, idx_to_node: Dict[int, Node] = None): + def load_json( + cls, filename: str, idx_to_node: Dict[int, Node] = None + ) -> "Skeleton": """Load a skeleton from a JSON file. This method will load the Skeleton from JSON file saved with; :meth:`~Skeleton.save_json` @@ -828,17 +919,17 @@ def load_json(cls, filename: str, idx_to_node: Dict[int, Node] = None): idx_to_node (optional): Map for converting int node in json back to corresponding `Node`. Returns: - The Skeleton object stored in the JSON filename. + The `Skeleton` object stored in the JSON filename. """ with open(filename, "r") as file: - skeleton = Skeleton.from_json(file.read(), idx_to_node) + skeleton = cls.from_json(file.read(), idx_to_node) return skeleton @classmethod - def load_hdf5(cls, file: Union[str, h5.File], name: str): + def load_hdf5(cls, file: Union[str, h5.File], name: str) -> List["Skeleton"]: """ Load a specific skeleton (by name) from the HDF5 file. @@ -847,13 +938,13 @@ def load_hdf5(cls, file: Union[str, h5.File], name: str): name: The name of the skeleton. Returns: - The skeleton instance stored in the HDF5 file. + The specified `Skeleton` instance stored in the HDF5 file. """ if isinstance(file, str): with h5.File(file) as _file: - skeletons = Skeleton._load_hdf5(_file) # Load all skeletons + skeletons = cls._load_hdf5(_file) # Load all skeletons else: - skeletons = Skeleton._load_hdf5(file) + skeletons = cls._load_hdf5(file) return skeletons[name] @@ -866,18 +957,20 @@ def load_all_hdf5( Args: file: The file name or open h5.File - return_dict: True if the the return value should be a dict where the - keys are skeleton names and values the corresponding skeleton. False - if the return should just be a list of the skeletons. + return_dict: Whether the the return value should be a dict + where the keys are skeleton names and values the + corresponding skeleton. If False, then method will + return just a list of the skeletons. Returns: - The skeleton instances stored in the HDF5 file. Either in List or Dict form. + The skeleton instances stored in the HDF5 file. + Either in List or Dict form. """ if isinstance(file, str): with h5.File(file) as _file: - skeletons = Skeleton._load_hdf5(_file) # Load all skeletons + skeletons = cls._load_hdf5(_file) # Load all skeletons else: - skeletons = Skeleton._load_hdf5(file) + skeletons = cls._load_hdf5(file) if return_dict: return skeletons @@ -889,27 +982,24 @@ def _load_hdf5(cls, file: h5.File): skeletons = {} for name, json_str in file["skeleton"].attrs.items(): - skeletons[name] = Skeleton.from_json(json_str) + skeletons[name] = cls.from_json(json_str) return skeletons - def save_hdf5(self, file: Union[str, h5.File]): - if isinstance(file, str): - with h5.File(file) as _file: - self._save_hdf5(_file) - else: - self._save_hdf5(file) - @classmethod def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List["Skeleton"]): """ - Convenience method to save a list of skeletons to HDF5 file. Skeletons are saved - as attributes of a /skeleton group in the file. + Convenience method to save a list of skeletons to HDF5 file. + + Skeletons are saved as attributes of a /skeleton group in the file. Args: file: The filename or the open h5.File object. skeletons: The list of skeletons to save. + Raises: + ValueError: If multiple skeletons have the same name. + Returns: None """ @@ -923,12 +1013,29 @@ def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List["Skeleton"]): for skeleton in skeletons: skeleton.save_hdf5(file) + def save_hdf5(self, file: Union[str, h5.File]): + """ + Wrapper for HDF5 saving which takes either filename or h5.File. + + Args: + file: can be filename (string) or `h5.File` object + + Returns: + None + """ + + if isinstance(file, str): + with h5.File(file) as _file: + self._save_hdf5(_file) + else: + self._save_hdf5(file) + def _save_hdf5(self, file: h5.File): """ Actual implementation of HDF5 saving. Args: - file: The open h5.File to write the skeleton data too. + file: The open h5.File to write the skeleton data to. Returns: None @@ -945,10 +1052,12 @@ def _save_hdf5(self, file: h5.File): all_sk_group.attrs[self.name] = np.string_(self.to_json()) @classmethod - def load_mat(cls, filename: str): + def load_mat(cls, filename: str) -> "Skeleton": """ - Load the skeleton from a Matlab MAT file. This is to support backwards - compatibility with old LEAP MATLAB code and datasets. + Load the skeleton from a Matlab MAT file. + + This is to support backwards compatibility with old LEAP + MATLAB code and datasets. Args: filename: The name of the skeleton file From 2e4455dc174aa28b1edce9af51f27ea588539ca7 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 21:30:14 -0400 Subject: [PATCH 123/176] Remove unused, undocumented, untested methods. --- sleap/skeleton.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index c7c96f5bc..e607f9bc6 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -136,27 +136,6 @@ def dict_match(dict1, dict2): # Check if the two graphs are equal return True - @property - def graph(self): - edges = [ - (src, dst, key) - for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") - if edge_type == EdgeType.BODY - ] - # TODO: properly induce subgraph for MultiDiGraph - # Currently, NetworkX will just return the nodes in the subgraph. - # See: https://stackoverflow.com/questions/16150557/networkxcreating-a-subgraph-induced-from-edges - return self._graph.edge_subgraph(edges) - - @property - def graph_symmetry(self): - edges = [ - (src, dst, key) - for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") - if edge_type == EdgeType.SYMMETRY - ] - return self._graph.edge_subgraph(edges) - @staticmethod def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: """ From 16d7fe162caddecb8141426dc5cec01b1e7e9cc5 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 25 Sep 2019 21:38:39 -0400 Subject: [PATCH 124/176] Fixed typing on make_cattr. --- sleap/skeleton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index e607f9bc6..a4c5ce923 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -150,7 +150,7 @@ def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: return list({node for skeleton in skeletons for node in skeleton.nodes}) @staticmethod - def make_cattr(idx_to_node: Dict[int, Node] = None) -> Callable: + def make_cattr(idx_to_node: Dict[int, Node] = None) -> cattr.Converter: """ Make cattr.Convert() for `Skeleton`. From 6f5eb418110def0c25020cc5b69d1a7d6fb9da65 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 05:50:46 -0400 Subject: [PATCH 125/176] appveyor debugging --- appveyor.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index b5ea78828..a7837d9bf 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,7 +71,8 @@ test_script: - cmd: pytest --cov=sleap tests/ on_success: - - coveralls + - cmd: set + - cmd: coveralls # here we are going to override common configuration for: From 4c194eec31982fcc9112f00b9af7ff6c4ec228bd Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 05:59:11 -0400 Subject: [PATCH 126/176] Revert "Remove unused, undocumented, methods." This reverts commit 2e4455dc174aa28b1edce9af51f27ea588539ca7. --- sleap/skeleton.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index a4c5ce923..5596db6ae 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -136,6 +136,27 @@ def dict_match(dict1, dict2): # Check if the two graphs are equal return True + @property + def graph(self): + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.BODY + ] + # TODO: properly induce subgraph for MultiDiGraph + # Currently, NetworkX will just return the nodes in the subgraph. + # See: https://stackoverflow.com/questions/16150557/networkxcreating-a-subgraph-induced-from-edges + return self._graph.edge_subgraph(edges) + + @property + def graph_symmetry(self): + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.SYMMETRY + ] + return self._graph.edge_subgraph(edges) + @staticmethod def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: """ From c27d4e6f116892639b2c02c84d762905ed3c34cb Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 06:17:05 -0400 Subject: [PATCH 127/176] coveralls instead of python-coveralls --- appveyor.yml | 1 - dev_requirements.txt | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index a7837d9bf..d069bdb97 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -71,7 +71,6 @@ test_script: - cmd: pytest --cov=sleap tests/ on_success: - - cmd: set - cmd: coveralls # here we are going to override common configuration diff --git a/dev_requirements.txt b/dev_requirements.txt index 0f005e97b..7b70af633 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,4 +4,4 @@ pytest-cov ipython sphinx sphinx_rtd_theme -python-coveralls \ No newline at end of file +coveralls \ No newline at end of file From 274ceca9f23c33d69f07c6b429dd3e90e5e81f7a Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 06:17:05 -0400 Subject: [PATCH 128/176] coveralls instead of python-coveralls --- dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 0f005e97b..7b70af633 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,4 +4,4 @@ pytest-cov ipython sphinx sphinx_rtd_theme -python-coveralls \ No newline at end of file +coveralls \ No newline at end of file From 1415335b0e927117f5c93c2fa9a97254e237068e Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 08:50:47 -0400 Subject: [PATCH 129/176] Wrap docstrings to 79 chars. --- sleap/skeleton.py | 165 ++++++++++++++++++++++++++-------------------- 1 file changed, 95 insertions(+), 70 deletions(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 5596db6ae..1f3c21e98 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -26,11 +26,13 @@ class EdgeType(Enum): """ The skeleton graph can store different types of edges to represent - different things. All edges must specify one or more of the following types. + different things. All edges must specify one or more of the + following types: - * BODY - these edges represent connections between parts or landmarks. - * SYMMETRY - these edges represent symmetrical relationships between - parts (e.g. left and right arms) + * BODY - these edges represent connections between parts or + landmarks. + * SYMMETRY - these edges represent symmetrical relationships + between parts (e.g. left and right arms) """ BODY = 1 @@ -74,21 +76,24 @@ def matches(self, other: "Node") -> bool: class Skeleton: - """The main object for representing animal skeletons. + """ + The main object for representing animal skeletons. - The skeleton represents the constituent parts of the animal whose pose - is being estimated. + The skeleton represents the constituent parts of the animal whose + pose is being estimated. - An index variable used to give skeletons a default name that attempts - to be unique across all skeletons. + An index variable used to give skeletons a default name that should + be unique across all skeletons. """ _skeleton_idx = count(0) def __init__(self, name: str = None): - """Initialize an empty skeleton object. + """ + Initialize an empty skeleton object. - Skeleton objects, once they are created can be modified by adding nodes and edges. + Skeleton objects, once created, can be modified by adding nodes + and edges. Args: name: A name for this skeleton. @@ -175,14 +180,15 @@ def make_cattr(idx_to_node: Dict[int, Node] = None) -> cattr.Converter: """ Make cattr.Convert() for `Skeleton`. - Make a cattr.Converter() that registers structure and unstructure + Make a cattr.Converter() that registers structure/unstructure hooks for Skeleton objects to handle serialization of skeletons. Args: idx_to_node: A dict that maps node index to Node objects. Returns: - A cattr.Converter() instance for skeleton serialization and deserialization. + A cattr.Converter() instance for skeleton serialization + and deserialization. """ node_to_idx = ( {node: idx for idx, node in idx_to_node.items()} @@ -217,7 +223,8 @@ def name(self, name: str): If you want to rename a Skeleton you must use the class method :code:`rename_skeleton`: - >>> new_skeleton = Skeleton.rename_skeleton(skeleton=old_skeleton, name="New Name") + >>> new_skeleton = Skeleton.rename_skeleton( + >>> skeleton=old_skeleton, name="New Name") Args: name: The name of the Skeleton. @@ -240,7 +247,8 @@ def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": This property is immutable because it is used to hash skeletons. If you want to rename a Skeleton you must use this class method. - >>> new_skeleton = Skeleton.rename_skeleton(skeleton=old_skeleton, name="New Name") + >>> new_skeleton = Skeleton.rename_skeleton( + >>> skeleton=old_skeleton, name="New Name") Args: skeleton: The skeleton to copy. @@ -415,7 +423,8 @@ def add_nodes(self, name_list: List[str]): def delete_node(self, name: str): """Remove a node from the skeleton. - The method removes a node from the skeleton and any edge that is connected to it. + The method removes a node from the skeleton and any edge that is + connected to it. Args: name: The name of the node to remove @@ -554,10 +563,10 @@ def delete_edge(self, source: str, destination: str): self._graph.remove_edge(source_node, destination_node) def add_symmetry(self, node1: str, node2: str): - """Specify that two parts (nodes) in the skeleton are symmetrical. + """Specify that two parts (nodes) in skeleton are symmetrical. Certain parts of an animal body can be related as symmetrical - parts in a pair. For example, the left and right hands of a person. + parts in a pair. For example, left and right hands of a person. Args: node1: The name of the first part in the symmetric pair @@ -593,11 +602,12 @@ def add_symmetry(self, node1: str, node2: str): self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY) def delete_symmetry(self, node1: str, node2: str): - """Deletes a previously established symmetry relationship between two nodes. + """ + Deletes a previously established symmetry between two nodes. Args: - node1: The name of the first part in the symmetric pair - node2: The name of the second part in the symmetric pair + node1: The name of the first part in the symmetric pair. + node2: The name of the second part in the symmetric pair. Raises: ValueError: If there's no symmetry between node1 and node2. @@ -624,7 +634,8 @@ def delete_symmetry(self, node1: str, node2: str): self._graph.remove_edges_from(edges) def get_symmetry(self, node: str) -> Optional[Node]: - """ Returns the node symmetric with the specified node. + """ + Returns the node symmetric with the specified node. Args: node: The name of the node to query. @@ -651,7 +662,8 @@ def get_symmetry(self, node: str) -> Optional[Node]: raise ValueError(f"{node} has more than one symmetry.") def get_symmetry_name(self, node: str) -> Optional[str]: - """Returns the name of the node symmetric with the specified node. + """ + Returns the name of the node symmetric with the specified node. Args: node: The name of the node to query. @@ -664,7 +676,7 @@ def get_symmetry_name(self, node: str) -> Optional[str]: def __getitem__(self, node_name: str) -> dict: """ - Retrieves the node data associated with Skeleton node. + Retrieves the node data associated with skeleton node. Args: node_name: The name from which to retrieve data. @@ -783,19 +795,20 @@ def has_edge(self, source_name: str, dest_name: str) -> bool: @staticmethod def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> Dict: """ - Convert `Skeleton` to dict; used for saving as JSON. + Convert skeleton to dict; used for saving as JSON. Args: - obj: the `Skeleton` - node_to_idx: optional dict which maps `Node` objects - to index in some list. This is used when saving `Labels` - where we want to serialize the `Nodes` outside the - `Skeleton` object. - If given, then we replace each `Node` with specified - index before converting `Skeleton`. Otherwise, we - convert `Node`s with the rest of the `Skeleton`. + obj: the :object:`Skeleton` to convert + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node`s with the rest of + the :class:`Skeleton`. Returns: - dict with data from `Skeleton` + dict with data from skeleton """ # This is a weird hack to serialize the whole _graph into a dict. @@ -805,39 +818,41 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D @classmethod def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": """ - Create `Skeleton` from dict; used for loading from JSON. + Create skeleton from dict; used for loading from JSON. Args: - d: the `dict` from which to deserialize - node_to_idx: optional dict which maps `Node` objects - to index in some list. This is used when saving `Labels` - where we want to serialize the `Nodes` outside the - `Skeleton` object. - If given, then we can replace the int graph nodes - with appropriate `Node` objects. Otherwise, we'll - leave the nodes as is. + d: the dict from which to deserialize + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node`s with the rest of + the :class:`Skeleton`. Returns: - `Skeleton`. + :class:`Skeleton`. """ return Skeleton.from_json(json.dumps(d), node_to_idx) def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: """ - Convert the skeleton to a JSON representation. + Convert the :class:`Skeleton` to a JSON representation. Args: - node_to_idx: optional dict which maps `Node` objects - to index in some list. This is used when saving `Labels` - where we want to serialize the `Nodes` outside the - `Skeleton` object. - If given, then we replace each `Node` with specified - index before converting `Skeleton`. Otherwise, we - convert `Node`s with the rest of the `Skeleton`. + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node`s with the rest of + the :class:`Skeleton`. Returns: - A string containing the JSON representation of the Skeleton. + A string containing the JSON representation of the skeleton. """ jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) if node_to_idx is not None: @@ -854,19 +869,20 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: def save_json(self, filename: str, node_to_idx: Optional[Dict[Node, int]] = None): """ - Save the skeleton as JSON file. + Save the :class:`Skeleton` as JSON file. Output the complete skeleton to a file in JSON format. Args: filename: The filename to save the JSON to. - node_to_idx: optional dict which maps `Node` objects - to index in some list. This is used when saving `Labels` - where we want to serialize the `Nodes` outside the - `Skeleton` object. - If given, then we can replace the int graph nodes - with appropriate `Node` objects. Otherwise, we'll - leave the nodes as is. + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node`s with the rest of + the :class:`Skeleton`. Returns: None @@ -882,15 +898,16 @@ def from_json( cls, json_str: str, idx_to_node: Dict[int, Node] = None ) -> "Skeleton": """ - Instantiate `Skeleton` from JSON string. + Instantiate :class:`Skeleton` from JSON string. Args: json_str: The JSON encoded Skeleton. idx_to_node: optional dict which maps an int (indexing a - list of `Node`s) to the already deserialized `Node`. + list of :class:`Node`s) to the already deserialized + :class:`Node`. This should invert `node_to_idx` we used when saving. - If not given, then we'll assume each `Node` was left - in the `Skeleton` when it was saved. + If not given, then we'll assume each :class:`Node` was + left in the :class:`Skeleton` when it was saved. Returns: An instance of the `Skeleton` object decoded from the JSON. @@ -910,13 +927,20 @@ def from_json( def load_json( cls, filename: str, idx_to_node: Dict[int, Node] = None ) -> "Skeleton": - """Load a skeleton from a JSON file. + """ + Load a skeleton from a JSON file. - This method will load the Skeleton from JSON file saved with; :meth:`~Skeleton.save_json` + This method will load the Skeleton from JSON file saved + with; :meth:`~Skeleton.save_json` Args: - filename: The file that contains the JSON specifying the skeleton. - idx_to_node (optional): Map for converting int node in json back to corresponding `Node`. + filename: The file that contains the JSON. + idx_to_node: optional dict which maps an int (indexing a + list of :class:`Node`s) to the already deserialized + :class:`Node`. + This should invert `node_to_idx` we used when saving. + If not given, then we'll assume each :class:`Node` was + left in the :class:`Skeleton` when it was saved. Returns: The `Skeleton` object stored in the JSON filename. @@ -991,7 +1015,8 @@ def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List["Skeleton"]): """ Convenience method to save a list of skeletons to HDF5 file. - Skeletons are saved as attributes of a /skeleton group in the file. + Skeletons are saved as attributes of a /skeleton group in the + file. Args: file: The filename or the open h5.File object. From 3686e536a41f1163a0ec445010e7c8f4c2bcd4e2 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 10:08:34 -0400 Subject: [PATCH 130/176] Better typing and docstrings. --- sleap/instance.py | 530 +++++++++++++++++++++++++++++------------ tests/test_instance.py | 62 ++--- 2 files changed, 417 insertions(+), 175 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index 62c0ff0e2..10140bebf 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1,12 +1,23 @@ """ +Data structures for all labeled data contained with a SLEAP project. +The relationships between objects in this module: + +* A `LabeledFrame` can contain zero or more `Instance`s + (and `PredictedInstance`s). + +* `Instance`s (and `PredictedInstance`s) have `PointArray` + (or `PredictedPointArray`). + +* A `PointArray` (or `PredictedPointArray`) contains zero or more + `Point`s (or `PredictedPoint`s), ideally as many as there are in the + associated :class:`Skeleton` although these can get out of sync if the + skeleton is manipulated. """ import math import numpy as np -import h5py as h5 -import pandas as pd import cattr from copy import copy @@ -27,13 +38,13 @@ class Point(np.record): """ - A very simple class to define a labelled point and any metadata associated with it. + A labelled point and any metadata associated with it. Args: - x: The horizontal pixel location of the point within the image frame. - y: The vertical pixel location of the point within the image frame. + x: The horizontal pixel location of point within image frame. + y: The vertical pixel location of point within image frame. visible: Whether point is visible in the labelled image or not. - complete: Has the point been verified by the a user labeler. + complete: Has the point been verified by the user labeler. """ # Define the dtype from the point class attributes plus some @@ -47,7 +58,7 @@ def __new__( y: float = math.nan, visible: bool = True, complete: bool = False, - ): + ) -> "Point": # HACK: This is a crazy way to instantiate at new Point but I can't figure # out how recarray does it. So I just use it to make matrix of size 1 and @@ -64,10 +75,10 @@ def __new__( return val - def __str__(self): + def __str__(self) -> str: return f"({self.x}, {self.y})" - def isnan(self): + def isnan(self) -> bool: """ Are either of the coordinates a NaN value. @@ -84,15 +95,16 @@ def isnan(self): class PredictedPoint(Point): """ - A predicted point is an output of the inference procedure. It has all - the properties of a labeled point with an accompanying score. + A predicted point is an output of the inference procedure. + + It has all the properties of a labeled point, plus a score. Args: - x: The horizontal pixel location of the point within the image frame. - y: The vertical pixel location of the point within the image frame. + x: The horizontal pixel location of point within image frame. + y: The vertical pixel location of point within image frame. visible: Whether point is visible in the labelled image or not. - complete: Has the point been verified by the a user labeler. - score: The point level prediction score. + complete: Has the point been verified by the user labeler. + score: The point-level prediction score. """ # Define the dtype from the point class attributes plus some @@ -109,7 +121,7 @@ def __new__( visible: bool = True, complete: bool = False, score: float = 0.0, - ): + ) -> "PredictedPoint": # HACK: This is a crazy way to instantiate at new Point but I can't figure # out how recarray does it. So I just use it to make matrix of size 1 and @@ -128,7 +140,7 @@ def __new__( return val @classmethod - def from_point(cls, point: Point, score: float = 0.0): + def from_point(cls, point: Point, score: float = 0.0) -> "PredictedPoint": """ Create a PredictedPoint from a Point @@ -169,7 +181,7 @@ def __new__( byteorder=None, aligned=False, order="C", - ): + ) -> "PointArray": dtype = subtype._record_type.dtype @@ -196,16 +208,21 @@ def __new__( def __array_finalize__(self, obj): """ - Overide __array_finalize__ on recarray because it converting the dtype - of any np.void subclass to np.record, we don't want this. + Override :method:`np.recarray.__array_finalize__()`. + + Overide __array_finalize__ on recarray because it converting the + dtype of any np.void subclass to np.record, we don't want this. """ pass @classmethod - def make_default(cls, size: int): + def make_default(cls, size: int) -> "PointArray": """ - Construct a point array of specific size where each value in the array - is assigned the default values for a Point. + Construct a point array where points are all set to default. + + The constructed :class:`PointArray` will have specified size + and each value in the array is assigned the default values for + a :class:`Point``. Args: size: The number of points to allocate. @@ -217,7 +234,8 @@ def make_default(cls, size: int): p[:] = cls._record_type() return p - def __getitem__(self, indx): + def __getitem__(self, indx: int) -> "Point": + """Get point by its index in the array.""" obj = super(np.recarray, self).__getitem__(indx) # copy behavior of getattr, except that here @@ -235,17 +253,21 @@ def __getitem__(self, indx): return obj @classmethod - def from_array(cls, a: "PointArray"): + def from_array(cls, a: "PointArray") -> "PointArray": """ - Convert a PointArray to a new PointArray - (or child class, i.e., PredictedPointArray), - use the default attribute values for new array. + Converts a :class:`PointArray` (or child) to a new instance. + + This will convert an object to the same type as itself, + so a :class:`PredictedPointArray` will result in the same. + + Uses the default attribute values for new array. Args: a: The array to convert. Returns: - A PredictedPointArray with the same points as a. + A :class:`PointArray` or :class:`PredictedPointArray` with + the same points as a. """ v = cls.make_default(len(a)) @@ -264,7 +286,7 @@ class PredictedPointArray(PointArray): _record_type = PredictedPoint @classmethod - def to_array(cls, a: "PredictedPointArray"): + def to_array(cls, a: "PredictedPointArray") -> "PointArray": """ Convert a PredictedPointArray to a normal PointArray. @@ -285,12 +307,12 @@ def to_array(cls, a: "PredictedPointArray"): @attr.s(slots=True, cmp=False) class Track: """ - A track object is associated with a set of animal/object instances across multiple - frames of video. This allows tracking of unique entities in the video over time and - space. + A track object is associated with a set of animal/object instances + across multiple frames of video. This allows tracking of unique + entities in the video over time and space. Args: - spawned_on: The frame of the video that this track was spawned on. + spawned_on: The video frame that this track was spawned on. name: A name given to this track for identifying purposes. """ @@ -319,17 +341,21 @@ def matches(self, other: "Track"): @attr.s(cmp=False, slots=True) class Instance: """ - The class :class:`Instance` represents a labelled instance of skeleton + The class :class:`Instance` represents a labelled instance of a skeleton. Args: skeleton: The skeleton that this instance is associated with. - points: A dictionary where keys are skeleton node names and values are Point objects. Alternatively, - a point array whose length and order matches skeleton.nodes - track: An optional multi-frame object track associated with this instance. - This allows individual animals/objects to be tracked across frames. - from_predicted: The predicted instance (if any) that this was copied from. - frame: A back reference to the LabeledFrame that this Instance belongs to. - This field is set when Instances are added to LabeledFrame objects. + points: A dictionary where keys are skeleton node names and + values are Point objects. Alternatively, a point array whose + length and order matches skeleton.nodes. + track: An optional multi-frame object track associated with + this instance. This allows individual animals/objects to be + tracked across frames. + from_predicted: The predicted instance (if any) that this was + copied from. + frame: A back reference to the :class:`LabeledFrame` that this + :class:`Instance` belongs to. This field is set when + instances are added to :class:`LabeledFrame` objects. """ skeleton: Skeleton = attr.ib() @@ -343,22 +369,47 @@ class Instance: _point_array_type = PointArray @from_predicted.validator - def _validate_from_predicted_(self, attribute, from_predicted): + def _validate_from_predicted_( + self, attribute, from_predicted: Optional["PredictedInstance"] + ): + """ + Validation method called by attrs. + + Checks that from_predicted is None or :class:`PredictedInstance` + + Args: + attribute: Attribute being validated; not used. + from_predicted: Value being validated. + + Raises: + TypeError: If from_predicted is anything other than None + or a `PredictedInstance`. + + """ if from_predicted is not None and type(from_predicted) != PredictedInstance: raise TypeError( f"Instance.from_predicted type must be PredictedInstance (not {type(from_predicted)})" ) @_points.validator - def _validate_all_points(self, attribute, points): + def _validate_all_points(self, attribute, points: Union[dict, PointArray]): """ - Function that makes sure all the _points defined for the skeleton are found in the skeleton. + Validation method called by attrs. - Returns: - None + Checks that all the _points defined for the skeleton are found + in the skeleton. + + Args: + attribute: Attribute being validated; not used. + points: Either dict of points or PointArray + If dict, keys should be node names. Raises: - ValueError: If a point is associated with a skeleton node name that doesn't exist. + ValueError: If a point is associated with a skeleton node + name that doesn't exist. + + Returns: + None """ if type(points) is dict: is_string_dict = set(map(type, points)) == {str} @@ -375,6 +426,22 @@ def _validate_all_points(self, attribute, points): ) def __attrs_post_init__(self): + """ + Method called by attrs after __init__() + + Initializes points if none were specified when creating object, + caches list of nodes so what we can still find points in array + if the `Skeleton` changes. + + Args: + None + + Raises: + ValueError: If object has no `Skeleton`. + + Returns: + None + """ if not self.skeleton: raise ValueError("No skeleton set for Instance") @@ -399,7 +466,26 @@ def __attrs_post_init__(self): self._nodes = self.skeleton.nodes @staticmethod - def _points_dict_to_array(points, parray, skeleton): + def _points_dict_to_array( + points: Dict[Union[str, Node], Point], parray: PointArray, skeleton: Skeleton + ): + """ + Sets values in given :class:`PointsArray` from dictionary. + + Args: + points: The dictionary of points. Keys can be either node + names or :class:`Node`s, values are :class:`Point`s. + parray: The :class:`PointsArray` which is being updated. + skeleton: The :class:`Skeleton` which contains the nodes + referenced in the dictionary of points. + + Raises: + ValueError: If dictionary keys are not either all strings + or all :class:`Node`s. + + Returns: + None + """ # Check if the dict contains all strings is_string_dict = set(map(type, points)) == {str} @@ -432,27 +518,35 @@ def _points_dict_to_array(points, parray, skeleton): except: pass - def _node_to_index(self, node_name): + def _node_to_index(self, node: Union[str, Node]) -> int: """ Helper method to get the index of a node from its name. Args: - node_name: The name of the node. + node: Node name or :class:`Node` object. Returns: The index of the node on skeleton graph. """ - return self.skeleton.node_to_index(node_name) + return self.skeleton.node_to_index(node) - def __getitem__(self, node): + def __getitem__( + self, node: Union[List[Union[str, Node]], Union[str, Node]] + ) -> Union[List[Point], Point]: """ - Get the Points associated with particular skeleton node or list of skeleton nodes + Get the Points associated with particular skeleton node(s). Args: - node: A single node or list of nodes within the skeleton associated with this instance. + node: A single node or list of nodes within the skeleton + associated with this instance. + + Raises: + KeyError: If node cannot be found in skeleton. Returns: - A single point of list of _points related to the nodes provided as argument. + Either a single point (if a single node given), or + a list of points (if a list of nodes given) corresponding + to each node. """ @@ -472,18 +566,19 @@ def __getitem__(self, node): f"The underlying skeleton ({self.skeleton}) has no node '{node}'" ) - def __contains__(self, node): + def __contains__(self, node: Union[str, Node]) -> bool: """ - Returns True if this instance has a point with the specified node. + Whether this instance has a point with the specified node. Args: - node: node name + node: Node name or :class:`Node` object. Returns: - bool: True if the point with the node name specified has a point in this instance. + bool: True if the point with the node name specified has a + point in this instance. """ - if type(node) is Node: + if isinstance(node, Node): node = node.name if node not in self.skeleton: @@ -494,7 +589,26 @@ def __contains__(self, node): # If the points are nan, then they haven't been allocated. return not self._points[node_idx].isnan() - def __setitem__(self, node, value): + def __setitem__( + self, + node: Union[List[Union[str, Node]], Union[str, Node]], + value: Union[List[Point], Point], + ): + """ + Set the point(s) for given node(s). + + Args: + node: Either node (by name or `Node`) or list of nodes. + value: Either `Point` or list of `Point`s. + + Raises: + IndexError: If lengths of lists don't match, or if exactly + one of the inputs is a list. + KeyError: If skeleton does not have (one of) the node(s). + + Returns: + None + """ # Make sure node and value, if either are lists, are of compatible size if type(node) is not list and type(value) is list and len(value) != 1: @@ -521,8 +635,19 @@ def __setitem__(self, node, value): f"The underlying skeleton ({self.skeleton}) has no node '{node}'" ) - def __delitem__(self, node): - """ Delete node key and points associated with that node. """ + def __delitem__(self, node: Union[str, Node]): + """ + Delete node key and points associated with that node. + + Args: + node: Node name or :class:`Node` object. + + Raises: + KeyError: If skeleton does not have the node. + + Returns: + None + """ try: node_idx = self._node_to_index(node) self._points[node_idx].x = math.nan @@ -532,12 +657,14 @@ def __delitem__(self, node): f"The underlying skeleton ({self.skeleton}) has no node '{node}'" ) - def matches(self, other): + def matches(self, other: "Instance") -> bool: """ - Compare this `Instance` to another, modulo the particular `Node` objects. + Whether two instances match by value. + + Checks the types, points, track, and frame index. Args: - other: The other instance. + other: The other :class:`Instance`. Returns: True if match, False otherwise. @@ -564,15 +691,11 @@ def matches(self, other): return True @property - def nodes(self): + def nodes(self) -> Tuple[Node, ...]: """ - Get the list of nodes that have been labelled for this instance. - - Returns: - A tuple of nodes that have been labelled for this instance. - + The tuple of nodes that have been labelled for this instance. """ - self.fix_array() + self._fix_array() return tuple( self._nodes[i] for i, point in enumerate(self._points) @@ -580,30 +703,32 @@ def nodes(self): ) @property - def nodes_points(self): + def nodes_points(self) -> List[Tuple[Node, Point]]: """ - Return view object that displays a list of the instance's (node, point) tuple pairs - for all labelled point. - - Returns: - The instance's (node, point) tuple pairs for all labelled point. + The list of (node, point) tuples for all labelled points. """ names_to_points = dict(zip(self.nodes, self.points)) return names_to_points.items() @property - def points(self) -> Tuple[Point]: + def points(self) -> Tuple[Point, ...]: """ - Return the list of labelled points, in order they were labelled. - - Returns: - The list of labelled points, in order they were labelled. + The tuple of labelled points, in order they were labelled. """ - self.fix_array() + self._fix_array() return tuple(point for point in self._points if not point.isnan()) - def fix_array(self): - """Fix points array after nodes have been added or removed.""" + def _fix_array(self): + """ + Fixes PointArray after nodes have been added or removed. + + This updates the PointArray as required by comparing the cached + list of nodes to the nodes in the `Skeleton` object (which may + have changed). + + Returns: + None + """ # Check if cached skeleton nodes are different than current nodes if self._nodes != self.skeleton.nodes: @@ -628,19 +753,21 @@ def get_points_array( Args: copy: If True, the return a copy of the points array as an - Nx2 ndarray where first column is x and second column is y. - If False, return a view of the underlying recarray. + Nx2 ndarray where first column is x and second is y. + If False, return a view of the underlying recarray. invisible_as_nan: Should invisible points be marked as NaN. - full: If True, return the raw underlying recarray with all attributes - of the point, if not, return just the x and y coordinate. Assumes - copy is False and invisible_as_nan is False. + full: If True, return the raw underlying recarray with all + attributes of the point. + Otherwise, return just the x and y coordinate. + Assumes copy is False and invisible_as_nan is False. Returns: A Nx2 array containing x and y coordinates of each point - as the rows of the array and N is the number of nodes in the skeleton. - The order of the rows corresponds to the ordering of the skeleton nodes. - Any skeleton node not defined will have NaNs present. + as the rows of the array and N is the number of nodes in the + skeleton. The order of the rows corresponds to the ordering + of the skeleton nodes. Any skeleton node not defined will + have NaNs present. """ - self.fix_array() + self._fix_array() if full: return self._points @@ -657,6 +784,15 @@ def get_points_array( @property def points_array(self) -> np.ndarray: + """ + Nx2 array of x and y for visible points. + + Row in arrow corresponds to order of points in skeleton. + Invisible points will have nans. + + Returns: + ndarray of visible point coordinates. + """ return self.get_points_array(invisible_as_nan=True) @property @@ -667,13 +803,15 @@ def centroid(self) -> np.ndarray: return centroid @property - def frame_idx(self) -> Union[None, int]: + def frame_idx(self) -> Optional[int]: """ - Get the index of the frame that this instance was found on. This is a convenience - method for Instance.frame.frame_idx. + Get the index of the frame that this instance was found on. + + This is a convenience method for Instance.frame.frame_idx. Returns: - The frame number this instance was found on. + The frame number this instance was found on, or None if the + instance is not associated with frame. """ if self.frame is None: return None @@ -684,11 +822,10 @@ def frame_idx(self) -> Union[None, int]: @attr.s(cmp=False, slots=True) class PredictedInstance(Instance): """ - A predicted instance is an output of the inference procedure. It is - the main output of the inference procedure. + A predicted instance is an output of the inference procedure. Args: - score: The instance level prediction score. + score: The instance-level prediction score. """ score: float = attr.ib(default=0.0, converter=float) @@ -703,12 +840,13 @@ def __attrs_post_init__(self): raise ValueError("PredictedInstance should not have from_predicted.") @classmethod - def from_instance(cls, instance: Instance, score): + def from_instance(cls, instance: Instance, score: float): """ - Create a PredictedInstance from and Instance object. The fields are - copied in a shallow manner with the exception of points. For each - point in the instance an PredictedPoint is created with score set - to default value. + Create a :class:`PredictedInstance` from an :class:`Instance`. + + The fields are copied in a shallow manner with the exception of + points. For each point in the instance a :class:`PredictedPoint` + is created with score set to default value. Args: instance: The Instance object to shallow copy data from. @@ -727,13 +865,17 @@ def from_instance(cls, instance: Instance, score): return cls(**kw_args) -def make_instance_cattr(): +def make_instance_cattr() -> cattr.Converter: """ - Create a cattr converter for handling Lists of Instances/PredictedInstances + Create a cattr converter for Lists of Instances/PredictedInstances. + + This is required because cattrs doesn't automatically detect the + class when the attributes of one class are a subset of another. Returns: - A cattr converter with hooks registered for structuring and unstructuring - Instances. + A cattr converter with hooks registered for structuring and + unstructuring :class:`Instance`s and + :class:`PredictedInstance`s. """ converter = cattr.Converter() @@ -774,6 +916,7 @@ def structure_points(x, type): converter.register_structure_hook(Union[Point, PredictedPoint], structure_points) + # Function to determine object type for objects being structured. def structure_instances_list(x, type): inst_list = [] for inst_data in x: @@ -815,6 +958,14 @@ def structure_point_array(x, t): @attr.s(auto_attribs=True) class LabeledFrame: + """ + Holds labeled data for a single frame of a video. + + Args: + video: The :class:`Video` associated with this frame. + frame_idx: The index of frame in video. + """ + video: Video = attr.ib() frame_idx: int = attr.ib(converter=int) _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib( @@ -822,21 +973,31 @@ class LabeledFrame: ) def __attrs_post_init__(self): + """ + Called by attrs. + + Updates :attribute:`Instance.frame` for each instance associated + with this :class:`LabeledFrame`. + """ # Make sure all instances have a reference to this frame for instance in self.instances: instance.frame = self - def __len__(self): + def __len__(self) -> int: + """Returns number of instances associated with frame.""" return len(self.instances) - def __getitem__(self, index): + def __getitem__(self, index) -> Instance: + """Returns instance (retrieved by index).""" return self.instances.__getitem__(index) - def index(self, value: Instance): + def index(self, value: Instance) -> int: + """Returns index of given :class:`Instance`.""" return self.instances.index(value) def __delitem__(self, index): + """Removes instance (by index) from frame.""" value = self.instances.__getitem__(index) self.instances.__delitem__(index) @@ -844,19 +1005,54 @@ def __delitem__(self, index): # Modify the instance to remove reference to this frame value.frame = None - def insert(self, index, value: Instance): + def insert(self, index: int, value: Instance): + """ + Adds instance to frame. + + Args: + index: The index in list of frame instances where we should + insert the new instance. + value: The instance to associate with frame. + + Returns: + None. + """ self.instances.insert(index, value) # Modify the instance to have a reference back to this frame value.frame = self def __setitem__(self, index, value: Instance): + """ + Sets nth instance in frame to the given instance. + + Args: + index: The index of instance to replace with new instance. + value: The new instance to associate with frame. + + Returns: + None. + """ self.instances.__setitem__(index, value) # Modify the instance to have a reference back to this frame value.frame = self - def find(self, track=-1, user=False): + def find( + self, track: Optional[Union[Track, int]] = -1, user: bool = False + ) -> List[Instance]: + """ + Retrieves instances (if any) matching specifications. + + Args: + track: The :class:`Track` to match. Note that None will only + match instances where :attribute:`Instance.track` is + None. If track is -1, then we'll match any track. + user: Whether to only match user (non-predicted) instances. + + Returns: + List of instances. + """ instances = self.instances if user: instances = list(filter(lambda inst: type(inst) == Instance, instances)) @@ -865,24 +1061,22 @@ def find(self, track=-1, user=False): return instances @property - def instances(self): - """ - A list of instances to associated with this frame. - - Returns: - A list of instances to associated with this frame. - """ + def instances(self) -> List[Instance]: + """Returns list of all instances associated with this frame.""" return self._instances @instances.setter def instances(self, instances: List[Instance]): """ - Set the list of instances assigned to this frame. Note: whenever an instance - is associated with a LabeledFrame that Instance objects frame property will - be overwritten to the LabeledFrame. + Sets the list of instances associated with this frame. + + Updates the `frame` attribute on each instance to the + :class:`LabeledFrame` which will contain the instance. + The list of instances replaces instances that were previously + associated with frame. Args: - instances: A list of instances to associated with this frame. + instances: A list of instances associated with this frame. Returns: None @@ -895,21 +1089,30 @@ def instances(self, instances: List[Instance]): self._instances = instances @property - def user_instances(self): + def user_instances(self) -> List[Instance]: + """Returns list of user instances associated with this frame.""" return [ inst for inst in self._instances if not isinstance(inst, PredictedInstance) ] @property - def predicted_instances(self): + def predicted_instances(self) -> List[PredictedInstance]: + """Returns list of predicted instances associated with frame.""" return [inst for inst in self._instances if isinstance(inst, PredictedInstance)] @property - def has_user_instances(self): + def has_user_instances(self) -> bool: + """Whether the frame contains any user instances.""" return len(self.user_instances) > 0 @property - def unused_predictions(self): + def unused_predictions(self) -> List[Instance]: + """ + Returns list of "unused" :class:`PredictedInstance`s in frame. + + This is all the :class:`PredictedInstance`s which do not have + a corresponding :class:`Instance` in the same track in frame. + """ unused_predictions = [] any_tracks = [inst.track for inst in self._instances if inst.track is not None] if len(any_tracks): @@ -942,10 +1145,15 @@ def unused_predictions(self): return unused_predictions @property - def instances_to_show(self): + def instances_to_show(self) -> List[Instance]: """ - Return a list of instances associated with this frame, but excluding any - predicted instances for which there's a corresponding regular instance. + Return a list of instances to show in GUI for this frame. + + This list will not include any predicted instances for which + there's a corresponding regular instance. + + Returns: + List of instances to show in GUI. """ unused_predictions = self.unused_predictions inst_to_show = [ @@ -996,18 +1204,24 @@ def merge_frames(labeled_frames, video, remove_redundant=True): @classmethod def complex_merge_between( cls, base_labels: "Labels", new_frames: List["LabeledFrame"] - ): - """Merge new_frames into base_labels cleanly when possible, - return conflicts if any. + ) -> Tuple[Dict[Video, Dict[int, List[Instance]]], List[Instance], List[Instance]]: + """ + Merge data from new frames into a :class:`Labels` object. + + Everything that can be merged cleanly is merged, any conflicts + are returned. Args: - base_labels - new_frames + base_labels: The :class:`Labels` into which we are merging. + new_frames: The list of :class:`LabeledFrame`s from + which we are merging. Returns: tuple of three items: - * dict with {video: list (per frame) of list of merged instances - * list of conflicting instances in base - * list of conflicting instances in new_frames + * Dictionary, keys are :class:`Video`, values are + dictionary in which keys are frame index (int) + and value is list of :class:`Instance`s + * list of conflicting :class:`Instance`s from base + * list of conflicting :class:`Instance`s from new frames """ merged = dict() extra_base = [] @@ -1017,10 +1231,14 @@ def complex_merge_between( base_lfs = base_labels.find(new_frame.video, new_frame.frame_idx) merged_instances = None + # If the base doesn't have a frame corresponding this new + # frame, then it can be merged cleanly. if not base_lfs: base_labels.labeled_frames.append(new_frame) merged_instances = new_frame.instances else: + # There's a corresponding frame in the base labels, + # so try merging the data. merged_instances, extra_base_frame, extra_new_frame = cls.complex_frame_merge( base_lfs[0], new_frame ) @@ -1036,8 +1254,28 @@ def complex_merge_between( return merged, extra_base, extra_new @classmethod - def complex_frame_merge(cls, base_frame, new_frame): - """Merge two frames, return conflicts if any.""" + def complex_frame_merge( + cls, base_frame: "LabeledFrame", new_frame: "LabeledFrame" + ) -> Tuple[List[Instance], List[Instance], List[Instance]]: + """ + Merge two frames, return conflicts if any. + + A conflict occurs when + * each frame has Instances which don't perfectly match those + in the other frame, or + * each frame has PredictedInstances which don't perfectly match + those in the other frame. + + Args: + base_frame: The `LabeledFrame` into which we want to merge. + new_frame: The `LabeledFrame` from which we want to merge. + + Returns: + tuple of three items: + * list of instances that were merged + * list of conflicting instances from base + * list of conflicting instances from new + """ merged_instances = [] redundant_instances = [] extra_base_instances = copy(base_frame.instances) diff --git a/tests/test_instance.py b/tests/test_instance.py index e61d07f11..d368ae477 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -8,6 +8,7 @@ from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame + def test_instance_node_get_set_item(skeleton): """ Test basic get item and set item functionality of instances. @@ -32,7 +33,7 @@ def test_instance_node_multi_get_set_item(skeleton): Test basic get item and set item functionality of instances. """ node_names = ["left-wing", "head", "right-wing"] - points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3,6)} + points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} instance1 = Instance(skeleton=skeleton, points=points) @@ -55,7 +56,7 @@ def test_non_exist_node(skeleton): instance["non-existent-node"].x = 1 with pytest.raises(KeyError): - instance = Instance(skeleton=skeleton, points = {"non-exist": Point()}) + instance = Instance(skeleton=skeleton, points={"non-exist": Point()}) def test_instance_point_iter(skeleton): @@ -67,7 +68,7 @@ def test_instance_point_iter(skeleton): instance = Instance(skeleton=skeleton, points=points) - assert [node.name for node in instance.nodes] == ['head', 'left-wing', 'right-wing'] + assert [node.name for node in instance.nodes] == ["head", "left-wing", "right-wing"] assert np.allclose([p.x for p in instance.points], [1, 2, 3]) assert np.allclose([p.y for p in instance.points], [4, 5, 6]) @@ -83,28 +84,29 @@ def test_skeleton_node_name_change(): """ s = Skeleton("Test") - s.add_nodes(['a', 'b', 'c', 'd', 'e']) - s.add_edge('a', 'b') + s.add_nodes(["a", "b", "c", "d", "e"]) + s.add_edge("a", "b") instance = Instance(s) - instance['a'] = Point(1,2) - instance['b'] = Point(3,4) + instance["a"] = Point(1, 2) + instance["b"] = Point(3, 4) # Rename the node - s.relabel_nodes({'a': 'A'}) + s.relabel_nodes({"a": "A"}) # Reference to the old node name should raise a KeyError with pytest.raises(KeyError): - instance['a'].x = 2 + instance["a"].x = 2 # Make sure the A now references the same point on the instance - assert instance['A'] == Point(1, 2) - assert instance['b'] == Point(3, 4) + assert instance["A"] == Point(1, 2) + assert instance["b"] == Point(3, 4) + def test_instance_comparison(skeleton): node_names = ["left-wing", "head", "right-wing"] - points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3,6)} + points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} instance1 = Instance(skeleton=skeleton, points=points) instance2 = copy.deepcopy(instance1) @@ -119,9 +121,10 @@ def test_instance_comparison(skeleton): assert not instance1.matches(instance2) instance2 = copy.deepcopy(instance1) - instance2.skeleton.add_node('extra_node') + instance2.skeleton.add_node("extra_node") assert not instance1.matches(instance2) + def test_points_array(skeleton): """ Test conversion of instances to points array""" @@ -133,26 +136,27 @@ def test_points_array(skeleton): pts = instance1.get_points_array() assert pts.shape == (len(skeleton.nodes), 2) - assert np.allclose(pts[skeleton.node_to_index('left-wing'), :], [2, 5]) - assert np.allclose(pts[skeleton.node_to_index('head'), :], [1, 4]) - assert np.allclose(pts[skeleton.node_to_index('right-wing'), :], [3, 6]) - assert np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() + assert np.allclose(pts[skeleton.node_to_index("left-wing"), :], [2, 5]) + assert np.allclose(pts[skeleton.node_to_index("head"), :], [1, 4]) + assert np.allclose(pts[skeleton.node_to_index("right-wing"), :], [3, 6]) + assert np.isnan(pts[skeleton.node_to_index("thorax"), :]).all() # Now change a point, make sure it is reflected - instance1['head'].x = 0 - instance1['thorax'] = Point(1, 2) + instance1["head"].x = 0 + instance1["thorax"] = Point(1, 2) pts = instance1.get_points_array() - assert np.allclose(pts[skeleton.node_to_index('head'), :], [0, 4]) - assert np.allclose(pts[skeleton.node_to_index('thorax'), :], [1, 2]) + assert np.allclose(pts[skeleton.node_to_index("head"), :], [0, 4]) + assert np.allclose(pts[skeleton.node_to_index("thorax"), :], [1, 2]) # Make sure that invisible points are nan iff invisible_as_nan=True - instance1['thorax'] = Point(1, 2, visible=False) + instance1["thorax"] = Point(1, 2, visible=False) pts = instance1.get_points_array() - assert not np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() + assert not np.isnan(pts[skeleton.node_to_index("thorax"), :]).all() pts = instance1.points_array - assert np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() + assert np.isnan(pts[skeleton.node_to_index("thorax"), :]).all() + def test_modifying_skeleton(skeleton): node_names = ["left-wing", "head", "right-wing"] @@ -162,16 +166,17 @@ def test_modifying_skeleton(skeleton): assert len(instance1.points) == 3 - skeleton.add_node('new test node') + skeleton.add_node("new test node") - instance1.fix_array() # update with changes from skeleton - instance1['new test node'] = Point(7,8) + instance1.points # this updates instance with changes from skeleton + instance1["new test node"] = Point(7, 8) assert len(instance1.points) == 4 - skeleton.delete_node('head') + skeleton.delete_node("head") assert len(instance1.points) == 3 + def test_instance_labeled_frame_ref(skeleton, centered_pair_vid): """ Test whether links between labeled frames and instances are kept @@ -183,4 +188,3 @@ def test_instance_labeled_frame_ref(skeleton, centered_pair_vid): assert frame.instances[0].frame == frame assert frame[0].frame == frame assert frame[0].frame_idx == 0 - From 264dd51c14ef60580e08c3549b1ccdd65cf35849 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 10:15:52 -0400 Subject: [PATCH 131/176] Better typing and docstrings. --- sleap/util.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/sleap/util.py b/sleap/util.py index 9f689a4f2..581ea0e09 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -10,14 +10,16 @@ import attr import psutil -from typing import Hashable, Iterable, List, Optional +from typing import Any, Hashable, Iterable, List, Optional -def attr_to_dtype(cls): - """Convert classes with basic types to numpy composite dtypes. +def attr_to_dtype(cls: Any): + """ + Converts classes with basic types to numpy composite dtypes. Arguments: cls: class to convert + Returns: numpy dtype. """ @@ -42,7 +44,8 @@ def attr_to_dtype(cls): def usable_cpu_count() -> int: - """Get number of CPUs usable by the current process. + """ + Gets number of CPUs usable by the current process. Takes into consideration cpusets restrictions. @@ -61,16 +64,22 @@ def usable_cpu_count() -> int: def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): """ - Saves dictionary to an HDF5 file, calls itself recursively if items in - dictionary are not np.ndarray, np.int64, np.float64, str, bytes. Objects - must be iterable. + Saves dictionary to an HDF5 file. + + Calls itself recursively if items in dictionary are not + `np.ndarray`, `np.int64`, `np.float64`, `str`, or bytes. + Objects must be iterable. Args: - h5file: The HDF5 filename object to save the data to. Assume it is open. + h5file: The HDF5 filename object to save the data to. + Assume it is open. path: The path to group save the dict under. dic: The dict to save. + Raises: ValueError: If type for item in dict cannot be saved. + + Returns: None """ @@ -102,10 +111,12 @@ def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): def frame_list(frame_str: str) -> Optional[List[int]]: - """Convert 'n-m' string to list of ints. + """ + Converts 'n-m' string to list of ints. Args: frame_str: string representing range + Returns: List of ints, or None if string does not represent valid range. """ @@ -122,7 +133,7 @@ def frame_list(frame_str: str) -> Optional[List[int]]: def uniquify(seq: Iterable[Hashable]) -> List: """ - Given a list, return unique elements but preserve order. + Returns unique elements from list, preserving order. Note: This will not work on Python 3.5 or lower since dicts don't preserve order. @@ -131,7 +142,8 @@ def uniquify(seq: Iterable[Hashable]) -> List: seq: The list to remove duplicates from. Returns: - The unique elements from the input list extracted in original order. + The unique elements from the input list extracted in original + order. """ # Raymond Hettinger @@ -148,6 +160,7 @@ def weak_filename_match(filename_a: str, filename_b: str) -> bool: Args: filename_a: first path to check filename_b: path to check against first path + Returns: True if the paths probably match. """ From e0c642de7de90cbf146fd5f13779beb24b5e8ffc Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 10:20:11 -0400 Subject: [PATCH 132/176] Reformat all tests with black. --- tests/conftest.py | 2 +- tests/fixtures/datasets.py | 9 +- tests/fixtures/instances.py | 24 ++-- tests/fixtures/skeletons.py | 19 ++-- tests/fixtures/videos.py | 26 ++++- tests/gui/test_active.py | 18 ++- tests/gui/test_conf_maps_view.py | 5 +- tests/gui/test_dataviews.py | 13 ++- tests/gui/test_import.py | 25 +++-- tests/gui/test_multicheck.py | 17 +-- tests/gui/test_quiver.py | 11 +- tests/gui/test_shortcuts.py | 3 +- tests/gui/test_slider.py | 15 +-- tests/gui/test_tracks.py | 24 ++-- tests/gui/test_video_player.py | 8 +- tests/io/test_dataset.py | 185 +++++++++++++++++++++++-------- tests/io/test_video.py | 77 ++++++++----- tests/io/test_visuals.py | 17 +-- tests/nn/test_datagen.py | 8 +- tests/nn/test_inference.py | 35 ++++-- tests/nn/test_tracking.py | 2 + tests/nn/test_training.py | 23 ++-- tests/test_point_array.py | 29 +++-- tests/test_rangelist.py | 21 ++-- tests/test_skeleton.py | 165 ++++++++++++++------------- tests/test_util.py | 25 +++-- 26 files changed, 513 insertions(+), 293 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 52b682e44..8c850b0ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ try: import pytestqt except: - logging.warning('Could not import PySide2 or pytestqt, skipping GUI tests.') + logging.warning("Could not import PySide2 or pytestqt, skipping GUI tests.") collect_ignore_glob = ["gui/*"] from tests.fixtures.skeletons import * diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index ec370a899..fa1a0150b 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -9,6 +9,7 @@ TEST_JSON_MIN_LABELS = "tests/data/json_format_v2/minimal_instance.json" TEST_MAT_LABELS = "tests/data/mat/labels.mat" + @pytest.fixture def centered_pair_labels(): return Labels.load_json(TEST_JSON_LABELS) @@ -18,14 +19,17 @@ def centered_pair_labels(): def centered_pair_predictions(): return Labels.load_json(TEST_JSON_PREDICTIONS) + @pytest.fixture def min_labels(): return Labels.load_json(TEST_JSON_MIN_LABELS) + @pytest.fixture def mat_labels(): return Labels.load_mat(TEST_MAT_LABELS) + @pytest.fixture def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman): """ @@ -60,7 +64,9 @@ def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman): stickman_instances = [] for i in range(6): - stickman_instances.append(Instance(skeleton=stickman, track=stick_tracks[i])) + stickman_instances.append( + Instance(skeleton=stickman, track=stick_tracks[i]) + ) for node in stickman.nodes: stickman_instances[i][node] = Point(x=i % vid.width, y=i % vid.height) @@ -70,4 +76,3 @@ def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman): labels = Labels(labels) return labels - diff --git a/tests/fixtures/instances.py b/tests/fixtures/instances.py index dcac5ce29..862577457 100644 --- a/tests/fixtures/instances.py +++ b/tests/fixtures/instances.py @@ -12,21 +12,23 @@ def instances(skeleton): instances = [] for i in range(NUM_INSTANCES): instance = Instance(skeleton=skeleton) - instance['head'] = Point(i*1, i*2) - instance['left-wing'] = Point(10 + i * 1, 10 + i * 2) - instance['right-wing'] = Point(20 + i * 1, 20 + i * 2) + instance["head"] = Point(i * 1, i * 2) + instance["left-wing"] = Point(10 + i * 1, 10 + i * 2) + instance["right-wing"] = Point(20 + i * 1, 20 + i * 2) # Lets make an NaN entry to test skip_nan as well - instance['thorax'] + instance["thorax"] instances.append(instance) return instances + @pytest.fixture def predicted_instances(instances): return [PredictedInstance.from_instance(i, 1.0) for i in instances] + @pytest.fixture def multi_skel_instances(skeleton, stickman): """ @@ -39,21 +41,21 @@ def multi_skel_instances(skeleton, stickman): instances = [] for i in range(NUM_INSTANCES): instance = Instance(skeleton=skeleton, video=None, frame_idx=i) - instance['head'] = Point(i*1, i*2) - instance['left-wing'] = Point(10 + i * 1, 10 + i * 2) - instance['right-wing'] = Point(20 + i * 1, 20 + i * 2) + instance["head"] = Point(i * 1, i * 2) + instance["left-wing"] = Point(10 + i * 1, 10 + i * 2) + instance["right-wing"] = Point(20 + i * 1, 20 + i * 2) # Lets make an NaN entry to test skip_nan as well - instance['thorax'] + instance["thorax"] instances.append(instance) # Setup some instances of the stick man on the same frames for i in range(NUM_INSTANCES): instance = Instance(skeleton=stickman, video=None, frame_idx=i) - instance['head'] = Point(i * 10, i * 20) - instance['body'] = Point(100 + i * 1, 100 + i * 2) - instance['left-arm'] = Point(200 + i * 1, 200 + i * 2) + instance["head"] = Point(i * 10, i * 20) + instance["body"] = Point(100 + i * 1, 100 + i * 2) + instance["left-arm"] = Point(200 + i * 1, 200 + i * 2) instances.append(instance) diff --git a/tests/fixtures/skeletons.py b/tests/fixtures/skeletons.py index 13a2b741f..c340270bb 100644 --- a/tests/fixtures/skeletons.py +++ b/tests/fixtures/skeletons.py @@ -2,23 +2,27 @@ from sleap.skeleton import Skeleton + @pytest.fixture def stickman(): # Make a skeleton with a space in its name to test things. stickman = Skeleton("Stick man") - stickman.add_nodes(['head', 'neck', 'body', 'right-arm', 'left-arm', 'right-leg', 'left-leg']) - stickman.add_edge('neck', 'head') - stickman.add_edge('body', 'neck') - stickman.add_edge('body', 'right-arm') - stickman.add_edge('body', 'left-arm') - stickman.add_edge('body', 'right-leg') - stickman.add_edge('body', 'left-leg') + stickman.add_nodes( + ["head", "neck", "body", "right-arm", "left-arm", "right-leg", "left-leg"] + ) + stickman.add_edge("neck", "head") + stickman.add_edge("body", "neck") + stickman.add_edge("body", "right-arm") + stickman.add_edge("body", "left-arm") + stickman.add_edge("body", "right-leg") + stickman.add_edge("body", "left-leg") stickman.add_symmetry(node1="left-arm", node2="right-arm") stickman.add_symmetry(node1="left-leg", node2="right-leg") return stickman + @pytest.fixture def skeleton(): @@ -36,4 +40,3 @@ def skeleton(): skeleton.add_symmetry(node1="left-wing", node2="right-wing") return skeleton - diff --git a/tests/fixtures/videos.py b/tests/fixtures/videos.py index ea4369790..fc55e6019 100644 --- a/tests/fixtures/videos.py +++ b/tests/fixtures/videos.py @@ -8,26 +8,42 @@ TEST_H5_AFFINITY = "/pafs" TEST_H5_INPUT_FORMAT = "channels_first" + @pytest.fixture def hdf5_vid(): - return Video.from_hdf5(filename=TEST_H5_FILE, dataset=TEST_H5_DSET, input_format=TEST_H5_INPUT_FORMAT) + return Video.from_hdf5( + filename=TEST_H5_FILE, dataset=TEST_H5_DSET, input_format=TEST_H5_INPUT_FORMAT + ) + @pytest.fixture def hdf5_confmaps(): - return Video.from_hdf5(filename=TEST_H5_FILE, dataset=TEST_H5_CONFMAPS, input_format=TEST_H5_INPUT_FORMAT) - + return Video.from_hdf5( + filename=TEST_H5_FILE, + dataset=TEST_H5_CONFMAPS, + input_format=TEST_H5_INPUT_FORMAT, + ) + + @pytest.fixture def hdf5_affinity(): - return Video.from_hdf5(filename=TEST_H5_FILE, dataset=TEST_H5_AFFINITY, input_format=TEST_H5_INPUT_FORMAT, convert_range=False) + return Video.from_hdf5( + filename=TEST_H5_FILE, + dataset=TEST_H5_AFFINITY, + input_format=TEST_H5_INPUT_FORMAT, + convert_range=False, + ) TEST_SMALL_ROBOT_MP4_FILE = "tests/data/videos/small_robot.mp4" TEST_SMALL_CENTERED_PAIR_VID = "tests/data/videos/centered_pair_small.mp4" + @pytest.fixture def small_robot_mp4_vid(): return Video.from_media(TEST_SMALL_ROBOT_MP4_FILE) + @pytest.fixture def centered_pair_vid(): - return Video.from_media(TEST_SMALL_CENTERED_PAIR_VID) \ No newline at end of file + return Video.from_media(TEST_SMALL_CENTERED_PAIR_VID) diff --git a/tests/gui/test_active.py b/tests/gui/test_active.py index 04e1e80ae..b3c565e4f 100644 --- a/tests/gui/test_active.py +++ b/tests/gui/test_active.py @@ -5,13 +5,18 @@ from sleap.io.video import Video from sleap.io.dataset import Labels from sleap.nn.model import ModelOutputType -from sleap.gui.active import ActiveLearningDialog, make_default_training_jobs, find_saved_jobs, add_frames_from_json +from sleap.gui.active import ( + ActiveLearningDialog, + make_default_training_jobs, + find_saved_jobs, + add_frames_from_json, +) + def test_active_gui(qtbot, centered_pair_labels): win = ActiveLearningDialog( - labels_filename="foo.json", - labels=centered_pair_labels, - mode="expert") + labels_filename="foo.json", labels=centered_pair_labels, mode="expert" + ) win.show() qtbot.addWidget(win) @@ -25,6 +30,7 @@ def test_active_gui(qtbot, centered_pair_labels): jobs = win._get_current_training_jobs() assert ModelOutputType.PART_AFFINITY_FIELD not in jobs + def test_make_default_training_jobs(): jobs = make_default_training_jobs() @@ -35,6 +41,7 @@ def test_make_default_training_jobs(): assert jobs[output_type].model.output_type == output_type assert jobs[output_type].best_model_filename is None + def test_find_saved_jobs(): jobs_a = find_saved_jobs("tests/data/training_profiles/set_a") assert len(jobs_a) == 3 @@ -59,6 +66,7 @@ def test_find_saved_jobs(): assert os.path.basename(paths[0]) == "test_confmaps.json" assert os.path.basename(paths[1]) == "default_confmaps.json" + def test_add_frames_from_json(): vid_a = Video.from_filename("foo.mp4") vid_b = Video.from_filename("bar.mp4") @@ -131,4 +139,4 @@ def test_add_frames_from_json(): assert len(labels_with_skeleton.videos) == 2 assert len(labels_with_skeleton.skeletons) == 1 - labels_with_skeleton.to_dict() \ No newline at end of file + labels_with_skeleton.to_dict() diff --git a/tests/gui/test_conf_maps_view.py b/tests/gui/test_conf_maps_view.py index eafb56497..5a97276a8 100644 --- a/tests/gui/test_conf_maps_view.py +++ b/tests/gui/test_conf_maps_view.py @@ -5,13 +5,14 @@ import PySide2.QtCore as QtCore + def test_gui_conf_maps(qtbot, hdf5_confmaps): - + vp = QtVideoPlayer() vp.show() conf_maps = ConfMapsPlot(hdf5_confmaps.get_frame(1), show_box=False) vp.view.scene.addItem(conf_maps) - + # make sure we're showing all the channels assert len(conf_maps.childItems()) == 6 diff --git a/tests/gui/test_dataviews.py b/tests/gui/test_dataviews.py index 3a7541681..9af6dedb8 100644 --- a/tests/gui/test_dataviews.py +++ b/tests/gui/test_dataviews.py @@ -8,8 +8,8 @@ SkeletonNodesTable, SkeletonEdgesTable, LabeledFrameTable, - SkeletonNodeModel - ) + SkeletonNodeModel, +) def test_skeleton_nodes(qtbot, centered_pair_predictions): @@ -24,8 +24,13 @@ def test_skeleton_nodes(qtbot, centered_pair_predictions): table = VideosTable(centered_pair_predictions.videos) table.selectRow(0) - assert table.model().data(table.currentIndex()).find("centered_pair_low_quality.mp4") > -1 + assert ( + table.model().data(table.currentIndex()).find("centered_pair_low_quality.mp4") + > -1 + ) - table = LabeledFrameTable(centered_pair_predictions.labels[13], centered_pair_predictions) + table = LabeledFrameTable( + centered_pair_predictions.labels[13], centered_pair_predictions + ) table.selectRow(1) assert table.model().data(table.currentIndex()) == "21/24" diff --git a/tests/gui/test_import.py b/tests/gui/test_import.py index 8fd84f1e8..760d1839e 100644 --- a/tests/gui/test_import.py +++ b/tests/gui/test_import.py @@ -2,39 +2,46 @@ import PySide2.QtCore as QtCore + def test_gui_import(qtbot): file_names = [ - "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5", - "tests/data/videos/small_robot.mp4", - ] + "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5", + "tests/data/videos/small_robot.mp4", + ] importer = ImportParamDialog(file_names) importer.show() qtbot.addWidget(importer) - + data = importer.get_data() assert len(data) == 2 assert len(data[0]["params"]) > 1 - + for import_item in importer.import_widgets: btn = import_item.enabled_checkbox_widget with qtbot.waitSignal(btn.stateChanged, timeout=10): qtbot.mouseClick(btn, QtCore.Qt.LeftButton) assert not import_item.is_enabled() - + assert len(importer.get_data()) == 0 - + for import_item in importer.import_widgets: btn = import_item.enabled_checkbox_widget with qtbot.waitSignal(btn.stateChanged, timeout=10): qtbot.mouseClick(btn, QtCore.Qt.LeftButton) assert import_item.is_enabled() - + assert len(importer.get_data()) == 2 + def test_video_import_detect_params(): - importer = ImportParamDialog(["tests/data/videos/centered_pair_small.mp4", "tests/data/videos/small_robot.mp4"]) + importer = ImportParamDialog( + [ + "tests/data/videos/centered_pair_small.mp4", + "tests/data/videos/small_robot.mp4", + ] + ) data = importer.get_data() assert data[0]["params"]["grayscale"] == True diff --git a/tests/gui/test_multicheck.py b/tests/gui/test_multicheck.py index cc936d2de..2a2ee0bbe 100644 --- a/tests/gui/test_multicheck.py +++ b/tests/gui/test_multicheck.py @@ -2,23 +2,24 @@ import PySide2.QtCore as QtCore + def test_gui_video(qtbot): cs = MultiCheckWidget(count=10, title="Test", default=True) cs.show() qtbot.addWidget(cs) - + assert cs.getSelected() == list(range(10)) - + for btn in cs.check_group.buttons(): # click all the odd buttons to uncheck them if cs.check_group.id(btn) % 2 == 1: - qtbot.mouseClick(btn, QtCore.Qt.LeftButton) - assert cs.getSelected() == list(range(0,10,2)) - - cs.setSelected([1,2,3]) - assert cs.getSelected() == [1,2,3] - + qtbot.mouseClick(btn, QtCore.Qt.LeftButton) + assert cs.getSelected() == list(range(0, 10, 2)) + + cs.setSelected([1, 2, 3]) + assert cs.getSelected() == [1, 2, 3] + # Watch for the app.worker.finished signal, then start the worker. with qtbot.waitSignal(cs.selectionChanged, timeout=10): qtbot.mouseClick(cs.check_group.buttons()[0], QtCore.Qt.LeftButton) diff --git a/tests/gui/test_quiver.py b/tests/gui/test_quiver.py index a1b877c14..6875cbd2d 100644 --- a/tests/gui/test_quiver.py +++ b/tests/gui/test_quiver.py @@ -5,17 +5,16 @@ import PySide2.QtCore as QtCore + def test_gui_quiver(qtbot, hdf5_affinity): - + vp = QtVideoPlayer() vp.show() affinity_fields = MultiQuiverPlot( - frame=hdf5_affinity.get_frame(0)[265:275,238:248], - show=[0,1], - decimation=1 - ) + frame=hdf5_affinity.get_frame(0)[265:275, 238:248], show=[0, 1], decimation=1 + ) vp.view.scene.addItem(affinity_fields) - + # make sure we're showing all the channels we selected assert len(affinity_fields.childItems()) == 2 # make sure we're showing all arrows in first channel diff --git a/tests/gui/test_shortcuts.py b/tests/gui/test_shortcuts.py index d6524dcb9..67c900bca 100644 --- a/tests/gui/test_shortcuts.py +++ b/tests/gui/test_shortcuts.py @@ -2,6 +2,7 @@ from sleap.gui.shortcuts import Shortcuts + def test_shortcuts(): shortcuts = Shortcuts() @@ -9,4 +10,4 @@ def test_shortcuts(): assert shortcuts["new"] == QKeySequence.fromString("Ctrl+N") shortcuts["new"] = QKeySequence.fromString("Ctrl+Shift+N") assert shortcuts["new"] == QKeySequence.fromString("Ctrl+Shift+N") - assert list(shortcuts[0:2].keys()) == ["new", "open"] \ No newline at end of file + assert list(shortcuts[0:2].keys()) == ["new", "open"] diff --git a/tests/gui/test_slider.py b/tests/gui/test_slider.py index f69164f3b..0d05b057b 100644 --- a/tests/gui/test_slider.py +++ b/tests/gui/test_slider.py @@ -1,26 +1,27 @@ from sleap.gui.slider import VideoSlider + def test_slider(qtbot, centered_pair_predictions): - + labels = centered_pair_predictions - - slider = VideoSlider(min=0, max=1200, val=15, marks=(10,15)) - + + slider = VideoSlider(min=0, max=1200, val=15, marks=(10, 15)) + assert slider.value() == 15 slider.setValue(20) assert slider.value() == 20 - + assert slider.getSelection() == (0, 0) slider.startSelection(3) slider.endSelection(5) assert slider.getSelection() == (3, 5) slider.clearSelection() assert slider.getSelection() == (0, 0) - + initial_height = slider.maximumHeight() slider.setTracks(20) assert slider.maximumHeight() != initial_height - + slider.setTracksFromLabels(labels, labels.videos[0]) assert len(slider.getMarks()) == 40 diff --git a/tests/gui/test_tracks.py b/tests/gui/test_tracks.py index 4c4481931..b92f773e3 100644 --- a/tests/gui/test_tracks.py +++ b/tests/gui/test_tracks.py @@ -1,37 +1,39 @@ from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay from sleap.io.video import Video + def test_track_trails(centered_pair_predictions): - + labels = centered_pair_predictions - trail_manager = TrackTrailOverlay(labels, scene=None, trail_length = 6) - + trail_manager = TrackTrailOverlay(labels, scene=None, trail_length=6) + frames = trail_manager.get_frame_selection(labels.videos[0], 27) assert len(frames) == 6 assert frames[0].frame_idx == 22 - + tracks = trail_manager.get_tracks_in_frame(labels.videos[0], 27) assert len(tracks) == 2 assert tracks[0].name == "1" assert tracks[1].name == "2" trails = trail_manager.get_track_trails(frames, tracks[0]) - + assert len(trails) == 24 - - test_trail = [(245.0, 208.0), + + test_trail = [ + (245.0, 208.0), (245.0, 207.0), (245.0, 206.0), (246.0, 205.0), (247.0, 203.0), - (248.0, 202.0) - ] + (248.0, 202.0), + ] assert test_trail in trails - + # Test track colors color_manager = TrackColorManager(labels=labels) tracks = trail_manager.get_tracks_in_frame(labels.videos[0], 1099) assert len(tracks) == 5 - assert color_manager.get_color(tracks[3]) == [119, 172, 48] \ No newline at end of file + assert color_manager.get_color(tracks[3]) == [119, 172, 48] diff --git a/tests/gui/test_video_player.py b/tests/gui/test_video_player.py index e83bae07f..ff981a5dd 100644 --- a/tests/gui/test_video_player.py +++ b/tests/gui/test_video_player.py @@ -2,6 +2,7 @@ import PySide2.QtCore as QtCore + def test_gui_video(qtbot): vp = QtVideoPlayer() vp.show() @@ -13,6 +14,7 @@ def test_gui_video(qtbot): # for i in range(20): # qtbot.mouseClick(vp.btn, QtCore.Qt.LeftButton) + def test_gui_video_instances(qtbot, small_robot_mp4_vid, centered_pair_labels): vp = QtVideoPlayer(small_robot_mp4_vid) qtbot.addWidget(vp) @@ -22,7 +24,7 @@ def test_gui_video_instances(qtbot, small_robot_mp4_vid, centered_pair_labels): def plot_instances(vp, idx): for instance in labeled_frames[test_frame_idx].instances: - vp.addInstance(instance=instance, color=(0,0,128)) + vp.addInstance(instance=instance, color=(0, 0, 128)) vp.changedPlot.connect(plot_instances) vp.view.updatedViewer.emit() @@ -36,7 +38,7 @@ def plot_instances(vp, idx): vp.zoomToFit() # Check that we zoomed correctly - assert(vp.view.zoomFactor > 1) + assert vp.view.zoomFactor > 1 vp.instances[0].updatePoints(complete=True) @@ -63,4 +65,4 @@ def plot_instances(vp, idx): qtbot.keyClick(vp, QtCore.Qt.Key_1) assert cb.args[0] == [1, 0] - assert vp.close() \ No newline at end of file + assert vp.close() diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 599d0f476..54fd4781a 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -8,9 +8,10 @@ from sleap.io.dataset import Labels, load_labels_json_old from sleap.gui.suggestions import VideoFrameSuggestions -TEST_H5_DATASET = 'tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5' +TEST_H5_DATASET = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" -def _check_labels_match(expected_labels, other_labels, format = 'png'): + +def _check_labels_match(expected_labels, other_labels, format="png"): """ A utitlity function to check whether to sets of labels match. This doesn't directly compares some things (like video objects). @@ -43,7 +44,10 @@ def dict_match(dict1, dict2): # Check if the graphs are iso-morphic import networkx as nx - is_isomorphic = nx.is_isomorphic(self._graph, other._graph, node_match=dict_match) + + is_isomorphic = nx.is_isomorphic( + self._graph, other._graph, node_match=dict_match + ) if not is_isomorphic: assert False @@ -69,11 +73,14 @@ def dict_match(dict1, dict2): # Compare the first frames of the videos, do it on a small sub-region to # make the test reasonable in time. - if format is 'png': + if format is "png": assert np.allclose(frame_data, expected_frame_data) # Compare the instances - assert all(i1.matches(i2) for (i1, i2) in zip(expected_label.instances, label.instances)) + assert all( + i1.matches(i2) + for (i1, i2) in zip(expected_label.instances, label.instances) + ) # This test takes to long, break after 20 or so. if frame_idx > 20: @@ -81,7 +88,7 @@ def dict_match(dict1, dict2): def test_labels_json(tmpdir, multi_skel_vid_labels): - json_file_path = os.path.join(tmpdir, 'dataset.json') + json_file_path = os.path.join(tmpdir, "dataset.json") if os.path.isfile(json_file_path): os.remove(json_file_path) @@ -112,17 +119,38 @@ def test_labels_json(tmpdir, multi_skel_vid_labels): assert multi_skel_vid_labels.nodes[3] in loaded_labels.nodes assert multi_skel_vid_labels.videos[0] in loaded_labels.videos + def test_load_labels_json_old(tmpdir): - new_file_path = os.path.join(tmpdir, 'centered_pair_v2.json') + new_file_path = os.path.join(tmpdir, "centered_pair_v2.json") # Function to run some checks on loaded labels def check_labels(labels): - skel_node_names = ['head', 'neck', 'thorax', 'abdomen', 'wingL', - 'wingR', 'forelegL1', 'forelegL2', 'forelegL3', - 'forelegR1', 'forelegR2', 'forelegR3', 'midlegL1', - 'midlegL2', 'midlegL3', 'midlegR1', 'midlegR2', - 'midlegR3', 'hindlegL1', 'hindlegL2', 'hindlegL3', - 'hindlegR1', 'hindlegR2', 'hindlegR3'] + skel_node_names = [ + "head", + "neck", + "thorax", + "abdomen", + "wingL", + "wingR", + "forelegL1", + "forelegL2", + "forelegL3", + "forelegR1", + "forelegR2", + "forelegR3", + "midlegL1", + "midlegL2", + "midlegL3", + "midlegR1", + "midlegR2", + "midlegR3", + "hindlegL1", + "hindlegL2", + "hindlegL3", + "hindlegR1", + "hindlegR2", + "hindlegR3", + ] # Do some basic checks assert len(labels) == 70 @@ -168,7 +196,7 @@ def test_label_accessors(centered_pair_labels): next(f) next(f) # test that iterator now has fewer items left - assert len(list(f)) == 70-3 + assert len(list(f)) == 70 - 3 assert labels.instance_count(video, 15) == 2 assert labels.instance_count(video, 7) == 0 @@ -204,7 +232,7 @@ def test_label_mutability(): dummy_video = Video(backend=MediaVideo) dummy_skeleton = Skeleton() dummy_instance = Instance(dummy_skeleton) - dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance,]) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance]) labels = Labels() labels.append(dummy_frame) @@ -221,7 +249,7 @@ def test_label_mutability(): dummy_video2 = Video(backend=MediaVideo) dummy_skeleton2 = Skeleton(name="dummy2") dummy_instance2 = Instance(dummy_skeleton2) - dummy_frame2 = LabeledFrame(dummy_video2, frame_idx=0, instances=[dummy_instance2,]) + dummy_frame2 = LabeledFrame(dummy_video2, frame_idx=0, instances=[dummy_instance2]) assert dummy_video2 not in labels assert dummy_skeleton2 not in labels assert dummy_frame2 not in labels @@ -245,9 +273,9 @@ def test_label_mutability(): for f in dummy_frames + dummy_frames2: labels.append(f) - assert(len(labels) == 20) + assert len(labels) == 20 labels.remove_video(dummy_video2) - assert(len(labels) == 10) + assert len(labels) == 10 assert len(labels.find(dummy_video)) == 10 assert dummy_frame in labels @@ -260,6 +288,7 @@ def test_label_mutability(): labels.remove_video(dummy_video) assert len(labels.find(dummy_video)) == 0 + def test_labels_merge(): dummy_video = Video(backend=MediaVideo) dummy_skeleton = Skeleton() @@ -270,8 +299,8 @@ def test_labels_merge(): # Add 10 instances with different points (so they aren't "redundant") for i in range(10): - instance = Instance(skeleton=dummy_skeleton, points=dict(node=Point(i,i))) - dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[instance,]) + instance = Instance(skeleton=dummy_skeleton, points=dict(node=Point(i, i))) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[instance]) dummy_frames.append(dummy_frame) labels.labeled_frames.extend(dummy_frames) @@ -282,6 +311,7 @@ def test_labels_merge(): assert len(labels) == 1 assert len(labels.labeled_frames[0].instances) == 10 + def test_complex_merge(): dummy_video_a = Video.from_filename("foo.mp4") dummy_video_b = Video.from_filename("foo.mp4") @@ -293,26 +323,40 @@ def test_complex_merge(): dummy_skeleton_b.add_node("node") dummy_instances_a = [] - dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1,1)))) - dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2,2)))) + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1, 1))) + ) + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2, 2))) + ) labels_a = Labels() - labels_a.append(LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a)) + labels_a.append( + LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a) + ) dummy_instances_b = [] - dummy_instances_b.append(Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1,1)))) - dummy_instances_b.append(Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(3,3)))) + dummy_instances_b.append( + Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1, 1))) + ) + dummy_instances_b.append( + Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(3, 3))) + ) labels_b = Labels() - labels_b.append(LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b)) # conflict - labels_b.append(LabeledFrame(dummy_video_b, frame_idx=1, instances=dummy_instances_b)) # clean + labels_b.append( + LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b) + ) # conflict + labels_b.append( + LabeledFrame(dummy_video_b, frame_idx=1, instances=dummy_instances_b) + ) # clean merged, extra_a, extra_b = Labels.complex_merge_between(labels_a, labels_b) # Check that we have the cleanly merged frame assert dummy_video_a in merged - assert len(merged[dummy_video_a]) == 1 # one merged frame - assert len(merged[dummy_video_a][1]) == 2 # with two instances + assert len(merged[dummy_video_a]) == 1 # one merged frame + assert len(merged[dummy_video_a][1]) == 2 # with two instances # Check that labels_a includes redundant and clean assert len(labels_a.labeled_frames) == 2 @@ -339,6 +383,7 @@ def test_complex_merge(): assert len(labels_a.labeled_frames[0].instances) == 2 assert labels_a.labeled_frames[0].instances[1].points[0].x == 3 + def test_merge_predictions(): dummy_video_a = Video.from_filename("foo.mp4") dummy_video_b = Video.from_filename("foo.mp4") @@ -350,30 +395,46 @@ def test_merge_predictions(): dummy_skeleton_b.add_node("node") dummy_instances_a = [] - dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1,1)))) - dummy_instances_a.append(Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2,2)))) + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1, 1))) + ) + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2, 2))) + ) labels_a = Labels() - labels_a.append(LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a)) + labels_a.append( + LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a) + ) dummy_instances_b = [] - dummy_instances_b.append(Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1,1)))) - dummy_instances_b.append(PredictedInstance(skeleton=dummy_skeleton_b, points=dict(node=Point(3,3)), score=1)) + dummy_instances_b.append( + Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1, 1))) + ) + dummy_instances_b.append( + PredictedInstance( + skeleton=dummy_skeleton_b, points=dict(node=Point(3, 3)), score=1 + ) + ) labels_b = Labels() - labels_b.append(LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b)) + labels_b.append( + LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b) + ) # Frames have one redundant instance (perfect match) and all the # non-matching instances are different types (one predicted, one not). merged, extra_a, extra_b = Labels.complex_merge_between(labels_a, labels_b) assert len(merged[dummy_video_a]) == 1 - assert len(merged[dummy_video_a][0]) == 1 # the predicted instance was merged + assert len(merged[dummy_video_a][0]) == 1 # the predicted instance was merged assert not extra_a assert not extra_b + def skeleton_ids_from_label_instances(labels): return list(map(id, (lf.instances[0].skeleton for lf in labels.labeled_frames))) + def test_duplicate_skeletons_serializing(): vid = Video.from_filename("foo.mp4") @@ -386,6 +447,7 @@ def test_duplicate_skeletons_serializing(): new_labels = Labels(labeled_frames=[lf_a, lf_b]) new_labels_json = new_labels.to_dict() + def test_distinct_skeletons_serializing(): vid = Video.from_filename("foo.mp4") @@ -401,6 +463,7 @@ def test_distinct_skeletons_serializing(): # Make sure we can serialize this new_labels_json = new_labels.to_dict() + def test_unify_skeletons(): vid = Video.from_filename("foo.mp4") @@ -422,6 +485,7 @@ def test_unify_skeletons(): # Make sure we can serialize this labels.to_dict() + def test_dont_unify_skeletons(): vid = Video.from_filename("foo.mp4") @@ -441,6 +505,7 @@ def test_dont_unify_skeletons(): # Make sure we can serialize this labels.to_dict() + def test_instance_access(): labels = Labels() @@ -449,31 +514,49 @@ def test_instance_access(): dummy_video2 = Video(backend=MediaVideo) for i in range(10): - labels.append(LabeledFrame(dummy_video, frame_idx=i, instances=[Instance(dummy_skeleton), Instance(dummy_skeleton)])) + labels.append( + LabeledFrame( + dummy_video, + frame_idx=i, + instances=[Instance(dummy_skeleton), Instance(dummy_skeleton)], + ) + ) for i in range(10): - labels.append(LabeledFrame(dummy_video2, frame_idx=i, instances=[Instance(dummy_skeleton), Instance(dummy_skeleton), Instance(dummy_skeleton)])) + labels.append( + LabeledFrame( + dummy_video2, + frame_idx=i, + instances=[ + Instance(dummy_skeleton), + Instance(dummy_skeleton), + Instance(dummy_skeleton), + ], + ) + ) assert len(labels.all_instances) == 50 assert len(list(labels.instances(video=dummy_video))) == 20 assert len(list(labels.instances(video=dummy_video2))) == 30 + def test_suggestions(small_robot_mp4_vid): dummy_video = small_robot_mp4_vid dummy_skeleton = Skeleton() dummy_instance = Instance(dummy_skeleton) - dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance,]) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance]) labels = Labels() labels.append(dummy_frame) suggestions = dict() suggestions[dummy_video] = VideoFrameSuggestions.suggest( - dummy_video, - params=dict(method="random", per_video=13)) + dummy_video, params=dict(method="random", per_video=13) + ) labels.set_suggestions(suggestions) assert len(labels.get_video_suggestions(dummy_video)) == 13 + def test_negative_anchors(): video = Video.from_filename("foo.mp4") labels = Labels() @@ -487,12 +570,13 @@ def test_negative_anchors(): labels.remove_negative_anchors(video, 1) assert len(labels.negative_anchors[video]) == 1 + def test_load_labels_mat(mat_labels): assert len(mat_labels.nodes) == 6 assert len(mat_labels) == 43 -@pytest.mark.parametrize("format", ['png', 'mjpeg/avi']) +@pytest.mark.parametrize("format", ["png", "mjpeg/avi"]) def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): """ Test saving and loading a labels dataset with frame data included @@ -502,8 +586,13 @@ def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): # Lets take a subset of the labels so this doesn't take too long multi_skel_vid_labels.labeled_frames = multi_skel_vid_labels.labeled_frames[5:30] - filename = os.path.join(tmpdir, 'test.json') - Labels.save_json(multi_skel_vid_labels, filename=filename, save_frame_data=True, frame_data_format=format) + filename = os.path.join(tmpdir, "test.json") + Labels.save_json( + multi_skel_vid_labels, + filename=filename, + save_frame_data=True, + frame_data_format=format, + ) # Load the data back in loaded_labels = Labels.load_json(f"{filename}.zip") @@ -517,7 +606,7 @@ def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): def test_labels_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels - filename = os.path.join(tmpdir, 'test.h5') + filename = os.path.join(tmpdir, "test.h5") Labels.save_hdf5(filename=filename, labels=labels) @@ -528,7 +617,7 @@ def test_labels_hdf5(multi_skel_vid_labels, tmpdir): def test_labels_predicted_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels - filename = os.path.join(tmpdir, 'test.h5') + filename = os.path.join(tmpdir, "test.h5") # Lets promote some of these Instances to predicted instances for label in labels: @@ -559,9 +648,10 @@ def test_labels_predicted_hdf5(multi_skel_vid_labels, tmpdir): loaded_labels = Labels.load_hdf5(filename=filename) _check_labels_match(labels, loaded_labels) + def test_labels_append_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels - filename = os.path.join(tmpdir, 'test.h5') + filename = os.path.join(tmpdir, "test.h5") # Save each frame of the Labels dataset one by one in append # mode @@ -578,4 +668,3 @@ def test_labels_append_hdf5(multi_skel_vid_labels, tmpdir): loaded_labels = Labels.load_hdf5(filename=filename) _check_labels_match(labels, loaded_labels) - diff --git a/tests/io/test_video.py b/tests/io/test_video.py index e7d18bc1d..3e17991fb 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -11,83 +11,92 @@ # of redundant test code here. # See: https://github.com/pytest-dev/pytest/issues/349 + def test_from_filename(): assert type(Video.from_filename(TEST_H5_FILE).backend) == HDF5Video assert type(Video.from_filename(TEST_SMALL_ROBOT_MP4_FILE).backend) == MediaVideo + def test_hdf5_get_shape(hdf5_vid): - assert(hdf5_vid.shape == (42, 512, 512, 1)) + assert hdf5_vid.shape == (42, 512, 512, 1) def test_hdf5_len(hdf5_vid): - assert(len(hdf5_vid) == 42) + assert len(hdf5_vid) == 42 def test_hdf5_dtype(hdf5_vid): - assert(hdf5_vid.dtype == np.uint8) + assert hdf5_vid.dtype == np.uint8 def test_hdf5_get_frame(hdf5_vid): - assert(hdf5_vid.get_frame(0).shape == (512, 512, 1)) + assert hdf5_vid.get_frame(0).shape == (512, 512, 1) def test_hdf5_get_frames(hdf5_vid): - assert(hdf5_vid.get_frames(0).shape == (1, 512, 512, 1)) - assert(hdf5_vid.get_frames([0,1]).shape == (2, 512, 512, 1)) + assert hdf5_vid.get_frames(0).shape == (1, 512, 512, 1) + assert hdf5_vid.get_frames([0, 1]).shape == (2, 512, 512, 1) def test_hdf5_get_item(hdf5_vid): - assert(hdf5_vid[0].shape == (1, 512, 512, 1)) - assert(np.alltrue(hdf5_vid[1:10:3] == hdf5_vid.get_frames([1, 4, 7]))) + assert hdf5_vid[0].shape == (1, 512, 512, 1) + assert np.alltrue(hdf5_vid[1:10:3] == hdf5_vid.get_frames([1, 4, 7])) + def test_hd5f_file_not_found(): with pytest.raises(FileNotFoundError): - Video.from_hdf5("non-existent-filename.h5", 'dataset_name') + Video.from_hdf5("non-existent-filename.h5", "dataset_name") + def test_mp4_get_shape(small_robot_mp4_vid): - assert(small_robot_mp4_vid.shape == (166, 320, 560, 3)) + assert small_robot_mp4_vid.shape == (166, 320, 560, 3) def test_mp4_len(small_robot_mp4_vid): - assert(len(small_robot_mp4_vid) == 166) + assert len(small_robot_mp4_vid) == 166 def test_mp4_dtype(small_robot_mp4_vid): - assert(small_robot_mp4_vid.dtype == np.uint8) + assert small_robot_mp4_vid.dtype == np.uint8 def test_mp4_get_frame(small_robot_mp4_vid): - assert(small_robot_mp4_vid.get_frame(0).shape == (320, 560, 3)) + assert small_robot_mp4_vid.get_frame(0).shape == (320, 560, 3) def test_mp4_get_frames(small_robot_mp4_vid): - assert(small_robot_mp4_vid.get_frames(0).shape == (1, 320, 560, 3)) - assert(small_robot_mp4_vid.get_frames([0,1]).shape == (2, 320, 560, 3)) + assert small_robot_mp4_vid.get_frames(0).shape == (1, 320, 560, 3) + assert small_robot_mp4_vid.get_frames([0, 1]).shape == (2, 320, 560, 3) def test_mp4_get_item(small_robot_mp4_vid): - assert(small_robot_mp4_vid[0].shape == (1, 320, 560, 3)) - assert(np.alltrue(small_robot_mp4_vid[1:10:3] == small_robot_mp4_vid.get_frames([1, 4, 7]))) + assert small_robot_mp4_vid[0].shape == (1, 320, 560, 3) + assert np.alltrue( + small_robot_mp4_vid[1:10:3] == small_robot_mp4_vid.get_frames([1, 4, 7]) + ) + def test_mp4_file_not_found(): with pytest.raises(FileNotFoundError): vid = Video.from_media("non-existent-filename.mp4") vid.channels + def test_numpy_frames(small_robot_mp4_vid): - clip_frames = small_robot_mp4_vid.get_frames((3,7,9)) + clip_frames = small_robot_mp4_vid.get_frames((3, 7, 9)) np_vid = Video.from_numpy(clip_frames) assert np.all(np.equal(np_vid.get_frame(1), small_robot_mp4_vid.get_frame(7))) -@pytest.mark.parametrize("format", ['png', 'jpg', "mjpeg/avi"]) + +@pytest.mark.parametrize("format", ["png", "jpg", "mjpeg/avi"]) def test_imgstore_video(small_robot_mp4_vid, tmpdir, format): - path = os.path.join(tmpdir, 'test_imgstore') + path = os.path.join(tmpdir, "test_imgstore") # If format is video, test saving all the frames. if format == "mjpeg/avi": - frame_indices = None + frame_indices = None else: frame_indices = [0, 1, 5] @@ -95,9 +104,13 @@ def test_imgstore_video(small_robot_mp4_vid, tmpdir, format): # video. if format == "png": # Check that the default format is "png" - imgstore_vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices) + imgstore_vid = small_robot_mp4_vid.to_imgstore( + path, frame_numbers=frame_indices + ) else: - imgstore_vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices, format=format) + imgstore_vid = small_robot_mp4_vid.to_imgstore( + path, frame_numbers=frame_indices, format=format + ) if frame_indices is None: assert small_robot_mp4_vid.num_frames == imgstore_vid.num_frames @@ -107,19 +120,21 @@ def test_imgstore_video(small_robot_mp4_vid, tmpdir, format): assert type(imgstore_vid.get_frame(i)) == np.ndarray else: - assert(imgstore_vid.num_frames == len(frame_indices)) + assert imgstore_vid.num_frames == len(frame_indices) # Make sure we can read arbitrary frames by imgstore frame number for i in frame_indices: assert type(imgstore_vid.get_frame(i)) == np.ndarray - assert(imgstore_vid.channels == 3) - assert(imgstore_vid.height == 320) - assert(imgstore_vid.width == 560) + assert imgstore_vid.channels == 3 + assert imgstore_vid.height == 320 + assert imgstore_vid.width == 560 # Check the image data is exactly the same when lossless is used. if format == "png": - assert np.allclose(imgstore_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91) + assert np.allclose( + imgstore_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91 + ) def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): @@ -127,11 +142,13 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): Test different types of indexing (by frame number or index) supported by only imgstore videos. """ - path = os.path.join(tmpdir, 'test_imgstore') + path = os.path.join(tmpdir, "test_imgstore") frame_indices = [20, 40, 15] - imgstore_vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices, index_by_original=False) + imgstore_vid = small_robot_mp4_vid.to_imgstore( + path, frame_numbers=frame_indices, index_by_original=False + ) # Index by frame index in imgstore frames = imgstore_vid.get_frames([0, 1, 2]) diff --git a/tests/io/test_visuals.py b/tests/io/test_visuals.py index 15c55e005..9887a38c4 100644 --- a/tests/io/test_visuals.py +++ b/tests/io/test_visuals.py @@ -1,11 +1,14 @@ import os from sleap.io.visuals import save_labeled_video + def test_write_visuals(tmpdir, centered_pair_predictions): - path = os.path.join(tmpdir, 'clip.avi') - save_labeled_video(filename=path, - labels=centered_pair_predictions, - video=centered_pair_predictions.videos[0], - frames=(0,1,2), - fps=15) - assert os.path.exists(path) \ No newline at end of file + path = os.path.join(tmpdir, "clip.avi") + save_labeled_video( + filename=path, + labels=centered_pair_predictions, + video=centered_pair_predictions.videos[0], + frames=(0, 1, 2), + fps=15, + ) + assert os.path.exists(path) diff --git a/tests/nn/test_datagen.py b/tests/nn/test_datagen.py index 97d0d5edd..a671d1a26 100644 --- a/tests/nn/test_datagen.py +++ b/tests/nn/test_datagen.py @@ -1,5 +1,6 @@ from sleap.nn.datagen import generate_images, generate_confidence_maps, generate_pafs + def test_datagen(min_labels): import numpy as np @@ -9,15 +10,14 @@ def test_datagen(min_labels): assert imgs.shape == (1, 384, 384, 1) assert imgs.dtype == np.dtype("float32") - assert math.isclose(np.ptp(imgs), .898, abs_tol=.01) + assert math.isclose(np.ptp(imgs), 0.898, abs_tol=0.01) confmaps = generate_confidence_maps(min_labels) assert confmaps.shape == (1, 384, 384, 2) assert confmaps.dtype == np.dtype("float32") - assert math.isclose(np.ptp(confmaps), .999, abs_tol=.01) - + assert math.isclose(np.ptp(confmaps), 0.999, abs_tol=0.01) pafs = generate_pafs(min_labels) assert pafs.shape == (1, 384, 384, 2) assert pafs.dtype == np.dtype("float32") - assert math.isclose(np.ptp(pafs), 1.57, abs_tol=.01) \ No newline at end of file + assert math.isclose(np.ptp(pafs), 1.57, abs_tol=0.01) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index b9c1b4088..d75f28dee 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -12,6 +12,7 @@ from sleap.io.dataset import Labels + def check_labels(labels): # Make sure there are 1100 frames @@ -33,15 +34,19 @@ def check_labels(labels): # FIXME: We need more checks here. + def test_load_old_json(): - labels = load_predicted_labels_json_old("tests/data/json_format_v1/centered_pair.json") + labels = load_predicted_labels_json_old( + "tests/data/json_format_v1/centered_pair.json" + ) check_labels(labels) - #Labels.save_json(labels, 'tests/data/json_format_v2/centered_pair_predictions.json') + # Labels.save_json(labels, 'tests/data/json_format_v2/centered_pair_predictions.json') + def test_save_load_json(centered_pair_predictions, tmpdir): - test_out_file = os.path.join(tmpdir, 'test_tmp.json') + test_out_file = os.path.join(tmpdir, "test_tmp.json") # Check the labels check_labels(centered_pair_predictions) @@ -54,20 +59,21 @@ def test_save_load_json(centered_pair_predictions, tmpdir): check_labels(new_labels) + def test_peaks_with_scaling(): # load from scratch so we won't change centered_pair_predictions - true_labels = Labels.load_json('tests/data/json_format_v1/centered_pair.json') + true_labels = Labels.load_json("tests/data/json_format_v1/centered_pair.json") # only use a few frames true_labels.labeled_frames = true_labels.labeled_frames[13:23:2] skeleton = true_labels.skeletons[0] imgs = generate_images(true_labels) # scaling - scale = .5 + scale = 0.5 transform = DataTransform() img_size = imgs.shape[1], imgs.shape[2] - scaled_size = int(imgs.shape[1]//(1/scale)), int(imgs.shape[2]//(1/scale)) + scaled_size = int(imgs.shape[1] // (1 / scale)), int(imgs.shape[2] // (1 / scale)) imgs = transform.scale_to(imgs, scaled_size) assert transform.scale == scale assert imgs.shape[1], imgs.shape[2] == scaled_size @@ -83,15 +89,24 @@ def test_peaks_with_scaling(): # make sure what we got from interence matches what we started with for i in range(len(new_labels.labeled_frames)): - assert len(true_labels.labeled_frames[i].instances) <= len(new_labels.labeled_frames[i].instances) + assert len(true_labels.labeled_frames[i].instances) <= len( + new_labels.labeled_frames[i].instances + ) # sort instances by location of thorax true_labels.labeled_frames[i].instances.sort(key=lambda inst: inst["thorax"]) new_labels.labeled_frames[i].instances.sort(key=lambda inst: inst["thorax"]) # make sure that each true instance has points matching one of the new instances - for inst_a, inst_b in zip(true_labels.labeled_frames[i].instances, new_labels.labeled_frames[i].instances): - + for inst_a, inst_b in zip( + true_labels.labeled_frames[i].instances, + new_labels.labeled_frames[i].instances, + ): + assert inst_a.get_points_array().shape == inst_b.get_points_array().shape # FIXME: new instances have nans, so for now just check first 5 points - assert np.allclose(inst_a.get_points_array()[0:5], inst_b.get_points_array()[0:5], atol=1/scale) + assert np.allclose( + inst_a.get_points_array()[0:5], + inst_b.get_points_array()[0:5], + atol=1 / scale, + ) diff --git a/tests/nn/test_tracking.py b/tests/nn/test_tracking.py index 7e56b1606..6f794f3e2 100644 --- a/tests/nn/test_tracking.py +++ b/tests/nn/test_tracking.py @@ -1,6 +1,7 @@ from sleap.nn.tracking import FlowShiftTracker from sleap.io.dataset import Labels + def test_flow_tracker(centered_pair_vid, centered_pair_predictions): # We are going to test tracking. The dataset we have loaded @@ -24,6 +25,7 @@ def test_flow_tracker(centered_pair_vid, centered_pair_predictions): # shouldn't be hard coded. assert len(tracks) == 24 + # def test_tracking_optflow_fail(centered_pair_vid, centered_pair_predictions): # frame_nums = range(0, len(centered_pair_predictions), 2) # labels = Labels([centered_pair_predictions[i] for i in frame_nums]) diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index 52f661507..24f4d8e5a 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -7,21 +7,29 @@ from sleap.nn.architectures.leap import leap_cnn from sleap.nn.training import Trainer, TrainingJob + def test_model_fail_non_available_backbone(multi_skel_vid_labels): with pytest.raises(ValueError): - Model(output_type=ModelOutputType.CONFIDENCE_MAP, backbone=object(), - skeletons=multi_skel_vid_labels.skeletons) + Model( + output_type=ModelOutputType.CONFIDENCE_MAP, + backbone=object(), + skeletons=multi_skel_vid_labels.skeletons, + ) @pytest.mark.parametrize("backbone", available_archs) def test_training_job_json(tmpdir, multi_skel_vid_labels, backbone): - run_name = 'training' + run_name = "training" - model = Model(output_type=ModelOutputType.CONFIDENCE_MAP, backbone=backbone(), - skeletons=multi_skel_vid_labels.skeletons) + model = Model( + output_type=ModelOutputType.CONFIDENCE_MAP, + backbone=backbone(), + skeletons=multi_skel_vid_labels.skeletons, + ) - train_run = TrainingJob(model=model, trainer=Trainer(), - save_dir=os.path.join(tmpdir), run_name=run_name) + train_run = TrainingJob( + model=model, trainer=Trainer(), save_dir=os.path.join(tmpdir), run_name=run_name + ) # Create and serialize training info json_path = os.path.join(tmpdir, f"{run_name}.json") @@ -39,4 +47,3 @@ def test_training_job_json(tmpdir, multi_skel_vid_labels, backbone): train_run.model.skeletons = [] assert loaded_run == train_run - diff --git a/tests/test_point_array.py b/tests/test_point_array.py index 49d452064..9410b52b5 100644 --- a/tests/test_point_array.py +++ b/tests/test_point_array.py @@ -3,8 +3,16 @@ from sleap.instance import Point, PredictedPoint, PointArray, PredictedPointArray -@pytest.mark.parametrize("p1", [Point(0.0, 0.0), PredictedPoint(0.0, 0.0, 0.0), - PointArray(3)[0], PredictedPointArray(3)[0]]) + +@pytest.mark.parametrize( + "p1", + [ + Point(0.0, 0.0), + PredictedPoint(0.0, 0.0, 0.0), + PointArray(3)[0], + PredictedPointArray(3)[0], + ], +) def test_point(p1): """ Test the Point and PredictedPoint API. This is mainly a safety @@ -40,15 +48,15 @@ def test_constructor(): assert p.score == 0.3 -@pytest.mark.parametrize('parray_cls', [PointArray, PredictedPointArray]) +@pytest.mark.parametrize("parray_cls", [PointArray, PredictedPointArray]) def test_point_array(parray_cls): p = parray_cls(5) # Make sure length works assert len(p) == 5 - assert len(p['x']) == 5 - assert len(p[['x', 'y']]) == 5 + assert len(p["x"]) == 5 + assert len(p[["x", "y"]]) == 5 # Check that single point getitem returns a Point class if parray_cls is PredictedPointArray: @@ -69,7 +77,10 @@ def test_point_array(parray_cls): # I have to convert from structured to unstructured to get this comparison # to work. from numpy.lib.recfunctions import structured_to_unstructured - np.testing.assert_array_equal(structured_to_unstructured(d1), structured_to_unstructured(d2)) + + np.testing.assert_array_equal( + structured_to_unstructured(d1), structured_to_unstructured(d2) + ) def test_from_and_to_array(): @@ -79,9 +90,11 @@ def test_from_and_to_array(): r = PredictedPointArray.to_array(PredictedPointArray.from_array(p)) from numpy.lib.recfunctions import structured_to_unstructured - np.testing.assert_array_equal(structured_to_unstructured(p), structured_to_unstructured(r)) + + np.testing.assert_array_equal( + structured_to_unstructured(p), structured_to_unstructured(r) + ) # Make sure conversion uses default score r = PredictedPointArray.from_array(p) assert r.score[0] == PredictedPointArray.make_default(1)[0].score - diff --git a/tests/test_rangelist.py b/tests/test_rangelist.py index 518049509..42da9e8f9 100644 --- a/tests/test_rangelist.py +++ b/tests/test_rangelist.py @@ -1,21 +1,26 @@ from sleap.rangelist import RangeList + def test_rangelist(): - a = RangeList([(1,2),(3,5),(7,13),(50,100)]) + a = RangeList([(1, 2), (3, 5), (7, 13), (50, 100)]) assert a.list == [(1, 2), (3, 5), (7, 13), (50, 100)] assert a.cut(8) == ([(1, 2), (3, 5), (7, 8)], [(8, 13), (50, 100)]) - assert a.cut_range((60,70)) == ([(1, 2), (3, 5), (7, 13), (50, 60)], [(60, 70)], [(70, 100)]) - assert a.insert((10,20)) == [(1, 2), (3, 5), (7, 20), (50, 100)] - assert a.insert((5,8)) == [(1, 2), (3, 20), (50, 100)] + assert a.cut_range((60, 70)) == ( + [(1, 2), (3, 5), (7, 13), (50, 60)], + [(60, 70)], + [(70, 100)], + ) + assert a.insert((10, 20)) == [(1, 2), (3, 5), (7, 20), (50, 100)] + assert a.insert((5, 8)) == [(1, 2), (3, 20), (50, 100)] - a.remove((5,8)) + a.remove((5, 8)) assert a.list == [(1, 2), (3, 5), (8, 20), (50, 100)] assert a.start == 1 - a.remove((1,3)) + a.remove((1, 3)) assert a.start == 3 - + b = RangeList() b.add(1) b.add(2) @@ -25,4 +30,4 @@ def test_rangelist(): b.add(9) b.add(10) - assert b.list == [(1, 3), (4, 7), (9, 11)] \ No newline at end of file + assert b.list == [(1, 3), (4, 7), (9, 11)] diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index e9f39f453..ebb88721b 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -64,7 +64,8 @@ def test_getitem_node(skeleton): skeleton["non_exist_node"] # Now try to get the head node - assert(skeleton["head"] is not None) + assert skeleton["head"] is not None + def test_contains_node(skeleton): """ @@ -86,17 +87,16 @@ def test_node_rename(skeleton): skeleton["head"] # Make sure new head has the correct name - assert(skeleton["new_head_name"] is not None) + assert skeleton["new_head_name"] is not None def test_eq(): s1 = Skeleton("s1") - s1.add_nodes(['1','2','3','4','5','6']) - s1.add_edge('1', '2') - s1.add_edge('3', '4') - s1.add_edge('5', '6') - s1.add_symmetry('3', '6') - + s1.add_nodes(["1", "2", "3", "4", "5", "6"]) + s1.add_edge("1", "2") + s1.add_edge("3", "4") + s1.add_edge("5", "6") + s1.add_symmetry("3", "6") # Make a copy check that they are equal s2 = copy.deepcopy(s1) @@ -104,22 +104,22 @@ def test_eq(): # Add an edge, check that they are not equal s2 = copy.deepcopy(s1) - s2.add_edge('5', '1') + s2.add_edge("5", "1") assert not s1.matches(s2) # Add a symmetry edge, not equal s2 = copy.deepcopy(s1) - s2.add_symmetry('5', '1') + s2.add_symmetry("5", "1") assert not s1.matches(s2) # Delete a node s2 = copy.deepcopy(s1) - s2.delete_node('5') + s2.delete_node("5") assert not s1.matches(s2) # Delete and edge, not equal s2 = copy.deepcopy(s1) - s2.delete_edge('1', '2') + s2.delete_edge("1", "2") assert not s1.matches(s2) # FIXME: Probably shouldn't test it this way. @@ -133,14 +133,15 @@ def test_eq(): # s2._graph.nodes['1']['test'] = 5 # assert s1 != s2 + def test_symmetry(): s1 = Skeleton("s1") - s1.add_nodes(['1','2','3','4','5','6']) - s1.add_edge('1', '2') - s1.add_edge('3', '4') - s1.add_edge('5', '6') - s1.add_symmetry('1', '5') - s1.add_symmetry('3', '6') + s1.add_nodes(["1", "2", "3", "4", "5", "6"]) + s1.add_edge("1", "2") + s1.add_edge("3", "4") + s1.add_edge("5", "6") + s1.add_symmetry("1", "5") + s1.add_symmetry("3", "6") assert s1.get_symmetry("1").name == "5" assert s1.get_symmetry("5").name == "1" @@ -149,22 +150,22 @@ def test_symmetry(): # Cannot add more than one symmetry to a node with pytest.raises(ValueError): - s1.add_symmetry('1', '6') + s1.add_symmetry("1", "6") with pytest.raises(ValueError): - s1.add_symmetry('6', '1') + s1.add_symmetry("6", "1") - s1.delete_symmetry('1', '5') + s1.delete_symmetry("1", "5") assert s1.get_symmetry("1") is None with pytest.raises(ValueError): - s1.delete_symmetry('1', '5') + s1.delete_symmetry("1", "5") def test_json(skeleton, tmpdir): """ Test saving and loading a Skeleton object in JSON. """ - JSON_TEST_FILENAME = os.path.join(tmpdir, 'skeleton.json') + JSON_TEST_FILENAME = os.path.join(tmpdir, "skeleton.json") # Save it to a JSON filename skeleton.save_json(JSON_TEST_FILENAME) @@ -173,11 +174,11 @@ def test_json(skeleton, tmpdir): skeleton_copy = Skeleton.load_json(JSON_TEST_FILENAME) # Make sure we get back the same skeleton we saved. - assert(skeleton.matches(skeleton_copy)) + assert skeleton.matches(skeleton_copy) def test_hdf5(skeleton, stickman, tmpdir): - filename = os.path.join(tmpdir, 'skeleton.h5') + filename = os.path.join(tmpdir, "skeleton.h5") if os.path.isfile(filename): os.remove(filename) @@ -209,7 +210,7 @@ def test_hdf5(skeleton, stickman, tmpdir): # Make sure we can't load a non-existent skeleton with pytest.raises(KeyError): - Skeleton.load_hdf5(filename, 'BadName') + Skeleton.load_hdf5(filename, "BadName") # Make sure we can't save skeletons with the same name with pytest.raises(ValueError): @@ -240,73 +241,83 @@ def dict_match(dict1, dict2): with pytest.raises(NotImplementedError): skeleton.name = "Test" + def test_graph_property(skeleton): assert [node for node in skeleton.graph.nodes()] == skeleton.nodes + def test_load_mat_format(): - skeleton = Skeleton.load_mat('tests/data/skeleton/leap_mat_format/skeleton_legs.mat') + skeleton = Skeleton.load_mat( + "tests/data/skeleton/leap_mat_format/skeleton_legs.mat" + ) # Check some stuff about the skeleton we loaded - assert(len(skeleton.nodes) == 24) - assert(len(skeleton.edges) == 23) + assert len(skeleton.nodes) == 24 + assert len(skeleton.edges) == 23 # The node and edge list that should be present in skeleton_legs.mat node_names = [ - 'head', - 'neck', - 'thorax', - 'abdomen', - 'wingL', - 'wingR', - 'forelegL1', - 'forelegL2', - 'forelegL3', - 'forelegR1', - 'forelegR2', - 'forelegR3', - 'midlegL1' , - 'midlegL2' , - 'midlegL3' , - 'midlegR1' , - 'midlegR2' , - 'midlegR3' , - 'hindlegL1', - 'hindlegL2', - 'hindlegL3', - 'hindlegR1', - 'hindlegR2', - 'hindlegR3'] + "head", + "neck", + "thorax", + "abdomen", + "wingL", + "wingR", + "forelegL1", + "forelegL2", + "forelegL3", + "forelegR1", + "forelegR2", + "forelegR3", + "midlegL1", + "midlegL2", + "midlegL3", + "midlegR1", + "midlegR2", + "midlegR3", + "hindlegL1", + "hindlegL2", + "hindlegL3", + "hindlegR1", + "hindlegR2", + "hindlegR3", + ] edges = [ - [ 2, 1], - [ 1, 0], - [ 2, 3], - [ 2, 4], - [ 2, 5], - [ 2, 6], - [ 6, 7], - [ 7, 8], - [ 2, 9], - [ 9, 10], - [10, 11], - [ 2, 12], - [12, 13], - [13, 14], - [ 2, 15], - [15, 16], - [16, 17], - [ 2, 18], - [18, 19], - [19, 20], - [ 2, 21], - [21, 22], - [22, 23]] + [2, 1], + [1, 0], + [2, 3], + [2, 4], + [2, 5], + [2, 6], + [6, 7], + [7, 8], + [2, 9], + [9, 10], + [10, 11], + [2, 12], + [12, 13], + [13, 14], + [2, 15], + [15, 16], + [16, 17], + [2, 18], + [18, 19], + [19, 20], + [2, 21], + [21, 22], + [22, 23], + ] assert [n.name for n in skeleton.nodes] == node_names # Check the edges and their order for i, edge in enumerate(skeleton.edge_names): - assert tuple(edges[i]) == (skeleton.node_to_index(edge[0]), skeleton.node_to_index(edge[1])) + assert tuple(edges[i]) == ( + skeleton.node_to_index(edge[0]), + skeleton.node_to_index(edge[1]), + ) + def test_edge_order(): """Test is edge list order is maintained upon insertion""" diff --git a/tests/test_util.py b/tests/test_util.py index f17429f0f..acd194d9e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,11 +6,13 @@ from sleap.util import attr_to_dtype, frame_list, weak_filename_match + def test_attr_to_dtype(): """ Test that we can convert classes with basic types to numpy composite dtypes. """ + @attr.s class TestAttr: a: int = attr.ib() @@ -30,10 +32,10 @@ class TestAttr3: c: Dict = attr.ib() # Dict should throw exception! dtype = attr_to_dtype(TestAttr) - dtype.fields['a'][0] == np.dtype(int) - dtype.fields['b'][0] == np.dtype(float) - dtype.fields['c'][0] == np.dtype(bool) - dtype.fields['d'][0] == np.dtype(object) + dtype.fields["a"][0] == np.dtype(int) + dtype.fields["b"][0] == np.dtype(float) + dtype.fields["c"][0] == np.dtype(bool) + dtype.fields["d"][0] == np.dtype(object) with pytest.raises(TypeError): attr_to_dtype(TestAttr2) @@ -41,17 +43,20 @@ class TestAttr3: with pytest.raises(TypeError): attr_to_dtype(TestAttr3) + def test_frame_list(): - assert frame_list("3-5") == [3,4,5] - assert frame_list("7,10") == [7,10] + assert frame_list("3-5") == [3, 4, 5] + assert frame_list("7,10") == [7, 10] + def test_weak_match(): assert weak_filename_match("one/two", "one/two") assert weak_filename_match( "M:\\code\\sandbox\\sleap_nas\\pilot_6pts\\tmp_11576_FoxP1_6pts.training.n=468.json.zip\\frame_data_vid0\\metadata.yaml", - "D:\\projects\\code\\sandbox\\sleap_nas\\pilot_6pts\\tmp_99713_FoxP1_6pts.training.n=468.json.zip\\frame_data_vid0\\metadata.yaml") - assert weak_filename_match("zero/one/two/three.mp4","other\\one\\two\\three.mp4") + "D:\\projects\\code\\sandbox\\sleap_nas\\pilot_6pts\\tmp_99713_FoxP1_6pts.training.n=468.json.zip\\frame_data_vid0\\metadata.yaml", + ) + assert weak_filename_match("zero/one/two/three.mp4", "other\\one\\two\\three.mp4") assert not weak_filename_match("one/two/three", "two/three") - assert not weak_filename_match("one/two/three.mp4","one/two/three.avi") - assert not weak_filename_match("foo.mp4","bar.mp4") + assert not weak_filename_match("one/two/three.mp4", "one/two/three.avi") + assert not weak_filename_match("foo.mp4", "bar.mp4") From 3cd512db6d55b74e15c9385b87b751d196c792d9 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 12:37:51 -0400 Subject: [PATCH 133/176] Refactor, add docstrings and tests. --- sleap/info/write_tracking_h5.py | 192 +++++++++++++++++++++++++------- tests/info/test_h5.py | 89 +++++++++++++++ 2 files changed, 238 insertions(+), 43 deletions(-) create mode 100644 tests/info/test_h5.py diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 2232823a1..17a192b8c 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -24,34 +24,41 @@ import h5py as h5 import numpy as np +from typing import Any, Dict, List, Tuple + from sleap.io.dataset import Labels -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("data_path", help="Path to labels json file") - parser.add_argument( - "--all-frames", - dest="all_frames", - action="store_const", - const=True, - default=False, - help="include all frames without predictions", - ) - args = parser.parse_args() +def get_tracks_as_np_strings(labels: Labels) -> List[np.string_]: + """Get list of track names as `np.string_`s.""" + return [np.string_(track.name) for track in labels.tracks] - video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) - labels = Labels.load_file(args.data_path, video_callback=video_callback) +def get_occupancy_and_points_matrices( + labels: Labels, all_frames: bool +) -> Tuple[np.ndarray, np.ndarray]: + """ + Builds numpy matrices with track occupancy and point location data. + + Args: + labels: The :class:`Labels` from which to get data. + all_frames: If True, then includes zeros so that frame index + will line up with columns in the output. Otherwise, + there will only be columns for the frames between the + first and last frames with labeling data. + + Returns: + tuple of two matrices: + * occupancy matrix with shape (tracks, frames) + * point location matrix with shape (frames, nodes, 2, tracks) + """ track_count = len(labels.tracks) - track_names = [np.string_(track.name) for track in labels.tracks] node_count = len(labels.skeletons[0].nodes) frame_idxs = [lf.frame_idx for lf in labels] frame_idxs.sort() - first_frame_idx = 0 if args.all_frames else frame_idxs[0] + first_frame_idx = 0 if all_frames else frame_idxs[0] frame_count = ( frame_idxs[-1] - first_frame_idx + 1 @@ -63,7 +70,7 @@ # "track_names" tracks occupancy_matrix = np.zeros((track_count, frame_count), dtype=np.uint8) - prediction_matrix = np.full( + locations_matrix = np.full( (frame_count, node_count, 2, track_count), np.nan, dtype=float ) @@ -74,39 +81,138 @@ occupancy_matrix[track_i, frame_i] = 1 inst_points = inst.points_array - prediction_matrix[frame_i, ..., track_i] = inst_points + locations_matrix[frame_i, ..., track_i] = inst_points + + return occupancy_matrix, locations_matrix + +def remove_empty_tracks_from_matrices( + track_names: List, occupancy_matrix: np.ndarray, locations_matrix: np.ndarray +) -> Tuple[List, np.ndarray, np.ndarray]: + """ + Removes matrix rows/columns for unoccupied tracks. + + Args: + track_names: List of track names + occupancy_matrix: 2d numpy matrix, rows correspond to tracks + locations_matrix: 4d numpy matrix, last index is track + + Returns: + track_names, occupancy_matrix, locations_matrix from input, + but without the rows/columns corresponding to unoccupied tracks. + """ + # Make mask with only the occupied tracks occupied_track_mask = np.sum(occupancy_matrix, axis=1) > 0 # Ignore unoccupied tracks if np.sum(~occupied_track_mask): + print(f"ignoring {np.sum(~occupied_track_mask)} empty tracks") + occupancy_matrix = occupancy_matrix[occupied_track_mask] - prediction_matrix = prediction_matrix[..., occupied_track_mask] + locations_matrix = locations_matrix[..., occupied_track_mask] track_names = [ track_names[i] for i in range(len(track_names)) if occupied_track_mask[i] ] - print(f"track_occupancy: {occupancy_matrix.shape}") - print(f"tracks: {prediction_matrix.shape}") - - output_filename = re.sub("(\.json(\.zip)?|\.h5)$", "", args.data_path) - output_filename = output_filename + ".tracking.h5" - - with h5.File(output_filename, "w") as f: - # We have to transpose the arrays since MATLAB expects column-major - ds = f.create_dataset("track_names", data=track_names) - ds = f.create_dataset( - "track_occupancy", - data=np.transpose(occupancy_matrix), - compression="gzip", - compression_opts=9, - ) - ds = f.create_dataset( - "tracks", - data=np.transpose(prediction_matrix), - compression="gzip", - compression_opts=9, - ) - - print(f"Saved as {output_filename}") + return track_names, occupancy_matrix, locations_matrix + + +def write_occupancy_file( + output_path: str, data_dict: Dict[str, Any], transpose: bool = True +): + """ + Write HDF5 file with data from given dictionary. + + Args: + output_path: Path of HDF5 file. + data_dict: Dictionary with data to save. Keys are dataset names, + values are the data. + transpose: If True, then any ndarray in data dictionary will be + transposed before saving. This is useful for writing files + that will be imported into MATLAB, which expects data in + column-major format. + + Returns: + None + """ + + with h5.File(output_path, "w") as f: + for key, val in data_dict.items(): + if isinstance(val, np.ndarray): + print(f"key: {val.shape}") + + if transpose: + # Transpose since MATLAB expects column-major + f.create_dataset( + key, + data=np.transpose(val), + compression="gzip", + compression_opts=9, + ) + else: + f.create_dataset( + key, data=val, compression="gzip", compression_opts=9 + ) + else: + f.create_dataset(key, data=val) + + print(f"Saved as {output_path}") + + +def main(labels: Labels, output_path: str, all_frames: bool = True): + """ + Writes HDF5 file with matrices of track occupancy and coordinates. + + Args: + labels: The :class:`Labels` from which to get data. + output_path: Path of HDF5 file to create. + all_frames: If True, then includes zeros so that frame index + will line up with columns in the output. Otherwise, + there will only be columns for the frames between the + first and last frames with labeling data. + + Returns: + None + """ + track_names = get_tracks_as_np_strings(labels) + + occupancy_matrix, locations_matrix = get_occupancy_and_predictions_matrices( + labels, all_frames + ) + + track_names, occupancy_matrix, locations_matrix = remove_empty_tracks_from_matrices( + track_names, occupancy_matrix, locations_matrix + ) + + data_dict = dict( + track_names=track_names, + tracks=locations_matrix, + track_occupancy=occupancy_matrix, + ) + + write_occupancy_file(output_path, data_dict, transpose=True) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("data_path", help="Path to labels json file") + parser.add_argument( + "--all-frames", + dest="all_frames", + action="store_const", + const=True, + default=False, + help="include all frames without predictions", + ) + args = parser.parse_args() + + video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) + labels = Labels.load_file(args.data_path, video_callback=video_callback) + + output_path = re.sub("(\.json(\.zip)?|\.h5)$", "", args.data_path) + output_path = output_path + ".tracking.h5" + + main(labels, output_path=output_path, all_frames=args.all_frames) diff --git a/tests/info/test_h5.py b/tests/info/test_h5.py new file mode 100644 index 000000000..edcfe0208 --- /dev/null +++ b/tests/info/test_h5.py @@ -0,0 +1,89 @@ +import os + +import h5py +import numpy as np + +from sleap.info.write_tracking_h5 import ( + get_tracks_as_np_strings, + get_occupancy_and_points_matrices, + remove_empty_tracks_from_matrices, + write_occupancy_file, +) + + +def test_output_matrices(centered_pair_predictions): + + names = get_tracks_as_np_strings(centered_pair_predictions) + assert len(names) == 27 + assert isinstance(names[0], np.string_) + + # Remove the first labeled frame + del centered_pair_predictions[0] + assert len(centered_pair_predictions) == 1099 + + occupancy, points = get_occupancy_and_points_matrices( + centered_pair_predictions, all_frames=False + ) + + assert occupancy.shape == (27, 1099) + assert points.shape == (1099, 24, 2, 27) + + # Make sure "all_frames" includes the missing initial frame + occupancy, points = get_occupancy_and_points_matrices( + centered_pair_predictions, all_frames=True + ) + + assert occupancy.shape == (27, 1100) + assert points.shape == (1100, 24, 2, 27) + + # Make sure removing empty tracks doesn't yet change anything + names, occupancy, points = remove_empty_tracks_from_matrices( + names, occupancy, points + ) + + assert len(names) == 27 + assert occupancy.shape == (27, 1100) + assert points.shape == (1100, 24, 2, 27) + + # Remove all instances from track 13 + vid = centered_pair_predictions.videos[0] + track = centered_pair_predictions.tracks[13] + lfs_insts = centered_pair_predictions.find_track_occupancy(vid, track) + for lf, instance in lfs_insts: + centered_pair_predictions.remove_instance(lf, instance) + + # Make sure that this now remove empty track + occupancy, points = get_occupancy_and_points_matrices( + centered_pair_predictions, all_frames=True + ) + names, occupancy, points = remove_empty_tracks_from_matrices( + names, occupancy, points + ) + + assert len(names) == 26 + assert occupancy.shape == (26, 1100) + assert points.shape == (1100, 24, 2, 26) + + +def test_hdf5_saving(tmpdir): + path = os.path.join(tmpdir, "occupany.h5") + + x = np.array([[1, 2, 6], [3, 4, 5]]) + data_dict = dict(x=x) + + write_occupancy_file(path, data_dict, transpose=False) + + with h5py.File(path, "r") as f: + assert f["x"].shape == x.shape + + +def test_hdf5_tranposed_saving(tmpdir): + path = os.path.join(tmpdir, "transposed.h5") + + x = np.array([[1, 2, 6], [3, 4, 5]]) + data_dict = dict(x=x) + + write_occupancy_file(path, data_dict, transpose=True) + + with h5py.File(path, "r") as f: + assert f["x"].shape == np.transpose(x).shape From bc467b0f632b563e54fdbe88a1fb2ecaab657aae Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 12:59:27 -0400 Subject: [PATCH 134/176] Add docstring. --- sleap/info/labels.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sleap/info/labels.py b/sleap/info/labels.py index 5d5b8c7eb..e99d2b900 100644 --- a/sleap/info/labels.py +++ b/sleap/info/labels.py @@ -1,3 +1,6 @@ +""" +Command line utility which prints data about labels file. +""" import os from sleap.io.dataset import Labels From fd795d24c2907101e36081ff06b94d6954fffb81 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 15:12:07 -0400 Subject: [PATCH 135/176] Rename, add docstrings and tests. --- sleap/gui/app.py | 4 +- sleap/info/summary.py | 95 ++++++++++++++++++++++++++++++++++---- tests/fixtures/datasets.py | 78 ++++++++++++++++++++++++++++++- 3 files changed, 165 insertions(+), 12 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index d63f15b51..c8e3a85de 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -35,7 +35,7 @@ from sleap.instance import Instance, PredictedInstance, Point, LabeledFrame, Track from sleap.io.video import Video from sleap.io.dataset import Labels -from sleap.info.summary import Summary +from sleap.info.summary import StatisticSeries from sleap.gui.video import QtVideoPlayer from sleap.gui.dataviews import ( VideosTable, @@ -1029,7 +1029,7 @@ def updateSeekbarMarks(self): self.player.seekbar.setTracksFromLabels(self.labels, self.video) def setSeekbarHeader(self, graph_name): - data_obj = Summary(self.labels) + data_obj = StatisticSeries(self.labels) header_functions = { "Point Displacement (sum)": data_obj.get_point_displacement_series, "Point Displacement (max)": data_obj.get_point_displacement_series, diff --git a/sleap/info/summary.py b/sleap/info/summary.py index 9f25b2305..b907aeb09 100644 --- a/sleap/info/summary.py +++ b/sleap/info/summary.py @@ -1,12 +1,33 @@ +""" +Module for getting a series which gives some statistic based on labeling +data for each frame of some labeled video. +""" + import attr import numpy as np +from typing import Callable, Dict + +from sleap.io.dataset import Labels +from sleap.io.video import Video + @attr.s(auto_attribs=True) -class Summary: - labels: "Labels" +class StatisticSeries: + """ + Class to calculate various statistical series for labeled frames. + + Each method returns a series which is a dictionary in which keys + are frame index and value are some numerical value for the frame. - def get_point_count_series(self, video): + Args: + labels: The :class:`Labels` for which to calculate series. + """ + + labels: Labels + + def get_point_count_series(self, video: Video) -> Dict[int, float]: + """Get series with total number of labeled points in each frame.""" series = dict() for lf in self.labels.find(video): @@ -14,7 +35,20 @@ def get_point_count_series(self, video): series[lf.frame_idx] = val return series - def get_point_score_series(self, video, reduction="sum"): + def get_point_score_series( + self, video: Video, reduction: str = "sum" + ) -> Dict[int, float]: + """Get series with statistic of point scores in each frame. + + Args: + video: The :class:`Video` for which to calculate statistic. + reduction: name of function applied to scores: + * sum + * min + + Returns: + The series dictionary (see class docs for details) + """ reduce_funct = dict(sum=sum, min=lambda x: min(x, default=0))[reduction] series = dict() @@ -29,7 +63,18 @@ def get_point_score_series(self, video, reduction="sum"): series[lf.frame_idx] = val return series - def get_instance_score_series(self, video, reduction="sum"): + def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]: + """Get series with statistic of instance scores in each frame. + + Args: + video: The :class:`Video` for which to calculate statistic. + reduction: name of function applied to scores: + * sum + * min + + Returns: + The series dictionary (see class docs for details) + """ reduce_funct = dict(sum=sum, min=lambda x: min(x, default=0))[reduction] series = dict() @@ -39,7 +84,24 @@ def get_instance_score_series(self, video, reduction="sum"): series[lf.frame_idx] = val return series - def get_point_displacement_series(self, video, reduction="sum"): + def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, float]: + """ + Get series with statistic of point displacement in each frame. + + Point displacement is the distance between the point location in + frame and the location of the corresponding point (same node, + same track) from the closest earlier labeled frame. + + Args: + video: The :class:`Video` for which to calculate statistic. + reduction: name of function applied to point scores: + * sum + * mean + * max + + Returns: + The series dictionary (see class docs for details) + """ reduce_funct = dict(sum=np.sum, mean=np.nanmean, max=np.max)[reduction] series = dict() @@ -53,14 +115,29 @@ def get_point_displacement_series(self, video, reduction="sum"): return series @staticmethod - def _calculate_frame_velocity(lf, last_lf, reduce_function): + def _calculate_frame_velocity( + lf: "LabeledFrame", last_lf: "LabeledFrame", reduce_function: Callable + ) -> float: + """ + Calculate total point displacement between two given frames. + + Args: + lf: The :class:`LabeledFrame` for which we want velocity + last_lf: The frame from which to calculate displacement. + reduce_function: Numpy function (e.g., np.sum, np.nanmean) + is applied to *point* displacement, and then those + instance values are summed for the whole frame. + + Returns: + The total velocity for instances in frame. + """ val = 0 for inst in lf: if last_lf is not None: last_inst = last_lf.find(track=inst.track) if last_inst: - points_a = inst.visible_points_array - points_b = last_inst[0].visible_points_array + points_a = inst.points_array + points_b = last_inst[0].points_array point_dist = np.linalg.norm(points_a - points_b, axis=1) inst_dist = reduce_function(point_dist) val += inst_dist if not np.isnan(inst_dist) else 0 diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index fa1a0150b..9bb5db7ea 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,8 +1,17 @@ import os import pytest -from sleap.instance import Instance, Point, LabeledFrame, Track +from sleap.instance import ( + Instance, + PredictedInstance, + Point, + PredictedPoint, + LabeledFrame, + Track, +) +from sleap.skeleton import Skeleton from sleap.io.dataset import Labels +from sleap.io.video import Video TEST_JSON_LABELS = "tests/data/json_format_v1/centered_pair.json" TEST_JSON_PREDICTIONS = "tests/data/json_format_v2/centered_pair_predictions.json" @@ -30,6 +39,73 @@ def mat_labels(): return Labels.load_mat(TEST_MAT_LABELS) +@pytest.fixture +def simple_predictions(): + + video = Video.from_filename("video.mp4") + + skeleton = Skeleton() + skeleton.add_node("a") + skeleton.add_node("b") + + track_a = Track(0, "a") + track_b = Track(0, "b") + + labels = Labels() + + instances = [] + instances.append( + PredictedInstance( + skeleton=skeleton, + score=2, + track=track_a, + points=dict( + a=PredictedPoint(1, 1, score=0.5), b=PredictedPoint(1, 1, score=0.5) + ), + ) + ) + instances.append( + PredictedInstance( + skeleton=skeleton, + score=5, + track=track_b, + points=dict( + a=PredictedPoint(1, 1, score=0.7), b=PredictedPoint(1, 1, score=0.7) + ), + ) + ) + + labeled_frame = LabeledFrame(video, frame_idx=0, instances=instances) + labels.append(labeled_frame) + + instances = [] + instances.append( + PredictedInstance( + skeleton=skeleton, + score=3, + track=track_a, + points=dict( + a=PredictedPoint(4, 5, score=1.5), b=PredictedPoint(1, 1, score=1.0) + ), + ) + ) + instances.append( + PredictedInstance( + skeleton=skeleton, + score=6, + track=track_b, + points=dict( + a=PredictedPoint(6, 13, score=1.7), b=PredictedPoint(1, 1, score=1.0) + ), + ) + ) + + labeled_frame = LabeledFrame(video, frame_idx=1, instances=instances) + labels.append(labeled_frame) + + return labels + + @pytest.fixture def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman): """ From b3b9b5064f6b7b440c20ed577f56646b9d827d7d Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 15:22:19 -0400 Subject: [PATCH 136/176] Add Track to module docstring. --- sleap/instance.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sleap/instance.py b/sleap/instance.py index 10140bebf..82116d861 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -9,6 +9,8 @@ * `Instance`s (and `PredictedInstance`s) have `PointArray` (or `PredictedPointArray`). +* `Instance` (`PredictedInstance`) can be associated with a `Track` + * A `PointArray` (or `PredictedPointArray`) contains zero or more `Point`s (or `PredictedPoint`s), ideally as many as there are in the associated :class:`Skeleton` although these can get out of sync if the From d87956ef521a1853aeed720c585a784f661a127d Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 26 Sep 2019 15:22:49 -0400 Subject: [PATCH 137/176] Docstring edits. --- sleap/io/legacy.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sleap/io/legacy.py b/sleap/io/legacy.py index 8ddfe7fad..f1af21222 100644 --- a/sleap/io/legacy.py +++ b/sleap/io/legacy.py @@ -1,15 +1,16 @@ +""" +Module for legacy LEAP dataset. +""" import json import os import numpy as np import pandas as pd -from .dataset import Labels -from .video import Video +from sleap.io.dataset import Labels +from sleap.io.video import Video -from ..instance import LabeledFrame, PredictedPoint, PredictedInstance -from ..skeleton import Skeleton - -from ..nn.tracking import Track +from sleap.instance import LabeledFrame, PredictedPoint, PredictedInstance, Track +from sleap.skeleton import Skeleton def load_predicted_labels_json_old( @@ -24,9 +25,10 @@ def load_predicted_labels_json_old( Args: data_path: The path to the JSON file. - parsed_json: The parsed json if already loaded. Save some time if already parsed. - adjust_matlab_indexing: Do we need to adjust indexing from MATLAB. - fix_rel_paths: Fix paths to videos to absolute paths. + parsed_json: The parsed json if already loaded, so we can save + some time if already parsed. + adjust_matlab_indexing: Whether to adjust indexing from MATLAB. + fix_rel_paths: Whether to fix paths to videos to absolute paths. Returns: A newly constructed Labels object. From 8bebdd91621f2b9eb1418e046e6fb0f7be3769d0 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 09:41:56 -0400 Subject: [PATCH 138/176] Better docstrings, typing, minor refactoring. --- sleap/io/dataset.py | 665 ++++++++++++++++++++----------------- sleap/io/legacy.py | 143 +++++++- sleap/util.py | 42 ++- tests/info/test_h5.py | 6 +- tests/info/test_summary.py | 42 +++ tests/io/test_dataset.py | 3 +- tests/nn/test_inference.py | 5 +- tests/test_util.py | 13 +- 8 files changed, 593 insertions(+), 326 deletions(-) create mode 100644 tests/info/test_summary.py diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index f2e3665c8..fb33847c5 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1,23 +1,18 @@ -"""A LEAP Dataset represents annotated (labeled) video data. - -A LEAP Dataset stores almost all data required for training of a model. -This includes, raw video frame data, labelled instances of skeleton _points, -confidence maps, part affinity fields, and skeleton data. A LEAP :class:`.Dataset` -is a high level API to these data structures that abstracts away their underlying -storage format. +""" +A SLEAP dataset collects labeled video frames. +This contains labeled frame data (user annotations and/or predictions), +together with all the other data that is saved for a SLEAP project +(videos, skeletons, negative training sample anchors, etc.). """ import os import re import zipfile import atexit -import glob import attr import cattr -import json -import rapidjson import shutil import tempfile import numpy as np @@ -25,7 +20,7 @@ import h5py as h5 from collections import MutableSequence -from typing import List, Union, Dict, Optional, Tuple +from typing import Callable, List, Union, Dict, Optional try: from typing import ForwardRef @@ -46,38 +41,11 @@ PointArray, PredictedPointArray, ) -from sleap.rangelist import RangeList -from sleap.io.video import Video -from sleap.util import uniquify, weak_filename_match - - -def json_loads(json_str: str): - try: - return rapidjson.loads(json_str) - except: - return json.loads(json_str) - - -def json_dumps(d: Dict, filename: str = None): - """ - A simple wrapper around the JSON encoder we are using. - - Args: - d: The dict to write. - f: The filename to write to. - - Returns: - None - """ - import codecs - encoder = rapidjson - - if filename: - with open(filename, "w") as f: - encoder.dump(d, f, ensure_ascii=False) - else: - return encoder.dumps(d) +from sleap.io.legacy import load_labels_json_old +from sleap.io.video import Video +from sleap.rangelist import RangeList +from sleap.util import uniquify, weak_filename_match, json_dumps, json_loads """ @@ -89,22 +57,30 @@ def json_dumps(d: Dict, filename: str = None): @attr.s(auto_attribs=True) class Labels(MutableSequence): """ - The LEAP :class:`.Labels` class represents an API for accessing labeled video - frames and other associated metadata. This class is front-end for all - interactions with loading, writing, and modifying these labels. The actual - storage backend for the data is mostly abstracted away from the main - interface. - - Args: - labeled_frames: A list of `LabeledFrame`s - videos: A list of videos that these labels may or may not reference. - That is, every LabeledFrame's video will be in videos but a Video - object from videos might not have any LabeledFrame. - skeletons: A list of skeletons that these labels may or may not reference. - tracks: A list of tracks that instances can belong to. - suggestions: A dict with a list for each video of suggested frames to label. - negative_anchors: A dict with list of anchor coordinates - for negative training samples for each video. + The :class:`Labels` class collects the data for a SLEAP project. + + This class is front-end for all interactions with loading, writing, + and modifying these labels. The actual storage backend for the data + is mostly abstracted away from the main interface. + + Attributes: + labeled_frames: A list of :class:`LabeledFrame`s + videos: A list of :class:`Video`s that these labels may or may + not reference. The video for every `LabeledFrame` will be + stored in :attribute:`Labels.videos`, but some videos in + this list may not have any associated labeled frames. + skeletons: A list of :class:`Skeleton`s (again, that may or may + not be referenced by an :class:`Instance` in labeled frame). + tracks: A list of :class:`Track`s that instances can belong to. + suggestions: Dictionary that stores "suggested" frames for + videos in project. These can be suggested frames for user + to label or suggested frames for user to review. + Dictionary key is :class:`Video`, value is list of frame + indices. + negative_anchors: Dictionary that stores center-points around + which to crop as negative samples when training. + Dictionary key is :class:`Video`, value is list of + (frame index, x, y) tuples. """ labeled_frames: List[LabeledFrame] = attr.ib(default=attr.Factory(list)) @@ -116,23 +92,32 @@ class Labels(MutableSequence): negative_anchors: Dict[Video, list] = attr.ib(default=attr.Factory(dict)) def __attrs_post_init__(self): + """ + Called by attrs after the class is instantiated. + + This updates the top level contains (videos, skeletons, etc) + from data in the labeled frames, as well as various caches. + """ # Add any videos/skeletons/nodes/tracks that are in labeled # frames but not in the lists on our object self._update_from_labels() # Update caches used to find frames by frame index - self._update_lookup_cache() + self._build_lookup_caches() # Create a variable to store a temporary storage directory # used when we unzip self.__temp_dir = None - def _update_from_labels(self, merge=False): - """Update top level attributes with data from labeled frames. + def _update_from_labels(self, merge: bool = False): + """Updates top level attributes with data from labeled frames. Args: - merge: if True, then update even if there's already data + merge: If True, then update even if there's already data. + + Returns: + None. """ # Add any videos that are present in the labels but @@ -194,7 +179,38 @@ def _update_from_labels(self, merge=False): self.tracks.extend(new_tracks) - def _update_lookup_cache(self): + def _update_containers(self, new_label: LabeledFrame): + """ Ensure that top-level containers are kept updated with new + instances of objects that come along with new labels. """ + + if new_label.video not in self.videos: + self.videos.append(new_label.video) + + for skeleton in {instance.skeleton for instance in new_label}: + if skeleton not in self.skeletons: + self.skeletons.append(skeleton) + for node in skeleton.nodes: + if node not in self.nodes: + self.nodes.append(node) + + # Add any new Tracks as well + for instance in new_label.instances: + if instance.track and instance.track not in self.tracks: + self.tracks.append(instance.track) + + # Sort the tracks again + self.tracks.sort(key=lambda t: (t.spawned_on, t.name)) + + # Update cache datastructures + if new_label.video not in self._lf_by_video: + self._lf_by_video[new_label.video] = [] + if new_label.video not in self._frame_idx_map: + self._frame_idx_map[new_label.video] = dict() + self._lf_by_video[new_label.video].append(new_label) + self._frame_idx_map[new_label.video][new_label.frame_idx] = new_label + + def _build_lookup_caches(self): + """Builds (or rebuilds) various caches.""" # Data structures for caching self._lf_by_video = dict() self._frame_idx_map = dict() @@ -214,16 +230,30 @@ def _update_lookup_cache(self): @property def labels(self): - """ Alias for labeled_frames """ + """Alias for labeled_frames.""" return self.labeled_frames - def __len__(self): + def __len__(self) -> int: + """Returns number of labeled frames.""" return len(self.labeled_frames) - def index(self, value): + def index(self, value) -> int: + """Returns index of labeled frame in list of labeled frames.""" return self.labeled_frames.index(value) - def __contains__(self, item): + def __contains__(self, item) -> bool: + """ + Checks if object contains the given item. + + Args: + item: The item to look for within `Labels`. + This can be :class:`LabeledFrame`, + :class:`Video`, :class:`Skeleton`, + :class:`Node`, or (:class:`Video`, frame idx) tuple. + + Returns: + True if item is found. + """ if isinstance(item, LabeledFrame): return item in self.labeled_frames elif isinstance(item, Video): @@ -240,7 +270,18 @@ def __contains__(self, item): ): return self.find_first(*item) is not None - def __getitem__(self, key): + def __getitem__(self, key) -> List[LabeledFrame]: + """Returns labeled frames matching key. + + Args: + key: `Video` or (`Video`, frame index) to match against. + + Raises: + KeyError: If labeled frame for `Video` or frame index + cannot be found. + + Returns: A list with the matching labeled frame(s). + """ if isinstance(key, int): return self.labels.__getitem__(key) @@ -269,41 +310,13 @@ def __getitem__(self, key): raise KeyError("Invalid label indexing arguments.") def __setitem__(self, index, value: LabeledFrame): + """Sets labeled frame at given index.""" # TODO: Maybe we should remove this method altogether? self.labeled_frames.__setitem__(index, value) self._update_containers(value) - def _update_containers(self, new_label: LabeledFrame): - """ Ensure that top-level containers are kept updated with new - instances of objects that come along with new labels. """ - - if new_label.video not in self.videos: - self.videos.append(new_label.video) - - for skeleton in {instance.skeleton for instance in new_label}: - if skeleton not in self.skeletons: - self.skeletons.append(skeleton) - for node in skeleton.nodes: - if node not in self.nodes: - self.nodes.append(node) - - # Add any new Tracks as well - for instance in new_label.instances: - if instance.track and instance.track not in self.tracks: - self.tracks.append(instance.track) - - # Sort the tracks again - self.tracks.sort(key=lambda t: (t.spawned_on, t.name)) - - # Update cache datastructures - if new_label.video not in self._lf_by_video: - self._lf_by_video[new_label.video] = [] - if new_label.video not in self._frame_idx_map: - self._frame_idx_map[new_label.video] = dict() - self._lf_by_video[new_label.video].append(new_label) - self._frame_idx_map[new_label.video][new_label.frame_idx] = new_label - def insert(self, index, value: LabeledFrame): + """Inserts labeled frame at given index.""" if value in self or (value.video, value.frame_idx) in self: return @@ -311,12 +324,15 @@ def insert(self, index, value: LabeledFrame): self._update_containers(value) def append(self, value: LabeledFrame): + """Adds labeled frame to list of labeled frames.""" self.insert(len(self) + 1, value) def __delitem__(self, key): + """Removes labeled frame with given index.""" self.labeled_frames.remove(self.labeled_frames[key]) def remove(self, value: LabeledFrame): + """Removes given labeled frame.""" self.labeled_frames.remove(value) self._lf_by_video[value.video].remove(value) del self._frame_idx_map[value.video][value.frame_idx] @@ -324,19 +340,25 @@ def remove(self, value: LabeledFrame): def find( self, video: Video, - frame_idx: Union[int, range] = None, + frame_idx: Optional[Union[int, range]] = None, return_new: bool = False, ) -> List[LabeledFrame]: """ Search for labeled frames given video and/or frame index. Args: - video: a `Video` instance that is associated with the labeled frames - frame_idx: an integer specifying the frame index within the video - return_new: return singleton of new `LabeledFrame` if none found? + video: A :class:`Video` that is associated with the project. + frame_idx: The frame index (or indices) which we want to + find in the video. If a range is specified, we'll return + all frames with indices in that range. If not specific, + then we'll return all labeled frames for video. + return_new: Whether to return singleton of new and empty + :class:`LabeledFrame` if none is found in project. Returns: - List of `LabeledFrame`s that match the criteria. Empty if no matches found. - + List of `LabeledFrame`s that match the criteria. + Empty if no matches found, unless return_new is True, + in which case it contains a new `LabeledFrame` with + `video` and `frame_index` set. """ null_result = ( [LabeledFrame(video=video, frame_idx=frame_idx)] if return_new else [] @@ -364,8 +386,16 @@ def find( def frames(self, video: Video, from_frame_idx: int = -1, reverse=False): """ - Iterator over all frames in a video, starting with first frame - after specified frame_idx (or first frame in video if none specified). + Iterator over all labeled frames in a video. + + Args: + video: A :class:`Video` that is associated with the project. + from_frame_idx: The frame index from which we want to start. + Defaults to the first frame of video. + reverse: Whether to iterate over frames in reverse order. + + Yields: + :class:`LabeledFrame` """ if video not in self._frame_idx_map: return None @@ -391,15 +421,23 @@ def frames(self, video: Video, from_frame_idx: int = -1, reverse=False): for idx in frame_idxs: yield self._frame_idx_map[video][idx] - def find_first(self, video: Video, frame_idx: int = None) -> LabeledFrame: - """ Find the first occurrence of a labeled frame for the given video and/or frame index. + def find_first( + self, video: Video, frame_idx: Optional[int] = None + ) -> Optional[LabeledFrame]: + """ + Finds the first occurrence of a matching labeled frame. + + Matches on frames for the given video and/or frame index. Args: - video: a `Video` instance that is associated with the labeled frames - frame_idx: an integer specifying the frame index within the video + video: a `Video` instance that is associated with the + labeled frames + frame_idx: an integer specifying the frame index within + the video Returns: - First `LabeledFrame` that match the criteria or None if none were found. + First `LabeledFrame` that match the criteria + or None if none were found. """ if video in self.videos: @@ -409,15 +447,23 @@ def find_first(self, video: Video, frame_idx: int = None) -> LabeledFrame: ): return label - def find_last(self, video: Video, frame_idx: int = None) -> LabeledFrame: - """ Find the last occurrence of a labeled frame for the given video and/or frame index. + def find_last( + self, video: Video, frame_idx: Optional[int] = None + ) -> Optional[LabeledFrame]: + """ + Finds the last occurrence of a matching labeled frame. + + Matches on frames for the given video and/or frame index. Args: - video: A `Video` instance that is associated with the labeled frames - frame_idx: An integer specifying the frame index within the video + video: a `Video` instance that is associated with the + labeled frames + frame_idx: an integer specifying the frame index within + the video Returns: - LabeledFrame: Last label that matches the criteria or None if no results. + Last `LabeledFrame` that match the criteria + or None if none were found. """ if video in self.videos: @@ -429,9 +475,15 @@ def find_last(self, video: Video, frame_idx: int = None) -> LabeledFrame: @property def user_labeled_frames(self): + """ + Returns all labeled frames with user (non-predicted) instances. + """ return [lf for lf in self.labeled_frames if lf.has_user_instances] def get_video_user_labeled_frames(self, video: Video) -> List[LabeledFrame]: + """ + Returns labeled frames for given video with user instances. + """ return [ lf for lf in self.labeled_frames @@ -441,6 +493,7 @@ def get_video_user_labeled_frames(self, video: Video) -> List[LabeledFrame]: # Methods for instances def instance_count(self, video: Video, frame_idx: int) -> int: + """Returns number of instances matching video/frame index.""" count = 0 labeled_frame = self.find_first(video, frame_idx) if labeled_frame is not None: @@ -451,14 +504,17 @@ def instance_count(self, video: Video, frame_idx: int) -> int: @property def all_instances(self): + """Returns list of all instances.""" return list(self.instances()) @property def user_instances(self): + """Returns list of all user (non-predicted) instances.""" return [inst for inst in self.all_instances if type(inst) == Instance] def instances(self, video: Video = None, skeleton: Skeleton = None): - """ Iterate through all instances in the labels, optionally with filters. + """ + Iterate over instances in the labels, optionally with filters. Args: video: Only iterate through instances in this video @@ -475,19 +531,22 @@ def instances(self, video: Video = None, skeleton: Skeleton = None): # Methods for tracks - def get_track_occupany(self, video: Video): + def get_track_occupany(self, video: Video) -> List: + """Returns track occupancy list for given video""" try: return self._track_occupancy[video] except: return [] def add_track(self, video: Video, track: Track): + """Adds track to labels, updating occupancy.""" self.tracks.append(track) self._track_occupancy[video][track] = RangeList() def track_set_instance( self, frame: LabeledFrame, instance: Instance, new_track: Track ): + """Sets track on given instance, updating occupancy.""" self.track_swap( frame.video, new_track, @@ -499,9 +558,31 @@ def track_set_instance( instance.track = new_track def track_swap( - self, video: Video, new_track: Track, old_track: Track, frame_range: tuple + self, + video: Video, + new_track: Track, + old_track: Optional[Track], + frame_range: tuple, ): + """ + Swaps track assignment for instances in two tracks. + + If you need to change the track to or from None, you'll need + to use :method:`Labels.track_set_instance()` for each specific + instance you want to modify. + Args: + video: The :class:`Video` for which we want to swap tracks. + new_track: A :class:`Track` for which we want to swap + instances with another track. + old_track: The other :class:`Track` for swapping. + frame_range: Tuple of (start, end) frame indexes. + If you want to swap tracks on a single frame, use + (frame index, frame index + 1). + + Returns: + None. + """ # Get ranges in track occupancy cache _, within_old, _ = self._track_occupancy[video][old_track].cut_range( frame_range @@ -526,35 +607,33 @@ def track_swap( new_track_instances = self.find_track_occupancy(video, new_track, frame_range) # swap new to old tracks on all instances - for frame, instance in old_track_instances: + for instance in old_track_instances: instance.track = new_track # old_track can be `Track` or int # If int, it's index in instance list which we'll use as a pseudo-track, # but we won't set instances currently on new_track to old_track. if type(old_track) == Track: - for frame, instance in new_track_instances: + for instance in new_track_instances: instance.track = old_track def _track_remove_instance(self, frame: LabeledFrame, instance: Instance): + """Manipulates track occupancy cache.""" if instance.track not in self._track_occupancy[frame.video]: return # If this is only instance in track in frame, then remove frame from track. - if ( - len( - list(filter(lambda inst: inst.track == instance.track, frame.instances)) - ) - == 1 - ): + if len(frame.find(track=instance.track)) == 1: self._track_occupancy[frame.video][instance.track].remove( (frame.frame_idx, frame.frame_idx + 1) ) def remove_instance(self, frame: LabeledFrame, instance: Instance): + """Removes instance from frame, updating track occupancy.""" self._track_remove_instance(frame, instance) frame.instances.remove(instance) def add_instance(self, frame: LabeledFrame, instance: Instance): + """Adds instance to frame, updating track occupancy.""" if frame.video not in self._track_occupancy: self._track_occupancy[frame.video] = dict() @@ -576,7 +655,8 @@ def add_instance(self, frame: LabeledFrame, instance: Instance): ) frame.instances.append(instance) - def _make_track_occupany(self, video): + def _make_track_occupany(self, video: Video) -> Dict[Video, RangeList]: + """Build cached track occupancy data.""" frame_idx_map = self._frame_idx_map[video] tracks = dict() @@ -591,8 +671,8 @@ def _make_track_occupany(self, video): def find_track_occupancy( self, video: Video, track: Union[Track, int], frame_range=None - ) -> List[Tuple[LabeledFrame, Instance]]: - """Get instances for a given track. + ) -> List[Instance]: + """Get instances for a given video, track, and range of frames. Args: video: the `Video` @@ -601,7 +681,7 @@ def find_track_occupancy( If specified, only return instances on frames in range. If None, return all instances for given track. Returns: - list of `Instance` objects + List of :class:`Instance` objects. """ frame_range = range(*frame_range) if type(frame_range) == tuple else frame_range @@ -619,7 +699,7 @@ def does_track_match(inst, tr, labeled_frame): return match track_frame_inst = [ - (lf, instance) + instance for lf in self.find(video) for instance in lf.instances if does_track_match(instance, track, lf) @@ -627,9 +707,6 @@ def does_track_match(inst, tr, labeled_frame): ] return track_frame_inst - def find_track_instances(self, *args, **kwargs) -> List[Instance]: - return [inst for lf, inst in self.find_track_occupancy(*args, **kwargs)] - # Methods for suggestions def get_video_suggestions(self, video: Video) -> list: @@ -649,7 +726,7 @@ def get_suggestions(self) -> list: return suggestion_list def get_next_suggestion(self, video, frame_idx, seek_direction=1) -> list: - """Returns a (video, frame_idx) tuple.""" + """Returns a (video, frame_idx) tuple seeking from given frame.""" # make sure we have valid seek_direction if seek_direction not in (-1, 1): return (None, None) @@ -788,7 +865,7 @@ def extend_from( self, new_frames: Union["Labels", List[LabeledFrame]], unify: bool = False ): """ - Merge data from another Labels object or list of LabeledFrames into self. + Merge in data from another Labels object or list of LabeledFrames. Arg: new_frames: the object from which to copy data @@ -823,7 +900,7 @@ def extend_from( # update top level videos/nodes/skeletons/tracks self._update_from_labels(merge=True) - self._update_lookup_cache() + self._build_lookup_caches() return True @@ -876,7 +953,7 @@ def complex_merge_between( # Add any new videos (etc) into top level lists in base base_labels._update_from_labels(merge=True) # Update caches - base_labels._update_lookup_cache() + base_labels._build_lookup_caches() # Merge suggestions and negative anchors cls.merge_container_dicts(base_labels.suggestions, new_labels.suggestions) @@ -924,10 +1001,10 @@ def finish_complex_merge( # Add any new videos (etc) into top level lists in base base_labels._update_from_labels(merge=True) # Update caches - base_labels._update_lookup_cache() + base_labels._build_lookup_caches() @staticmethod - def merge_container_dicts(dict_a, dict_b): + def merge_container_dicts(dict_a: Dict, dict_b: Dict) -> Dict: """Merge data from dict_b into dict_a.""" for key in dict_b.keys(): if key in dict_a: @@ -936,14 +1013,14 @@ def merge_container_dicts(dict_a, dict_b): else: dict_a[key] = dict_b[key] - def merge_matching_frames(self, video=None): + def merge_matching_frames(self, video: Optional[Video] = None): """ - Combine all instances from LabeledFrames that have same frame_idx. + Merge `LabeledFrame`s that are for the same video frame. Args: - video (optional): combine for this video; if None, do all videos + video: combine for this video; if None, do all videos Returns: - none + None """ if video is None: for vid in {lf.video for lf in self.labeled_frames}: @@ -961,12 +1038,14 @@ def to_dict(self, skip_labels: bool = False): JSON and HDF5 serialized datasets. Args: - skip_labels: If True, skip labels serialization and just do the metadata. + skip_labels: If True, skip labels serialization and just do the + metadata. Returns: A dict containing the followings top level keys: * version - The version of the dict/json serialization format. - * skeletons - The skeletons associated with these underlying instances. + * skeletons - The skeletons associated with these underlying + instances. * nodes - The nodes that the skeletons represent. * videos - The videos that that the instances occur on. * labels - The labeled frames @@ -1027,7 +1106,7 @@ def to_json(self): JSON structured string. Returns: - The JSON representaiton of the string. + The JSON representation of the string. """ # Unstructure the data into dicts and dump to JSON. @@ -1047,16 +1126,19 @@ def save_json( Args: labels: The labels dataset to save. filename: The filename to save the data to. - compress: Should the data be zip compressed or not? If True, the JSON will be - compressed using Python's shutil.make_archive command into a PKZIP zip file. If - compress is True then filename will have a .zip appended to it. - save_frame_data: Whether to save the image data for each frame as well. For each - video in the dataset, all frames that have labels will be stored as an imgstore - dataset. If save_frame_data is True then compress will be forced to True since - the archive must contain both the JSON data and image data stored in ImgStores. - frame_data_format: If save_frame_data is True, then this argument is used to set - the data format to use when writing frame data to ImgStore objects. Supported - formats should be: + compress: Whether the data be zip compressed or not? If True, + the JSON will be compressed using Python's shutil.make_archive + command into a PKZIP zip file. If compress is True then + filename will have a .zip appended to it. + save_frame_data: Whether to save the image data for each frame. + For each video in the dataset, all frames that have labels + will be stored as an imgstore dataset. + If save_frame_data is True then compress will be forced to True + since the archive must contain both the JSON data and image + data stored in ImgStores. + frame_data_format: If save_frame_data is True, then this argument + is used to set the data format to use when writing frame + data to ImgStore objects. Supported formats should be: * 'pgm', * 'bmp', @@ -1069,8 +1151,9 @@ def save_json( * 'h264/mkv', * 'avc1/mp4' - Note: 'h264/mkv' and 'avc1/mp4' require separate installation of these codecs - on your system. They are excluded from sLEAP because of their GPL license. + Note: 'h264/mkv' and 'avc1/mp4' require separate installation of + these codecs on your system. They are excluded from SLEAP + because of their GPL license. Returns: None @@ -1132,6 +1215,21 @@ def save_json( def from_json( cls, data: Union[str, dict], match_to: Optional["Labels"] = None ) -> "Labels": + """ + Create instance of class from data in dictionary. + + Method is used by other methods that load from JSON. + + Args: + data: Dictionary, deserialized from JSON. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video`s). + Returns: + A new :class:`Labels` object. + """ # Parse the json string if needed. if type(data) is str: @@ -1228,8 +1326,29 @@ def from_json( @classmethod def load_json( - cls, filename: str, video_callback=None, match_to: Optional["Labels"] = None - ): + cls, + filename: str, + video_callback: Optional[Callable] = None, + match_to: Optional["Labels"] = None, + ) -> "Labels": + """ + Deserialize JSON file as new :class:`Labels` instance. + + Args: + filename: Path to JSON file. + video_callback: A callback function that which can modify + video paths before we try to create the corresponding + :class:`Video` objects. Usually you'll want to pass + a callback created by :method:`Labels.make_video_callback` + or :method:`Labels.make_gui_video_callback`. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video`s). + Returns: + A new :class:`Labels` object. + """ tmp_dir = None @@ -1332,7 +1451,8 @@ def load_json( return labels else: - return load_labels_json_old(data_path=filename, parsed_json=dicts) + frames = load_labels_json_old(data_path=filename, parsed_json=dicts) + return Labels(frames) @staticmethod def save_hdf5( @@ -1345,14 +1465,18 @@ def save_hdf5( Serialize the labels dataset to an HDF5 file. Args: - labels: The Labels dataset to save + labels: The :class:`Labels` dataset to save filename: The file to serialize the dataset to. - append: Whether to append these labeled frames to the file or - not. - save_frame_data: Whether to save the image frame data for any - labeled frame as well. This is useful for uploading the HDF5 for - model training when video files are to large to move. This will only - save video frames that have some labeled instances. + append: Whether to append these labeled frames to the file + or not. + save_frame_data: Whether to save the image frame data for + any labeled frame as well. This is useful for uploading + the HDF5 for model training when video files are to + large to move. This will only save video frames that + have some labeled instances. NOT YET IMPLENTED. + + Raises: + NotImplementedError: If save_frame_data is True. Returns: None @@ -1655,7 +1779,7 @@ def load_hdf5( labels.labeled_frames = frames # Do the stuff that should happen after we have labeled frames - labels._update_lookup_cache() + labels._build_lookup_caches() return labels @@ -1689,19 +1813,23 @@ def save_frame_data_imgstore( self, output_dir: str = "./", format: str = "png", all_labels: bool = False ): """ - Write all labeled frames from all videos to a collection of imgstore datasets. - This only writes frames that have been labeled. Videos without any labeled frames - will be included as empty imgstores. + Write all labeled frames from all videos to imgstore datasets. + + This only writes frames that have been labeled. Videos without + any labeled frames will be included as empty imgstores. Args: - output_dir: - format: The image format to use for the data. png for lossless, jpg for lossy. - Other imgstore formats will probably work as well but have not been tested. + output_dir: Path to directory which will contain imgstores. + format: The image format to use for the data. + Use "png" for lossless, "jpg" for lossy. + Other imgstore formats will probably work as well but + have not been tested. all_labels: Include any labeled frames, not just the frames - we'll use for training (i.e., those with Instances). + we'll use for training (i.e., those with `Instance`s). Returns: - A list of ImgStoreVideo objects that represent the stored frames. + A list of :class:`ImgStoreVideo` objects with the stored + frames. """ # For each label imgstore_vids = [] @@ -1728,6 +1856,7 @@ def save_frame_data_imgstore( @staticmethod def _unwrap_mat_scalar(a): + """Extract single value from nested MATLAB file data.""" if a.shape == (1,): return Labels._unwrap_mat_scalar(a[0]) else: @@ -1735,12 +1864,20 @@ def _unwrap_mat_scalar(a): @staticmethod def _unwrap_mat_array(a): + """Extract list of values from nested MATLAB file data.""" b = a[0][0] c = [Labels._unwrap_mat_scalar(x) for x in b] return c @classmethod - def load_mat(cls, filename): + def load_mat(cls, filename: str) -> "Labels": + """Load LEAP MATLAB file as dataset. + + Args: + filename: Path to csv file. + Returns: + The :class:`Labels` dataset. + """ mat_contents = sio.loadmat(filename) box_path = Labels._unwrap_mat_scalar(mat_contents["boxPath"]) @@ -1795,7 +1932,14 @@ def load_mat(cls, filename): return labels @classmethod - def load_deeplabcut_csv(cls, filename): + def load_deeplabcut_csv(cls, filename: str) -> "Labels": + """Load DeepLabCut csv file as dataset. + + Args: + filename: Path to csv file. + Returns: + The :class:`Labels` dataset. + """ # At the moment we don't need anything from the config file, # but the code to read it is here in case we do in the future. @@ -1874,7 +2018,21 @@ def fix_img_path(img_dir, img_filename): return cls(labels) @classmethod - def make_video_callback(cls, search_paths=None): + def make_video_callback(cls, search_paths: Optional[List] = None) -> Callable: + """ + Create a non-GUI callback for finding missing videos. + + The callback can be used while loading a saved project and + allows the user to find videos which have been moved (or have + paths from a different system). + + Args: + search_paths: If specified, this is a list of paths where + we'll automatically try to find the missing videos. + + Returns: + The callback function. + """ search_paths = search_paths or [] def video_callback(video_list, new_paths=search_paths): @@ -1884,7 +2042,6 @@ def video_callback(video_list, new_paths=search_paths): current_filename = video_item["backend"]["filename"] # check if we can find video if not os.path.exists(current_filename): - is_found = False current_basename = os.path.basename(current_filename) # handle unix, windows, or mixed paths @@ -1903,23 +2060,32 @@ def video_callback(video_list, new_paths=search_paths): if os.path.exists(check_path): # we found the file in a different directory video_item["backend"]["filename"] = check_path - is_found = True break return video_callback @classmethod - def make_gui_video_callback(cls, search_paths): + def make_gui_video_callback(cls, search_paths: Optional[List] = None) -> Callable: + """ + Create a callback with GUI for finding missing videos. + + The callback can be used while loading a saved project and + allows the user to find videos which have been moved (or have + paths from a different system). + + Args: + search_paths: If specified, this is a list of paths where + we'll automatically try to find the missing videos. + + Returns: + The callback function. + """ search_paths = search_paths or [] def gui_video_callback(video_list, new_paths=search_paths): import os from PySide2.QtWidgets import QFileDialog, QMessageBox - has_shown_prompt = ( - False - ) # have we already alerted user about missing files? - basename_list = [] # Check each video @@ -1960,7 +2126,6 @@ def gui_video_callback(video_list, new_paths=search_paths): QMessageBox( text=f"We're unable to locate one or more video files for this project. Please locate {current_filename}." ).exec_() - has_shown_prompt = True current_root, current_ext = os.path.splitext(current_basename) caption = f"Please locate {current_basename}..." @@ -1980,123 +2145,3 @@ def gui_video_callback(video_list, new_paths=search_paths): basename_list.append(current_basename) return gui_video_callback - - -def load_labels_json_old( - data_path: str, - parsed_json: dict = None, - adjust_matlab_indexing: bool = True, - fix_rel_paths: bool = True, -) -> Labels: - """ - Simple utitlity code to load data from Talmo's old JSON format into newer - Labels object. - - Args: - data_path: The path to the JSON file. - parsed_json: The parsed json if already loaded. Save some time if already parsed. - adjust_matlab_indexing: Do we need to adjust indexing from MATLAB. - fix_rel_paths: Fix paths to videos to absolute paths. - - Returns: - A newly constructed Labels object. - """ - if parsed_json is None: - data = json_loads(open(data_path).read()) - else: - data = parsed_json - - videos = pd.DataFrame(data["videos"]) - instances = pd.DataFrame(data["instances"]) - points = pd.DataFrame(data["points"]) - predicted_instances = pd.DataFrame(data["predicted_instances"]) - predicted_points = pd.DataFrame(data["predicted_points"]) - - if adjust_matlab_indexing: - instances.frameIdx -= 1 - points.frameIdx -= 1 - predicted_instances.frameIdx -= 1 - predicted_points.frameIdx -= 1 - - points.node -= 1 - predicted_points.node -= 1 - - points.x -= 1 - predicted_points.x -= 1 - - points.y -= 1 - predicted_points.y -= 1 - - skeleton = Skeleton() - skeleton.add_nodes(data["skeleton"]["nodeNames"]) - edges = data["skeleton"]["edges"] - if adjust_matlab_indexing: - edges = np.array(edges) - 1 - for (src_idx, dst_idx) in edges: - skeleton.add_edge( - data["skeleton"]["nodeNames"][src_idx], - data["skeleton"]["nodeNames"][dst_idx], - ) - - if fix_rel_paths: - for i, row in videos.iterrows(): - p = row.filepath - if not os.path.exists(p): - p = os.path.join(os.path.dirname(data_path), p) - if os.path.exists(p): - videos.at[i, "filepath"] = p - - # Make the video objects - video_objects = {} - for i, row in videos.iterrows(): - if videos.at[i, "format"] == "media": - vid = Video.from_media(videos.at[i, "filepath"]) - else: - vid = Video.from_hdf5( - filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"] - ) - - video_objects[videos.at[i, "id"]] = vid - - # A function to get all the instances for a particular video frame - def get_frame_instances(video_id, frame_idx): - is_in_frame = (points["videoId"] == video_id) & ( - points["frameIdx"] == frame_idx - ) - if not is_in_frame.any(): - return [] - - instances = [] - frame_instance_ids = np.unique(points["instanceId"][is_in_frame]) - for i, instance_id in enumerate(frame_instance_ids): - is_instance = is_in_frame & (points["instanceId"] == instance_id) - instance_points = { - data["skeleton"]["nodeNames"][n]: Point(x, y, visible=v) - for x, y, n, v in zip( - *[points[k][is_instance] for k in ["x", "y", "node", "visible"]] - ) - } - - instance = Instance(skeleton=skeleton, points=instance_points) - instances.append(instance) - - return instances - - # Get the unique labeled frames and construct a list of LabeledFrame objects for them. - frame_keys = list( - { - (videoId, frameIdx) - for videoId, frameIdx in zip(points["videoId"], points["frameIdx"]) - } - ) - frame_keys.sort() - labels = [] - for videoId, frameIdx in frame_keys: - label = LabeledFrame( - video=video_objects[videoId], - frame_idx=frameIdx, - instances=get_frame_instances(videoId, frameIdx), - ) - labels.append(label) - - return Labels(labels) diff --git a/sleap/io/legacy.py b/sleap/io/legacy.py index f1af21222..4340d460c 100644 --- a/sleap/io/legacy.py +++ b/sleap/io/legacy.py @@ -6,10 +6,19 @@ import numpy as np import pandas as pd -from sleap.io.dataset import Labels +from typing import List + +from sleap.util import json_loads from sleap.io.video import Video -from sleap.instance import LabeledFrame, PredictedPoint, PredictedInstance, Track +from sleap.instance import ( + LabeledFrame, + PredictedPoint, + PredictedInstance, + Track, + Point, + Instance, +) from sleap.skeleton import Skeleton @@ -18,10 +27,9 @@ def load_predicted_labels_json_old( parsed_json: dict = None, adjust_matlab_indexing: bool = True, fix_rel_paths: bool = True, -) -> Labels: +) -> List[LabeledFrame]: """ - Simple utitlity code to load data from Talmo's old JSON format into newer - Labels object. This loads the prediced instances + Load predicted instances from Talmo's old JSON format. Args: data_path: The path to the JSON file. @@ -31,7 +39,7 @@ def load_predicted_labels_json_old( fix_rel_paths: Whether to fix paths to videos to absolute paths. Returns: - A newly constructed Labels object. + List of :class:`LabeledFrame`s. """ if parsed_json is None: data = json.loads(open(data_path).read()) @@ -160,4 +168,125 @@ def get_frame_predicted_instances(video_id, frame_idx): ) labels.append(label) - return Labels(labels) + return labels + + +def load_labels_json_old( + data_path: str, + parsed_json: dict = None, + adjust_matlab_indexing: bool = True, + fix_rel_paths: bool = True, +) -> List[LabeledFrame]: + """ + Load predicted instances from Talmo's old JSON format. + + Args: + data_path: The path to the JSON file. + parsed_json: The parsed json if already loaded, so we can save + some time if already parsed. + adjust_matlab_indexing: Whether to adjust indexing from MATLAB. + fix_rel_paths: Whether to fix paths to videos to absolute paths. + + Returns: + A newly constructed Labels object. + """ + if parsed_json is None: + data = json_loads(open(data_path).read()) + else: + data = parsed_json + + videos = pd.DataFrame(data["videos"]) + instances = pd.DataFrame(data["instances"]) + points = pd.DataFrame(data["points"]) + predicted_instances = pd.DataFrame(data["predicted_instances"]) + predicted_points = pd.DataFrame(data["predicted_points"]) + + if adjust_matlab_indexing: + instances.frameIdx -= 1 + points.frameIdx -= 1 + predicted_instances.frameIdx -= 1 + predicted_points.frameIdx -= 1 + + points.node -= 1 + predicted_points.node -= 1 + + points.x -= 1 + predicted_points.x -= 1 + + points.y -= 1 + predicted_points.y -= 1 + + skeleton = Skeleton() + skeleton.add_nodes(data["skeleton"]["nodeNames"]) + edges = data["skeleton"]["edges"] + if adjust_matlab_indexing: + edges = np.array(edges) - 1 + for (src_idx, dst_idx) in edges: + skeleton.add_edge( + data["skeleton"]["nodeNames"][src_idx], + data["skeleton"]["nodeNames"][dst_idx], + ) + + if fix_rel_paths: + for i, row in videos.iterrows(): + p = row.filepath + if not os.path.exists(p): + p = os.path.join(os.path.dirname(data_path), p) + if os.path.exists(p): + videos.at[i, "filepath"] = p + + # Make the video objects + video_objects = {} + for i, row in videos.iterrows(): + if videos.at[i, "format"] == "media": + vid = Video.from_media(videos.at[i, "filepath"]) + else: + vid = Video.from_hdf5( + filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"] + ) + + video_objects[videos.at[i, "id"]] = vid + + # A function to get all the instances for a particular video frame + def get_frame_instances(video_id, frame_idx): + """ """ + is_in_frame = (points["videoId"] == video_id) & ( + points["frameIdx"] == frame_idx + ) + if not is_in_frame.any(): + return [] + + instances = [] + frame_instance_ids = np.unique(points["instanceId"][is_in_frame]) + for i, instance_id in enumerate(frame_instance_ids): + is_instance = is_in_frame & (points["instanceId"] == instance_id) + instance_points = { + data["skeleton"]["nodeNames"][n]: Point(x, y, visible=v) + for x, y, n, v in zip( + *[points[k][is_instance] for k in ["x", "y", "node", "visible"]] + ) + } + + instance = Instance(skeleton=skeleton, points=instance_points) + instances.append(instance) + + return instances + + # Get the unique labeled frames and construct a list of LabeledFrame objects for them. + frame_keys = list( + { + (videoId, frameIdx) + for videoId, frameIdx in zip(points["videoId"], points["frameIdx"]) + } + ) + frame_keys.sort() + labels = [] + for videoId, frameIdx in frame_keys: + label = LabeledFrame( + video=video_objects[videoId], + frame_idx=frameIdx, + instances=get_frame_instances(videoId, frameIdx), + ) + labels.append(label) + + return labels diff --git a/sleap/util.py b/sleap/util.py index 581ea0e09..70eb7cd3e 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -2,6 +2,7 @@ A miscellaneous set of utility functions. Try not to put things in here unless they really have no other place. """ + import os import re @@ -9,8 +10,47 @@ import numpy as np import attr import psutil +import json +import rapidjson + +from typing import Any, Dict, Hashable, Iterable, List, Optional + + +def json_loads(json_str: str) -> Dict: + """ + A simple wrapper around the JSON decoder we are using. + + Args: + json_str: JSON string to decode. + + Returns: + Result of decoding JSON string. + """ + try: + return rapidjson.loads(json_str) + except: + return json.loads(json_str) + + +def json_dumps(d: Dict, filename: str = None): + """ + A simple wrapper around the JSON encoder we are using. + + Args: + d: The dict to write. + filename: The filename to write to. + + Returns: + None + """ + + encoder = rapidjson -from typing import Any, Hashable, Iterable, List, Optional + if filename: + with open(filename, "w") as f: + encoder.dump(d, f, ensure_ascii=False) + else: + return encoder.dumps(d) def attr_to_dtype(cls: Any): diff --git a/tests/info/test_h5.py b/tests/info/test_h5.py index edcfe0208..e93e1bc7b 100644 --- a/tests/info/test_h5.py +++ b/tests/info/test_h5.py @@ -48,9 +48,9 @@ def test_output_matrices(centered_pair_predictions): # Remove all instances from track 13 vid = centered_pair_predictions.videos[0] track = centered_pair_predictions.tracks[13] - lfs_insts = centered_pair_predictions.find_track_occupancy(vid, track) - for lf, instance in lfs_insts: - centered_pair_predictions.remove_instance(lf, instance) + instances = centered_pair_predictions.find_track_occupancy(vid, track) + for instance in instances: + centered_pair_predictions.remove_instance(instance.frame, instance) # Make sure that this now remove empty track occupancy, points = get_occupancy_and_points_matrices( diff --git a/tests/info/test_summary.py b/tests/info/test_summary.py new file mode 100644 index 000000000..2cf76c166 --- /dev/null +++ b/tests/info/test_summary.py @@ -0,0 +1,42 @@ +from sleap.info.summary import StatisticSeries + + +def test_frame_statistics(simple_predictions): + video = simple_predictions.videos[0] + stats = StatisticSeries(simple_predictions) + + x = stats.get_point_count_series(video) + assert len(x) == 2 + assert x[0] == 4 + assert x[1] == 4 + + x = stats.get_point_score_series(video, "sum") + assert len(x) == 2 + assert x[0] == 2.4 + assert x[1] == 5.2 + + x = stats.get_point_score_series(video, "min") + assert len(x) == 2 + assert x[0] == 0.5 + assert x[1] == 1.0 + + x = stats.get_instance_score_series(video, "sum") + assert len(x) == 2 + assert x[0] == 7 + assert x[1] == 9 + + x = stats.get_instance_score_series(video, "min") + assert len(x) == 2 + assert x[0] == 2 + assert x[1] == 3 + + x = stats.get_point_displacement_series(video, "mean") + assert len(x) == 2 + assert x[0] == 0 + assert x[1] == 9.0 + + x = stats.get_point_displacement_series(video, "max") + assert len(x) == 2 + assert len(x) == 2 + assert x[0] == 0 + assert x[1] == 18.0 diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 54fd4781a..b0799e402 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -5,7 +5,8 @@ from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance from sleap.io.video import Video, MediaVideo -from sleap.io.dataset import Labels, load_labels_json_old +from sleap.io.dataset import Labels +from sleap.io.legacy import load_labels_json_old from sleap.gui.suggestions import VideoFrameSuggestions TEST_H5_DATASET = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index d75f28dee..2641b6c8a 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -36,9 +36,8 @@ def check_labels(labels): def test_load_old_json(): - labels = load_predicted_labels_json_old( - "tests/data/json_format_v1/centered_pair.json" - ) + old_json_filename = "tests/data/json_format_v1/centered_pair.json" + labels = Labels(load_predicted_labels_json_old(old_json_filename)) check_labels(labels) diff --git a/tests/test_util.py b/tests/test_util.py index acd194d9e..e72a78bc2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,18 @@ from typing import List, Dict -from sleap.util import attr_to_dtype, frame_list, weak_filename_match +from sleap.util import ( + json_dumps, + json_loads, + attr_to_dtype, + frame_list, + weak_filename_match, +) + + +def test_json(): + original_dict = dict(key=123) + assert original_dict == json_loads(json_dumps(original_dict)) def test_attr_to_dtype(): From 679e230df10de09a9e7dd20f8e028ef579e88a3b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 09:57:55 -0400 Subject: [PATCH 139/176] Better typing and docstrings. --- sleap/io/visuals.py | 91 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 79 insertions(+), 12 deletions(-) diff --git a/sleap/io/visuals.py b/sleap/io/visuals.py index 58f4091fc..04c758865 100644 --- a/sleap/io/visuals.py +++ b/sleap/io/visuals.py @@ -7,7 +7,7 @@ import numpy as np import math from time import time, clock -from typing import List +from typing import List, Tuple from queue import Queue from threading import Thread @@ -26,8 +26,11 @@ def reader(out_q: Queue, video: Video, frames: List[int]): Args: out_q: Queue to send (list of frame indexes, ndarray of frame images) for chunks of video. - video: the `Video` object to read - frames: full list frame indexes we want to read + video: The `Video` object to read. + frames: Full list frame indexes we want to read. + + Returns: + None. """ cv2.setNumThreads(usable_cpu_count()) @@ -66,10 +69,11 @@ def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): """Annotate frame images (draw instances). Args: - in_q: Queue with (list of frame indexes, ndarray of frame images) - out_q: Queue to send annotated images as (images, h, w, channels) ndarray + in_q: Queue with (list of frame indexes, ndarray of frame images). + out_q: Queue to send annotated images as + (images, h, w, channels) ndarray. labels: the `Labels` object from which to get data for annotating. - video_idx: index of `Video` in `labels.videos` list + video_idx: index of `Video` in `labels.videos` list. Returns: None. @@ -109,7 +113,13 @@ def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): out_q.put(_sentinel) -def writer(in_q: Queue, progress_queue: Queue, filename: str, fps: int, img_w_h: tuple): +def writer( + in_q: Queue, + progress_queue: Queue, + filename: str, + fps: float, + img_w_h: Tuple[int, int], +): """Write annotated images to video. Args: @@ -168,7 +178,19 @@ def save_labeled_video( fps: int = 15, gui_progress: bool = False, ): - """Function to generate and save video with annotations.""" + """Function to generate and save video with annotations. + + Args: + filename: Output filename. + labels: The dataset from which to get data. + video: The source :class:`Video` we want to annotate. + frames: List of frames to include in output video. + fps: Frames per second for output video. + gui_progress: Whether to show Qt GUI progress dialog. + + Returns: + None. + """ output_size = (video.height, video.width) print(f"Writing video with {len(frames)} frame images...") @@ -224,7 +246,8 @@ def save_labeled_video( print(f"Done in {elapsed} s, fps = {fps}.") -def img_to_cv(img): +def img_to_cv(img: np.ndarray) -> np.ndarray: + """Prepares frame image as needed for opencv.""" # Convert RGB to BGR for OpenCV if img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) @@ -234,17 +257,44 @@ def img_to_cv(img): return img -def get_frame_image(video_frame, video_idx, frame_idx, labels): +def get_frame_image( + video_frame: np.ndarray, video_idx: int, frame_idx: int, labels: Labels +) -> np.ndarray: + """Returns single annotated frame image. + + Args: + video_frame: The ndarray of the frame image. + video_idx: Index of video in :attribute:`Labels.videos` list. + frame_idx: Index of frame in video. + labels: The dataset from which to get data. + + Returns: + ndarray of frame image with visual annotations added. + """ img = img_to_cv(video_frame) plot_instances_cv(img, video_idx, frame_idx, labels) return img def _point_int_tuple(point): + """Returns (x, y) tuple from :class:`Point`.""" return int(point.x), int(point.y) -def plot_instances_cv(img, video_idx, frame_idx, labels): +def plot_instances_cv( + img: np.ndarray, video_idx: int, frame_idx: int, labels: Labels +) -> np.ndarray: + """Adds visuals annotations to single frame image. + + Args: + img: The ndarray of the frame image. + video_idx: Index of video in :attribute:`Labels.videos` list. + frame_idx: Index of frame in video. + labels: The dataset from which to get data. + + Returns: + ndarray of frame image with visual annotations added. + """ cmap = [ [0, 114, 189], [217, 83, 25], @@ -273,7 +323,24 @@ def plot_instances_cv(img, video_idx, frame_idx, labels): plot_instance_cv(img, instance, inst_color) -def plot_instance_cv(img, instance, color, marker_radius=4): +def plot_instance_cv( + img: np.ndarray, + instance: "Instance", + color: Tuple[int, int, int], + marker_radius: float = 4, +) -> np.ndarray: + """ + Add visual annotations for single instance. + + Args: + img: The ndarray of the frame image. + instance: The :class:`Instance` to add to frame image. + color: (r, g, b) color for this instance. + marker_radius: Radius of marker for instance points (nodes). + + Returns: + ndarray of frame image with visual annotations for instance added. + """ # RGB -> BGR for cv2 cv_color = color[::-1] From f21c96bcb1ab627c117aed297b6cdcdaa1618d45 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 11:02:37 -0400 Subject: [PATCH 140/176] Better typing and docstrings. --- sleap/io/video.py | 255 ++++++++++++++++++++++++++++------------------ 1 file changed, 155 insertions(+), 100 deletions(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index ef5096be3..68ebf475a 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -11,7 +11,7 @@ import cattr import logging -from typing import Iterable, Union, List +from typing import Iterable, Union, List, Tuple logger = logging.getLogger(__name__) @@ -19,15 +19,16 @@ @attr.s(auto_attribs=True, cmp=False) class HDF5Video: """ - Video data stored as 4D datasets in HDF5 files can be imported into - the sLEAP system with this class. + Video data stored as 4D datasets in HDF5 files. Args: - filename: The name of the HDF5 file where the dataset with video data is stored. + filename: The name of the HDF5 file where the dataset with video data + is stored. dataset: The name of the HDF5 dataset where the video data is stored. file_h5: The h5.File object that the underlying dataset is stored. dataset_h5: The h5.Dataset object that the underlying data is stored. - input_format: A string value equal to either "channels_last" or "channels_first". + input_format: A string value equal to either "channels_last" or + "channels_first". This specifies whether the underlying video data is stored as: * "channels_first": shape = (frames, channels, width, height) @@ -41,6 +42,7 @@ class HDF5Video: convert_range: bool = attr.ib(default=True) def __attrs_post_init__(self): + """Called by attrs after __init__().""" # Handle cases where the user feeds in h5.File objects instead of filename if isinstance(self.filename, h5.File): @@ -72,6 +74,7 @@ def __attrs_post_init__(self): @input_format.validator def check(self, attribute, value): + """Called by attrs to validates input format.""" if value not in ["channels_first", "channels_last"]: raise ValueError(f"HDF5Video input_format={value} invalid.") @@ -84,15 +87,15 @@ def check(self, attribute, value): self.__width_idx = 2 self.__height_idx = 1 - def matches(self, other): + def matches(self, other: "HDF5Video") -> bool: """ - Check if attributes match. + Check if attributes match those of another video. Args: - other: The instance to compare with. + other: The other video to compare with. Returns: - True if attributes match, False otherwise + True if attributes match, False otherwise. """ return ( self.filename == other.filename @@ -106,25 +109,30 @@ def matches(self, other): @property def frames(self): + """See :class:`Video`.""" return self.__dataset_h5.shape[0] @property def channels(self): + """See :class:`Video`.""" return self.__dataset_h5.shape[self.__channel_idx] @property def width(self): + """See :class:`Video`.""" return self.__dataset_h5.shape[self.__width_idx] @property def height(self): + """See :class:`Video`.""" return self.__dataset_h5.shape[self.__height_idx] @property def dtype(self): + """See :class:`Video`.""" return self.__dataset_h5.dtype - def get_frame(self, idx): # -> np.ndarray: + def get_frame(self, idx) -> np.ndarray: """ Get a frame from the underlying HDF5 video data. @@ -148,18 +156,19 @@ def get_frame(self, idx): # -> np.ndarray: @attr.s(auto_attribs=True, cmp=False) class MediaVideo: """ - Video data stored in traditional media formats readable by FFMPEG can be loaded - with this class. This class provides bare minimum read only interface on top of + Video data stored in traditional media formats readable by FFMPEG + + This class provides bare minimum read only interface on top of OpenCV's VideoCapture class. Args: filename: The name of the file (.mp4, .avi, etc) grayscale: Whether the video is grayscale or not. "auto" means detect - based on first frame. + based on first frame. + bgr: Whether color channels ordered as (blue, green, red). """ filename: str = attr.ib() - # grayscale: bool = attr.ib(default=None, converter=bool) grayscale: bool = attr.ib() bgr: bool = attr.ib(default=True) _detect_grayscale = False @@ -203,15 +212,15 @@ def __test_frame(self): # Return stored test frame return self._test_frame_ - def matches(self, other): + def matches(self, other: "MediaVideo") -> bool: """ - Check if attributes match. + Check if attributes match those of another video. Args: - other: The instance to compare with. + other: The other video to compare with. Returns: - True if attributes match, False otherwise + True if attributes match, False otherwise. """ return ( self.filename == other.filename @@ -220,7 +229,8 @@ def matches(self, other): ) @property - def fps(self): + def fps(self) -> float: + """Returns frames per second of video.""" return self.__reader.get(cv2.CAP_PROP_FPS) # The properties and methods below complete our contract with the @@ -228,14 +238,17 @@ def fps(self): @property def frames(self): + """See :class:`Video`.""" return int(self.__reader.get(cv2.CAP_PROP_FRAME_COUNT)) @property def frames_float(self): + """See :class:`Video`.""" return self.__reader.get(cv2.CAP_PROP_FRAME_COUNT) @property def channels(self): + """See :class:`Video`.""" if self.grayscale: return 1 else: @@ -243,17 +256,21 @@ def channels(self): @property def width(self): + """See :class:`Video`.""" return self.__test_frame.shape[1] @property def height(self): + """See :class:`Video`.""" return self.__test_frame.shape[0] @property def dtype(self): + """See :class:`Video`.""" return self.__test_frame.dtype - def get_frame(self, idx, grayscale=None): + def get_frame(self, idx: int, grayscale: bool = None) -> np.ndarray: + """See :class:`Video`.""" if self.__reader.get(cv2.CAP_PROP_POS_FRAMES) != idx: self.__reader.set(cv2.CAP_PROP_POS_FRAMES, idx) @@ -308,15 +325,15 @@ def __attrs_post_init__(self): # The properties and methods below complete our contract with the # higher level Video interface. - def matches(self, other): + def matches(self, other: "NumpyVideo") -> np.ndarray: """ - Check if attributes match. + Check if attributes match those of another video. Args: - other: The instance to comapare with. + other: The other video to compare with. Returns: - True if attributes match, False otherwise + True if attributes match, False otherwise. """ return np.all(self.__data == other.__data) @@ -347,17 +364,22 @@ def get_frame(self, idx): @attr.s(auto_attribs=True, cmp=False) class ImgStoreVideo: """ - Video data stored as an ImgStore dataset. See: https://github.com/loopbio/imgstore - This class is just a lightweight wrapper for reading such datasets as videos sources - for sLEAP. + Video data stored as an ImgStore dataset. + + See: https://github.com/loopbio/imgstore + This class is just a lightweight wrapper for reading such datasets as + video sources for SLEAP. Args: filename: The name of the file or directory to the imgstore. - index_by_original: ImgStores are great for storing a collection of frame - selected frames from an larger video. If the index_by_original is set to - True than the get_frame function will accept the original frame numbers of - from original video. If False, then it will accept the frame index from the - store directly. + index_by_original: ImgStores are great for storing a collection of + selected frames from an larger video. If the index_by_original is + set to True then the get_frame function will accept the original + frame numbers of from original video. If False, then it will + accept the frame index from the store directly. + Default to True so that we can use an ImgStoreVideo in a dataset + to replace another video without having to update all the frame + indices on :class:`LabeledFrame`s in the dataset. """ filename: str = attr.ib(default=None) @@ -437,14 +459,16 @@ def height(self): def dtype(self): return self.__img.dtype - def get_frame(self, frame_number) -> np.ndarray: + def get_frame(self, frame_number: int) -> np.ndarray: """ Get a frame from the underlying ImgStore video data. Args: - frame_num: The number of the frame to get. If index_by_original is set to True, - then this number should actually be a frame index withing the imgstore. That is, - if there are 4 frames in the imgstore, this number shoulde be from 0 to 3. + frame_number: The number of the frame to get. If + index_by_original is set to True, then this number should + actually be a frame index within the imgstore. That is, + if there are 4 frames in the imgstore, this number should be + be from 0 to 3. Returns: The numpy.ndarray representing the video frame data. @@ -508,37 +532,40 @@ def close(self): @attr.s(auto_attribs=True, cmp=False) class Video: """ - The top-level interface to any Video data used by sLEAP is represented by - the :class:`.Video` class. This class provides a common interface for - various supported video data backends. It provides the bare minimum of - properties and methods that any video data needs to support in order to - function with other sLEAP components. This interface currently only supports - reading of video data, there is no write support. Unless one is creating a new video + The top-level interface to any Video data used by SLEAP. + + This class provides a common interface for various supported video data + backends. It provides the bare minimum of properties and methods that + any video data needs to support in order to function with other SLEAP + components. This interface currently only supports reading of video + data, there is no write support. Unless one is creating a new video backend, this class should be instantiated from its various class methods for different formats. For example: - >>> video = Video.from_hdf5(filename='test.h5', dataset='box') - >>> video = Video.from_media(filename='test.mp4') + >>> video = Video.from_hdf5(filename="test.h5", dataset="box") + >>> video = Video.from_media(filename="test.mp4") Or we can use auto-detection based on filename: - >>> video = Video.from_filename(filename='test.mp4') + >>> video = Video.from_filename(filename="test.mp4") Args: - backend: A backend is and object that implements the following basic - required methods and properties + backend: A backend is an object that implements the following basic + required methods and properties * Properties * :code:`frames`: The number of frames in the video - * :code:`channels`: The number of channels in the video (e.g. 1 for grayscale, 3 for RGB) + * :code:`channels`: The number of channels in the video + (e.g. 1 for grayscale, 3 for RGB) * :code:`width`: The width of each frame in pixels * :code:`height`: The height of each frame in pixels * Methods - * :code:`get_frame(frame_index: int) -> np.ndarray(shape=(width, height, channels)`: - Get a single frame from the underlying video data + * :code:`get_frame(frame_index: int) -> np.ndarray`: + Get a single frame from the underlying video data with + output shape=(width, height, channels). """ @@ -550,11 +577,14 @@ def __getattr__(self, item): @property def num_frames(self) -> int: - """The number of frames in the video. Just an alias for frames property.""" + """ + The number of frames in the video. Just an alias for frames property. + """ return self.frames @property - def shape(self): + def shape(self) -> Tuple[int, int, int, int]: + """ Returns (frame count, height, width, channels).""" return (self.frames, self.height, self.width, self.channels) def __str__(self): @@ -590,7 +620,8 @@ def get_frames(self, idxs: Union[int, Iterable[int]]) -> np.ndarray: idxs: An iterable object that contains the indices of frames. Returns: - The requested video frames with shape (len(idxs), width, height, channels) + The requested video frames with shape + (len(idxs), width, height, channels) """ if np.isscalar(idxs): idxs = [idxs] @@ -609,16 +640,18 @@ def from_hdf5( filename: Union[str, h5.File] = None, input_format: str = "channels_last", convert_range: bool = True, - ): + ) -> "Video": """ - Create an instance of a video object from an HDF5 file and dataset. This - is a helper method that invokes the HDF5Video backend. + Create an instance of a video object from an HDF5 file and dataset. + + This is a helper method that invokes the HDF5Video backend. Args: - dataset: The name of the dataset or and h5.Dataset object. If filename is - h5.File, dataset must be a str of the dataset name. + dataset: The name of the dataset or and h5.Dataset object. If + filename is h5.File, dataset must be a str of the dataset name. filename: The name of the HDF5 file or and open h5.File object. - input_format: Whether the data is oriented with "channels_first" or "channels_last" + input_format: Whether the data is oriented with "channels_first" + or "channels_last" convert_range: Whether we should convert data to [0, 255]-range Returns: @@ -634,12 +667,14 @@ def from_hdf5( return cls(backend=backend) @classmethod - def from_numpy(cls, filename, *args, **kwargs): + def from_numpy(cls, filename: str, *args, **kwargs) -> "Video": """ Create an instance of a video object from a numpy array. Args: filename: The numpy array or the name of the file + args: Arguments to pass to :class:`NumpyVideo` + kwargs: Arguments to pass to :class:`NumpyVideo` Returns: A Video object with a NumpyVideo backend @@ -649,12 +684,16 @@ def from_numpy(cls, filename, *args, **kwargs): return cls(backend=backend) @classmethod - def from_media(cls, filename: str, *args, **kwargs): + def from_media(cls, filename: str, *args, **kwargs) -> "Video": """ - Create an instance of a video object from a typical media file (e.g. .mp4, .avi). + Create an instance of a video object from a typical media file. + + For example, mp4, avi, or other types readable by FFMPEG. Args: filename: The name of the file + args: Arguments to pass to :class:`MediaVideo` + kwargs: Arguments to pass to :class:`MediaVideo` Returns: A Video object with a MediaVideo backend @@ -664,20 +703,25 @@ def from_media(cls, filename: str, *args, **kwargs): return cls(backend=backend) @classmethod - def from_filename(cls, filename: str, *args, **kwargs): + def from_filename(cls, filename: str, *args, **kwargs) -> "Video": """ - Create an instance of a video object from a filename, auto-detecting the backend. + Create an instance of a video object, auto-detecting the backend. Args: - filename: The path to the video filename. Currently supported types are: + filename: The path to the video filename. + Currently supported types are: + + * Media Videos - AVI, MP4, etc. handled by OpenCV directly + * HDF5 Datasets - .h5 files + * Numpy Arrays - npy files + * imgstore datasets - produced by loopbio's Motif recording + system. See: https://github.com/loopbio/imgstore. - * Media Videos - AVI, MP4, etc. handled by OpenCV directly - * HDF5 Datasets - .h5 files - * Numpy Arrays - npy files - * imgstore datasets - produced by loopbio's Motif recording system. See: https://github.com/loopbio/imgstore. + args: Arguments to pass to :class:`NumpyVideo` + kwargs: Arguments to pass to :class:`NumpyVideo` Returns: - A Video object with the detected backend + A Video object with the detected backend. """ filename = Video.fixup_path(filename) @@ -696,27 +740,27 @@ def from_filename(cls, filename: str, *args, **kwargs): @classmethod def imgstore_from_filenames( cls, filenames: list, output_filename: str, *args, **kwargs - ): - """Create an imagestore from a list of image files. + ) -> "Video": + """Create an imgstore from a list of image files. Args: filenames: List of filenames for the image files. - output_filename: Filename for the imagestore to create. + output_filename: Filename for the imgstore to create. Returns: - A `Video` object for the new imagestore. + A `Video` object for the new imgstore. """ # get the image size from the first file first_img = cv2.imread(filenames[0], flags=cv2.IMREAD_COLOR) img_shape = first_img.shape - # create the imagestore + # create the imgstore store = imgstore.new_for_format( "png", mode="w", basedir=output_filename, imgshape=img_shape ) - # read each frame and write it to the imagestore + # read each frame and write it to the imgstore # unfortunately imgstore doesn't let us just add the file for i, img_filename in enumerate(filenames): img = cv2.imread(img_filename, flags=cv2.IMREAD_COLOR) @@ -727,33 +771,33 @@ def imgstore_from_filenames( # Return an ImgStoreVideo object referencing this new imgstore. return cls(backend=ImgStoreVideo(filename=output_filename)) - @classmethod - def to_numpy(cls, frame_data: np.array, file_name: str): - np.save(file_name, frame_data, "w") - def to_imgstore( self, - path, + path: str, frame_numbers: List[int] = None, format: str = "png", index_by_original: bool = True, - ): + ) -> "Video": """ - Read frames from an arbitrary video backend and store them in a loopbio imgstore. + Converts frames from arbitrary video backend to ImgStoreVideo. + This should facilitate conversion of any video to a loopbio imgstore. Args: path: Filename or directory name to store imgstore. - frame_numbers: A list of frame numbers from the video to save. If None save - the entire video. - format: By default it will create a DirectoryImgStore with lossless PNG format. - Unless the frame_indices = None, in which case, it will default to 'mjpeg/avi' - format for video. + frame_numbers: A list of frame numbers from the video to save. + If None save the entire video. + format: By default it will create a DirectoryImgStore with lossless + PNG format unless the frame_indices = None, in which case, + it will default to 'mjpeg/avi' format for video. index_by_original: ImgStores are great for storing a collection of - selected frames from an larger video. If the index_by_original is set to - True than the get_frame function will accept the original frame numbers of - from original video. If False, then it will accept the frame index from the - store directly. + selected frames from an larger video. If the index_by_original + is set to True then the get_frame function will accept the + original frame numbers of from original video. If False, + then it will accept the frame index from the store directly. + Default to True so that we can use an ImgStoreVideo in a + dataset to replace another video without having to update + all the frame indices on :class:`LabeledFrame`s in the dataset. Returns: A new Video object that references the imgstore. @@ -813,7 +857,7 @@ def to_imgstore( @staticmethod def cattr(): """ - Return a cattr converter for serialiazing/deserializing Video objects. + Returns a cattr converter for serialiazing/deserializing Video objects. Returns: A cattr converter. @@ -839,18 +883,29 @@ def fixup_video(x, cl): return vid_cattr @staticmethod - def fixup_path(path, raise_error=False) -> str: + def fixup_path(path: str, raise_error: bool = False) -> str: """ - Given a path to a video try to find it. This is attempt to make the paths - serialized for different video objects portabls across multiple computers. - The default behaviour is to store whatever path is stored on the backend - object. If this is an absolute path it is almost certainly wrong when - transfered when the object is created on another computer. We try to - find the video by looking in the current working directory as well. + Tries to locate video if the given path doesn't work. + + Given a path to a video try to find it. This is attempt to make the + paths serialized for different video objects portable across multiple + computers. The default behavior is to store whatever path is stored + on the backend object. If this is an absolute path it is almost + certainly wrong when transferred when the object is created on + another computer. We try to find the video by looking in the current + working directory as well. + + Note that when loading videos during the process of deserializing a + saved :class:`Labels` dataset, it's usually preferable to fix video + paths using a `video_callback`. Args: path: The path the video asset. + Raises: + FileNotFoundError: If file still cannot be found and raise_error + is True. + Returns: The fixed up path """ From 544a0895461f449229b9151db82883c39707578a Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 11:05:07 -0400 Subject: [PATCH 141/176] Fixed docstring arg name. --- sleap/skeleton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 1f3c21e98..f5c0d6381 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -761,7 +761,7 @@ def has_nodes(self, names: Iterable[str]) -> bool: Check whether the skeleton has a list of nodes. Args: - name: The list names of the nodes to check for. + names: The list names of the nodes to check for. Returns: True for yes, False for no. From 2d1961828540992f54e6656b493fab1db98d6237 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 11:07:31 -0400 Subject: [PATCH 142/176] Fixed docstring arg name. --- sleap/io/video.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sleap/io/video.py b/sleap/io/video.py index 68ebf475a..8ea759410 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -901,6 +901,7 @@ def fixup_path(path: str, raise_error: bool = False) -> str: Args: path: The path the video asset. + raise_error: Whether to raise error if we cannot find video. Raises: FileNotFoundError: If file still cannot be found and raise_error From e62a4ad1ed54787029fe6fe0b09daba096bdf1c4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 11:12:09 -0400 Subject: [PATCH 143/176] Fixed syntax error (function name). --- sleap/info/write_tracking_h5.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 17a192b8c..896def438 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -140,7 +140,7 @@ def write_occupancy_file( with h5.File(output_path, "w") as f: for key, val in data_dict.items(): if isinstance(val, np.ndarray): - print(f"key: {val.shape}") + print(f"{key}: {val.shape}") if transpose: # Transpose since MATLAB expects column-major @@ -155,6 +155,7 @@ def write_occupancy_file( key, data=val, compression="gzip", compression_opts=9 ) else: + print(f"{key}: {len(val)}") f.create_dataset(key, data=val) print(f"Saved as {output_path}") @@ -177,7 +178,7 @@ def main(labels: Labels, output_path: str, all_frames: bool = True): """ track_names = get_tracks_as_np_strings(labels) - occupancy_matrix, locations_matrix = get_occupancy_and_predictions_matrices( + occupancy_matrix, locations_matrix = get_occupancy_and_points_matrices( labels, all_frames ) From 34d55d436894bd5ca4e3c6b0921b79d51262087d Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 14:21:46 -0400 Subject: [PATCH 144/176] Better typing and docstrings. --- sleap/gui/merge.py | 110 ++++++++++++++++++++++++++++++++++++---- sleap/io/dataset.py | 9 ++-- tests/gui/test_merge.py | 5 ++ 3 files changed, 111 insertions(+), 13 deletions(-) create mode 100644 tests/gui/test_merge.py diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py index 8aa687241..cbaa20c29 100644 --- a/sleap/gui/merge.py +++ b/sleap/gui/merge.py @@ -4,7 +4,7 @@ import attr -from typing import List +from typing import Dict, List from sleap.instance import LabeledFrame from sleap.io.dataset import Labels @@ -18,7 +18,25 @@ class MergeDialog(QtWidgets.QDialog): + """ + Dialog window for complex merging of two SLEAP datasets. + + This will immediately merge any labeled frames that can be cleanly merged, + show summary of merge and prompt user about how to handle merge conflict, + and then finish merge (resolving conflicts as the user requested). + """ + def __init__(self, base_labels: Labels, new_labels: Labels, *args, **kwargs): + """ + Creates merge dialog and begins merging. + + Args: + base_labels: The base dataset into which we're inserting data. + new_labels: New dataset from which we're getting data to insert. + + Returns: + None. + """ super(MergeDialog, self).__init__(*args, **kwargs) @@ -50,9 +68,11 @@ def __init__(self, base_labels: Labels, new_labels: Labels, *args, **kwargs): merge_table = MergeTable(merged) layout.addWidget(merge_table) - conflict_text = ( - "There are no conflicts." if not self.extra_base else "Merge conflicts:" - ) + if not self.extra_base: + conflict_text = "There are no conflicts." + else: + conflict_text = "Merge conflicts:" + conflict_label = QtWidgets.QLabel(conflict_text) layout.addWidget(conflict_label) @@ -80,6 +100,20 @@ def __init__(self, base_labels: Labels, new_labels: Labels, *args, **kwargs): self.setLayout(layout) def finishMerge(self): + """ + Finishes merge process, possibly resolving conflicts. + + This is connected to `accepted` signal. + + Args: + None. + + Raises: + ValueError: If no valid merge method was selected in dialog. + + Returns: + None. + """ merge_method = self.merge_method.currentText() if merge_method == USE_BASE_STRING: Labels.finish_complex_merge(self.base_labels, self.extra_base) @@ -94,12 +128,36 @@ def finishMerge(self): class ConflictTable(QtWidgets.QTableView): - def __init__(self, *args, **kwargs): + """ + Qt table view for summarizing merge conflicts. + + Arguments are passed through to the table view object. + + The two lists of `LabeledFrame`s should be correlated (idx in one will + match idx of the conflicting frame in other). + + Args: + base_labels: The base dataset. + extra_base: `LabeledFrame`s from base that conflicted. + extra_new: `LabeledFrame`s from new dataset that conflicts. + """ + + def __init__( + self, + base_labels: Labels, + extra_base: List[LabeledFrame], + extra_new: List[LabeledFrame], + ): super(ConflictTable, self).__init__() - self.setModel(ConflictTableModel(*args, **kwargs)) + self.setModel(ConflictTableModel(base_labels, extra_base, extra_new)) class ConflictTableModel(QtCore.QAbstractTableModel): + """Qt table model for summarizing merge conflicts. + + See :class:`ConflictTable`. + """ + _props = ["video", "frame", "base", "new"] def __init__( @@ -114,6 +172,7 @@ def __init__( self.extra_new = extra_new def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Required by Qt.""" if role == QtCore.Qt.DisplayRole and index.isValid(): idx = index.row() prop = self._props[index.column()] @@ -131,14 +190,17 @@ def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): return None def rowCount(self, *args): + """Required by Qt.""" return len(self.extra_base) def columnCount(self, *args): + """Required by Qt.""" return len(self._props) def headerData( self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole ): + """Required by Qt.""" if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -148,15 +210,30 @@ def headerData( class MergeTable(QtWidgets.QTableView): - def __init__(self, *args, **kwargs): + """ + Qt table view for summarizing cleanly merged frames. + + Arguments are passed through to the table view object. + + Args: + merged: The frames that were cleanly merged. + See :method:`Labels.complex_merge_between` for details. + """ + + def __init__(self, merged, *args, **kwargs): super(MergeTable, self).__init__() - self.setModel(MergeTableModel(*args, **kwargs)) + self.setModel(MergeTableModel(merged)) class MergeTableModel(QtCore.QAbstractTableModel): + """Qt table model for summarizing merge conflicts. + + See :class:`MergeTable`. + """ + _props = ["video", "frame", "merged instances"] - def __init__(self, merged: List[List["Instance"]]): + def __init__(self, merged: Dict["Video", Dict[int, List["Instance"]]]): super(MergeTableModel, self).__init__() self.merged = merged @@ -172,6 +249,7 @@ def __init__(self, merged: List[List["Instance"]]): ) def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Required by Qt.""" if role == QtCore.Qt.DisplayRole and index.isValid(): idx = index.row() prop = self._props[index.column()] @@ -187,14 +265,17 @@ def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): return None def rowCount(self, *args): + """Required by Qt.""" return len(self.data_table) def columnCount(self, *args): + """Required by Qt.""" return len(self._props) def headerData( self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole ): + """Required by Qt.""" if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] @@ -203,7 +284,16 @@ def headerData( return None -def show_instance_type_counts(instance_list): +def show_instance_type_counts(instance_list: List["Instance"]) -> str: + """ + Returns string of instance counts to show in table. + + Args: + instance_list: The list of instances to count. + + Returns: + String with number of predicted instances and number of user instances. + """ prediction_count = len( list(filter(lambda inst: hasattr(inst, "score"), instance_list)) ) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index fb33847c5..974e9e5d2 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -932,9 +932,12 @@ def complex_merge_between( new_labels with *matching* objects from base Returns: - tuple of two lists of `LabeledFrame`s - * data from base that conflicts - * data from new that conflicts + tuple of three items: + * Dictionary, keys are :class:`Video`, values are + dictionary in which keys are frame index (int) + and value is list of :class:`Instance`s + * list of conflicting :class:`Instance`s from base + * list of conflicting :class:`Instance`s from new frames """ # If unify, we want to replace objects in the frames with # corresponding objects from the current labels. diff --git a/tests/gui/test_merge.py b/tests/gui/test_merge.py new file mode 100644 index 000000000..c625c0def --- /dev/null +++ b/tests/gui/test_merge.py @@ -0,0 +1,5 @@ +from sleap.gui.merge import show_instance_type_counts + + +def test_count_string(simple_predictions): + assert show_instance_type_counts(simple_predictions[0]) == 2 From c69337fdd6554ec38fd0682733096c81b88cd5ff Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 16:57:31 -0400 Subject: [PATCH 145/176] Better typing and docstrings. --- sleap/gui/formbuilder.py | 72 +++++++++--- sleap/gui/importvideos.py | 3 +- sleap/gui/multicheck.py | 31 +++-- sleap/gui/shortcuts.py | 60 ++++++++-- sleap/gui/slider.py | 235 ++++++++++++++++++++++++-------------- sleap/gui/video.py | 2 +- sleap/util.py | 17 +++ 7 files changed, 295 insertions(+), 125 deletions(-) diff --git a/sleap/gui/formbuilder.py b/sleap/gui/formbuilder.py index 875996e6a..83591421d 100644 --- a/sleap/gui/formbuilder.py +++ b/sleap/gui/formbuilder.py @@ -1,4 +1,5 @@ -"""Module for creating a form from a yaml file. +""" +Module for creating a form from a yaml file. Example: >>> widget = YamlFormWidget(yaml_file="example.yaml") @@ -11,6 +12,8 @@ import yaml +from typing import Any, Dict, List, Optional + from PySide2 import QtWidgets, QtCore @@ -78,7 +81,7 @@ class FormBuilderLayout(QtWidgets.QFormLayout): Custom QFormLayout which populates itself from list of form fields. Args: - items_to_create: list which gets passed to get_form_data() + items_to_create: list which gets passed to :method:`get_form_data` (see there for details about format) """ @@ -113,7 +116,7 @@ def set_form_data(self, data: dict): """Set specified user-editable data. Args: - data (dict): key should match field name + data: dictionary of datay, key should match field name """ widgets = self.fields for name, val in data.items(): @@ -126,7 +129,7 @@ def set_form_data(self, data: dict): # print(f"no {name} widget found") @staticmethod - def set_widget_value(widget, val): + def set_widget_value(widget: QtWidgets.QWidget, val): """Set value for specific widget.""" # if widget.property("field_data_type") == "sci": # val = str(val) @@ -145,11 +148,13 @@ def set_widget_value(widget, val): widget.repaint() @staticmethod - def get_widget_value(widget): - """Get value of form field (using whichever method appropriate for widget). + def get_widget_value(widget: QtWidgets.QWidget) -> Any: + """Returns value of form field. + + This determines the method appropriate for the type of widget. Args: - widget: subclass of QtWidget + widget: The widget for which to return value. Returns: value (can be bool, numeric, string, or None) """ @@ -174,26 +179,25 @@ def get_widget_value(widget): val = None if val == "None" else val return val - def build_form(self, items_to_create): - """Add widgets to form layout for each item in items_to_create. + def build_form(self, items_to_create: List[Dict[str, Any]]): + """Adds widgets to form layout for each item in items_to_create. Args: - items_to_create: list of dicts with fields + items_to_create: list of dictionaries with fields: * name: used as key when we return form data as dict * label: string to show in form * type: supports double, int, bool, list, button, stack * default: default value for form field - * [options]: comma separated list of options, used for list or stack + * [options]: comma separated list of options, + used for list or stack field-types * for stack, array of dicts w/ form data for each stack page - Note: a "stack" has a dropdown menu that determines which stack page to show + A "stack" has a dropdown menu that determines which stack page to show. Returns: None. """ for item in items_to_create: - field = None - # double: show spinbox (number w/ up/down controls) if item["type"] == "double": field = QtWidgets.QDoubleSpinBox() @@ -264,7 +268,10 @@ def build_form(self, items_to_create): if item["type"].split("_")[0] == "file": self.addRow("", self._make_file_button(item, field)) - def _make_file_button(self, item, field): + def _make_file_button( + self, item: Dict, field: QtWidgets.QWidget + ) -> QtWidgets.QPushButton: + """Creates the button for a file_* field-type.""" file_button = QtWidgets.QPushButton("Select " + item["label"]) if item["type"].split("_")[-1] == "open": @@ -296,6 +303,17 @@ def select_file(*args, x=field): class StackBuilderWidget(QtWidgets.QWidget): + """ + A custom widget that shows different subforms depending on menu selection. + + Args: + stack_data: Dictionary for field from `items_to_create`. + The "options" key will give the list of options to show in + menu. Each of the "options" will also be the key of a dictionary + within stack_data that has the same structure as the dictionary + passed to :method:`FormBuilderLayout.build_form()`. + """ + def __init__(self, stack_data, *args, **kwargs): super(StackBuilderWidget, self).__init__(*args, **kwargs) @@ -337,19 +355,35 @@ def __init__(self, stack_data, *args, **kwargs): self.setLayout(multi_layout) def value(self): + """Returns value of menu.""" return self.combo_box.currentText() def get_data(self): + """Returns value from currently shown subform.""" return self.page_layouts[self.value()].get_form_data() class FieldComboWidget(QtWidgets.QComboBox): + """ + A custom ComboBox widget with method to easily add set of options. + """ + def __init__(self, *args, **kwargs): super(FieldComboWidget, self).__init__(*args, **kwargs) self.setSizeAdjustPolicy(QtWidgets.QComboBox.AdjustToContents) self.setMinimumContentsLength(3) - def set_options(self, options_list, select_item=None): + def set_options(self, options_list: List[str], select_item: Optional[str] = None): + """ + Sets list of menu options. + + Args: + options_list: List of items (strings) to show in menu. + select_item: Item to select initially. + + Returns: + None. + """ self.clear() for item in options_list: if item == "---": @@ -362,11 +396,17 @@ def set_options(self, options_list, select_item=None): class ResizingStackedWidget(QtWidgets.QStackedWidget): + """ + QStackedWidget that updates its sizeHint and minimumSizeHint as needed. + """ + def __init__(self, *args, **kwargs): super(ResizingStackedWidget, self).__init__(*args, **kwargs) def sizeHint(self): + """Qt method.""" return self.currentWidget().sizeHint() def minimumSizeHint(self): + """Qt method.""" return self.currentWidget().minimumSizeHint() diff --git a/sleap/gui/importvideos.py b/sleap/gui/importvideos.py index 3cbd145e2..7c049ef49 100644 --- a/sleap/gui/importvideos.py +++ b/sleap/gui/importvideos.py @@ -28,11 +28,10 @@ QRadioButton, QCheckBox, QComboBox, - QStackedWidget, ) from sleap.gui.video import GraphicsView -from sleap.io.video import Video, HDF5Video, MediaVideo +from sleap.io.video import Video import h5py import qimage2ndarray diff --git a/sleap/gui/multicheck.py b/sleap/gui/multicheck.py index 60d34db9d..5af7137e5 100644 --- a/sleap/gui/multicheck.py +++ b/sleap/gui/multicheck.py @@ -1,26 +1,37 @@ """ -Module for Qt Widget to show multiple checkboxes for selecting from a sequence of numbers. +Module for Qt Widget to show multiple checkboxes for selecting. Example: >>> mc = MultiCheckWidget(count=5, selected=[0,1],title="My Items") - >>> me.selectionChanged.connect(window.plot) + >>> mc.selectionChanged.connect(window.plot) >>> window.layout.addWidget(mc) """ + +from typing import List, Optional + from PySide2.QtCore import QRectF, Signal from PySide2.QtWidgets import QGridLayout, QGroupBox, QButtonGroup, QCheckBox class MultiCheckWidget(QGroupBox): - """Qt Widget to show multiple checkboxes for selecting from a sequence of numbers. + """Qt Widget to show multiple checkboxes for a sequence of numbers. Args: - count (int): The number of checkboxes to show. - title (str, optional): Display title for group of checkboxes. - selected (list, optional): List of checkbox numbers to initially have checked. - default (bool, optional): Default to checked/unchecked (ignored if selected arg given). + count: The number of checkboxes to show. + title: Display title for group of checkboxes. + selected: List of checkbox numbers to initially check. + default: Whether to default boxes as checked. """ - def __init__(self, *args, count, title="", selected=None, default=False, **kwargs): + def __init__( + self, + *args, + count: int, + title: Optional[str] = "", + selected: Optional[List] = None, + default: Optional[bool] = False, + **kwargs + ): super(MultiCheckWidget, self).__init__(*args, **kwargs) # QButtonGroup is the logical container @@ -48,7 +59,7 @@ def __init__(self, *args, count, title="", selected=None, default=False, **kwarg self.setSelected(selected) """ - selectionChanged signal is sent whenever one of the checkboxes gets a stateChanged signal. + selectionChanged signal sent when a checkbox gets a stateChanged signal """ selectionChanged = Signal() @@ -68,7 +79,7 @@ def setSelected(self, selected: list): """Method to set some checkboxes as checked. Args: - selected (list): List of checkboxes to check. + selected: List of checkboxes to check. Returns: None diff --git a/sleap/gui/shortcuts.py b/sleap/gui/shortcuts.py index 3b155e732..c4953de69 100644 --- a/sleap/gui/shortcuts.py +++ b/sleap/gui/shortcuts.py @@ -1,14 +1,19 @@ -from PySide2 import QtWidgets, QtCore -from PySide2.QtCore import Qt +""" +Gui for keyboard shortcuts. +""" +from PySide2 import QtWidgets from PySide2.QtGui import QKeySequence -import sys import yaml +from typing import Dict, List, Union from pkg_resources import Requirement, resource_filename class ShortcutDialog(QtWidgets.QDialog): + """ + Dialog window for reviewing and modifying the keyboard shortcuts. + """ _column_len = 13 @@ -20,6 +25,7 @@ def __init__(self, *args, **kwargs): self.make_form() def accept(self): + """Triggered when form is accepted; saves the shortcuts.""" for action, widget in self.key_widgets.items(): self.shortcuts[action] = widget.keySequence().toString() self.shortcuts.save() @@ -27,22 +33,26 @@ def accept(self): super(ShortcutDialog, self).accept() def load_shortcuts(self): + """Loads shortcuts object.""" self.shortcuts = Shortcuts() def make_form(self): + """Creates the form with fields for all shortcuts.""" self.key_widgets = dict() # dict to store QKeySequenceEdit widgets layout = QtWidgets.QVBoxLayout() layout.addWidget(self.make_shortcuts_widget()) layout.addWidget( QtWidgets.QLabel( - "Any changes to keyboard shortcuts will not take effect until you quit and re-open the application." + "Any changes to keyboard shortcuts will not take effect " + "until you quit and re-open the application." ) ) layout.addWidget(self.make_buttons_widget()) self.setLayout(layout) - def make_buttons_widget(self): + def make_buttons_widget(self) -> QtWidgets.QDialogButtonBox: + """Makes the form buttons.""" buttons = QtWidgets.QDialogButtonBox( QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel ) @@ -50,7 +60,8 @@ def make_buttons_widget(self): buttons.rejected.connect(self.reject) return buttons - def make_shortcuts_widget(self): + def make_shortcuts_widget(self) -> QtWidgets.QWidget: + """Makes the widget will fields for all shortcuts.""" shortcuts = self.shortcuts widget = QtWidgets.QWidget() @@ -64,7 +75,15 @@ def make_shortcuts_widget(self): widget.setLayout(layout) return widget - def make_column_widget(self, shortcuts): + def make_column_widget(self, shortcuts: List) -> QtWidgets.QWidget: + """Makes a single column of shortcut fields. + + Args: + shortcuts: The list of shortcuts to include in this column. + + Returns: + The widget. + """ column_widget = QtWidgets.QWidget() column_layout = QtWidgets.QFormLayout() for action in shortcuts: @@ -75,11 +94,14 @@ def make_column_widget(self, shortcuts): return column_widget -def dict_cut(d, a, b): - return dict(list(d.items())[a:b]) +class Shortcuts(object): + """ + Class for accessing keyboard shortcuts. + Shortcuts are saved in `sleap/config/shortcuts.yaml` -class Shortcuts: + When instantiated, this reads in the shortcuts from the file. + """ _shortcuts = None _names = ( @@ -136,13 +158,25 @@ def __init__(self): self._shortcuts = shortcuts def save(self): + """Saves all shortcuts to shortcut file.""" shortcut_yaml = resource_filename( Requirement.parse("sleap"), "sleap/config/shortcuts.yaml" ) with open(shortcut_yaml, "w") as f: yaml.dump(self._shortcuts, f) - def __getitem__(self, idx): + def __getitem__(self, idx: Union[slice, int, str]) -> Union[str, Dict[str, str]]: + """ + Returns shortcut value, accessed by range, index, or key. + + Args: + idx: Index (range, int, or str) of shortcut to access. + + Returns: + If idx is int or string, then return value is the shortcut string. + If idx is range, then return value is dictionary in which keys + are shortcut name and value are shortcut strings. + """ if isinstance(idx, slice): # dict with names and values return {self._names[i]: self[i] for i in range(*idx.indices(len(self)))} @@ -156,7 +190,8 @@ def __getitem__(self, idx): return self._shortcuts[idx] return "" - def __setitem__(self, idx, val): + def __setitem__(self, idx: Union[str, int], val: str): + """Sets shortcut by index.""" if type(idx) == int: idx = self._names[idx] self[idx] = val @@ -164,6 +199,7 @@ def __setitem__(self, idx, val): self._shortcuts[idx] = val def __len__(self): + """Returns number of shortcuts.""" return len(self._names) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 407755fc2..f81e7ef1b 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -2,30 +2,36 @@ Drop-in replacement for QSlider with additional features. """ -from PySide2.QtWidgets import QApplication, QWidget, QLayout, QAbstractSlider -from PySide2.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem -from PySide2.QtWidgets import QSizePolicy, QLabel, QGraphicsRectItem -from PySide2.QtGui import ( - QPainter, - QPen, - QBrush, - QColor, - QKeyEvent, - QPolygonF, - QPainterPath, -) -from PySide2.QtCore import Qt, Signal, QRect, QRectF, QPointF +from PySide2 import QtCore, QtWidgets +from PySide2.QtGui import QPen, QBrush, QColor, QKeyEvent, QPolygonF, QPainterPath from sleap.gui.overlays.tracks import TrackColorManager import attr import itertools import numpy as np -from typing import Dict, Optional, Union +from typing import Dict, Iterable, List, Optional, Union @attr.s(auto_attribs=True, cmp=False) class SliderMark: + """ + Class to hold data for an individual mark on the slider. + + Attributes: + type: Type of the mark, options are: + * "simple" (single value) + * "filled" (single value) + * "open" (single value) + * "predicted" (single value) + * "track" (range of values) + val: Beginning of mark range + end_val: End of mark range (for "track" marks) + row: The row that the mark goes in; used for tracks. + color: Color of mark, can be string or (r, g, b) tuple. + filled: Whether the mark is shown filled (solid color). + """ + type: str val: float end_val: float = None @@ -62,7 +68,7 @@ def filled(self): return True -class VideoSlider(QGraphicsView): +class VideoSlider(QtWidgets.QGraphicsView): """Drop-in replacement for QSlider with additional features. Args: @@ -74,38 +80,51 @@ class VideoSlider(QGraphicsView): this can be either * list of values to mark * list of (track, value)-tuples to mark + color_manager: A :class:`TrackColorManager` which determines the + color to use for "track"-type marks + + Signals: + mousePressed: triggered on Qt event + mouseMoved: triggered on Qt event + mouseReleased: triggered on Qt event + keyPress: triggered on Qt event + keyReleased: triggered on Qt event + valueChanged: triggered when value of slider changes + selectionChanged: triggered when slider range selection changes + heightUpdated: triggered when the height of slider changes """ - mousePressed = Signal(float, float) - mouseMoved = Signal(float, float) - mouseReleased = Signal(float, float) - keyPress = Signal(QKeyEvent) - keyRelease = Signal(QKeyEvent) - valueChanged = Signal(int) - selectionChanged = Signal(int, int) - updatedTracks = Signal() + mousePressed = QtCore.Signal(float, float) + mouseMoved = QtCore.Signal(float, float) + mouseReleased = QtCore.Signal(float, float) + keyPress = QtCore.Signal(QKeyEvent) + keyRelease = QtCore.Signal(QKeyEvent) + valueChanged = QtCore.Signal(int) + selectionChanged = QtCore.Signal(int, int) + heightUpdated = QtCore.Signal() def __init__( self, - orientation=-1, + orientation=-1, # for compatibility with QSlider min=0, max=100, val=0, marks=None, - tracks=0, - color_manager=None, + color_manager: Optional[TrackColorManager] = None, *args, **kwargs ): super(VideoSlider, self).__init__(*args, **kwargs) - self.scene = QGraphicsScene() + self.scene = QtWidgets.QGraphicsScene() self.setScene(self.scene) - self.setAlignment(Qt.AlignLeft | Qt.AlignTop) + self.setAlignment(QtCore.Qt.AlignLeft | QtCore.Qt.AlignTop) - self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # ScrollBarAsNeeded + self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) + self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy( + QtCore.Qt.ScrollBarAlwaysOff + ) # ScrollBarAsNeeded self._color_manager = color_manager @@ -115,13 +134,15 @@ def __init__( self._min_height = 19 + self._header_height # Add border rect - outline_rect = QRect(0, 0, 200, self._min_height - 3) + outline_rect = QtCore.QRect(0, 0, 200, self._min_height - 3) self.outlineBox = self.scene.addRect(outline_rect) self.outlineBox.setPen(QPen(QColor("black"))) # Add drag handle rect handle_width = 6 - handle_rect = QRect(0, self._handleTop(), handle_width, self._handleHeight()) + handle_rect = QtCore.QRect( + 0, self._handleTop(), handle_width, self._handleHeight() + ) self.setMinimumHeight(self._min_height) self.setMaximumHeight(self._min_height) self.handle = self.scene.addRect(handle_rect) @@ -129,7 +150,9 @@ def __init__( self.handle.setBrush(QColor(128, 128, 128, 128)) # Add (hidden) rect to highlight selection - self.select_box = self.scene.addRect(QRect(0, 1, 0, outline_rect.height() - 2)) + self.select_box = self.scene.addRect( + QtCore.QRect(0, 1, 0, outline_rect.height() - 2) + ) self.select_box.setPen(QPen(QColor(80, 80, 255))) self.select_box.setBrush(QColor(80, 80, 255, 128)) self.select_box.hide() @@ -149,19 +172,23 @@ def __init__( self.headerSeries = dict() self.drawHeader() - def _pointsToPath(self, points): + def _pointsToPath(self, points: List[QtCore.QPointF]) -> QPainterPath: + """Converts list of `QtCore.QPointF`s to a `QPainterPath`.""" path = QPainterPath() path.addPolygon(QPolygonF(points)) return path - def setTracksFromLabels(self, labels, video): + def setTracksFromLabels(self, labels: "Labels", video: "Video"): """Set slider marks using track information from `Labels` object. Note that this is the only method coupled to a SLEAP object. Args: - labels: the `labels` with tracks and labeled_frames + labels: the dataset with tracks and labeled frames video: the video for which to show marks + + Returns: + None """ if self._color_manager is None: @@ -267,10 +294,20 @@ def updateHeight(self): self.setMarks(marks) self.resizeEvent() - self.updatedTracks.emit() + self.heightUpdated.emit() + + def _toPos(self, val: float, center=False) -> float: + """ + Converts slider value to x position on slider. + + Args: + val: The slider value. + center: Whether to offset by half the width of drag handle, + so that plotted location will light up with center of handle. - def _toPos(self, val, center=False): - """Convert value to x position on slider.""" + Returns: + x position. + """ x = val x -= self._val_min x /= max(1, self._val_max - self._val_min) @@ -279,8 +316,8 @@ def _toPos(self, val, center=False): x += self.handle.rect().width() / 2.0 return x - def _toVal(self, x, center=False): - """Convert x position to value.""" + def _toVal(self, x: float, center=False) -> float: + """Converts x position to slider value.""" val = x val /= self._sliderWidth() val *= max(1, self._val_max - self._val_min) @@ -288,28 +325,29 @@ def _toVal(self, x, center=False): val = round(val) return val - def _sliderWidth(self): + def _sliderWidth(self) -> float: + """Returns visual width of slider.""" return self.outlineBox.rect().width() - self.handle.rect().width() - def value(self): - """Get value of slider.""" + def value(self) -> float: + """Returns value of slider.""" return self._val_main - def setValue(self, val): - """Set value of slider.""" + def setValue(self, val: float) -> float: + """Sets value of slider.""" self._val_main = val x = self._toPos(val) self.handle.setPos(x, 0) - def setMinimum(self, min): - """Set minimum value for slider.""" + def setMinimum(self, min: float) -> float: + """Sets minimum value for slider.""" self._val_min = min - def setMaximum(self, max): - """Set maximum value for slider.""" + def setMaximum(self, max: float) -> float: + """Sets maximum value for slider.""" self._val_max = max - def setEnabled(self, val): + def setEnabled(self, val: float) -> float: """Set whether the slider is enabled.""" self._enabled = val @@ -318,23 +356,28 @@ def enabled(self): return self._enabled def clearSelection(self): - """Clear selection endpoints.""" + """Clears selection endpoints.""" self._selection = [] self.select_box.hide() def startSelection(self, val): - """Add initial selection endpoint. + """Adds initial selection endpoint. + + Called when user starts dragging to select range in slider. Args: val: value of endpoint """ self._selection.append(val) - def endSelection(self, val, update=False): + def endSelection(self, val, update: bool = False): """Add final selection endpoint. + Called during or after the user is dragging to select range. + Args: val: value of endpoint + update: """ # If we want to update endpoint and there's already one, remove it if update and len(self._selection) % 2 == 0: @@ -350,12 +393,12 @@ def endSelection(self, val, update=False): self.selectionChanged.emit(*self.getSelection()) def hasSelection(self) -> bool: - """Return True if a clip is selected, False otherwise.""" + """Returns True if a clip is selected, False otherwise.""" a, b = self.getSelection() return a < b def getSelection(self): - """Return start and end value of current selection endpoints.""" + """Returns start and end value of current selection endpoints.""" a, b = 0, 0 if len(self._selection) % 2 == 0 and len(self._selection) > 0: a, b = self._selection[-2:] @@ -363,30 +406,37 @@ def getSelection(self): end = max(a, b) return start, end - def drawSelection(self, a, b): - """Draw selection box on slider. + def drawSelection(self, a: float, b: float): + """Draws selection box on slider. Args: a: one endpoint value b: other endpoint value + + Returns: + None. """ start = min(a, b) end = max(a, b) start_pos = self._toPos(start, center=True) end_pos = self._toPos(end, center=True) - selection_rect = QRect( + selection_rect = QtCore.QRect( start_pos, 1, end_pos - start_pos, self.outlineBox.rect().height() - 2 ) self.select_box.setRect(selection_rect) self.select_box.show() - def moveSelectionAnchor(self, x, y): - """Move selection anchor in response to mouse position. + def moveSelectionAnchor(self, x: float, y: float): + """ + Moves selection anchor in response to mouse position. Args: x: x position of mouse y: y position of mouse + + Returns: + None. """ x = max(x, 0) x = min(x, self.outlineBox.rect().width()) @@ -398,11 +448,15 @@ def moveSelectionAnchor(self, x, y): self.drawSelection(anchor_val, self._selection[-1]) def releaseSelectionAnchor(self, x, y): - """Finish selection in response to mouse release. + """ + Finishes selection in response to mouse release. Args: x: x position of mouse y: y position of mouse + + Returns: + None. """ x = max(x, 0) x = min(x, self.outlineBox.rect().width()) @@ -410,18 +464,21 @@ def releaseSelectionAnchor(self, x, y): self.endSelection(anchor_val) def clearMarks(self): - """Clear all marked values for slider.""" + """Clears all marked values for slider.""" if hasattr(self, "_mark_items"): for item in self._mark_items.values(): self.scene.removeItem(item) self._marks = set() # holds mark position self._mark_items = dict() # holds visual Qt object for plotting mark - def setMarks(self, marks): - """Set all marked values for the slider. + def setMarks(self, marks: Iterable[Union[SliderMark, int]]): + """Sets all marked values for the slider. Args: marks: iterable with all values to mark + + Returns: + None. """ self.clearMarks() if marks is not None: @@ -432,17 +489,18 @@ def setMarks(self, marks): self.updatePos() def getMarks(self): - """Return list of marks. - - Each mark is either val or (track, val)-tuple. - """ + """Returns list of marks.""" return self._marks - def addMark(self, new_mark, update=True): - """Add a marked value to the slider. + def addMark(self, new_mark: SliderMark, update: bool = True): + """Adds a marked value to the slider. Args: new_mark: value to mark + update: Whether to redraw slider with new mark. + + Returns: + None. """ # check if mark is within slider range if new_mark.val > self._val_max: @@ -476,9 +534,6 @@ def addMark(self, new_mark, update=True): if update: self.updatePos() - def _mark_val(self, mark): - return mark.val - def updatePos(self): """Update the visual x position of handle and slider annotations.""" x = self._toPos(self.value()) @@ -538,8 +593,8 @@ def toYPos(val): (self._toPos(max(series.keys()) + 1, center=True), toYPos(series_min)) ) - # Convert to list of QPointF objects - points = list(itertools.starmap(QPointF, points)) + # Convert to list of QtCore.QPointF objects + points = list(itertools.starmap(QtCore.QPointF, points)) self.poly.setPath(self._pointsToPath(points)) def moveHandle(self, x, y): @@ -558,7 +613,7 @@ def moveHandle(self, x, y): val = self._toVal(x) # snap to nearby mark within handle - mark_vals = [self._mark_val(mark) for mark in self._marks] + mark_vals = [mark.val for mark in self._marks] handle_left = self._toVal(x - self.handle.rect().width() / 2) handle_right = self._toVal(x + self.handle.rect().width() / 2) marks_in_handle = [ @@ -602,10 +657,22 @@ def resizeEvent(self, event=None): self.drawHeader() super(VideoSlider, self).resizeEvent(event) - def _handleTop(self): + def _handleTop(self) -> float: + """Returns y position of top of handle (i.e., header height).""" return 1 + self._header_height - def _handleHeight(self, outline_rect=None): + def _handleHeight(self, outline_rect=None) -> float: + """ + Returns visual height of handle. + + Args: + outline_rect: The rect of the outline box for the slider. This + is only required when calling during initialization (when the + outline box doesn't yet exist). + + Returns: + Height of handle in pixels. + """ if outline_rect is None: outline_rect = self.outlineBox.rect() @@ -633,13 +700,13 @@ def mousePressEvent(self, event): move_function = None release_function = None - if event.modifiers() == Qt.ShiftModifier: + if event.modifiers() == QtCore.Qt.ShiftModifier: move_function = self.moveSelectionAnchor release_function = self.releaseSelectionAnchor self.clearSelection() - elif event.modifiers() == Qt.NoModifier: + elif event.modifiers() == QtCore.Qt.NoModifier: move_function = self.moveHandle release_function = None @@ -679,7 +746,7 @@ def keyReleaseEvent(self, event): self.keyRelease.emit(event) event.accept() - def boundingRect(self) -> QRectF: + def boundingRect(self) -> QtCore.QRectF: """Method required by Qt.""" return self.outlineBox.rect() @@ -689,7 +756,7 @@ def paint(self, *args, **kwargs): if __name__ == "__main__": - app = QApplication([]) + app = QtWidgets.QApplication([]) window = VideoSlider( min=0, diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 3876b0f67..cf0630d7f 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -80,7 +80,7 @@ def __init__(self, video: Video = None, color_manager=None, *args, **kwargs): self.splitter = QtWidgets.QSplitter(Qt.Vertical) self.splitter.addWidget(self.view) self.splitter.addWidget(self.seekbar) - self.seekbar.updatedTracks.connect(lambda: self.splitter.refresh()) + self.seekbar.heightUpdated.connect(lambda: self.splitter.refresh()) self.layout = QVBoxLayout() self.layout.addWidget(self.splitter) diff --git a/sleap/util.py b/sleap/util.py index 70eb7cd3e..d91deadfc 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -214,3 +214,20 @@ def weak_filename_match(filename_a: str, filename_b: str) -> bool: # check if last three parts of path match return filename_a.split("/")[-3:] == filename_b.split("/")[-3:] + + +def dict_cut(d: Dict, a: int, b: int) -> Dict: + """ + Helper function for creating subdictionary by numeric indexing of items. + + Assumes that `dict.items()` will have a fixed order. + + Args: + d: The dictionary to "split" + a: Start index of range of items to include in result. + b: End index of range of items to include in result. + + Returns: + A dictionary that contains a subset of the items in the original dict. + """ + return dict(list(d.items())[a:b]) From 33bebfbff0cd0a2963777b8e2f654c0edb8931d7 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 27 Sep 2019 17:14:11 -0400 Subject: [PATCH 146/176] Fixed incorrect test. --- sleap/gui/merge.py | 4 ++-- tests/gui/test_merge.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py index cbaa20c29..f8b038d8c 100644 --- a/sleap/gui/merge.py +++ b/sleap/gui/merge.py @@ -292,13 +292,13 @@ def show_instance_type_counts(instance_list: List["Instance"]) -> str: instance_list: The list of instances to count. Returns: - String with number of predicted instances and number of user instances. + String with numbers of user/predicted instances. """ prediction_count = len( list(filter(lambda inst: hasattr(inst, "score"), instance_list)) ) user_count = len(instance_list) - prediction_count - return f"{prediction_count}/{user_count}" + return f"{user_count}/{prediction_count}" if __name__ == "__main__": diff --git a/tests/gui/test_merge.py b/tests/gui/test_merge.py index c625c0def..6f75753a2 100644 --- a/tests/gui/test_merge.py +++ b/tests/gui/test_merge.py @@ -2,4 +2,4 @@ def test_count_string(simple_predictions): - assert show_instance_type_counts(simple_predictions[0]) == 2 + assert show_instance_type_counts(simple_predictions[0]) == "0/2" From c879b1f647366921b9fc9c58f3397c8b57652f8b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 30 Sep 2019 13:30:24 -0400 Subject: [PATCH 147/176] Better docstrings and typing, minor refactoring. --- sleap/gui/active.py | 111 ++++- sleap/gui/app.py | 869 ++++++++++++++++++--------------- sleap/gui/dataviews.py | 463 +++++++++--------- sleap/gui/overlays/anchors.py | 42 +- sleap/gui/overlays/base.py | 10 +- sleap/gui/overlays/confmaps.py | 15 +- sleap/gui/overlays/instance.py | 18 +- sleap/gui/overlays/pafs.py | 101 +--- sleap/gui/overlays/tracks.py | 69 ++- sleap/gui/slider.py | 4 + sleap/gui/suggestions.py | 180 +++---- sleap/gui/training_editor.py | 19 +- sleap/gui/video.py | 280 +++++++---- sleap/info/metrics.py | 11 +- sleap/instance.py | 17 +- sleap/io/dataset.py | 18 + sleap/io/video.py | 11 + sleap/io/visuals.py | 4 + sleap/skeleton.py | 39 +- tests/gui/test_tracks.py | 2 +- 20 files changed, 1292 insertions(+), 991 deletions(-) diff --git a/sleap/gui/active.py b/sleap/gui/active.py index ffb0d1074..e86ce1bc2 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -1,11 +1,13 @@ +""" +Module for running active learning (or just inference) from GUI. +""" + import os import cattr -from datetime import datetime -import multiprocessing from functools import reduce from pkg_resources import Requirement, resource_filename -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from sleap.io.dataset import Labels from sleap.io.video import Video @@ -18,6 +20,18 @@ class ActiveLearningDialog(QtWidgets.QDialog): + """Active learning dialog. + + The dialog can be used in different modes: + * simplified active learning (fewer controls) + * expert active learning (full controls) + * inference only + + Arguments: + labels_filename: Path to the dataset where we'll get training data. + labels: The dataset where we'll get training data and add predictions. + mode: String which specified mode ("active", "expert", or "inference"). + """ learningFinished = QtCore.Signal() @@ -26,7 +40,6 @@ def __init__( labels_filename: str, labels: Labels, mode: str = "expert", - only_predict: bool = False, *args, **kwargs, ): @@ -36,7 +49,6 @@ def __init__( self.labels_filename = labels_filename self.labels = labels self.mode = mode - self.only_predict = only_predict print(f"Number of frames to train on: {len(labels.user_labeled_frames)}") @@ -99,18 +111,18 @@ def __init__( # connect actions to buttons def edit_conf_profile(): - self.view_profile( + self._view_profile( self.form_widget["conf_job"], model_type=ModelOutputType.CONFIDENCE_MAP ) def edit_paf_profile(): - self.view_profile( + self._view_profile( self.form_widget["paf_job"], model_type=ModelOutputType.PART_AFFINITY_FIELD, ) def edit_cent_profile(): - self.view_profile( + self._view_profile( self.form_widget["centroid_job"], model_type=ModelOutputType.CENTROIDS ) @@ -133,6 +145,9 @@ def edit_cent_profile(): self.update_gui() def _rebuild_job_options(self): + """ + Rebuilds list of profile options (checking for new profile files). + """ # load list of job profiles from directory profile_dir = resource_filename( Requirement.parse("sleap"), "sleap/training_profiles" @@ -147,28 +162,41 @@ def _rebuild_job_options(self): # list default profiles find_saved_jobs(profile_dir, self.job_options) - def _update_job_menus(self, init=False): + def _update_job_menus(self, init: bool = False): + """Updates the menus with training profile options. + + Args: + init: Whether this is first time calling (so we should connect + signals), or we're just updating menus. + + Returns: + None. + """ for model_type, field in self.training_profile_widgets.items(): if model_type not in self.job_options: self.job_options[model_type] = [] if init: field.currentIndexChanged.connect( - lambda idx, mt=model_type: self.select_job(mt, idx) + lambda idx, mt=model_type: self._update_from_selected_job(mt, idx) ) else: # block signals so we can update combobox without overwriting # any user data with the defaults from the profile field.blockSignals(True) - field.set_options(self.option_list_from_jobs(model_type)) + field.set_options(self._option_list_from_jobs(model_type)) # enable signals again so that choice of profile will update params field.blockSignals(False) @property - def frame_selection(self): + def frame_selection(self) -> Dict[Video, List[int]]: + """ + Returns dictionary with frames that user has selected for inference. + """ return self._frame_selection @frame_selection.setter - def frame_selection(self, frame_selection): + def frame_selection(self, frame_selection: Dict[str, Dict[Video, List[int]]]): + """Sets options of frames on which to run inference.""" self._frame_selection = frame_selection if "_predict_frames" in self.form_widget.fields.keys(): @@ -209,6 +237,7 @@ def count_total_frames(videos_frames): ) def show(self): + """Shows dialog (we hide rather than close to maintain settings).""" super(ActiveLearningDialog, self).show() # TODO: keep selection and any items added from training editor @@ -217,6 +246,7 @@ def show(self): self._update_job_menus() def update_gui(self): + """Updates gui state after user changes to options.""" form_data = self.form_widget.get_form_data() can_run = True @@ -279,7 +309,14 @@ def update_gui(self): self.run_button.setEnabled(can_run) - def _get_current_job(self, model_type): + def _get_current_job(self, model_type: ModelOutputType) -> Tuple[TrainingJob, str]: + """Returns training job currently selected for given model type. + + Args: + model_type: The type of model for which we want data. + + Returns: Tuple of (TrainingJob, path to job profile). + """ # by default use the first model for a given type idx = 0 if model_type in self.training_profile_widgets: @@ -301,6 +338,7 @@ def _get_current_job(self, model_type): return job, job_filename def _get_model_types_to_use(self): + """Returns lists of model types which user has enabled.""" form_data = self.form_widget.get_form_data() types_to_use = [] @@ -317,7 +355,8 @@ def _get_model_types_to_use(self): return types_to_use - def _get_current_training_jobs(self): + def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]: + """Returns all currently selected training jobs.""" form_data = self.form_widget.get_form_data() training_jobs = dict() @@ -345,6 +384,7 @@ def _get_current_training_jobs(self): return training_jobs def run(self): + """Run active learning (or inference) with current dialog settings.""" # Collect TrainingJobs and params from form form_data = self.form_widget.get_form_data() training_jobs = self._get_current_training_jobs() @@ -392,6 +432,7 @@ def run(self): ).exec_() def view_datagen(self): + """Shows windows with sample visual data that will be used training.""" from sleap.nn.datagen import ( generate_training_data, generate_confmaps_from_points, @@ -444,8 +485,8 @@ def view_datagen(self): # can we show these windows without closing dialog? self.hide() - # open profile editor in new dialog window - def view_profile(self, filename, model_type, windows=[]): + def _view_profile(self, filename: str, model_type: ModelOutputType, windows=[]): + """Opens profile editor in new dialog window.""" saved_files = [] win = TrainingEditor(filename, saved_files=saved_files, parent=self) windows.append(win) @@ -454,14 +495,16 @@ def view_profile(self, filename, model_type, windows=[]): for new_filename in saved_files: self._add_job_file_to_list(new_filename, model_type) - def option_list_from_jobs(self, model_type): + def _option_list_from_jobs(self, model_type: ModelOutputType): + """Returns list of menu options for given model type.""" jobs = self.job_options[model_type] option_list = [name for (name, job) in jobs] option_list.append("---") option_list.append("Select a training profile file...") return option_list - def add_job_file(self, model_type): + def _add_job_file(self, model_type): + """Allow user to add training profile for given model type.""" filename, _ = QtWidgets.QFileDialog.getOpenFileName( None, dir=None, @@ -475,7 +518,8 @@ def add_job_file(self, model_type): if field.currentIndex() == field.count() - 1: # subtract 1 for separator field.setCurrentIndex(-1) - def _add_job_file_to_list(self, filename, model_type): + def _add_job_file_to_list(self, filename: str, model_type: ModelOutputType): + """Adds selected training profile for given model type.""" if len(filename): try: # try to load json as TrainingJob @@ -497,14 +541,15 @@ def _add_job_file_to_list(self, filename, model_type): if model_type in self.training_profile_widgets: field = self.training_profile_widgets[model_type] field.set_options( - self.option_list_from_jobs(model_type), filename + self._option_list_from_jobs(model_type), filename ) else: QtWidgets.QMessageBox( text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}." ).exec_() - def select_job(self, model_type, idx): + def _update_from_selected_job(self, model_type: ModelOutputType, idx: int): + """Updates dialog settings after user selects a training profile.""" jobs = self.job_options[model_type] if idx == -1: return @@ -535,10 +580,11 @@ def select_job(self, model_type, idx): self.form_widget[field_name] = has_trained else: # last item is "select file..." - self.add_job_file(model_type) + self._add_job_file(model_type) -def make_default_training_jobs(): +def make_default_training_jobs() -> Dict[ModelOutputType, TrainingJob]: + """Creates TrainingJobs with some default settings.""" from sleap.nn.model import Model from sleap.nn.training import Trainer from sleap.nn.architectures import unet, leap @@ -595,12 +641,15 @@ def make_default_training_jobs(): return training_jobs -def find_saved_jobs(job_dir, jobs=None): +def find_saved_jobs( + job_dir: str, jobs=None +) -> Dict[ModelOutputType, List[Tuple[str, TrainingJob]]]: """Find all the TrainingJob json files in a given directory. Args: job_dir: the directory in which to look for json files - jobs (optional): append to jobs, rather than creating new dict + jobs: If given, then the found jobs will be added to this object, + rather than creating new dict. Returns: dict of {ModelOutputType: list of (filename, TrainingJob) tuples} """ @@ -633,7 +682,15 @@ def find_saved_jobs(job_dir, jobs=None): return jobs -def add_frames_from_json(labels: Labels, new_labels_json: str): +def add_frames_from_json(labels: Labels, new_labels_json: str) -> int: + """Merges new predictions (given as json string) into dataset. + + Args: + labels: The dataset to which we're adding the predictions. + new_labels_json: A JSON string which can be deserialized into `Labels`. + Returns: + Number of labeled frames with new predictions. + """ # Deserialize the new frames, matching to the existing videos/skeletons if possible new_lfs = Labels.from_json(new_labels_json, match_to=labels).labeled_frames diff --git a/sleap/gui/app.py b/sleap/gui/app.py index c8e3a85de..494fe8e96 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1,38 +1,29 @@ +""" +Main GUI application for labeling, active learning, and proofreading. +""" + from PySide2 import QtCore, QtWidgets from PySide2.QtCore import Qt, QEvent -from PySide2.QtGui import QKeyEvent, QKeySequence, QStatusTipEvent - from PySide2.QtWidgets import QApplication, QMainWindow, QWidget, QDockWidget -from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout -from PySide2.QtWidgets import ( - QLabel, - QPushButton, - QLineEdit, - QSpinBox, - QDoubleSpinBox, - QComboBox, - QCheckBox, -) -from PySide2.QtWidgets import QTableWidget, QTableView, QTableWidgetItem -from PySide2.QtWidgets import QMenu, QAction +from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox +from PySide2.QtWidgets import QLabel, QPushButton, QComboBox from PySide2.QtWidgets import QFileDialog, QMessageBox -import copy import re import operator import os import sys -import yaml from pkg_resources import Requirement, resource_filename from pathlib import PurePath +from typing import Callable, Dict, Iterator, Optional + import numpy as np -import pandas as pd -from sleap.skeleton import Skeleton, Node -from sleap.instance import Instance, PredictedInstance, Point, LabeledFrame, Track +from sleap.skeleton import Skeleton +from sleap.instance import Instance, PredictedInstance, Point, Track from sleap.io.video import Video from sleap.io.dataset import Labels from sleap.info.summary import StatisticSeries @@ -63,11 +54,30 @@ class MainWindow(QMainWindow): + """The SLEAP GUI application. + + Each project (`Labels` dataset) that you have loaded in the GUI will + have it's own `MainWindow` object. + + Attributes: + labels: The :class:`Labels` dataset. If None, a new, empty project + (i.e., :class:`Labels' object) will be created. + skeleton: The active :class:`Skeleton` for the project in the gui + video: The active :class:`Video` in view in the gui + """ + labels: Labels skeleton: Skeleton video: Video - def __init__(self, data_path=None, video=None, import_data=None, *args, **kwargs): + def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): + """Initialize the app. + + Args: + labels_path: Path to saved :class:`Labels` dataset. + Returns: + None. + """ super(MainWindow, self).__init__(*args, **kwargs) self.labels = Labels() @@ -75,7 +85,6 @@ def __init__(self, data_path=None, video=None, import_data=None, *args, **kwargs self.labeled_frame = None self.video = None self.video_idx = None - self.mark_idx = None self.filename = None self._menu_actions = dict() self._buttons = dict() @@ -91,45 +100,52 @@ def __init__(self, data_path=None, video=None, import_data=None, *args, **kwargs self._auto_zoom = False self.changestack_clear() - self.initialize_gui() + self._initialize_gui() - if data_path is not None: - pass + if labels_path is not None: + self.loadProject(labels_path) - if import_data is not None: - self.importData(import_data) + def event(self, e: QEvent) -> bool: + """Custom event handler. - # TODO: auto-add video to clean project if no data provided - # TODO: auto-select video if data provided, or add it to project - if video is not None: - self.addVideo(video) + We use this to ignore events that would clear status bar. - def event(self, e): + Args: + e: The event. + Returns: + True if we ignore event, otherwise returns whatever the usual + event handler would return. + """ if e.type() == QEvent.StatusTip: if e.tip() == "": return True return super().event(e) - def changestack_push(self, change=None): - """Add to stack of changes made by user.""" + def changestack_push(self, change: bool = None): + """Adds to stack of changes made by user.""" # Currently the change doesn't store any data, and we're only using this # to determine if there are unsaved changes. Eventually we could use this # to support undo/redo. self._change_stack.append(change) def changestack_savepoint(self): + """Marks that project was just saved.""" self.changestack_push("SAVE") def changestack_clear(self): + """Clears stack of changes.""" self._change_stack = list() - def changestack_start_atomic(self, change=None): - pass + def changestack_start_atomic(self): + """Marks that we want to track a set of changes as a single change.""" + self.changestack_push("ATOMIC_START") def changestack_end_atomic(self): - pass + """Marks that we want finished the set of changes to track together.""" + self.changestack_push("ATOMIC_END") def changestack_has_changes(self) -> bool: + """Returns whether there are any unsaved changes.""" # True iff there are no unsaved changed if len(self._change_stack) == 0: return False @@ -139,21 +155,34 @@ def changestack_has_changes(self) -> bool: @property def filename(self): + """Returns filename for current project.""" return self._filename @filename.setter def filename(self, x): + """Sets filename for current project. Doesn't load file.""" self._filename = x if x is not None: self.setWindowTitle(x) - def initialize_gui(self): + def _initialize_gui(self): + """Creates menus, dock windows, starts timers to update gui state.""" - shortcuts = Shortcuts() + self._create_video_player() + self.statusBar() + self.load_overlays() + self._create_menus() + self._create_dock_windows() - ####### Video player ####### + # Create timer to update state of gui at regular intervals + self.update_gui_timer = QtCore.QTimer() + self.update_gui_timer.timeout.connect(self._update_gui_state) + self.update_gui_timer.start(0.1) + + def _create_video_player(self): + """Creates and connects :class:`QtVideoPlayer` for gui.""" self.player = QtVideoPlayer(color_manager=self._color_manager) - self.player.changedPlot.connect(self.newFrame) + self.player.changedPlot.connect(self._after_plot_update) self.player.changedData.connect( lambda inst: self.changestack_push("viewer change") ) @@ -161,107 +190,88 @@ def initialize_gui(self): self.player.seekbar.selectionChanged.connect(lambda: self.updateStatusMessage()) self.setCentralWidget(self.player) - ####### Status bar ####### - self.statusBar() # Initialize status bar - - self.load_overlays() + def _create_menus(self): + """Creates main application menus.""" + shortcuts = Shortcuts() - ####### Menus ####### + def _menu_item(menu, key: str, name: str, action: Callable): + menu_item = menu.addAction(name, action, shortcuts[key]) + self._menu_actions[key] = menu_item ### File Menu ### fileMenu = self.menuBar().addMenu("File") - self._menu_actions["new"] = fileMenu.addAction( - "New Project", self.newProject, shortcuts["new"] - ) - self._menu_actions["open"] = fileMenu.addAction( - "Open Project...", self.openProject, shortcuts["open"] - ) - self._menu_actions["import predictions"] = fileMenu.addAction( - "Import Labels...", self.importPredictions + _menu_item(fileMenu, "new", "New Project", self.newProject) + _menu_item(fileMenu, "open", "Open Project...", self.openProject) + _menu_item( + fileMenu, "import predictions", "Import Labels...", self.importPredictions ) + fileMenu.addSeparator() - self._menu_actions["add videos"] = fileMenu.addAction( - "Add Videos...", self.addVideo, shortcuts["add videos"] - ) + _menu_item(fileMenu, "add videos", "Add Videos...", self.addVideo) + fileMenu.addSeparator() - self._menu_actions["save"] = fileMenu.addAction( - "Save", self.saveProject, shortcuts["save"] - ) - self._menu_actions["save as"] = fileMenu.addAction( - "Save As...", self.saveProjectAs, shortcuts["save as"] - ) + _menu_item(fileMenu, "save", "Save", self.saveProject) + _menu_item(fileMenu, "save as", "Save As...", self.saveProjectAs) + fileMenu.addSeparator() - self._menu_actions["close"] = fileMenu.addAction( - "Quit", self.close, shortcuts["close"] - ) + _menu_item(fileMenu, "close", "Quit", self.close) ### Go Menu ### goMenu = self.menuBar().addMenu("Go") - self._menu_actions["goto next labeled"] = goMenu.addAction( - "Next Labeled Frame", self.nextLabeledFrame, shortcuts["goto next labeled"] + _menu_item( + goMenu, "goto next labeled", "Next Labeled Frame", self.nextLabeledFrame ) - self._menu_actions["goto prev labeled"] = goMenu.addAction( + _menu_item( + goMenu, + "goto prev labeled", "Previous Labeled Frame", self.previousLabeledFrame, - shortcuts["goto prev labeled"], ) - - self._menu_actions["goto next user"] = goMenu.addAction( + _menu_item( + goMenu, + "goto next user", "Next User Labeled Frame", self.nextUserLabeledFrame, - shortcuts["goto next user"], ) - - self._menu_actions["goto next suggestion"] = goMenu.addAction( - "Next Suggestion", - self.nextSuggestedFrame, - shortcuts["goto next suggestion"], + _menu_item( + goMenu, "goto next suggestion", "Next Suggestion", self.nextSuggestedFrame ) - self._menu_actions["goto prev suggestion"] = goMenu.addAction( + _menu_item( + goMenu, + "goto prev suggestion", "Previous Suggestion", lambda: self.nextSuggestedFrame(-1), - shortcuts["goto prev suggestion"], ) - - self._menu_actions["goto next track spawn"] = goMenu.addAction( + _menu_item( + goMenu, + "goto next track spawn", "Next Track Spawn Frame", self.nextTrackFrame, - shortcuts["goto next track spawn"], ) goMenu.addSeparator() - self._menu_actions["next video"] = goMenu.addAction( - "Next Video", self.nextVideo, shortcuts["next video"] - ) - self._menu_actions["prev video"] = goMenu.addAction( - "Previous Video", self.previousVideo, shortcuts["prev video"] - ) + _menu_item(goMenu, "next video", "Next Video", self.nextVideo) + _menu_item(goMenu, "prev video", "Previous Video", self.previousVideo) goMenu.addSeparator() - self._menu_actions["goto frame"] = goMenu.addAction( - "Go to Frame...", self.gotoFrame, shortcuts["goto frame"] - ) - self._menu_actions["mark frame"] = goMenu.addAction( - "Mark Frame", self.markFrame, shortcuts["mark frame"] - ) - self._menu_actions["goto marked"] = goMenu.addAction( - "Go to Marked Frame", self.goMarkedFrame, shortcuts["goto marked"] - ) + _menu_item(goMenu, "goto frame", "Go to Frame...", self.gotoFrame) ### View Menu ### viewMenu = self.menuBar().addMenu("View") + self.viewMenu = viewMenu # store as attribute so docks can add items viewMenu.addSeparator() - self._menu_actions["color predicted"] = viewMenu.addAction( + _menu_item( + viewMenu, + "color predicted", "Color Predicted Instances", self.toggleColorPredicted, - shortcuts["color predicted"], ) self.paletteMenu = viewMenu.addMenu("Color Palette") @@ -294,15 +304,9 @@ def initialize_gui(self): viewMenu.addSeparator() - self._menu_actions["show labels"] = viewMenu.addAction( - "Show Node Names", self.toggleLabels, shortcuts["show labels"] - ) - self._menu_actions["show edges"] = viewMenu.addAction( - "Show Edges", self.toggleEdges, shortcuts["show edges"] - ) - self._menu_actions["show trails"] = viewMenu.addAction( - "Show Trails", self.toggleTrails, shortcuts["show trails"] - ) + _menu_item(viewMenu, "show labels", "Show Node Names", self.toggleLabels) + _menu_item(viewMenu, "show edges", "Show Edges", self.toggleEdges) + _menu_item(viewMenu, "show trails", "Show Trails", self.toggleTrails) self.trailLengthMenu = viewMenu.addMenu("Trail Length") for length_option in (4, 10, 20): @@ -313,9 +317,7 @@ def initialize_gui(self): viewMenu.addSeparator() - self._menu_actions["fit"] = viewMenu.addAction( - "Fit Instances to View", self.toggleAutoZoom, shortcuts["fit"] - ) + _menu_item(viewMenu, "fit", "Fit Instances to View", self.toggleAutoZoom) viewMenu.addSeparator() @@ -335,36 +337,37 @@ def initialize_gui(self): ### Label Menu ### labelMenu = self.menuBar().addMenu("Labels") - self._menu_actions["add instance"] = labelMenu.addAction( - "Add Instance", self.newInstance, shortcuts["add instance"] - ) - self._menu_actions["delete instance"] = labelMenu.addAction( - "Delete Instance", self.deleteSelectedInstance, shortcuts["delete instance"] + _menu_item(labelMenu, "add instance", "Add Instance", self.newInstance) + _menu_item( + labelMenu, "delete instance", "Delete Instance", self.deleteSelectedInstance ) labelMenu.addSeparator() self.track_menu = labelMenu.addMenu("Set Instance Track") - self._menu_actions["transpose"] = labelMenu.addAction( - "Transpose Instance Tracks", self.transposeInstance, shortcuts["transpose"] + _menu_item( + labelMenu, "transpose", "Transpose Instance Tracks", self.transposeInstance ) - self._menu_actions["delete track"] = labelMenu.addAction( + _menu_item( + labelMenu, + "delete track", "Delete Instance and Track", self.deleteSelectedInstanceTrack, - shortcuts["delete track"], ) labelMenu.addSeparator() - self._menu_actions["select next"] = labelMenu.addAction( + _menu_item( + labelMenu, + "select next", "Select Next Instance", self.player.view.nextSelection, - shortcuts["select next"], ) - self._menu_actions["clear selection"] = labelMenu.addAction( + _menu_item( + labelMenu, + "clear selection", "Clear Selection", self.player.view.clearSelection, - shortcuts["clear selection"], ) labelMenu.addSeparator() @@ -372,62 +375,102 @@ def initialize_gui(self): ### Predict Menu ### predictionMenu = self.menuBar().addMenu("Predict") - self._menu_actions["active learning"] = predictionMenu.addAction( - "Run Active Learning...", self.runActiveLearning, shortcuts["learning"] + + _menu_item( + predictionMenu, + "active learning", + "Run Active Learning...", + lambda: self.showLearningDialog("learning"), ) - self._menu_actions["inference"] = predictionMenu.addAction( - "Run Inference...", self.runInference + _menu_item( + predictionMenu, + "inference", + "Run Inference...", + lambda: self.showLearningDialog("inference"), ) - self._menu_actions["learning expert"] = predictionMenu.addAction( - "Expert Controls...", self.runLearningExpert + _menu_item( + predictionMenu, + "learning expert", + "Expert Controls...", + lambda: self.showLearningDialog("expert"), ) + predictionMenu.addSeparator() - self._menu_actions["negative sample"] = predictionMenu.addAction( - "Mark Negative Training Sample...", self.markNegativeAnchor + _menu_item( + predictionMenu, + "negative sample", + "Mark Negative Training Sample...", + self.markNegativeAnchor, ) - self._menu_actions["clear negative samples"] = predictionMenu.addAction( - "Clear Current Frame Negative Samples", self.clearFrameNegativeAnchors + _menu_item( + predictionMenu, + "clear negative samples", + "Clear Current Frame Negative Samples", + self.clearFrameNegativeAnchors, ) + predictionMenu.addSeparator() - self._menu_actions["visualize models"] = predictionMenu.addAction( - "Visualize Model Outputs...", self.visualizeOutputs + _menu_item( + predictionMenu, + "visualize models", + "Visualize Model Outputs...", + self.visualizeOutputs, ) + predictionMenu.addSeparator() - self._menu_actions["remove predictions"] = predictionMenu.addAction( - "Delete All Predictions...", self.deletePredictions - ) - self._menu_actions["remove clip predictions"] = predictionMenu.addAction( + _menu_item( + predictionMenu, + "remove predictions", + "Delete All Predictions...", + self.deletePredictions, + ) + _menu_item( + predictionMenu, + "remove clip predictions", "Delete Predictions from Clip...", self.deleteClipPredictions, - shortcuts["delete clip"], ) - self._menu_actions["remove area predictions"] = predictionMenu.addAction( + _menu_item( + predictionMenu, + "remove area predictions", "Delete Predictions from Area...", self.deleteAreaPredictions, - shortcuts["delete area"], ) - self._menu_actions["remove score predictions"] = predictionMenu.addAction( - "Delete Predictions with Low Score...", self.deleteLowScorePredictions + _menu_item( + predictionMenu, + "remove score predictions", + "Delete Predictions with Low Score...", + self.deleteLowScorePredictions, ) - self._menu_actions["remove frame limit predictions"] = predictionMenu.addAction( - "Delete Predictions beyond Frame Limit...", self.deleteFrameLimitPredictions + _menu_item( + predictionMenu, + "remove frame limit predictions", + "Delete Predictions beyond Frame Limit...", + self.deleteFrameLimitPredictions, ) + predictionMenu.addSeparator() - self._menu_actions["export frames"] = predictionMenu.addAction( - "Export Training Package...", self.exportLabeledFrames + _menu_item( + predictionMenu, + "export frames", + "Export Training Package...", + self.exportLabeledFrames, ) - self._menu_actions["export clip"] = predictionMenu.addAction( - "Export Labeled Clip...", self.exportLabeledClip, shortcuts["export clip"] + _menu_item( + predictionMenu, + "export clip", + "Export Labeled Clip...", + self.exportLabeledClip, ) ############ helpMenu = self.menuBar().addMenu("Help") - helpMenu.addAction("Documentation", self.openDocumentation) helpMenu.addAction("Keyboard Reference", self.openKeyRef) - helpMenu.addAction("About", self.openAbout) - ####### Helpers ####### + def _create_dock_windows(self): + """Create dock windows and connects them to gui.""" + def _make_dock(name, widgets=[], tab_with=None): dock = QDockWidget(name) dock.setAllowedAreas(Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea) @@ -438,7 +481,7 @@ def _make_dock(name, widgets=[], tab_with=None): dock_widget.setLayout(layout) dock.setWidget(dock_widget) self.addDockWidget(Qt.RightDockWidgetArea, dock) - viewMenu.addAction(dock.toggleViewAction()) + self.viewMenu.addAction(dock.toggleViewAction()) if tab_with is not None: self.tabifyDockWidget(tab_with, dock) return layout @@ -488,6 +531,9 @@ def _make_dock(name, widgets=[], tab_with=None): gb.setLayout(vb) skeleton_layout.addWidget(gb) + def _update_edge_src(): + self.skeletonEdgesDst.model().skeleton = self.skeleton + gb = QGroupBox("Edges") vb = QVBoxLayout() self.skeletonEdgesTable = SkeletonEdgesTable(self.skeleton) @@ -495,7 +541,7 @@ def _make_dock(name, widgets=[], tab_with=None): hb = QHBoxLayout() self.skeletonEdgesSrc = QComboBox() self.skeletonEdgesSrc.setEditable(False) - self.skeletonEdgesSrc.currentIndexChanged.connect(self.selectSkeletonEdgeSrc) + self.skeletonEdgesSrc.currentIndexChanged.connect(_update_edge_src) self.skeletonEdgesSrc.setModel(SkeletonNodeModel(self.skeleton)) hb.addWidget(self.skeletonEdgesSrc) hb.addWidget(QLabel("to")) @@ -607,34 +653,15 @@ def update_instance_table_selection(): ) ) - # - # Set timer to update state of gui at regular intervals - # - self.update_gui_timer = QtCore.QTimer() - self.update_gui_timer.timeout.connect(self.update_gui_state) - self.update_gui_timer.start(0.1) - def load_overlays(self): - - self.overlays["track_labels"] = TrackListOverlay( - labels=self.labels, view=self.player.view, color_manager=self._color_manager - ) - - self.overlays["negative"] = NegativeAnchorOverlay( - labels=self.labels, scene=self.player.view.scene - ) - - self.overlays["trails"] = TrackTrailOverlay( - labels=self.labels, - scene=self.player.view.scene, - color_manager=self._color_manager, - ) - - self.overlays["instance"] = InstanceOverlay( - labels=self.labels, player=self.player, color_manager=self._color_manager - ) - - def update_gui_state(self): + """Load all standard video overlays.""" + self.overlays["track_labels"] = TrackListOverlay(self.labels, self.player) + self.overlays["negative"] = NegativeAnchorOverlay(self.labels, self.player) + self.overlays["trails"] = TrackTrailOverlay(self.labels, self.player) + self.overlays["instance"] = InstanceOverlay(self.labels, self.player) + + def _update_gui_state(self): + """Enable/disable gui items based on current state.""" has_selected_instance = self.player.view.getSelection() is not None has_unsaved_changes = self.changestack_has_changes() has_multiple_videos = self.labels is not None and len(self.labels.videos) > 1 @@ -662,7 +689,6 @@ def update_gui_state(self): self._menu_actions["transpose"].setEnabled(has_multiple_instances) self._menu_actions["save"].setEnabled(has_unsaved_changes) - self._menu_actions["goto marked"].setEnabled(self.mark_idx is not None) self._menu_actions["next video"].setEnabled(has_multiple_videos) self._menu_actions["prev video"].setEnabled(has_multiple_videos) @@ -698,14 +724,23 @@ def update_gui_state(self): control_key_down and has_selected_instance ) - def update_data_views(self, *update): + def _update_data_views(self, *update): + """Update data used by data view table models. + + Args: + Accepts names of what data to update as unnamed string arguments: + "video", "skeleton", "labels", "frame", "suggestions" + If no arguments are given, then everything is updated. + Returns: + None. + """ update = update or ("video", "skeleton", "labels", "frame", "suggestions") if len(self.skeleton.nodes) == 0 and len(self.labels.skeletons): self.skeleton = self.labels.skeletons[0] if "video" in update: - self.videosTable.model().videos = self.labels.videos + self.videosTable.model().items = self.labels.videos if "skeleton" in update: self.skeletonNodesTable.model().skeleton = self.skeleton @@ -738,7 +773,7 @@ def update_data_views(self, *update): self.suggested_count_label.setText(suggestion_status_text) def plotFrame(self, *args, **kwargs): - """Wrap call to player.plot so we can redraw/update things.""" + """Plots (or replots) current frame.""" if self.video is None: return @@ -748,7 +783,61 @@ def plotFrame(self, *args, **kwargs): if self._auto_zoom: self.player.zoomToFit() - def importData(self, filename=None, do_load=True): + def _after_plot_update(self, player, frame_idx, selected_inst): + """Called each time a new frame is drawn.""" + + # Store the current LabeledFrame (or make new, empty object) + self.labeled_frame = self.labels.find(self.video, frame_idx, return_new=True)[0] + + # Show instances, etc, for this frame + for overlay in self.overlays.values(): + overlay.add_to_scene(self.video, frame_idx) + + # Select instance if there was already selection + if selected_inst is not None: + player.view.selectInstance(selected_inst) + + # Update related displays + self.updateStatusMessage() + self._update_data_views("frame") + + # Trigger event after the overlays have been added + player.view.updatedViewer.emit() + + def updateStatusMessage(self, message: Optional[str] = None): + """Updates status bar.""" + if message is None: + message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" + if self.player.seekbar.hasSelection(): + start, end = self.player.seekbar.getSelection() + message += f" (selection: {start}-{end})" + + if len(self.labels.videos) > 1: + message += f" of video {self.labels.videos.index(self.video)}" + + message += f" Labeled Frames: " + if self.video is not None: + message += ( + f"{len(self.labels.get_video_user_labeled_frames(self.video))}" + ) + if len(self.labels.videos) > 1: + message += " in video, " + if len(self.labels.videos) > 1: + message += f"{len(self.labels.user_labeled_frames)} in project" + + self.statusBar().showMessage(message) + + def loadProject(self, filename: Optional[str] = None): + """ + Loads given labels file into GUI. + + Args: + filename: The path to the saved labels dataset. If None, + then don't do anything. + + Returns: + None: + """ show_msg = False if len(filename) == 0: @@ -772,43 +861,40 @@ def importData(self, filename=None, do_load=True): print(e) QMessageBox(text=f"Unable to load {filename}.").exec_() - if do_load: + self.labels = labels + self.filename = filename - self.labels = labels - self.filename = filename + if has_loaded: + self.changestack_clear() + self._color_manager.labels = self.labels + self._color_manager.set_palette(self._color_palette) - if has_loaded: - self.changestack_clear() - self._color_manager.labels = self.labels - self._color_manager.set_palette(self._color_palette) + self.load_overlays() - self.load_overlays() + self.setTrailLength(self.overlays["trails"].trail_length) - self.setTrailLength(self.overlays["trails"].trail_length) - - if show_msg: - msgBox = QMessageBox( - text=f"Imported {len(self.labels)} labeled frames." - ) - msgBox.exec_() + if show_msg: + msgBox = QMessageBox( + text=f"Imported {len(self.labels)} labeled frames." + ) + msgBox.exec_() - if len(self.labels.skeletons): - # TODO: add support for multiple skeletons - self.skeleton = self.labels.skeletons[0] + if len(self.labels.skeletons): + # TODO: add support for multiple skeletons + self.skeleton = self.labels.skeletons[0] - # Update UI tables - self.update_data_views() + # Update UI tables + self._update_data_views() - # Load first video - if len(self.labels.videos): - self.loadVideo(self.labels.videos[0], 0) + # Load first video + if len(self.labels.videos): + self.loadVideo(self.labels.videos[0], 0) - # Update track menu options - self.updateTrackMenu() - else: - return labels + # Update track menu options + self.updateTrackMenu() def updateTrackMenu(self): + """Updates track menu options.""" self.track_menu.clear() for track in self.labels.tracks: key_command = "" @@ -820,13 +906,23 @@ def updateTrackMenu(self): self.track_menu.addAction("New Track", self.addTrack, Qt.CTRL + Qt.Key_0) def activateSelectedVideo(self, x): + """Activates video selected in table.""" # Get selected video idx = self.videosTable.currentIndex() if not idx.isValid(): return self.loadVideo(self.labels.videos[idx.row()], idx.row()) - def addVideo(self, filename=None): + def addVideo(self, filename: Optional[str] = None): + """Shows gui for adding video to project. + + Args: + filename: If given, then we just load this video. If not given, + then we show dialog for importing videos. + + Returns: + None. + """ # Browse for file video = None if isinstance(filename, str): @@ -847,9 +943,10 @@ def addVideo(self, filename=None): self.loadVideo(video, len(self.labels.videos) - 1) # Update data model/view - self.update_data_views("video") + self._update_data_views("video") def removeVideo(self): + """Removes video (selected in table) from project.""" # Get selected video idx = self.videosTable.currentIndex() if not idx.isValid(): @@ -864,7 +961,8 @@ def removeVideo(self): response = QMessageBox.critical( self, "Removing video with labels", - f"{n} labeled frames in this video will be deleted, are you sure you want to remove this video?", + f"{n} labeled frames in this video will be deleted, " + "are you sure you want to remove this video?", QMessageBox.Yes, QMessageBox.No, ) @@ -876,7 +974,7 @@ def removeVideo(self): self.changestack_push("remove video") # Update data model - self.update_data_views() + self._update_data_views() # Update view if this was the current video if self.video == video: @@ -888,8 +986,7 @@ def removeVideo(self): self.loadVideo(self.labels.videos[new_idx], new_idx) def loadVideo(self, video: Video, video_idx: int = None): - # Clear video frame mark - self.mark_idx = None + """Activates video in gui.""" # Update current video instance self.video = video @@ -909,6 +1006,7 @@ def loadVideo(self, video: Video, video_idx: int = None): self.plotFrame(last_label.frame_idx) def openSkeleton(self): + """Shows gui for loading saved skeleton into project.""" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] filename, selected_filter = QFileDialog.getOpenFileName( self, dir=None, caption="Open skeleton...", filter=";;".join(filters) @@ -929,9 +1027,10 @@ def openSkeleton(self): self.changestack_push("new skeleton") # Update data model - self.update_data_views() + self._update_data_views() def saveSkeleton(self): + """Shows gui for saving skeleton from project.""" default_name = "skeleton.json" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] filename, selected_filter = QFileDialog.getSaveFileName( @@ -947,6 +1046,7 @@ def saveSkeleton(self): self.skeleton.save_hdf5(filename) def newNode(self): + """Adds new node to skeleton.""" # Find new part name part_name = "new_part" i = 1 @@ -959,11 +1059,12 @@ def newNode(self): self.changestack_push("new node") # Update data model - self.update_data_views() + self._update_data_views() self.plotFrame() def deleteNode(self): + """Removes (currently selected) node from skeleton.""" # Get selected node idx = self.skeletonNodesTable.currentIndex() if not idx.isValid(): @@ -975,19 +1076,18 @@ def deleteNode(self): self.changestack_push("delete node") # Update data model - self.update_data_views() + self._update_data_views() # Replot instances self.plotFrame() - def selectSkeletonEdgeSrc(self): - self.skeletonEdgesDst.model().skeleton = self.skeleton - def updateEdges(self): - self.update_data_views() + """Called when edges in skeleton have been changed.""" + self._update_data_views() self.plotFrame() def newEdge(self): + """Adds new edge to skeleton.""" # TODO: Move this to unified data model # Get selected nodes @@ -1003,11 +1103,12 @@ def newEdge(self): self.changestack_push("new edge") # Update data model - self.update_data_views() + self._update_data_views() self.plotFrame() def deleteEdge(self): + """Removes (currently selected) edge from skeleton.""" # TODO: Move this to unified data model # Get selected edge @@ -1021,14 +1122,16 @@ def deleteEdge(self): self.changestack_push("delete edge") # Update data model - self.update_data_views() + self._update_data_views() self.plotFrame() def updateSeekbarMarks(self): + """Updates marks on seekbar.""" self.player.seekbar.setTracksFromLabels(self.labels, self.video) def setSeekbarHeader(self, graph_name): + """Updates graph shown in seekbar header.""" data_obj = StatisticSeries(self.labels) header_functions = { "Point Displacement (sum)": data_obj.get_point_displacement_series, @@ -1055,7 +1158,8 @@ def setSeekbarHeader(self, graph_name): else: print(f"Could not find function for {header_functions}") - def generateSuggestions(self, params): + def generateSuggestions(self, params: Dict): + """Generates suggestions using given params dictionary.""" new_suggestions = dict() for video in self.labels.videos: new_suggestions[video] = VideoFrameSuggestions.suggest( @@ -1064,10 +1168,19 @@ def generateSuggestions(self, params): self.labels.set_suggestions(new_suggestions) - self.update_data_views("suggestions") + self._update_data_views("suggestions") self.updateSeekbarMarks() def _frames_for_prediction(self): + """Builds options for frames on which to run inference. + + Args: + None. + Returns: + Dictionary, keys are names of options (e.g., "clip", "random"), + values are {video: list of frame indices} dictionaries. + """ + def remove_user_labeled( video, frames, user_labeled_frames=self.labels.user_labeled_frames ): @@ -1097,12 +1210,25 @@ def remove_user_labeled( return selection - def _show_learning_window(self, mode): + def showLearningDialog(self, mode: str): + """Helper function to show active learning dialog in given mode. + + Args: + mode: A string representing mode for dialog, which could be: + * "active" + * "inference" + * "expert" + + Returns: + None. + """ from sleap.gui.active import ActiveLearningDialog if "inference" in self.overlays: QMessageBox( - text=f"In order to use this function you must first quit and re-open sLEAP to release resources used by visualizing model outputs." + text="In order to use this function you must first quit and " + "re-open SLEAP to release resources used by visualizing " + "model outputs." ).exec_() return @@ -1116,22 +1242,15 @@ def _show_learning_window(self, mode): self._child_windows[mode].open() def learningFinished(self): + """Called when active learning (or inference) finishes.""" # we ran active learning so update display/ui self.plotFrame() self.updateSeekbarMarks() - self.update_data_views() + self._update_data_views() self.changestack_push("new predictions") - def runLearningExpert(self): - self._show_learning_window("expert") - - def runInference(self): - self._show_learning_window("inference") - - def runActiveLearning(self): - self._show_learning_window("learning") - def visualizeOutputs(self): + """Gui for adding overlay with live visualization of predictions.""" filters = ["Model (*.json)", "HDF5 output (*.h5 *.hdf5)"] # Default to opening from models directory from project @@ -1186,6 +1305,7 @@ def visualizeOutputs(self): self.plotFrame() def deletePredictions(self): + """Deletes all predicted instances in project.""" predicted_instances = [ (lf, inst) @@ -1194,26 +1314,10 @@ def deletePredictions(self): if type(inst) == PredictedInstance ] - resp = QMessageBox.critical( - self, - "Removing predicted instances", - f"There are {len(predicted_instances)} predicted instances. " - "Are you sure you want to delete these?", - QMessageBox.Yes, - QMessageBox.No, - ) - - if resp == QMessageBox.No: - return - - for lf, inst in predicted_instances: - self.labels.remove_instance(lf, inst) - - self.plotFrame() - self.updateSeekbarMarks() - self.changestack_push("removed predictions") + self._delete_confirm(predicted_instances) def deleteClipPredictions(self): + """Deletes all instances within selected range of video frames.""" predicted_instances = [ (lf, inst) @@ -1238,27 +1342,10 @@ def deleteClipPredictions(self): filter(lambda x: x[1].track == track, predicted_instances) ) - resp = QMessageBox.critical( - self, - "Removing predicted instances", - f"There are {len(predicted_instances)} predicted instances. " - "Are you sure you want to delete these?", - QMessageBox.Yes, - QMessageBox.No, - ) - - if resp == QMessageBox.No: - return - - # Delete the instances - for lf, inst in predicted_instances: - self.labels.remove_instance(lf, inst) - - self.plotFrame() - self.updateSeekbarMarks() - self.changestack_push("removed predictions") + self._delete_confirm(predicted_instances) def deleteAreaPredictions(self): + """Gui for deleting instances within some rect on frame images.""" # Callback to delete after area has been selected def delete_area_callback(x0, y0, x1, y1): @@ -1297,6 +1384,7 @@ def is_bounded(inst): self.player.onAreaSelection(delete_area_callback) def deleteLowScorePredictions(self): + """Gui for deleting instances below some score threshold.""" score_thresh, okay = QtWidgets.QInputDialog.getDouble( self, "Delete Instances with Low Score...", "Score Below:", 1, 0, 100 ) @@ -1312,6 +1400,7 @@ def deleteLowScorePredictions(self): self._delete_confirm(predicted_instances) def deleteFrameLimitPredictions(self): + """Gui for deleting instances beyond some number in each frame.""" count_thresh, okay = QtWidgets.QInputDialog.getInt( self, "Limit Instances in Frame...", @@ -1334,6 +1423,11 @@ def deleteFrameLimitPredictions(self): self._delete_confirm(predicted_instances) def _delete_confirm(self, lf_inst_list): + """Helper function to confirm before deleting instances. + + Args: + lf_inst_list: A list of (labeled frame, instance) tuples. + """ # Confirm that we want to delete resp = QMessageBox.critical( @@ -1358,6 +1452,8 @@ def _delete_confirm(self, lf_inst_list): self.changestack_push("removed predictions") def markNegativeAnchor(self): + """Allows user to add negative training sample anchor.""" + def click_callback(x, y): self.updateStatusMessage() self.labels.add_negative_anchor(self.video, self.player.frame_idx, (x, y)) @@ -1369,11 +1465,13 @@ def click_callback(x, y): self.player.onPointSelection(click_callback) def clearFrameNegativeAnchors(self): + """Removes negative training sample anchors on current frame.""" self.labels.remove_negative_anchors(self.video, self.player.frame_idx) self.changestack_push("remove negative anchors") self.plotFrame() def importPredictions(self): + """Starts gui for importing another dataset into currently one.""" filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"] filenames, selected_filter = QFileDialog.getOpenFileNames( self, dir=None, caption="Import labeled data...", filter=";;".join(filters) @@ -1395,10 +1493,20 @@ def importPredictions(self): # update display/ui self.plotFrame() self.updateSeekbarMarks() - self.update_data_views() + self._update_data_views() self.changestack_push("new predictions") - def doubleClickInstance(self, instance): + def doubleClickInstance(self, instance: Instance): + """ + Handles when the user has double-clicked an instance. + + If prediction, then copy to new user-instance. + If already user instance, then add any missing nodes (in case + skeleton has been changed after instance was created). + + Args: + instance: The :class:`Instance` that was double-clicked. + """ # When a predicted instance is double-clicked, add a new instance if hasattr(instance, "score"): self.newInstance(copy_instance=instance) @@ -1428,7 +1536,14 @@ def doubleClickInstance(self, instance): self.plotFrame() - def newInstance(self, copy_instance=None): + def newInstance(self, copy_instance: Optional[Instance] = None): + """ + Creates a new instance, copying node coordinates as appropriate. + + Args: + copy_instance: The :class:`Instance` (or + :class:`PredictedInstance`) which we want to copy. + """ if self.labeled_frame is None: return @@ -1537,6 +1652,7 @@ def newInstance(self, copy_instance=None): self.updateTrackMenu() def deleteSelectedInstance(self): + """Deletes currently selected instance.""" selected_inst = self.player.view.getSelectionInstance() if selected_inst is None: return @@ -1548,6 +1664,7 @@ def deleteSelectedInstance(self): self.updateSeekbarMarks() def deleteSelectedInstanceTrack(self): + """Deletes all instances from track of currently selected instance.""" selected_inst = self.player.view.getSelectionInstance() if selected_inst is None: return @@ -1570,6 +1687,7 @@ def deleteSelectedInstanceTrack(self): self.updateSeekbarMarks() def addTrack(self): + """Creates new track and moves selected instance into this track.""" track_numbers_used = [ int(track.name) for track in self.labels.tracks if track.name.isnumeric() ] @@ -1587,7 +1705,8 @@ def addTrack(self): self.updateTrackMenu() self.updateSeekbarMarks() - def setInstanceTrack(self, new_track): + def setInstanceTrack(self, new_track: "Track"): + """Sets track for selected instance.""" vis_idx = self.player.view.getSelection() if vis_idx is None: return @@ -1638,6 +1757,12 @@ def setInstanceTrack(self, new_track): self.player.view.selectInstance(idx) def transposeInstance(self): + """Transposes tracks for two instances. + + If there are only two instances, then this swaps tracks. + Otherwise, it allows user to select the instances for which we want + to swap tracks. + """ # We're currently identifying instances by numeric index, so it's # impossible to (e.g.) have a single instance which we identify # as the second instance in some other frame. @@ -1685,16 +1810,29 @@ def _transpose_instances(self, instance_ids: list): self.updateSeekbarMarks() def newProject(self): + """Create a new project in a new window.""" window = MainWindow() window.showMaximized() - def openProject(self, first_open=False): + def openProject(self, first_open: bool = False): + """ + Allows use to select and then open a saved project. + + Args: + first_open: Whether this is the first window opened. If True, + then the new project is loaded into the current window + rather than a new application window. + + Returns: + None. + """ filters = [ - "JSON labels (*.json *.json.zip)", "HDF5 dataset (*.h5 *.hdf5)", + "JSON labels (*.json *.json.zip)", "Matlab dataset (*.mat)", "DeepLabCut csv (*.csv)", ] + filename, selected_filter = QFileDialog.getOpenFileName( self, dir=None, caption="Import labeled data...", filter=";;".join(filters) ) @@ -1705,28 +1843,28 @@ def openProject(self, first_open=False): if OPEN_IN_NEW and not first_open: new_window = MainWindow() new_window.showMaximized() - new_window.importData(filename) + new_window.loadProject(filename) else: - self.importData(filename) + self.loadProject(filename) def saveProject(self): + """Show gui to save project (or save as if not yet saved).""" if self.filename is not None: - filename = self.filename - self._trySave(self.filename) else: # No filename (must be new project), so treat as "Save as" self.saveProjectAs() def saveProjectAs(self): + """Show gui to save project as a new file.""" default_name = self.filename if self.filename is not None else "untitled.json" p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem} copy{p.suffix}")) filters = [ + "HDF5 dataset (*.h5)", "JSON labels (*.json)", "Compressed JSON (*.zip)", - "HDF5 dataset (*.h5)", ] filename, selected_filter = QFileDialog.getSaveFileName( self, caption="Save As...", dir=default_name, filter=";;".join(filters) @@ -1740,6 +1878,7 @@ def saveProjectAs(self): self.filename = filename def _trySave(self, filename): + """Helper function which attempts save and handles errors.""" success = False try: Labels.save_file(labels=self.labels, filename=filename) @@ -1758,6 +1897,7 @@ def _trySave(self, filename): return success def closeEvent(self, event): + """Closes application window, prompting for saving as needed.""" if not self.changestack_has_changes(): # No unsaved changes, so accept event (close) event.accept() @@ -1785,16 +1925,19 @@ def closeEvent(self, event): event.accept() def nextVideo(self): + """Activates next video in project.""" new_idx = self.video_idx + 1 new_idx = 0 if new_idx >= len(self.labels.videos) else new_idx self.loadVideo(self.labels.videos[new_idx], new_idx) def previousVideo(self): + """Activates previous video in project.""" new_idx = self.video_idx - 1 new_idx = len(self.labels.videos) - 1 if new_idx < 0 else new_idx self.loadVideo(self.labels.videos[new_idx], new_idx) def gotoFrame(self): + """Shows gui to go to frame by number.""" frame_number, okay = QtWidgets.QInputDialog.getInt( self, "Go To Frame...", @@ -1806,13 +1949,8 @@ def gotoFrame(self): if okay: self.plotFrame(frame_number - 1) - def markFrame(self): - self.mark_idx = self.player.frame_idx - - def goMarkedFrame(self): - self.plotFrame(self.mark_idx) - def exportLabeledClip(self): + """Shows gui for exporting clip with visual annotations.""" from sleap.io.visuals import save_labeled_video if self.player.seekbar.hasSelection(): @@ -1848,6 +1986,7 @@ def exportLabeledClip(self): ) def exportLabeledFrames(self): + """Gui for exporting the training dataset of labels/frame images.""" filename, _ = QFileDialog.getSaveFileName( self, caption="Save Labeled Frames As...", dir=self.filename ) @@ -1855,49 +1994,45 @@ def exportLabeledFrames(self): return Labels.save_json(self.labels, filename, save_frame_data=True) - def previousLabeledFrameIndex(self): - cur_idx = self.player.frame_idx - frames = self.labels.frames(self.video, from_frame_idx=cur_idx, reverse=True) + def _plot_if_next(self, frame_iterator: Iterator) -> bool: + """Plots next frame (if there is one) from iterator. + Arguments: + frame_iterator: The iterator from which we'll try to get next + :class:`LabeledFrame`. + + Returns: + True if we went to next frame. + """ try: - next_idx = next(frames).frame_idx - except: - return + next_lf = next(frame_iterator) + except StopIteration: + return False - return next_idx + self.plotFrame(next_lf.frame_idx) + return True def previousLabeledFrame(self): - prev_idx = self.previousLabeledFrameIndex() - if prev_idx is not None: - self.plotFrame(prev_idx) + """Goes to labeled frame prior to current frame.""" + frames = self.labels.frames( + self.video, from_frame_idx=self.player.frame_idx, reverse=True + ) + self._plot_if_next(frames) def nextLabeledFrame(self): - cur_idx = self.player.frame_idx - - frames = self.labels.frames(self.video, from_frame_idx=cur_idx) - - try: - next_idx = next(frames).frame_idx - except: - return - - self.plotFrame(next_idx) + """Goes to labeled frame after current frame.""" + frames = self.labels.frames(self.video, from_frame_idx=self.player.frame_idx) + self._plot_if_next(frames) def nextUserLabeledFrame(self): - cur_idx = self.player.frame_idx - - frames = self.labels.frames(self.video, from_frame_idx=cur_idx) + """Goes to next labeled frame with user instances.""" + frames = self.labels.frames(self.video, from_frame_idx=self.player.frame_idx) # Filter to frames with user instances frames = filter(lambda lf: lf.has_user_instances, frames) - - try: - next_idx = next(frames).frame_idx - except: - return - - self.plotFrame(next_idx) + self._plot_if_next(frames) def nextSuggestedFrame(self, seek_direction=1): + """Goes to next (or previous) suggested frame.""" next_video, next_frame = self.labels.get_next_suggestion( self.video, self.player.frame_idx, seek_direction ) @@ -1910,6 +2045,7 @@ def nextSuggestedFrame(self, seek_direction=1): self.suggestionsTable.selectRow(selection_idx) def nextTrackFrame(self): + """Goes to next frame on which a track starts.""" cur_idx = self.player.frame_idx track_ranges = self.labels.get_track_occupany(self.video) next_idx = min( @@ -1923,35 +2059,41 @@ def nextTrackFrame(self): if next_idx > -1: self.plotFrame(next_idx) - def gotoVideoAndFrame(self, video, frame_idx): + def gotoVideoAndFrame(self, video: Video, frame_idx: int): + """Activates video and goes to frame.""" if video != self.video: # switch to the other video self.loadVideo(video) self.plotFrame(frame_idx) def toggleLabels(self): + """Toggles whether skeleton node labels are shown in video overlay.""" self._show_labels = not self._show_labels self._menu_actions["show labels"].setChecked(self._show_labels) self.player.showLabels(self._show_labels) def toggleEdges(self): + """Toggles whether skeleton edges are shown in video overlay.""" self._show_edges = not self._show_edges self._menu_actions["show edges"].setChecked(self._show_edges) self.player.showEdges(self._show_edges) def toggleTrails(self): + """Toggles whether track trails are shown in video overlay.""" self.overlays["trails"].show = not self.overlays["trails"].show self._menu_actions["show trails"].setChecked(self.overlays["trails"].show) self.plotFrame() - def setTrailLength(self, trail_length): + def setTrailLength(self, trail_length: int): + """Sets length of track trails to show in video overlay.""" self.overlays["trails"].trail_length = trail_length self._menu_check_single(self.trailLengthMenu, trail_length) if self.video is not None: self.plotFrame() - def setPalette(self, palette): + def setPalette(self, palette: str): + """Sets color palette used for track colors.""" self._color_manager.set_palette(palette) self._menu_check_single(self.paletteMenu, palette) if self.video is not None: @@ -1959,6 +2101,7 @@ def setPalette(self, palette): self.updateSeekbarMarks() def _menu_check_single(self, menu, item_text): + """Helper method to select exactly one submenu item.""" for menu_item in menu.children(): if menu_item.text() == str(item_text): menu_item.setChecked(True) @@ -1966,82 +2109,36 @@ def _menu_check_single(self, menu, item_text): menu_item.setChecked(False) def toggleColorPredicted(self): - self.overlays["instance"].color_predicted = not self.overlays[ - "instance" - ].color_predicted + """Toggles whether predicted instances are shown in track colors.""" + val = self.overlays["instance"].color_predicted + self.overlays["instance"].color_predicted = not val self._menu_actions["color predicted"].setChecked( self.overlays["instance"].color_predicted ) self.plotFrame() def toggleAutoZoom(self): + """Toggles whether to zoom viewer to fit labeled instances.""" self._auto_zoom = not self._auto_zoom self._menu_actions["fit"].setChecked(self._auto_zoom) if not self._auto_zoom: self.player.view.clearZoom() self.plotFrame() - def openDocumentation(self): - pass - def openKeyRef(self): + """Shows gui for viewing/modifying keyboard shortucts.""" ShortcutDialog().exec_() - def openAbout(self): - pass - - def newFrame(self, player, frame_idx, selected_inst): - """Called each time a new frame is drawn.""" - - # Store the current LabeledFrame (or make new, empty object) - self.labeled_frame = self.labels.find(self.video, frame_idx, return_new=True)[0] - - # Show instances, etc, for this frame - for overlay in self.overlays.values(): - overlay.add_to_scene(self.video, frame_idx) - - # Select instance if there was already selection - if selected_inst is not None: - player.view.selectInstance(selected_inst) - - # Update related displays - self.updateStatusMessage() - self.update_data_views("frame") - - # Trigger event after the overlays have been added - player.view.updatedViewer.emit() - - def updateStatusMessage(self, message=None): - if message is None: - message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" - if self.player.seekbar.hasSelection(): - start, end = self.player.seekbar.getSelection() - message += f" (selection: {start}-{end})" - - if len(self.labels.videos) > 1: - message += f" of video {self.labels.videos.index(self.video)}" - - message += f" Labeled Frames: " - if self.video is not None: - message += ( - f"{len(self.labels.get_video_user_labeled_frames(self.video))}" - ) - if len(self.labels.videos) > 1: - message += " in video, " - if len(self.labels.videos) > 1: - message += f"{len(self.labels.user_labeled_frames)} in project" - - self.statusBar().showMessage(message) - def main(*args, **kwargs): + """Starts new instance of app.""" app = QApplication([]) - app.setApplicationName("sLEAP Label") + app.setApplicationName("SLEAP Label") window = MainWindow(*args, **kwargs) window.showMaximized() - if "import_data" not in kwargs: + if "labels_path" not in kwargs: window.openProject(first_open=True) app.exec_() @@ -2051,6 +2148,6 @@ def main(*args, **kwargs): kwargs = dict() if len(sys.argv) > 1: - kwargs["import_data"] = sys.argv[1] + kwargs["labels_path"] = sys.argv[1] main(**kwargs) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index fe8613587..b9dd9b6d0 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -1,103 +1,107 @@ -from PySide2 import QtCore -from PySide2.QtCore import Qt - -from PySide2.QtGui import QKeyEvent, QColor - -from PySide2.QtWidgets import QApplication, QMainWindow, QWidget, QDockWidget -from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout -from PySide2.QtWidgets import ( - QLabel, - QPushButton, - QLineEdit, - QSpinBox, - QDoubleSpinBox, - QComboBox, - QCheckBox, -) -from PySide2.QtWidgets import ( - QTableWidget, - QTableView, - QTableWidgetItem, - QAbstractItemView, -) -from PySide2.QtWidgets import QTreeView, QTreeWidget, QTreeWidgetItem -from PySide2.QtWidgets import QMenu, QAction -from PySide2.QtWidgets import QFileDialog, QMessageBox +""" +Data table widgets and view models used in GUI app. +""" + +from PySide2 import QtCore, QtWidgets, QtGui import os -import numpy as np -import pandas as pd +from operator import itemgetter -from typing import Callable +from typing import Callable, List, Optional from sleap.gui.overlays.tracks import TrackColorManager -from sleap.io.video import Video from sleap.io.dataset import Labels from sleap.instance import LabeledFrame, Instance -from sleap.skeleton import Skeleton, Node +from sleap.skeleton import Skeleton -class VideosTable(QTableView): - """Table view widget backed by a custom data model for displaying - lists of Video instances. """ +class VideosTable(QtWidgets.QTableView): + """Table view widget for listing videos in dataset.""" def __init__(self, videos: list = []): super(VideosTable, self).__init__() - self.setModel(VideosTableModel(videos)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + props = ("filename", "frames", "height", "width", "channels") + model = GenericTableModel(props, videos, useCache=True) + + self.setModel(model) + + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + + +class GenericTableModel(QtCore.QAbstractTableModel): + """Generic table model to show a list of properties for some items. + + Args: + propList: The list of property names (table columns). + itemList: The list of items with said properties (rows). + useCache: Whether to build cache of property values for all items. + """ -class VideosTableModel(QtCore.QAbstractTableModel): - _props = ["filename", "frames", "height", "width", "channels"] + def __init__( + self, + propList: List[str], + itemList: Optional[list] = None, + useCache: bool = False, + ): + super(GenericTableModel, self).__init__() + self._use_cache = useCache + self._props = propList - def __init__(self, videos: list): - super(VideosTableModel, self).__init__() - self.videos = videos + if itemList is not None: + self.items = itemList + else: + self._data = [] @property - def videos(self): - return self._cache + def items(self): + """Gets or sets list of items to show in table.""" + return self._data - @videos.setter - def videos(self, val): + @items.setter + def items(self, val): self.beginResetModel() - self._cache = [] - for video in val: - row_data = dict( - filename=video.filename, - frames=video.frames, - height=video.height, - width=video.width, - channels=video.channels, - ) - self._cache.append(row_data) + if self._use_cache: + self._data = [] + for item in val: + item_data = {key: getattr(item, key) for key in self._props} + self._data.append(item_data) + else: + self._data = val self.endResetModel() - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data to show in table.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): idx = index.row() - prop = self._props[index.column()] + key = self._props[index.column()] + + if idx < self.rowCount(): + item = self.items[idx] - if len(self.videos) > (idx - 1): - video = self.videos[idx] + if isinstance(item, dict) and key in item: + return item[key] - if prop in video: - return video[prop] + if hasattr(item, key): + return getattr(item, key) return None - def rowCount(self, parent): - return len(self.videos) + def rowCount(self, parent=None): + """Overrides Qt method, returns number of rows (items).""" + return len(self._data) - def columnCount(self, parent): - return len(VideosTableModel._props) + def columnCount(self, parent=None): + """Overrides Qt method, returns number of columns (attributes).""" + return len(self._props) def headerData( - self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole ): - if role == Qt.DisplayRole: + """Overrides Qt method, returns column (attribute) names.""" + if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] elif orientation == QtCore.Qt.Vertical: @@ -105,22 +109,39 @@ def headerData( return None + def sort(self, column_idx: int, order: QtCore.Qt.SortOrder): + """Sorts table by given column and order.""" + prop = self._props[column_idx] + + sort_function = itemgetter(prop) + if prop in ("video", "frame"): + if "video" in self._props and "frame" in self._props: + sort_function = itemgetter("video", "frame") + + reverse = order == QtCore.Qt.SortOrder.DescendingOrder + + self.beginResetModel() + self._data.sort(key=sort_function, reverse=reverse) + self.endResetModel() + def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable + """Overrides Qt method, returns whether item is selectable etc.""" + return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable -class SkeletonNodesTable(QTableView): - """Table view widget backed by a custom data model for displaying and - editing Skeleton nodes. """ +class SkeletonNodesTable(QtWidgets.QTableView): + """Table view widget for displaying and editing Skeleton nodes. """ def __init__(self, skeleton: Skeleton): super(SkeletonNodesTable, self).__init__() self.setModel(SkeletonNodesTableModel(skeleton)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) class SkeletonNodesTableModel(QtCore.QAbstractTableModel): + """Table model for skeleton nodes.""" + _props = ["name", "symmetry"] def __init__(self, skeleton: Skeleton): @@ -129,6 +150,7 @@ def __init__(self, skeleton: Skeleton): @property def skeleton(self): + """Gets or sets current skeleton.""" return self._skeleton @skeleton.setter @@ -137,13 +159,12 @@ def skeleton(self, val): self._skeleton = val self.endResetModel() - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data to show in table.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): node_idx = index.row() prop = self._props[index.column()] - node = self.skeleton.nodes[ - node_idx - ] # FIXME? can we assume order is stable? + node = self.skeleton.nodes[node_idx] node_name = node.name if prop == "name": @@ -154,15 +175,18 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): return None def rowCount(self, parent): + """Overrides Qt method, returns number of rows.""" return len(self.skeleton.nodes) def columnCount(self, parent): + """Overrides Qt method, returns number of columns.""" return len(SkeletonNodesTableModel._props) def headerData( - self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole ): - if role == Qt.DisplayRole: + """Overrides Qt method, returns column names.""" + if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] elif orientation == QtCore.Qt.Vertical: @@ -170,24 +194,24 @@ def headerData( return None - def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): - if role == Qt.EditRole: + def setData(self, index: QtCore.QModelIndex, value: str, role=QtCore.Qt.EditRole): + """Overrides Qt method, updates skeleton with new data from user.""" + if role == QtCore.Qt.EditRole: node_idx = index.row() prop = self._props[index.column()] node_name = self.skeleton.nodes[node_idx].name try: if prop == "name": - if len(value) > 0: + # Change node name (unless empty string) + if value: self._skeleton.relabel_node(node_name, value) - # else: - # self._skeleton.delete_node(node_name) elif prop == "symmetry": - if len(value) > 0: + if value: self._skeleton.add_symmetry(node_name, value) else: - self._skeleton.delete_symmetry( - node_name, self._skeleton.get_symmetry(node_name) - ) + # Value was cleared by user, so delete symmetry + symmetric_to = self._skeleton.get_symmetry(node_name) + self._skeleton.delete_symmetry(node_name, symmetric_to) # send signal that data has changed self.dataChanged.emit(index, index) @@ -200,85 +224,65 @@ def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): return False def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable | Qt.ItemIsEditable + """Overrides Qt method, returns flags (editable etc).""" + return ( + QtCore.Qt.ItemIsEnabled + | QtCore.Qt.ItemIsSelectable + | QtCore.Qt.ItemIsEditable + ) -class SkeletonEdgesTable(QTableView): - """Table view widget backed by a custom data model for displaying and - editing Skeleton edges. """ +class SkeletonEdgesTable(QtWidgets.QTableView): + """Table view widget for skeleton edges.""" def __init__(self, skeleton: Skeleton): super(SkeletonEdgesTable, self).__init__() self.setModel(SkeletonEdgesTableModel(skeleton)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + + +class SkeletonEdgesTableModel(GenericTableModel): + """Table model for skeleton edges. -class SkeletonEdgesTableModel(QtCore.QAbstractTableModel): - _props = ["source", "destination"] + Args: + skeleton: The skeleton to show in table. + """ def __init__(self, skeleton: Skeleton): - super(SkeletonEdgesTableModel, self).__init__() - self._skeleton = skeleton + props = ("source", "destination") + super(SkeletonEdgesTableModel, self).__init__(props) + self.skeleton = skeleton @property def skeleton(self): + """Gets or sets current skeleton.""" return self._skeleton @skeleton.setter def skeleton(self, val): - self.beginResetModel() self._skeleton = val - self.endResetModel() - - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): - idx = index.row() - prop = self._props[index.column()] - edge = self.skeleton.edges[idx] - - if prop == "source": - return edge[0].name - elif prop == "destination": - return edge[1].name - - return None - - def rowCount(self, parent): - return len(self.skeleton.edges) - - def columnCount(self, parent): - return len(SkeletonNodesTableModel._props) - - def headerData( - self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole - ): - if role == Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: - return self._props[section] - elif orientation == QtCore.Qt.Vertical: - return section - - return None - - def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable + items = [ + dict(source=edge[0].name, destination=edge[1].name) + for edge in self._skeleton.edges + ] + self.items = items -class LabeledFrameTable(QTableView): - """Table view widget backed by a custom data model for displaying - lists of Video instances. """ +class LabeledFrameTable(QtWidgets.QTableView): + """Table view widget for listing instances in labeled frame.""" selectionChangedSignal = QtCore.Signal(Instance) def __init__(self, labeled_frame: LabeledFrame = None, labels: Labels = None): super(LabeledFrameTable, self).__init__() self.setModel(LabeledFrameTableModel(labeled_frame, labels)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) def selectionChanged(self, new, old): - """Return `Instance` selected in table.""" + """Custom event handler, emits selectionChangedSignal signal.""" super(LabeledFrameTable, self).selectionChanged(new, old) instance = None @@ -294,6 +298,15 @@ def selectionChanged(self, new, old): class LabeledFrameTableModel(QtCore.QAbstractTableModel): + """Table model for listing instances in labeled frame. + + Allows editing track names. + + Args: + labeled_frame: `LabeledFrame` to show + labels: `Labels` datasource + """ + _props = ("points", "track", "score", "skeleton") def __init__(self, labeled_frame: LabeledFrame, labels: Labels): @@ -303,6 +316,7 @@ def __init__(self, labeled_frame: LabeledFrame, labels: Labels): @property def labeled_frame(self): + """Gets or sets current labeled frame.""" return self._labeled_frame @labeled_frame.setter @@ -313,6 +327,7 @@ def labeled_frame(self, val): @property def labels(self): + """Gets or sets current labels dataset object.""" return self._labels @labels.setter @@ -322,13 +337,15 @@ def labels(self, val): @property def color_manager(self): + """Gets or sets object for determining track colors.""" return self._color_manager @color_manager.setter def color_manager(self, val): self._color_manager = val - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data to show in table.""" if index.isValid(): idx = index.row() prop = self._props[index.column()] @@ -337,7 +354,7 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): instance = self.labeled_frame.instances_to_show[idx] # Cell value - if role == Qt.DisplayRole: + if role == QtCore.Qt.DisplayRole: if prop == "points": return f"{len(instance.nodes)}/{len(instance.skeleton.nodes)}" elif prop == "track" and instance.track is not None: @@ -351,13 +368,16 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): return "" # Cell color - elif role == Qt.ForegroundRole: + elif role == QtCore.Qt.ForegroundRole: if prop == "track" and instance.track is not None: - return QColor(*self.color_manager.get_color(instance.track)) + return QtGui.QColor( + *self.color_manager.get_color(instance.track) + ) return None def rowCount(self, parent): + """Overrides Qt method, returns number of rows.""" return ( len(self.labeled_frame.instances_to_show) if self.labeled_frame is not None @@ -365,12 +385,14 @@ def rowCount(self, parent): ) def columnCount(self, parent): + """Overrides Qt method, returns number of columns.""" return len(LabeledFrameTableModel._props) def headerData( - self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole ): - if role == Qt.DisplayRole: + """Overrides Qt method, returns column names.""" + if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] elif orientation == QtCore.Qt.Vertical: @@ -378,8 +400,11 @@ def headerData( return None - def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): - if role == Qt.EditRole: + def setData(self, index: QtCore.QModelIndex, value: str, role=QtCore.Qt.EditRole): + """ + Overrides Qt method, sets data in labeled frame from user changes. + """ + if role == QtCore.Qt.EditRole: idx = index.row() prop = self._props[index.column()] instance = self.labeled_frame.instances_to_show[idx] @@ -394,18 +419,31 @@ def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): return False def flags(self, index: QtCore.QModelIndex): - f = Qt.ItemIsEnabled | Qt.ItemIsSelectable + """Overrides Qt method, returns flags (editable etc).""" + f = QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable if index.isValid(): idx = index.row() if idx < len(self.labeled_frame.instances_to_show): instance = self.labeled_frame.instances_to_show[idx] prop = self._props[index.column()] if prop == "track" and instance.track is not None: - f |= Qt.ItemIsEditable + f |= QtCore.Qt.ItemIsEditable return f class SkeletonNodeModel(QtCore.QStringListModel): + """ + String list model for source/destination nodes of edges. + + Args: + skeleton: The skeleton for which to list nodes. + src_node: If given, then we assume that this model is being used for + edge destination node. Otherwise, we assume that this model is + being used for an edge source node. + If given, then this should be function that will return the + selected edge source node. + """ + def __init__(self, skeleton: Skeleton, src_node: Callable = None): super(SkeletonNodeModel, self).__init__() self._src_node = src_node @@ -413,6 +451,7 @@ def __init__(self, skeleton: Skeleton, src_node: Callable = None): @property def skeleton(self): + """Gets or sets current skeleton.""" return self._skeleton @skeleton.setter @@ -447,44 +486,50 @@ def is_valid_dst(node): return valid_dst_nodes - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data for given row.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): idx = index.row() return self._node_list[idx] return None def rowCount(self, parent): + """Overrides Qt method, returns number of rows.""" return len(self._node_list) def columnCount(self, parent): + """Overrides Qt method, returns number of columns (1).""" return 1 def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable + """Overrides Qt method, returns flags (editable etc).""" + return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable -class SuggestionsTable(QTableView): - """Table view widget backed by a custom data model for displaying - lists of Video instances. """ +class SuggestionsTable(QtWidgets.QTableView): + """Table view widget for showing frame suggestions.""" def __init__(self, labels): super(SuggestionsTable, self).__init__() self.setModel(SuggestionsTableModel(labels)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) self.setSortingEnabled(True) -class SuggestionsTableModel(QtCore.QAbstractTableModel): - _props = ["video", "frame", "labeled", "mean score"] +class SuggestionsTableModel(GenericTableModel): + """Table model for showing frame suggestions.""" def __init__(self, labels): - super(SuggestionsTableModel, self).__init__() + props = ("video", "frame", "labeled", "mean score") + + super(SuggestionsTableModel, self).__init__(propList=props) self.labels = labels @property def labels(self): + """Gets or sets current labels dataset.""" return self._labels @labels.setter @@ -492,79 +537,37 @@ def labels(self, val): self.beginResetModel() self._labels = val - self._suggestions_list = self.labels.get_suggestions() - - self.endResetModel() - - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): - idx = index.row() - prop = self._props[index.column()] - - if idx < self.rowCount(): - video = self._suggestions_list[idx][0] - frame_idx = self._suggestions_list[idx][1] - - if prop == "video": - return f"{self.labels.videos.index(video)}: {os.path.basename(video.filename)}" - elif prop == "frame": - return int(frame_idx) + 1 # start at frame 1 rather than 0 - elif prop == "labeled": - # show how many labeled instances are in this frame - val = self._labels.instance_count(video, frame_idx) - val = str(val) if val > 0 else "" - return val - elif prop == "mean score": - return self._getScore(video, frame_idx) - return None - - def _getScore(self, video, frame_idx): - scores = [ - inst.score - for lf in self.labels.find(video, frame_idx) - for inst in lf - if hasattr(inst, "score") - ] - return sum(scores) / len(scores) - - def sort(self, column_idx: int, order: Qt.SortOrder): - prop = self._props[column_idx] - if prop in ("video", "frame"): - sort_function = lambda s: s - elif prop == "labeled": - sort_function = lambda s: self._labels.instance_count(*s) - elif prop == "mean score": - sort_function = lambda s: self._getScore(*s) - - reverse = order == Qt.SortOrder.DescendingOrder - - self.beginResetModel() - self._suggestions_list.sort(key=sort_function, reverse=reverse) + self._data = [] + for video, frame_idx in self.labels.get_suggestions(): + item = dict() + + item[ + "video" + ] = f"{self.labels.videos.index(video)}: {os.path.basename(video.filename)}" + item["frame"] = int(frame_idx) + 1 # start at frame 1 rather than 0 + + # show how many labeled instances are in this frame + val = self._labels.instance_count(video, frame_idx) + val = str(val) if val > 0 else "" + item["labeled"] = val + + # calculate score for frame + scores = [ + inst.score + for lf in self.labels.find(video, frame_idx) + for inst in lf + if hasattr(inst, "score") + ] + val = sum(scores) / len(scores) if scores else None + item["mean score"] = val + + self._data.append(item) self.endResetModel() - def rowCount(self, *args): - return len(self._suggestions_list) - - def columnCount(self, *args): - return len(self._props) - - def headerData( - self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole - ): - if role == Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: - return self._props[section] - elif orientation == QtCore.Qt.Vertical: - return section - - return None - - def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable - if __name__ == "__main__": + from PySide2.QtWidgets import QApplication labels = Labels.load_json( "tests/data/json_format_v2/centered_pair_predictions.json" diff --git a/sleap/gui/overlays/anchors.py b/sleap/gui/overlays/anchors.py index 3025f0e3f..1914545f5 100644 --- a/sleap/gui/overlays/anchors.py +++ b/sleap/gui/overlays/anchors.py @@ -1,6 +1,9 @@ +""" +Module with overlay for showing negative training sample anchors. +""" import attr -from PySide2 import QtWidgets, QtGui +from PySide2 import QtGui from sleap.gui.video import QtVideoPlayer from sleap.io.dataset import Labels @@ -8,13 +11,20 @@ @attr.s(auto_attribs=True) class NegativeAnchorOverlay: + """Class to overlay of negative training sample anchors to video frame. + + Attributes: + labels: The :class:`Labels` dataset from which to get overlay data. + player: The video player in which to show overlay. + """ labels: Labels = None - scene: QtWidgets.QGraphicsScene = None - pen = QtGui.QPen(QtGui.QColor("red")) - line_len: int = 3 + player: QtVideoPlayer = None + _pen = QtGui.QPen(QtGui.QColor("red")) + _line_len: int = 3 def add_to_scene(self, video, frame_idx): + """Adds anchor markers as overlay on frame image.""" if self.labels is None: return if video not in self.labels.negative_anchors: @@ -26,17 +36,17 @@ def add_to_scene(self, video, frame_idx): self._add(x, y) def _add(self, x, y): - self.scene.addLine( - x - self.line_len, - y - self.line_len, - x + self.line_len, - y + self.line_len, - self.pen, + self.player.scene.addLine( + x - self._line_len, + y - self._line_len, + x + self._line_len, + y + self._line_len, + self._pen, ) - self.scene.addLine( - x + self.line_len, - y - self.line_len, - x - self.line_len, - y + self.line_len, - self.pen, + self.player.scene.addLine( + x + self._line_len, + y - self._line_len, + x - self._line_len, + y + self._line_len, + self._pen, ) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index e521a2df7..b7d03077f 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -1,4 +1,4 @@ -"""Base class for overlays.""" +"""Base class for overlays that use datasource (hdf5, model).""" from PySide2 import QtWidgets @@ -12,6 +12,8 @@ class HDF5Data(HDF5Video): + """Class to wrap HDF5Video so we can use it as overlay datasource.""" + def __getitem__(self, i): """Get data for frame i from `HDF5Video` object.""" x = self.get_frame(i) @@ -20,6 +22,8 @@ def __getitem__(self, i): @attr.s(auto_attribs=True) class ModelData: + """Class to wrap model so we can use it as overlay datasource.""" + # TODO: Unify this class with inference.Predictor or InferenceModel model: "keras.Model" video: Video @@ -69,6 +73,7 @@ def __getitem__(self, i): @attr.s(auto_attribs=True) class DataOverlay: + """Base class for overlays which use datasources.""" data: Sequence = None player: QtVideoPlayer = None @@ -76,6 +81,7 @@ class DataOverlay: transform: DataTransform = None def add_to_scene(self, video, frame_idx): + """Add overlay to scene.""" if self.data is None: return @@ -131,6 +137,7 @@ def _add( @classmethod def from_h5(cls, filename, dataset, input_format="channels_last", **kwargs): + """Creates instance of class with HDF5 datasource.""" import h5py as h5 with h5.File(filename, "r") as f: @@ -147,6 +154,7 @@ def from_h5(cls, filename, dataset, input_format="channels_last", **kwargs): @classmethod def from_model(cls, filename, video, **kwargs): + """Creates instance of class with model datasource.""" from sleap.nn.model import ModelOutputType from sleap.nn.loadmodel import load_model, get_model_data from sleap.nn.training import TrainingJob diff --git a/sleap/gui/overlays/confmaps.py b/sleap/gui/overlays/confmaps.py index 49c06c149..40d86abc9 100644 --- a/sleap/gui/overlays/confmaps.py +++ b/sleap/gui/overlays/confmaps.py @@ -8,25 +8,25 @@ from PySide2 import QtWidgets, QtCore, QtGui -import attr import numpy as np import qimage2ndarray -from typing import Sequence -from sleap.io.video import Video, HDF5Video -from sleap.gui.video import QtVideoPlayer from sleap.gui.overlays.base import DataOverlay, h5_colors class ConfmapOverlay(DataOverlay): + """Overlay to show confidence maps.""" + @classmethod def from_h5(cls, filename, input_format="channels_last", **kwargs): + """Create object with hdf5 as datasource.""" return DataOverlay.from_h5( filename, "/confmaps", input_format, overlay_class=ConfMapsPlot, **kwargs ) @classmethod def from_model(cls, filename, video, **kwargs): + """Create object with live predictions from model as datasource.""" return DataOverlay.from_model( filename, video, overlay_class=ConfMapsPlot, **kwargs ) @@ -140,6 +140,9 @@ def get_conf_image(self) -> QtGui.QImage: def show_confmaps_from_h5(filename, input_format="channels_last", standalone=False): + """Demo function.""" + from sleap.io.video import HDF5Video + video = HDF5Video(filename, "/box", input_format=input_format) conf_data = HDF5Video( filename, "/confmaps", input_format=input_format, convert_range=False @@ -152,6 +155,7 @@ def show_confmaps_from_h5(filename, input_format="channels_last", standalone=Fal def demo_confmaps(confmaps, video, standalone=False, callback=None): + """Demo function.""" from PySide2 import QtWidgets from sleap.gui.video import QtVideoPlayer @@ -182,6 +186,3 @@ def plot_confmaps(parent, item_idx): data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" show_confmaps_from_h5(data_path, input_format="channels_first", standalone=True) - -# data_path = "/Users/tabris/Documents/predictions.h5" -# show_confmaps_from_h5(data_path, input_format="channels_last", standalone=True) diff --git a/sleap/gui/overlays/instance.py b/sleap/gui/overlays/instance.py index 14dde6bdd..621f0e229 100644 --- a/sleap/gui/overlays/instance.py +++ b/sleap/gui/overlays/instance.py @@ -1,21 +1,29 @@ +""" +Module with overlay for showing instances. +""" import attr -from PySide2 import QtWidgets - from sleap.gui.video import QtVideoPlayer from sleap.io.dataset import Labels -from sleap.gui.overlays.tracks import TrackColorManager @attr.s(auto_attribs=True) class InstanceOverlay: + """Class for adding instances as overlays on video frames. + + Attributes: + labels: The :class:`Labels` dataset from which to get overlay data. + player: The video player in which to show overlay. + color_predicted: Whether to show predicted instances in color ( + rather than all in gray/yellow). + """ labels: Labels = None player: QtVideoPlayer = None - color_manager: TrackColorManager = TrackColorManager(labels) color_predicted: bool = False def add_to_scene(self, video, frame_idx): + """Adds overlay for frame to player scene.""" if self.labels is None: return @@ -35,7 +43,7 @@ def add_to_scene(self, video, frame_idx): self.player.addInstance( instance=instance, - color=self.color_manager.get_color(pseudo_track), + color=self.player.color_manager.get_color(pseudo_track), predicted=is_predicted, color_predicted=self.color_predicted, ) diff --git a/sleap/gui/overlays/pafs.py b/sleap/gui/overlays/pafs.py index 4ff930886..19caf39bf 100644 --- a/sleap/gui/overlays/pafs.py +++ b/sleap/gui/overlays/pafs.py @@ -1,39 +1,44 @@ +""" +Module for showing part affinity fields as an overlay within a QtVideoPlayer. +""" from PySide2 import QtWidgets, QtGui, QtCore import numpy as np import itertools import math -from sleap.io.video import Video, HDF5Video -from sleap.gui.multicheck import MultiCheckWidget +from typing import Optional from sleap.gui.overlays.base import DataOverlay, h5_colors class PafOverlay(DataOverlay): + """Overlay to show part affinity fields.""" + @classmethod def from_h5(cls, filename, input_format="channels_last", **kwargs): + """Creates object with hdf5 as datasource.""" return DataOverlay.from_h5( filename, "/pafs", input_format, overlay_class=MultiQuiverPlot, **kwargs ) class MultiQuiverPlot(QtWidgets.QGraphicsObject): - """QtWidgets.QGraphicsObject to display multiple quiver plots in a QtWidgets.QGraphicsView. + """ + QGraphicsObject to display multiple quiver plots in a QGraphicsView. + + When initialized, creates on child QuiverPlot item for each channel. + Each channel in data corresponds to two (h, w) arrays: + x and y for the arrow vectors. Args: - frame (numpy.array): Data for one frame of quiver plot data. + frame: Data for one frame of quiver plot data. Shape of array should be (channels, height, width). - show (list, optional): List of channels to show. If None, show all channels. - decimation (int, optional): Decimation factor. If 1, show every arrow. + show: List of channels to show. If None, show all channels. + decimation: Decimation factor. If 1, show every arrow. Returns: None. - - Note: - Each channel corresponds to two (h, w) arrays: x and y for the vector. - - When initialized, creates one child QuiverPlot item for each channel. """ def __init__( @@ -87,10 +92,10 @@ class QuiverPlot(QtWidgets.QGraphicsObject): """QtWidgets.QGraphicsObject for drawing single quiver plot. Args: - field_x (numpy.array): (h, w) array of x component of vectors. - field_y (numpy.array): (h, w) array of y component of vectors. - color (list, optional): Arrow color. Format as (r, g, b) array. - decimation (int, optional): Decimation factor. If 1, show every arrow. + field_x: (h, w) array of x component of vectors. + field_y: (h, w) array of y component of vectors. + color: Arrow color. Format as (r, g, b) array. + decimation: Decimation factor. If 1, show every arrow. Returns: None. @@ -98,8 +103,8 @@ class QuiverPlot(QtWidgets.QGraphicsObject): def __init__( self, - field_x: np.array = None, - field_y: np.array = None, + field_x: Optional[np.ndarray] = None, + field_y: Optional[np.ndarray] = None, color=[255, 255, 255], decimation=1, scale=1, @@ -191,6 +196,7 @@ def _add_arrows(self, min_length=0.01): self.points = list(itertools.starmap(QtCore.QPointF, points)) def _decimate(self, image: np.array, box: int): + """Decimates quiverplot.""" height = width = box # Source: https://stackoverflow.com/questions/48482317/slice-an-image-into-tiles-using-numpy _nrows, _ncols, depth = image.shape @@ -230,6 +236,9 @@ def paint(self, painter, option, widget=None): def show_pafs_from_h5(filename, input_format="channels_last", standalone=False): + """Demo function.""" + from sleap.io.video import HDF5Video + video = HDF5Video(filename, "/box", input_format=input_format) paf_data = HDF5Video( filename, "/pafs", input_format=input_format, convert_range=False @@ -242,6 +251,7 @@ def show_pafs_from_h5(filename, input_format="channels_last", standalone=False): def demo_pafs(pafs, video, decimation=4, standalone=False): + """Demo function.""" from sleap.gui.video import QtVideoPlayer if standalone: @@ -280,66 +290,9 @@ def plot_fields(parent, i): if __name__ == "__main__": - from video import * - # data_path = "training.scale=1.00,sigma=5.h5" data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" input_format = "channels_first" - data_path = "/Volumes/fileset-mmurthy/nat/nyu-mouse/predict.h5" - input_format = "channels_last" - show_pafs_from_h5(data_path, input_format=input_format, standalone=True) - - -def foo(): - - vid = HDF5Video(data_path, "/box", input_format=input_format) - overlay_data = HDF5Video( - data_path, "/pafs", input_format=input_format, convert_range=False - ) - print( - f"{overlay_data.frames}, {overlay_data.height}, {overlay_data.width}, {overlay_data.channels}" - ) - app = QtWidgets.QApplication([]) - window = QtVideoPlayer(video=vid) - - field_count = overlay_data.get_frame(1).shape[-1] // 2 - 1 - # show the first, middle, and last fields - show_fields = [0, field_count // 2, field_count] - - field_check_groupbox = MultiCheckWidget( - count=field_count, selected=show_fields, title="Affinity Field Channel" - ) - field_check_groupbox.selectionChanged.connect(window.plot) - window.layout.addWidget(field_check_groupbox) - - # show one arrow for each decimation*decimation box - default_decimation = 9 - - decimation_size_bar = QSlider(QtCore.Qt.Horizontal) - decimation_size_bar.valueChanged.connect(lambda evt: window.plot()) - decimation_size_bar.setValue(default_decimation) - decimation_size_bar.setMinimum(1) - decimation_size_bar.setMaximum(21) - decimation_size_bar.setEnabled(True) - window.layout.addWidget(decimation_size_bar) - - def plot_fields(parent, i): - # build list of checked boxes to determine which affinity fields to show - selected = field_check_groupbox.getSelected() - # get decimation size from slider - decimation = decimation_size_bar.value() - # show affinity fields - frame_data = overlay_data.get_frame(parent.frame_idx) - aff_fields_item = MultiQuiverPlot(frame_data, selected, decimation) - - window.view.scene.addItem(aff_fields_item) - - window.changedPlot.connect(plot_fields) - - window.show() - window.plot() - - app.exec_() diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index ba2e537ea..eccda9a32 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -1,5 +1,8 @@ -from sleap.skeleton import Skeleton, Node -from sleap.instance import Instance, PredictedInstance, Point, LabeledFrame, Track +""" +Module that handles track-related overlays (including track color). +""" + +from sleap.instance import Track from sleap.io.dataset import Labels from sleap.io.video import Video @@ -7,18 +10,22 @@ import itertools from typing import Union -from PySide2 import QtCore, QtWidgets, QtGui +from PySide2 import QtCore, QtGui + +class TrackColorManager(object): + """Class to determine color to use for track. -class TrackColorManager: - """Class to determine color to use for track. The color depends on the order of - the tracks in `Labels` object, so we need to initialize with `Labels`. + The color depends on the order of the tracks in `Labels` object, + so we need to initialize with `Labels`. Args: - labels: `Labels` object which contains the tracks for which we want colors + labels: The :class:`Labels` dataset which contains the tracks for + which we want colors. + palette: String with the color palette name to use. """ - def __init__(self, labels: Labels = None, palette="standard"): + def __init__(self, labels: Labels = None, palette: str = "standard"): self.labels = labels # alphabet @@ -106,6 +113,7 @@ def __init__(self, labels: Labels = None, palette="standard"): @property def labels(self): + """Gets or sets labels dataset for which we are coloring tracks.""" return self._labels @labels.setter @@ -114,9 +122,11 @@ def labels(self, val): @property def palette_names(self): + """Gets list of palette names.""" return self._palettes.keys() def set_palette(self, palette): + """Sets palette (by name).""" if isinstance(palette, str): self.mode = "clip" if palette.endswith("+") else "cycle" @@ -145,23 +155,24 @@ def get_color(self, track: Union[Track, int]): @attr.s(auto_attribs=True) class TrackTrailOverlay: - """Class to show track trails. You initialize this object with both its data source - and its visual output scene, and it handles both extracting the relevant data for a - given frame and plotting it in the output. + """Class to show track trails as overlay on video frame. - Args: - labels: `Labels` object from which to get data - scene: `QGraphicsScene` in which to plot trails - trail_length (optional): maximum number of frames to include in trail + Initialize this object with both its data source and its visual output + scene, and it handles both extracting the relevant data for a given + frame and plotting it in the output. + + Attributes: + labels: The :class:`Labels` dataset from which to get overlay data. + player: The video player in which to show overlay. + trail_length: The maximum number of frames to include in trail. Usage: - After class is instatiated, call add_trails_to_scene(frame_idx) + After class is instantiated, call :method:`add_to_scene(frame_idx)` to plot the trails in scene. """ labels: Labels = None - scene: QtWidgets.QGraphicsScene = None - color_manager: TrackColorManager = TrackColorManager(labels) + player: "QtVideoPlayer" = None trail_length: int = 4 show: bool = False @@ -202,7 +213,9 @@ def get_track_trails(self, frame_selection, track: Track): return all_trails def get_frame_selection(self, video: Video, frame_idx: int): - """Return list of `LabeledFrame`s to include in trail for specified frame.""" + """ + Return list of `LabeledFrame`s to include in trail for specified frame. + """ frame_selection = self.labels.find(video, range(0, frame_idx + 1)) frame_selection.sort(key=lambda x: x.frame_idx) @@ -235,7 +248,7 @@ def add_to_scene(self, video: Video, frame_idx: int): trails = self.get_track_trails(frame_selection, track) - color = QtGui.QColor(*self.color_manager.get_color(track)) + color = QtGui.QColor(*self.player.color_manager.get_color(track)) pen = QtGui.QPen() pen.setCosmetic(True) @@ -245,12 +258,12 @@ def add_to_scene(self, video: Video, frame_idx: int): color.setAlphaF(1) pen.setColor(color) polygon = self.map_to_qt_polygon(trail[:half]) - self.scene.addPolygon(polygon, pen) + self.player.scene.addPolygon(polygon, pen) color.setAlphaF(0.5) pen.setColor(color) polygon = self.map_to_qt_polygon(trail[half:]) - self.scene.addPolygon(polygon, pen) + self.player.scene.addPolygon(polygon, pen) @staticmethod def map_to_qt_polygon(point_list): @@ -260,15 +273,16 @@ def map_to_qt_polygon(point_list): @attr.s(auto_attribs=True) class TrackListOverlay: - """Class to show track number and names in overlay. + """ + Class to show track number and names in overlay. """ labels: Labels = None - view: QtWidgets.QGraphicsView = None - color_manager: TrackColorManager = TrackColorManager(labels) + player: "QtVideoPlayer" = None text_box = None def add_to_scene(self, video: Video, frame_idx: int): + """Adds track list as overlay on video.""" from sleap.gui.video import QtTextWithBackground html = "" @@ -279,7 +293,7 @@ def add_to_scene(self, video: Video, frame_idx: int): if html: html += "
" - color = self.color_manager.get_color(track) + color = self.player.color_manager.get_color(track) html_color = f"#{color[0]:02X}{color[1]:02X}{color[2]:02X}" track_text = f"{track.name}" if str(idx) != track.name: @@ -294,10 +308,11 @@ def add_to_scene(self, video: Video, frame_idx: int): self.text_box = text_box self.visible = False - self.view.scene.addItem(self.text_box) + self.player.scene.addItem(self.text_box) @property def visible(self): + """Gets or set whether overlay is visible.""" if self.text_box is None: return False return self.text_box.isVisible() diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index f81e7ef1b..6eace1e83 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -41,6 +41,7 @@ class SliderMark: @property def color(self): + """Returns color of mark.""" colors = dict(simple="black", filled="blue", open="blue", predicted="red") if self.type in colors: @@ -50,10 +51,12 @@ def color(self): @color.setter def color(self, val): + """Sets color of mark.""" self._color = val @property def QColor(self): + """Returns color of mark as `QColor`.""" c = self.color if type(c) == str: return QColor(c) @@ -62,6 +65,7 @@ def QColor(self): @property def filled(self): + """Returns whether mark is filled or open.""" if self.type == "open": return False else: diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 5e99dd4d2..851616ea6 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -1,3 +1,7 @@ +""" +Module for generating lists of suggested frames (for labeling or reviewing). +""" + import numpy as np import itertools @@ -7,27 +11,44 @@ import cv2 +from typing import List, Tuple + from sleap.io.video import Video class VideoFrameSuggestions: + """ + Class for generating lists of suggested frames. + + Implements various algorithms as methods: + * strides + * random + * pca_cluster + * brisk + * proofreading + + Each of algorithm method should accept `video`; other parameters will be + passed from the `params` dict given to :method:`suggest()`. + + """ rescale = True rescale_below = 512 @classmethod - def suggest(cls, video: Video, params: dict, labels: "Labels" = None) -> list: + def suggest(cls, video: Video, params: dict, labels: "Labels" = None) -> List[int]: """ - This is the main entry point. + This is the main entry point for generating lists of suggested frames. Args: - video: a `Video` object for which we're generating suggestions - params: a dict with all params to control how we generate suggestions - * minimally this will have a `method` corresponding to a method in class - labels: a `Labels` object + video: A `Video` object for which we're generating suggestions. + params: A dictionary with all params to control how we generate + suggestions, minimally this will have a "method" key with + the name of one of the class methods. + labels: A `Labels` object for which we are generating suggestions. Returns: - list of frame suggestions + List of suggested frame indices. """ # map from method param value to corresponding class method @@ -35,7 +56,7 @@ def suggest(cls, video: Video, params: dict, labels: "Labels" = None) -> list: strides=cls.strides, random=cls.random, pca=cls.pca_cluster, - hog=cls.hog, + # hog=cls.hog, brisk=cls.brisk, proofreading=cls.proofreading, ) @@ -50,12 +71,14 @@ def suggest(cls, video: Video, params: dict, labels: "Labels" = None) -> list: @classmethod def strides(cls, video, per_video=20, **kwargs): + """Method to generate suggestions by taking strides through video.""" suggestions = list(range(0, video.frames, video.frames // per_video)) suggestions = suggestions[:per_video] return suggestions @classmethod def random(cls, video, per_video=20, **kwargs): + """Method to generate suggestions by taking random frames in video.""" import random suggestions = random.sample(range(video.frames), per_video) @@ -63,7 +86,7 @@ def random(cls, video, per_video=20, **kwargs): @classmethod def pca_cluster(cls, video, initial_samples, **kwargs): - + """Method to generate suggestions by using PCA clusters.""" sample_step = video.frames // initial_samples feature_stack, frame_idx_map = cls.frame_feature_stack(video, sample_step) @@ -75,7 +98,7 @@ def pca_cluster(cls, video, initial_samples, **kwargs): @classmethod def brisk(cls, video, initial_samples, **kwargs): - + """Method to generate suggestions using PCA on Brisk features.""" sample_step = video.frames // initial_samples feature_stack, frame_idx_map = cls.brisk_feature_stack(video, sample_step) @@ -85,37 +108,11 @@ def brisk(cls, video, initial_samples, **kwargs): return result - @classmethod - def hog( - cls, - video, - clusters=5, - per_cluster=5, - sample_step=5, - pca_components=50, - interleave=True, - **kwargs, - ): - - feature_stack, frame_idx_map = cls.hog_feature_stack(video, sample_step) - - result = cls.feature_stack_to_suggestions( - feature_stack, - frame_idx_map, - clusters=clusters, - per_cluster=per_cluster, - pca_components=pca_components, - interleave=interleave, - **kwargs, - ) - - return result - @classmethod def proofreading( cls, video: Video, labels: "Labels", score_limit, instance_limit, **kwargs ): - + """Method to generate suggestions for proofreading.""" score_limit = float(score_limit) instance_limit = int(instance_limit) @@ -148,7 +145,10 @@ def proofreading( # These are specific to the suggestion method @classmethod - def frame_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: + def frame_feature_stack( + cls, video: Video, sample_step: int = 5 + ) -> Tuple[np.ndarray, List[int]]: + """Generates matrix of sampled video frame images.""" sample_count = video.frames // sample_step factor = cls.get_scale_factor(video) @@ -172,7 +172,10 @@ def frame_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: return (flat_stack, frame_idx_map) @classmethod - def brisk_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: + def brisk_feature_stack( + cls, video: Video, sample_step: int = 5 + ) -> Tuple[np.ndarray, List[int]]: + """Generates Brisk features from sampled video frames.""" brisk = cv2.BRISK_create() factor = cls.get_scale_factor(video) @@ -194,48 +197,38 @@ def brisk_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: return (feature_stack, frame_idx_map) - @classmethod - def hog_feature_stack(cls, video: Video, sample_step: int = 5) -> tuple: - sample_count = video.frames // sample_step - - hog = cv2.HOGDescriptor() - - factor = cls.get_scale_factor(video) - first_hog = hog.compute(cls.resize(video[0][0], factor)) - hog_size = first_hog.shape[0] - - frame_idx_map = [None] * sample_count - flat_stack = np.zeros((sample_count, hog_size)) - - for i in range(sample_count): - frame_idx = i * sample_step - img = video[frame_idx][0] - img = cls.resize(img, factor) - flat_stack[i] = hog.compute(img).transpose()[0] - frame_idx_map[i] = frame_idx - - return (flat_stack, frame_idx_map) - # Functions for making suggestions based on "feature stack" # These are common for all suggestion methods @staticmethod - def to_frame_idx_list(selected_list: list, frame_idx_map: dict) -> list: + def to_frame_idx_list( + selected_list: List[int], frame_idx_map: List[int] + ) -> List[int]: """Convert list of row indexes to list of frame indexes.""" return list(map(lambda x: frame_idx_map[x], selected_list)) @classmethod def feature_stack_to_suggestions( - cls, feature_stack, frame_idx_map, return_clusters=False, **kwargs - ): + cls, + feature_stack: np.ndarray, + frame_idx_map: List[int], + return_clusters: bool = False, + **kwargs, + ) -> List[int]: """ Turns a feature stack matrix into a list of suggested frames. Args: - feature_stack: (n * features) matrix - frame_idx_map: n-length vector which gives frame_idx for each row in feature_stack - return_clusters (optional): return the intermediate result for debugging - i.e., a list that gives the list of suggested frames for each cluster + feature_stack: (n * features) matrix. + frame_idx_map: List indexed by rows of feature stack which gives + frame index for each row in feature_stack. This allows a + single frame to correspond to multiple rows in feature_stack. + return_clusters: Whether to return the intermediate result for + debugging, i.e., a list that gives the list of suggested + frames for each cluster. + + Returns: + List of frame index suggestions. """ selected_by_cluster = cls.feature_stack_to_clusters( @@ -254,25 +247,30 @@ def feature_stack_to_suggestions( @classmethod def feature_stack_to_clusters( cls, - feature_stack, - frame_idx_map, - clusters=5, - per_cluster=5, - pca_components=50, + feature_stack: np.ndarray, + frame_idx_map: List[int], + clusters: int = 5, + per_cluster: int = 5, + pca_components: int = 50, **kwargs, - ): + ) -> List[int]: """ - Turns feature stack matrix into list (per cluster) of list of frame indexes. + Runs PCA to generate clusters of frames based on given features. Args: - feature_stack: (n * features) matrix + feature_stack: (n * features) matrix. + frame_idx_map: List indexed by rows of feature stack which gives + frame index for each row in feature_stack. This allows a + single frame to correspond to multiple rows in feature_stack. clusters: number of clusters - per_clusters: how many suggestions to take from each cluster (at most) - pca_components: for reducing feature space before clustering + per_cluster: How many suggestions (at most) to take from each + cluster. + pca_components: Number of PCA components, for reducing feature + space before clustering Returns: - list of lists - for each cluster, a list of frame indexes + A list of lists: + * for each cluster, a list of frame indices. """ stack_height = feature_stack.shape[0] @@ -313,18 +311,17 @@ def feature_stack_to_clusters( @classmethod def clusters_to_list( - cls, selected_by_cluster, interleave: bool = True, **kwargs + cls, selected_by_cluster: List[List[int]], interleave: bool = True, **kwargs ) -> list: """ - Turns list (per cluster) of lists of frame index into single list of frame indexes. + Merges per cluster suggestion lists into single list for entire video. Args: - selected_by_cluster: the list of lists of row indexes - frame_idx_map: map from row index to frame index - interleave: whether we want to interleave suggestions from clusters + selected_by_cluster: The list of lists of row indexes. + interleave: Whether to interleave suggestions from clusters. Returns: - list of frame index + List of frame indices. """ if interleave: @@ -345,7 +342,11 @@ def clusters_to_list( # Utility functions @classmethod - def get_scale_factor(cls, video) -> int: + def get_scale_factor(cls, video: "Video") -> int: + """Determines how much we need to scale to get video within size. + + Size is specified by :attribute:`rescale_below`. + """ factor = 1 if cls.rescale: largest_dim = max(video.height, video.width) @@ -354,8 +355,9 @@ def get_scale_factor(cls, video) -> int: factor += 1 return factor - @classmethod - def resize(cls, img, factor) -> np.ndarray: + @staticmethod + def resize(img: np.ndarray, factor: float) -> np.ndarray: + """Resizes frame image by scaling factor.""" h, w, _ = img.shape if factor != 1: return cv2.resize(img, (h // factor, w // factor)) diff --git a/sleap/gui/training_editor.py b/sleap/gui/training_editor.py index 4e79ee30d..77a09d644 100644 --- a/sleap/gui/training_editor.py +++ b/sleap/gui/training_editor.py @@ -1,4 +1,7 @@ -import os +""" +Module for viewing and modifying training profiles. +""" + import attr import cattr from typing import Optional @@ -7,11 +10,19 @@ from PySide2 import QtWidgets -from sleap.io.dataset import Labels from sleap.gui.formbuilder import YamlFormWidget class TrainingEditor(QtWidgets.QDialog): + """ + Dialog for viewing and modifying training profiles. + + Args: + profile_filename: Path to saved training profile to view. + saved_files: When user saved profile, it's path is added to this + list (which will be updated in code that created TrainingEditor). + """ + def __init__( self, profile_filename: Optional[str] = None, @@ -59,10 +70,12 @@ def __init__( @property def profile_filename(self): + """Returns path to currently loaded training profile.""" return self._profile_filename @profile_filename.setter def profile_filename(self, val): + """Sets path to (and loads) training profile.""" self._profile_filename = val # set window title self.setWindowTitle(self.profile_filename) @@ -77,6 +90,7 @@ def _layout_widget(layout): return widget def _load_profile(self, profile_filename: str): + """Loads training profile settings from file.""" from sleap.nn.model import ModelOutputType from sleap.nn.training import TrainingJob @@ -93,6 +107,7 @@ def _load_profile(self, profile_filename: str): self.form_widgets[name].set_form_data(job_dict["trainer"]) def _save_as(self): + """Shows dialog to save training profile.""" # Show "Save" dialog save_filename, _ = QtWidgets.QFileDialog.getSaveFileName( diff --git a/sleap/gui/video.py b/sleap/gui/video.py index cf0630d7f..f8b669233 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -6,29 +6,29 @@ Example usage: >>> my_video = Video(...) >>> my_instance = Instance(...) - >>> color = (r, g, b) - >>> vp = QtVideoPlayer(video = my_video) - >>> vp.addInstance(instance = my_instance, color) + >>> vp = QtVideoPlayer(video=my_video) + >>> vp.addInstance(instance=my_instance, color=(r, g, b)) """ from PySide2 import QtWidgets -from PySide2.QtWidgets import QApplication, QVBoxLayout, QWidget -from PySide2.QtWidgets import QLabel, QPushButton, QSlider -from PySide2.QtWidgets import QAction - -from PySide2.QtWidgets import QGraphicsView, QGraphicsScene +from PySide2.QtWidgets import ( + QApplication, + QVBoxLayout, + QWidget, + QGraphicsView, + QGraphicsScene, +) from PySide2.QtGui import QImage, QPixmap, QPainter, QPainterPath, QTransform from PySide2.QtGui import QPen, QBrush, QColor, QFont from PySide2.QtGui import QKeyEvent from PySide2.QtCore import Qt, Signal, Slot -from PySide2.QtCore import QRectF, QLineF, QPointF, QMarginsF, QSizeF +from PySide2.QtCore import QRectF, QPointF, QMarginsF import math -import numpy as np -from typing import Callable, Union +from typing import Callable, List, Optional, Union from PySide2.QtWidgets import QGraphicsItem, QGraphicsObject @@ -52,8 +52,10 @@ class QtVideoPlayer(QWidget): """ Main QWidget for displaying video with skeleton instances. - Args: - video (optional): the :class:`Video` to display + Attributes: + video: The :class:`Video` to display + color_manager: A :class:`TrackColorManager` object which determines + which color to show the instances. Signals: changedPlot: Emitted whenever the plot is redrawn @@ -68,10 +70,10 @@ def __init__(self, video: Video = None, color_manager=None, *args, **kwargs): self._shift_key_down = False self.frame_idx = -1 - self._color_manager = color_manager + self.color_manager = color_manager self.view = GraphicsView() - self.seekbar = VideoSlider(color_manager=self._color_manager) + self.seekbar = VideoSlider(color_manager=self.color_manager) self.seekbar.valueChanged.connect(lambda evt: self.plot(self.seekbar.value())) self.seekbar.keyPress.connect(self.keyPressEvent) self.seekbar.keyRelease.connect(self.keyReleaseEvent) @@ -127,16 +129,24 @@ def reset(self): @property def instances(self): + """Returns list of all `QtInstance`s in view.""" return self.view.instances @property def selectable_instances(self): + """Returns list of selectable `QtInstance`s in view.""" return self.view.selectable_instances @property def predicted_instances(self): + """Returns list of predicted `QtInstance`s in view.""" return self.view.predicted_instances + @property + def scene(self): + """Returns `QGraphicsScene` for viewer.""" + return self.view.scene + def addInstance(self, instance, **kwargs): """Add a skeleton instance to the video. @@ -159,12 +169,12 @@ def addInstance(self, instance, **kwargs): # connect signal so we can adjust QtNodeLabel positions after zoom self.view.updatedViewer.connect(instance.updatePoints) - def plot(self, idx=None): + def plot(self, idx: Optional[int] = None): """ Do the actual plotting of the video frame. Args: - idx (optional): Go to frame idx. If None, stay on current frame. + idx: Go to frame idx. If None, stay on current frame. """ if self.video is None: @@ -244,16 +254,27 @@ def zoomToFit(self): self.view.zoomToRect(zoom_rect) def onSequenceSelect( - self, seq_len: int, on_success: Callable, on_each=None, on_failure=None + self, + seq_len: int, + on_success: Callable, + on_each: Optional[Callable] = None, + on_failure: Optional[Callable] = None, ): """ - Collect a sequence of instances (through user selection) and call `on_success`. - If the user cancels (by unselecting without new selection), call `on_failure`. + Collect a sequence of instances (through user selection). + + When the sequence is complete, the `on_success` callback is called. + After each selection in sequence, the `on_each` callback is called + (if given). If the user cancels (by unselecting without new + selection), the `on_failure` callback is called (if given). Args: - seq_len: number of instances we expect user to select - on_success: callback after use has selected desired number of instances - on_failure (optional): callback if user cancels selection + seq_len: Number of instances we want to collect in sequence. + on_success: Callback for when user has selected desired number of + instances. + on_each: Callback after user selects each instance. + on_failure: Callback if user cancels process before selecting + enough instances. Note: If successful, we call @@ -301,7 +322,19 @@ def handle_selection( on_each(indexes) @staticmethod - def _signal_once(signal, callback): + def _signal_once(signal: Signal, callback: Callable): + """ + Connects callback for next occurrence of signal. + + Args: + signal: The signal on which we want callback to be called. + callback: The function that should be called just once, the next + time the signal is emitted. + + Returns: + None. + """ + def call_once(*args): signal.disconnect(call_once) callback(*args) @@ -309,25 +342,47 @@ def call_once(*args): signal.connect(call_once) def onPointSelection(self, callback: Callable): + """ + Starts mode for user to click point, callback called when finished. + + Args: + callback: The function called after user clicks point, should + take x and y as arguments. + + Returns: + None. + """ self.view.click_mode = "point" self.view.setCursor(Qt.CrossCursor) self._signal_once(self.view.pointSelected, callback) def onAreaSelection(self, callback: Callable): + """ + Starts mode for user to select area, callback called when finished. + + Args: + callback: The function called after user clicks point, should + take x0, y0, x1, y1 as arguments. + + Returns: + None. + """ self.view.click_mode = "area" self.view.setCursor(Qt.CrossCursor) self._signal_once(self.view.areaSelected, callback) def keyReleaseEvent(self, event: QKeyEvent): + """ + Custom event handler, tracks when user releases modifier (shift) key. + """ if event.key() == Qt.Key.Key_Shift: self._shift_key_down = False event.ignore() def keyPressEvent(self, event: QKeyEvent): - """ Custom event handler. - Move between frames, toggle display of edges/labels, and select instances. """ - ignore = False + Custom event handler, allows navigation and selection within view. + """ frame_t0 = self.frame_idx if event.key() == Qt.Key.Key_Shift: @@ -356,10 +411,8 @@ def keyPressEvent(self, event: QKeyEvent): self.view.selectInstance(int(chr(event.key())) - 1) else: event.ignore() # Kicks the event up to parent - # print(event.key()) # If user is holding down shift and action resulted in moving to another frame - # event.modifiers() == Qt.ShiftModifier and if self._shift_key_down and frame_t0 != self.frame_idx: # If there's no select, start seekbar selection at frame before action start, end = self.seekbar.getSelection() @@ -371,15 +424,21 @@ def keyPressEvent(self, event: QKeyEvent): class GraphicsView(QGraphicsView): """ - QGraphicsView used by QtVideoPlayer. + Custom `QGraphicsView` used by `QtVideoPlayer`. - This contains elements for display of video and event handlers for zoom/selection. + This contains elements for display of video and event handlers for zoom + and selection of instances in view. Signals: - updatedViewer: Emitted after update to view (e.g., zoom) + updatedViewer: Emitted after update to view (e.g., zoom). Used internally so we know when to update points for each instance. - updatedSelection: Emitted after the user has selected/unselected an instance - instanceDoubleClicked: Emitted after an instance is double clicked + updatedSelection: Emitted after the user has (un)selected an instance. + instanceDoubleClicked: Emitted after an instance is double-clicked. + Passes the :class:`Instance` that was double-clicked. + areaSelected: Emitted after user selects an area when in "area" + click mode. Passes x0, y0, x1, y1 for selected box coordinates. + pointSelected: Emitted after user clicks a point (in "point" click + mode.) Passes x, y coordinates of point. leftMouseButtonPressed rightMouseButtonPressed @@ -392,14 +451,14 @@ class GraphicsView(QGraphicsView): updatedViewer = Signal() updatedSelection = Signal() instanceDoubleClicked = Signal(Instance) + areaSelected = Signal(float, float, float, float) + pointSelected = Signal(float, float) leftMouseButtonPressed = Signal(float, float) rightMouseButtonPressed = Signal(float, float) leftMouseButtonReleased = Signal(float, float) rightMouseButtonReleased = Signal(float, float) leftMouseButtonDoubleClicked = Signal(float, float) rightMouseButtonDoubleClicked = Signal(float, float) - areaSelected = Signal(float, float, float, float) - pointSelected = Signal(float, float) def __init__(self, *args, **kwargs): """ https://github.com/marcel-goldschen-ohm/PyQtImageViewer/blob/master/QtImageViewer.py """ @@ -412,7 +471,6 @@ def __init__(self, *args, **kwargs): self._pixmapHandle = None self.setRenderHint(QPainter.Antialiasing) - # self.setCacheMode(QGraphicsView.CacheNone) self.aspectRatioMode = Qt.KeepAspectRatio self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) @@ -424,12 +482,9 @@ def __init__(self, *args, **kwargs): self.zoomFactor = 1 anchor_mode = QGraphicsView.AnchorUnderMouse - # anchor_mode = QGraphicsView.AnchorViewCenter self.setTransformationAnchor(anchor_mode) - # self.scene.render() - - def hasImage(self): + def hasImage(self) -> bool: """ Returns whether or not the scene contains an image pixmap. """ return self._pixmapHandle is not None @@ -440,15 +495,22 @@ def clear(self): self._pixmapHandle = None self.scene.clear() - def setImage(self, image): - """ Set the scene's current image pixmap to the input QImage or QPixmap. - Raises a RuntimeError if the input image has type other than QImage or QPixmap. - :type image: QImage | QPixmap + def setImage(self, image: Union[QImage, QPixmap]): + """ + Set the scene's current image pixmap to the input QImage or QPixmap. + + Args: + image: The QPixmap or QImage to display. + + Raises: + RuntimeError: If the input image is not QImage or QPixmap + + Returns: + None. """ if type(image) is QPixmap: pixmap = image elif type(image) is QImage: - # pixmap = QPixmap.fromImage(image) pixmap = QPixmap(image) else: raise RuntimeError( @@ -462,8 +524,7 @@ def setImage(self, image): self.updateViewer() def updateViewer(self): - """ Show current zoom (if showing entire image, apply current aspect ratio mode). - """ + """ Apply current zoom. """ if not self.hasImage(): return @@ -477,51 +538,41 @@ def updateViewer(self): self.updatedViewer.emit() @property - def instances(self): + def instances(self) -> List["QtInstance"]: """ Returns a list of instances. - Order in list should match the order in which instances were added to scene. + Order should match the order in which instances were added to scene. """ - return [ - item - for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and not item.predicted - ] + return list(filter(lambda x: not x.predicted, self.all_instances)) @property - def selectable_instances(self): - return [ - item - for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and item.selectable - ] + def predicted_instances(self) -> List["QtInstance"]: + """ + Returns a list of predicted instances. + + Order should match the order in which instances were added to scene. + """ + return list(filter(lambda x: not x.predicted, self.all_instances)) @property - def predicted_instances(self): + def selectable_instances(self) -> List["QtInstance"]: """ - Returns a list of predicted instances. + Returns a list of instances which user can select. - Order in list should match the order in which instances were added to scene. + Order should match the order in which instances were added to scene. """ - return [ - item - for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and item.predicted - ] + return list(filter(lambda x: x.selectable, self.all_instances)) @property - def all_instances(self): + def all_instances(self) -> List["QtInstance"]: """ - Returns a list of instances and predicted instances. + Returns a list of all `QtInstance`s in scene. - Order in list should match the order in which instances were added to scene. + Order should match the order in which instances were added to scene. """ - return [ - item - for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance - ] + scene_items = self.scene.items(Qt.SortOrder.AscendingOrder) + return list(filter(lambda x: isinstance(x, QtInstance), scene_items)) def clearSelection(self, signal=True): """ Clear instance skeleton selection. @@ -555,7 +606,9 @@ def selectInstance(self, select: Union[Instance, int], signal=True): Select a particular instance in view. Args: - select: either `Instance` or index of instance in view + select: Either `Instance` or index of instance in view. + signal: Whether to emit updatedSelection. + Returns: None """ @@ -568,7 +621,7 @@ def selectInstance(self, select: Union[Instance, int], signal=True): if signal: self.updatedSelection.emit() - def getSelection(self): + def getSelection(self) -> int: """ Returns the index of the currently selected instance. If no instance selected, returns None. """ @@ -579,7 +632,7 @@ def getSelection(self): if instance.selected: return idx - def getSelectionInstance(self): + def getSelectionInstance(self) -> Instance: """ Returns the currently selected instance. If no instance selected, returns None. """ @@ -689,7 +742,6 @@ def zoomToRect(self, zoom_rect: QRectF): Args: zoom_rect: The `QRectF` to which we want to zoom. - relative: Controls whether rect is relative to current zoom. """ if zoom_rect.isNull(): @@ -708,7 +760,7 @@ def clearZoom(self): """ self.zoomFactor = 1 - def instancesBoundingRect(self, margin=0): + def instancesBoundingRect(self, margin: float = 0) -> QRectF: """ Returns a rect which contains all displayed skeleton instances. @@ -725,7 +777,7 @@ def instancesBoundingRect(self, margin=0): return rect def mouseDoubleClickEvent(self, event): - """ Custom event handler. Show entire image. + """ Custom event handler, clears zoom. """ scenePos = self.mapToScene(event.pos()) if event.button() == Qt.LeftButton: @@ -762,9 +814,11 @@ def wheelEvent(self, event): pass def keyPressEvent(self, event): + """Custom event hander, disables default QGraphicsView behavior.""" event.ignore() # Kicks the event up to parent def keyReleaseEvent(self, event): + """Custom event hander, disables default QGraphicsView behavior.""" event.ignore() # Kicks the event up to parent @@ -1003,11 +1057,12 @@ def calls(self): if callable(callback): callback(self) - def updatePoint(self, user_change=True): - """ Method to update data for node/edge after user manipulates visual point. + def updatePoint(self, user_change: bool = True): + """ + Method to update data for node/edge when node position is manipulated. Args: - user_change (optional): Is this being called because of change by user? + user_change: Whether this being called because of change by user. """ self.point.x = self.scenePos().x() self.point.y = self.scenePos().y() @@ -1106,6 +1161,7 @@ def wheelEvent(self, event): event.accept() def mouseDoubleClickEvent(self, event): + """Custom event handler to emit signal on event.""" scene = self.scene() if scene is not None: view = scene.views()[0] @@ -1119,6 +1175,8 @@ class QtEdge(QGraphicsLineItem): Args: src: The `QtNode` source node for the edge. dst: The `QtNode` destination node for the edge. + color: Color as (r, g, b) tuple. + show_non_visible: Whether to show "non-visible" nodes/edges. """ def __init__( @@ -1150,12 +1208,13 @@ def __init__( self.setPen(pen) self.full_opacity = 1 - def connected_to(self, node): + def connected_to(self, node: QtNode): """ Return the other node along the edge. Args: node: One of the edge's nodes. + Returns: The other node (or None if edge doesn't have node). """ @@ -1166,7 +1225,7 @@ def connected_to(self, node): return None - def angle_to(self, node): + def angle_to(self, node: QtNode) -> float: """ Returns the angle from one edge node to the other. @@ -1181,12 +1240,15 @@ def angle_to(self, node): y = to.point.y - node.point.y return math.atan2(y, x) - def updateEdge(self, node): + def updateEdge(self, node: QtNode): """ Updates the visual display of node. Args: node: The node to update. + + Returns: + None. """ if self.src.point.visible and self.dst.point.visible: self.full_opacity = 1 @@ -1213,18 +1275,26 @@ class QtInstance(QGraphicsObject): and handles the events to manipulate the skeleton within a video frame (i.e., moving, rotating, marking nodes). - It should be instatiated with a `Skeleton` or `Instance` - and added to the relevant `QGraphicsScene`. + It should be instantiated with an `Instance` and added to the relevant + `QGraphicsScene`. When instantiated, it creates `QtNode`, `QtEdge`, and `QtNodeLabel` items as children of itself. + + Args: + instance: The :class:`Instance` to show. + predicted: Whether this is a predicted instance. + color_predicted: Whether to show predicted instance in color. + color: Color of the visual item. + markerRadius: Radius of nodes. + show_non_visible: Whether to show "non-visible" nodes/edges. + """ changedData = Signal(Instance) def __init__( self, - skeleton: Skeleton = None, instance: Instance = None, predicted=False, color_predicted=False, @@ -1235,7 +1305,7 @@ def __init__( **kwargs, ): super(QtInstance, self).__init__(*args, **kwargs) - self.skeleton = skeleton if instance is None else instance.skeleton + self.skeleton = instance.skeleton self.instance = instance self.predicted = predicted self.color_predicted = color_predicted @@ -1251,8 +1321,6 @@ def __init__( self.labels_shown = True self._selected = False self._bounding_rect = QRectF() - # self.setFlag(QGraphicsItem.ItemIsMovable) - # self.setFlag(QGraphicsItem.ItemIsSelectable) if self.predicted: self.setZValue(0) @@ -1335,9 +1403,11 @@ def updatePoints(self, complete: bool = False, user_change: bool = False): This is called any time the skeleton is manipulated as a whole. Args: - complete (optional): If set, we mark the state of all - nodes in the skeleton to "complete". - user_change (optional): Is this being called because of change by user? + complete: Whether to update all nodes by setting "completed" + attribute. + user_change: Whether method is called because of change made by + user. + Returns: None. """ @@ -1367,7 +1437,7 @@ def updatePoints(self, complete: bool = False, user_change: bool = False): if user_change: self.changedData.emit(self.instance) - def getPointsBoundingRect(self): + def getPointsBoundingRect(self) -> QRectF: """Returns a rect which contains all the nodes in the skeleton.""" rect = None for item in self.edges: @@ -1400,10 +1470,12 @@ def updateBox(self, *args, **kwargs): @property def selected(self): + """Whether instance is selected.""" return self._selected @selected.setter def selected(self, selected: bool): + """Sets select-state for instance.""" self._selected = selected # Update the selection box for this skeleton instance self.updateBox() @@ -1413,7 +1485,7 @@ def toggleLabels(self): """ self.showLabels(not self.labels_shown) - def showLabels(self, show): + def showLabels(self, show: bool): """ Draws/hides the labels for this skeleton instance. @@ -1454,6 +1526,12 @@ def paint(self, painter, option, widget=None): class QtTextWithBackground(QGraphicsTextItem): + """ + Inherits methods/behavior of `QGraphicsTextItem`, but with background box. + + Color of brackground box is light or dark depending on the text color. + """ + def __init__(self, *args, **kwargs): super(QtTextWithBackground, self).__init__(*args, **kwargs) self.setFlag(QGraphicsItem.ItemIgnoresTransformations) @@ -1477,6 +1555,7 @@ def paint(self, painter, option, *args, **kwargs): def video_demo(labels, standalone=False): + """Demo function for showing (first) video from dataset.""" video = labels.videos[0] if standalone: app = QApplication([]) @@ -1494,6 +1573,7 @@ def video_demo(labels, standalone=False): def plot_instances(scene, frame_idx, labels, video=None, fixed=True): + """Demo function for plotting instances.""" from sleap.gui.overlays.tracks import TrackColorManager video = labels.videos[0] diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index 46969895f..3412d2626 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -1,5 +1,7 @@ +""" +Module for producing prediction metrics for SLEAP datasets. +""" from inspect import signature -import itertools import numpy as np from scipy.optimize import linear_sum_assignment from typing import Callable, List, Optional, Union, Tuple @@ -238,13 +240,6 @@ def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: return dist_array.shape[0] - point_match_count(dist_array, thresh) -def foo(labels_gt, labels_pr, frame_idx=1092): - list_a = labels_gt.find(labels_gt.videos[0], frame_idx=frame_idx)[0].instances - list_b = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx)[0].instances - - match_instance_lists_nodewise(list_a, list_b) - - if __name__ == "__main__": labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") diff --git a/sleap/instance.py b/sleap/instance.py index 82116d861..00d0d0705 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1171,7 +1171,22 @@ def instances_to_show(self) -> List[Instance]: return inst_to_show @staticmethod - def merge_frames(labeled_frames, video, remove_redundant=True): + def merge_frames( + labeled_frames: List["LabeledFrame"], video: "Video", remove_redundant=True + ) -> List["LabeledFrame"]: + """Merged LabeledFrames for same video and frame index. + + Args: + labeled_frames: List of :class:`LabeledFrame`s to merge. + video: The :class:`Video` for which to merge. + This is specified so we don't have to check all frames when we + already know which video has new labeled frames. + remove_redundant: Whether to drop instances in the merged frames + where there's a perfect match. + + Returns: + The merged list of :class:`LabeledFrame`s. + """ redundant_count = 0 frames_found = dict() # move instances into first frame with matching frame_idx diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 974e9e5d2..6d930abbf 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1713,7 +1713,25 @@ def append_unique(old, new): def load_hdf5( cls, filename: str, video_callback=None, match_to: Optional["Labels"] = None ): + """ + Deserialize HDF5 file as new :class:`Labels` instance. + + Args: + filename: Path to HDF5 file. + video_callback: A callback function that which can modify + video paths before we try to create the corresponding + :class:`Video` objects. Usually you'll want to pass + a callback created by :method:`Labels.make_video_callback` + or :method:`Labels.make_gui_video_callback`. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video`s). + Returns: + A new :class:`Labels` object. + """ with h5.File(filename, "r") as f: # Extract the Labels JSON metadata and create Labels object with just diff --git a/sleap/io/video.py b/sleap/io/video.py index 8ea759410..e0c18dc34 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -339,25 +339,31 @@ def matches(self, other: "NumpyVideo") -> np.ndarray: @property def frames(self): + """See :class:`Video`.""" return self.__data.shape[self.__frame_idx] @property def channels(self): + """See :class:`Video`.""" return self.__data.shape[self.__channel_idx] @property def width(self): + """See :class:`Video`.""" return self.__data.shape[self.__width_idx] @property def height(self): + """See :class:`Video`.""" return self.__data.shape[self.__height_idx] @property def dtype(self): + """See :class:`Video`.""" return self.__data.dtype def get_frame(self, idx): + """See :class:`Video`.""" return self.__data[idx] @@ -438,10 +444,12 @@ def __img(self): @property def frames(self): + """See :class:`Video`.""" return self.__store.frame_count @property def channels(self): + """See :class:`Video`.""" if len(self.__img.shape) < 3: return 1 else: @@ -449,14 +457,17 @@ def channels(self): @property def width(self): + """See :class:`Video`.""" return self.__img.shape[1] @property def height(self): + """See :class:`Video`.""" return self.__img.shape[0] @property def dtype(self): + """See :class:`Video`.""" return self.__img.dtype def get_frame(self, frame_number: int) -> np.ndarray: diff --git a/sleap/io/visuals.py b/sleap/io/visuals.py index 04c758865..a60e4bc28 100644 --- a/sleap/io/visuals.py +++ b/sleap/io/visuals.py @@ -1,3 +1,7 @@ +""" +Module for generating videos with visual annotation overlays. +""" + from sleap.io.video import Video from sleap.io.dataset import Labels from sleap.util import usable_cpu_count diff --git a/sleap/skeleton.py b/sleap/skeleton.py index f5c0d6381..1ed7eecf8 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -16,12 +16,15 @@ from enum import Enum from itertools import count -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import networkx as nx from networkx.readwrite import json_graph from scipy.io import loadmat +NodeRef = Union[str, "Node"] +H5FileRef = Union[str, h5.File] + class EdgeType(Enum): """ @@ -58,7 +61,7 @@ def from_names(name_list: str) -> List["Node"]: return nodes @classmethod - def as_node(cls, node: Union[str, "Node"]) -> "Node": + def as_node(cls, node: NodeRef) -> "Node": """Convert given `node` to `Node` object (if not already).""" return node if isinstance(node, cls) else cls(node) @@ -143,6 +146,7 @@ def dict_match(dict1, dict2): @property def graph(self): + """Returns subgraph of BODY edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") @@ -155,6 +159,7 @@ def graph(self): @property def graph_symmetry(self): + """Returns subgraph of symmetric edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") @@ -367,7 +372,7 @@ def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: if attr["type"] == EdgeType.SYMMETRY ] - def node_to_index(self, node: Union[str, Node]) -> int: + def node_to_index(self, node: NodeRef) -> int: """ Return the index of the node, accepts either `Node` or name. @@ -443,7 +448,7 @@ def delete_node(self, name: str): "The node named ({}) does not exist, cannot remove it.".format(name) ) - def find_node(self, name: str) -> Node: + def find_node(self, name: NodeRef) -> Node: """Find node in skeleton by name of node. Args: @@ -601,13 +606,13 @@ def add_symmetry(self, node1: str, node2: str): self._graph.add_edge(node1_node, node2_node, type=EdgeType.SYMMETRY) self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY) - def delete_symmetry(self, node1: str, node2: str): + def delete_symmetry(self, node1: NodeRef, node2: NodeRef): """ Deletes a previously established symmetry between two nodes. Args: - node1: The name of the first part in the symmetric pair. - node2: The name of the second part in the symmetric pair. + node1: One node (by `Node` object or name) in symmetric pair. + node2: Other node in symmetric pair. Raises: ValueError: If there's no symmetry between node1 and node2. @@ -633,18 +638,18 @@ def delete_symmetry(self, node1: str, node2: str): ] self._graph.remove_edges_from(edges) - def get_symmetry(self, node: str) -> Optional[Node]: + def get_symmetry(self, node: NodeRef) -> Optional[Node]: """ Returns the node symmetric with the specified node. Args: - node: The name of the node to query. + node: Node (by `Node` object or name) to query. Raises: ValueError: If node has more than one symmetry. Returns: - The symmetric :class:`Node`, None if no symmetry + The symmetric :class:`Node`, None if no symmetry. """ node_node = self.find_node(node) @@ -661,15 +666,15 @@ def get_symmetry(self, node: str) -> Optional[Node]: else: raise ValueError(f"{node} has more than one symmetry.") - def get_symmetry_name(self, node: str) -> Optional[str]: + def get_symmetry_name(self, node: NodeRef) -> Optional[str]: """ Returns the name of the node symmetric with the specified node. Args: - node: The name of the node to query. + node: Node (by `Node` object or name) to query. Returns: - name of symmetric node, None if no symmetry + Name of symmetric node, None if no symmetry. """ symmetric_node = self.get_symmetry(node) return None if symmetric_node is None else symmetric_node.name @@ -953,7 +958,7 @@ def load_json( return skeleton @classmethod - def load_hdf5(cls, file: Union[str, h5.File], name: str) -> List["Skeleton"]: + def load_hdf5(cls, file: H5FileRef, name: str) -> List["Skeleton"]: """ Load a specific skeleton (by name) from the HDF5 file. @@ -974,7 +979,7 @@ def load_hdf5(cls, file: Union[str, h5.File], name: str) -> List["Skeleton"]: @classmethod def load_all_hdf5( - cls, file: Union[str, h5.File], return_dict: bool = False + cls, file: H5FileRef, return_dict: bool = False ) -> Union[List["Skeleton"], Dict[str, "Skeleton"]]: """ Load all skeletons found in the HDF5 file. @@ -1011,7 +1016,7 @@ def _load_hdf5(cls, file: h5.File): return skeletons @classmethod - def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List["Skeleton"]): + def save_all_hdf5(self, file: H5FileRef, skeletons: List["Skeleton"]): """ Convenience method to save a list of skeletons to HDF5 file. @@ -1038,7 +1043,7 @@ def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List["Skeleton"]): for skeleton in skeletons: skeleton.save_hdf5(file) - def save_hdf5(self, file: Union[str, h5.File]): + def save_hdf5(self, file: H5FileRef): """ Wrapper for HDF5 saving which takes either filename or h5.File. diff --git a/tests/gui/test_tracks.py b/tests/gui/test_tracks.py index b92f773e3..5cc00cbc2 100644 --- a/tests/gui/test_tracks.py +++ b/tests/gui/test_tracks.py @@ -5,7 +5,7 @@ def test_track_trails(centered_pair_predictions): labels = centered_pair_predictions - trail_manager = TrackTrailOverlay(labels, scene=None, trail_length=6) + trail_manager = TrackTrailOverlay(labels, player=None, trail_length=6) frames = trail_manager.get_frame_selection(labels.videos[0], 27) assert len(frames) == 6 From 087bc4aa6a401411b43c1e140dffa65cd665ba5d Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 30 Sep 2019 13:51:10 -0400 Subject: [PATCH 148/176] Fixed docstrings (no `foo`s allowed). --- sleap/gui/merge.py | 6 +++--- sleap/gui/overlays/tracks.py | 7 ++++--- sleap/gui/slider.py | 2 +- sleap/gui/video.py | 8 ++++---- sleap/instance.py | 18 +++++++++--------- sleap/io/dataset.py | 36 ++++++++++++++++++------------------ sleap/io/video.py | 4 ++-- sleap/skeleton.py | 8 ++++---- 8 files changed, 45 insertions(+), 44 deletions(-) diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py index f8b038d8c..ed3a8c578 100644 --- a/sleap/gui/merge.py +++ b/sleap/gui/merge.py @@ -133,13 +133,13 @@ class ConflictTable(QtWidgets.QTableView): Arguments are passed through to the table view object. - The two lists of `LabeledFrame`s should be correlated (idx in one will + The two lists of `LabeledFrame` objects should be correlated (idx in one will match idx of the conflicting frame in other). Args: base_labels: The base dataset. - extra_base: `LabeledFrame`s from base that conflicted. - extra_new: `LabeledFrame`s from new dataset that conflicts. + extra_base: `LabeledFrame` objects from base that conflicted. + extra_new: `LabeledFrame` objects from new dataset that conflicts. """ def __init__( diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index eccda9a32..af4dd8c6a 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -180,8 +180,9 @@ def get_track_trails(self, frame_selection, track: Track): """Get data needed to draw track trail. Args: - frame_selection: an interable with the `LabeledFrame`s to include in trail - track: the `Track` for which to get trail + frame_selection: an interable with the :class:`LabeledFrame` + objects to include in trail. + track: the :class:`Track` for which to get trail Returns: list of lists of (x, y) tuples @@ -214,7 +215,7 @@ def get_track_trails(self, frame_selection, track: Track): def get_frame_selection(self, video: Video, frame_idx: int): """ - Return list of `LabeledFrame`s to include in trail for specified frame. + Return `LabeledFrame` objects to include in trail for specified frame. """ frame_selection = self.labels.find(video, range(0, frame_idx + 1)) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 6eace1e83..7d2720b06 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -177,7 +177,7 @@ def __init__( self.drawHeader() def _pointsToPath(self, points: List[QtCore.QPointF]) -> QPainterPath: - """Converts list of `QtCore.QPointF`s to a `QPainterPath`.""" + """Converts list of `QtCore.QPointF` objects to a `QPainterPath`.""" path = QPainterPath() path.addPolygon(QPolygonF(points)) return path diff --git a/sleap/gui/video.py b/sleap/gui/video.py index f8b669233..399b8abd9 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -129,17 +129,17 @@ def reset(self): @property def instances(self): - """Returns list of all `QtInstance`s in view.""" + """Returns list of all `QtInstance` objects in view.""" return self.view.instances @property def selectable_instances(self): - """Returns list of selectable `QtInstance`s in view.""" + """Returns list of selectable `QtInstance` objects in view.""" return self.view.selectable_instances @property def predicted_instances(self): - """Returns list of predicted `QtInstance`s in view.""" + """Returns list of predicted `QtInstance` objects in view.""" return self.view.predicted_instances @property @@ -567,7 +567,7 @@ def selectable_instances(self) -> List["QtInstance"]: @property def all_instances(self) -> List["QtInstance"]: """ - Returns a list of all `QtInstance`s in scene. + Returns a list of all `QtInstance` objects in scene. Order should match the order in which instances were added to scene. """ diff --git a/sleap/instance.py b/sleap/instance.py index 00d0d0705..c5a5e2afa 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -6,13 +6,13 @@ * A `LabeledFrame` can contain zero or more `Instance`s (and `PredictedInstance`s). -* `Instance`s (and `PredictedInstance`s) have `PointArray` +* `Instance` objects (and `PredictedInstance`s) have `PointArray` (or `PredictedPointArray`). * `Instance` (`PredictedInstance`) can be associated with a `Track` * A `PointArray` (or `PredictedPointArray`) contains zero or more - `Point`s (or `PredictedPoint`s), ideally as many as there are in the + `Point` objects (or `PredictedPoint`s), ideally as many as there are in the associated :class:`Skeleton` although these can get out of sync if the skeleton is manipulated. """ @@ -876,7 +876,7 @@ class when the attributes of one class are a subset of another. Returns: A cattr converter with hooks registered for structuring and - unstructuring :class:`Instance`s and + unstructuring :class:`Instance` objects and :class:`PredictedInstance`s. """ @@ -1110,9 +1110,9 @@ def has_user_instances(self) -> bool: @property def unused_predictions(self) -> List[Instance]: """ - Returns list of "unused" :class:`PredictedInstance`s in frame. + Returns list of "unused" :class:`PredictedInstance` objects in frame. - This is all the :class:`PredictedInstance`s which do not have + This is all the :class:`PredictedInstance` objects which do not have a corresponding :class:`Instance` in the same track in frame. """ unused_predictions = [] @@ -1177,7 +1177,7 @@ def merge_frames( """Merged LabeledFrames for same video and frame index. Args: - labeled_frames: List of :class:`LabeledFrame`s to merge. + labeled_frames: List of :class:`LabeledFrame` objects to merge. video: The :class:`Video` for which to merge. This is specified so we don't have to check all frames when we already know which video has new labeled frames. @@ -1230,15 +1230,15 @@ def complex_merge_between( Args: base_labels: The :class:`Labels` into which we are merging. - new_frames: The list of :class:`LabeledFrame`s from + new_frames: The list of :class:`LabeledFrame` objects from which we are merging. Returns: tuple of three items: * Dictionary, keys are :class:`Video`, values are dictionary in which keys are frame index (int) and value is list of :class:`Instance`s - * list of conflicting :class:`Instance`s from base - * list of conflicting :class:`Instance`s from new frames + * list of conflicting :class:`Instance` objects from base + * list of conflicting :class:`Instance` objects from new frames """ merged = dict() extra_base = [] diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 6d930abbf..072983c00 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -64,14 +64,14 @@ class Labels(MutableSequence): is mostly abstracted away from the main interface. Attributes: - labeled_frames: A list of :class:`LabeledFrame`s - videos: A list of :class:`Video`s that these labels may or may + labeled_frames: A list of :class:`LabeledFrame` objects + videos: A list of :class:`Video` objects that these labels may or may not reference. The video for every `LabeledFrame` will be - stored in :attribute:`Labels.videos`, but some videos in + stored in :attribute:`videos`, but some videos in this list may not have any associated labeled frames. - skeletons: A list of :class:`Skeleton`s (again, that may or may + skeletons: A list of :class:`Skeleton` objects (again, that may or may not be referenced by an :class:`Instance` in labeled frame). - tracks: A list of :class:`Track`s that instances can belong to. + tracks: A list of :class:`Track` that instances can belong to. suggestions: Dictionary that stores "suggested" frames for videos in project. These can be suggested frames for user to label or suggested frames for user to review. @@ -355,7 +355,7 @@ def find( :class:`LabeledFrame` if none is found in project. Returns: - List of `LabeledFrame`s that match the criteria. + List of `LabeledFrame` objects that match the criteria. Empty if no matches found, unless return_new is True, in which case it contains a new `LabeledFrame` with `video` and `frame_index` set. @@ -928,16 +928,16 @@ def complex_merge_between( Args: base_labels: the `Labels` that we're merging into new_labels: the `Labels` that we're merging from - unify: whether to replace objects (e.g., `Video`s) in + unify: whether to replace objects (e.g., `Video`) in new_labels with *matching* objects from base Returns: tuple of three items: * Dictionary, keys are :class:`Video`, values are dictionary in which keys are frame index (int) - and value is list of :class:`Instance`s - * list of conflicting :class:`Instance`s from base - * list of conflicting :class:`Instance`s from new frames + and value is list of :class:`Instance` objects + * list of conflicting :class:`Instance` objects from base + * list of conflicting :class:`Instance` objects from new frames """ # If unify, we want to replace objects in the frames with # corresponding objects from the current labels. @@ -973,11 +973,11 @@ def complex_merge_between( # the merged predictions. # # Args: - # extra_base: list of `LabeledFrame`s - # extra_new: list of `LabeledFrame`s + # extra_base: list of `LabeledFrame` objects + # extra_new: list of `LabeledFrame` objects # Conflicting frames should have same index in both lists. # Returns: - # list of `LabeledFrame`s with merged predictions + # list of `LabeledFrame` objects with merged predictions # """ # pass @@ -1018,7 +1018,7 @@ def merge_container_dicts(dict_a: Dict, dict_b: Dict) -> Dict: def merge_matching_frames(self, video: Optional[Video] = None): """ - Merge `LabeledFrame`s that are for the same video frame. + Merge `LabeledFrame` objects that are for the same video frame. Args: video: combine for this video; if None, do all videos @@ -1229,7 +1229,7 @@ def from_json( data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly instantiated :class:`Labels` can be merged without - duplicate matching objects (e.g., :class:`Video`s). + duplicate matching objects (e.g., :class:`Video` objects ). Returns: A new :class:`Labels` object. """ @@ -1348,7 +1348,7 @@ def load_json( data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly instantiated :class:`Labels` can be merged without - duplicate matching objects (e.g., :class:`Video`s). + duplicate matching objects (e.g., :class:`Video` objects ). Returns: A new :class:`Labels` object. """ @@ -1727,7 +1727,7 @@ def load_hdf5( data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly instantiated :class:`Labels` can be merged without - duplicate matching objects (e.g., :class:`Video`s). + duplicate matching objects (e.g., :class:`Video` objects ). Returns: A new :class:`Labels` object. @@ -1846,7 +1846,7 @@ def save_frame_data_imgstore( Other imgstore formats will probably work as well but have not been tested. all_labels: Include any labeled frames, not just the frames - we'll use for training (i.e., those with `Instance`s). + we'll use for training (i.e., those with `Instance` objects ). Returns: A list of :class:`ImgStoreVideo` objects with the stored diff --git a/sleap/io/video.py b/sleap/io/video.py index e0c18dc34..aac91b171 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -385,7 +385,7 @@ class ImgStoreVideo: accept the frame index from the store directly. Default to True so that we can use an ImgStoreVideo in a dataset to replace another video without having to update all the frame - indices on :class:`LabeledFrame`s in the dataset. + indices on :class:`LabeledFrame` objects in the dataset. """ filename: str = attr.ib(default=None) @@ -808,7 +808,7 @@ def to_imgstore( then it will accept the frame index from the store directly. Default to True so that we can use an ImgStoreVideo in a dataset to replace another video without having to update - all the frame indices on :class:`LabeledFrame`s in the dataset. + all the frame indices on :class:`LabeledFrame` objects in the dataset. Returns: A new Video object that references the imgstore. diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 1ed7eecf8..3eed29e69 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -810,7 +810,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. - Otherwise, we convert :class:`Node`s with the rest of + Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: dict with data from skeleton @@ -833,7 +833,7 @@ def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. - Otherwise, we convert :class:`Node`s with the rest of + Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: @@ -853,7 +853,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. - Otherwise, we convert :class:`Node`s with the rest of + Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: @@ -886,7 +886,7 @@ def save_json(self, filename: str, node_to_idx: Optional[Dict[Node, int]] = None :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. - Otherwise, we convert :class:`Node`s with the rest of + Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: From ff4de22fa8d890fedbbb1872e9f04ecb999c37c6 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 30 Sep 2019 13:56:11 -0400 Subject: [PATCH 149/176] Fixed docstrings (no `foo`s allowed). --- sleap/instance.py | 10 +++++----- sleap/skeleton.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index c5a5e2afa..9ca1d2c53 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -4,17 +4,17 @@ The relationships between objects in this module: * A `LabeledFrame` can contain zero or more `Instance`s - (and `PredictedInstance`s). + (and `PredictedInstance` objects). -* `Instance` objects (and `PredictedInstance`s) have `PointArray` +* `Instance` objects (and `PredictedInstance` objects) have `PointArray` (or `PredictedPointArray`). * `Instance` (`PredictedInstance`) can be associated with a `Track` * A `PointArray` (or `PredictedPointArray`) contains zero or more - `Point` objects (or `PredictedPoint`s), ideally as many as there are in the - associated :class:`Skeleton` although these can get out of sync if the - skeleton is manipulated. + `Point` objects (or `PredictedPoint` objectss), ideally as many as + there are in the associated :class:`Skeleton` although these can get + out of sync if the skeleton is manipulated. """ import math diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 3eed29e69..d0c290396 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -908,8 +908,8 @@ def from_json( Args: json_str: The JSON encoded Skeleton. idx_to_node: optional dict which maps an int (indexing a - list of :class:`Node`s) to the already deserialized - :class:`Node`. + list of :class:`Node` objects) to the already + deserialized :class:`Node`. This should invert `node_to_idx` we used when saving. If not given, then we'll assume each :class:`Node` was left in the :class:`Skeleton` when it was saved. @@ -941,8 +941,8 @@ def load_json( Args: filename: The file that contains the JSON. idx_to_node: optional dict which maps an int (indexing a - list of :class:`Node`s) to the already deserialized - :class:`Node`. + list of :class:`Node` objects) to the already + deserialized :class:`Node`. This should invert `node_to_idx` we used when saving. If not given, then we'll assume each :class:`Node` was left in the :class:`Skeleton` when it was saved. From f6b57ccb2b65c5999bcb5e5f8a54af0f139236d7 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 30 Sep 2019 16:47:55 -0400 Subject: [PATCH 150/176] Fixes to docs. Changed "sLEAP" to "SLEAP". Added modules to tree. Misc fixes to docstrings. --- docs/Makefile | 2 +- docs/conf.py | 12 ++--- docs/gui.rst | 90 +++++++++++++++++++++++++++++++-- docs/index.rst | 3 +- docs/misc.rst | 35 +++++++++++++ docs/tutorial.rst | 22 ++++---- sleap/gui/app.py | 4 +- sleap/gui/formbuilder.py | 8 +-- sleap/gui/importvideos.py | 1 + sleap/gui/merge.py | 2 +- sleap/gui/overlays/tracks.py | 2 +- sleap/gui/suggestions.py | 7 +-- sleap/gui/video.py | 37 +++++++------- sleap/info/write_tracking_h5.py | 12 +++-- sleap/io/dataset.py | 33 ++++++------ sleap/io/legacy.py | 2 +- 16 files changed, 201 insertions(+), 71 deletions(-) create mode 100644 docs/misc.rst diff --git a/docs/Makefile b/docs/Makefile index 8f5aa6ab8..9f8237083 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . -BUILDDIR = ..\..\sleap-docs +BUILDDIR = ../../sleap-docs # Export the BUILDDIR so we can pick it up in conf.py. We need this to # be able to copy some the files in _static to an alternative location diff --git a/docs/conf.py b/docs/conf.py index fb170194e..78f497c50 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ # -- Project information ----------------------------------------------------- -project = 'LEAP' +project = 'SLEAP' copyright = '2019, Murthy Lab @ Princeton' author = 'Talmo D. Pereira, Nat Tabris, David M. Turner' @@ -105,7 +105,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'sLEAPdoc' +htmlhelp_basename = 'SLEAPdoc' # -- Options for LaTeX output ------------------------------------------------ @@ -132,7 +132,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'sLEAP.tex', 'sLEAP Documentation', + (master_doc, 'SLEAP.tex', 'SLEAP Documentation', 'Talmo D. Pereira, Nat Tabris, David M. Turner', 'manual'), ] @@ -142,7 +142,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'sleap', 'sLEAP Documentation', + (master_doc, 'Sleap', 'SLEAP Documentation', [author], 1) ] @@ -153,8 +153,8 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'sLEAP', 'sLEAP Documentation', - author, 'sLEAP', 'One line description of project.', + (master_doc, 'SLEAP', 'SLEAP Documentation', + author, 'SLEAP', 'One line description of project.', 'Miscellaneous'), ] diff --git a/docs/gui.rst b/docs/gui.rst index a1d657797..76ee2b4ff 100644 --- a/docs/gui.rst +++ b/docs/gui.rst @@ -3,18 +3,100 @@ GUI .. automodule:: sleap.gui.app :members: + +Video Player +------------- .. automodule:: sleap.gui.video :members: + +Dialogs +------------- + +Active Learning +^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.active + :members: + +Video Importer +^^^^^^^^^^^^^^ .. automodule:: sleap.gui.importvideos :members: -.. automodule:: sleap.gui.confmapsplot + +Merging +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.merge :members: -.. automodule:: sleap.gui.quiverplot + +Shortcuts +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.shortcuts :members: -.. automodule:: sleap.gui.dataviews + +Suggestions +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.suggestions :members: -.. automodule:: sleap.gui.multicheck + +Training Profiles +^^^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.training_editor + :members: + +Other Widgets +------------- + +Form builder +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.formbuilder :members: + +Slider +^^^^^^^^^^^^^^ .. automodule:: sleap.gui.slider :members: +Multicheck +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.multicheck + :members: + +Overlays +------------- + +Instances +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.instance + :members: + +Tracks +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.tracks + :members: + +Anchors +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.anchors + :members: + +Datasource classes +^^^^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.base + :members: + +Confidence maps +^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.confmaps + :members: + + +Part affinity fields +^^^^^^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.pafs + :members: + + + +Dataviews +------------- +.. automodule:: sleap.gui.dataviews + :members: diff --git a/docs/index.rst b/docs/index.rst index bd16e505b..8fe194a91 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,7 @@ .. _sleap: .. toctree:: - :caption: sLEAP Package + :caption: SLEAP Package :maxdepth: 3 tutorial @@ -14,6 +14,7 @@ training inference gui + misc .. _Indices_and_Tables: diff --git a/docs/misc.rst b/docs/misc.rst new file mode 100644 index 000000000..cea0d774f --- /dev/null +++ b/docs/misc.rst @@ -0,0 +1,35 @@ +Misc +======== + +Utils +------------- +.. automodule:: sleap.util + :members: + +Range list +------------- +.. automodule:: sleap.rangelist + :members: + +Legacy formats +-------------- +.. automodule:: sleap.io.legacy + :members: + +Info tools +---------- + +Metrics +^^^^^^^^^^^^^^ +.. automodule:: sleap.info.metrics + :members: + +Summary +^^^^^^^^^^^^^^ +.. automodule:: sleap.info.summary + :members: + +Track Analysis +^^^^^^^^^^^^^^ +.. automodule:: sleap.info.write_tracking_h5 + :members: diff --git a/docs/tutorial.rst b/docs/tutorial.rst index ad0fca3ca..3932f9e13 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -1,11 +1,11 @@ Tutorial ======== -Before you can use sLEAP, you’ll need to install it. Follow the -instructions at :ref:`Installation` to install sLEAP and +Before you can use SLEAP, you’ll need to install it. Follow the +instructions at :ref:`Installation` to install SLEAP and start the GUI app. -There are three main stages of using sLEAP: +There are three main stages of using SLEAP: 1. Creating a project, opening a movie and defining the skeleton; @@ -18,7 +18,7 @@ There are three main stages of using sLEAP: Stage 1: Creating a project --------------------------- -When you first start sLEAP you’ll see an open dialog. Since you don’t +When you first start SLEAP you’ll see an open dialog. Since you don’t yet have a project to open, click “Cancel” and you’ll be left with a new, empty project. @@ -32,7 +32,7 @@ on the right side of the main window, the “Add Video” command in the |image0| You’ll then be able to select one or more video files and click “Open”. -sLEAP currently supports mp4, avi, and h5 files. For mp4 and avi files, +SLEAP currently supports mp4, avi, and h5 files. For mp4 and avi files, you’ll be asked whether to import the video as grayscale. For h5 files, you’ll be asked the dataset and whether the video is stored with channels first or last. @@ -65,7 +65,7 @@ Stage 2: Labeling and learning We start by assembling a candidate group of images to label. You can either pick your own frames or let the system suggest a set of frames -using the “Generate Suggestions” panel. sLEAP can choose these frames +using the “Generate Suggestions” panel. SLEAP can choose these frames (i) randomly, or using (ii) Strides (evenly spaced samples), (iii) PCA (runs Principle Component Analysis on the images, clusters the images into groups, and uses sample frames from each cluster), or (iv) BRISK @@ -102,9 +102,9 @@ Saving ~~~~~~ Since this is a new project, you’ll need to select a location and name -the first time you save. sLEAP will ask you to save before closing any +the first time you save. SLEAP will ask you to save before closing any project that has been changed to avoid losing any work. Note: There is -not yet an “undo” feature built into sLEAP. If you want to make +not yet an “undo” feature built into SLEAP. If you want to make temporary changes to a project, use the “Save As…” command first to save a copy of your project. @@ -197,7 +197,7 @@ model doesn’t improve for a certain number of epochs (15 by default) First we train a model for confidence maps, part affinity fields, and centroids, and then we run inference. The GUI doesn’t yet give you a way to monitor the progress during inference, although you can get more -information in the console window from which you started sLEAP. +information in the console window from which you started SLEAP. When active learning finishes, you’ll be told how many instances were predicted. Suggested frames with predicted instances will be marked in @@ -265,7 +265,7 @@ Running inference remotely (optional) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It’s also possible to run inference using the command line interface, which is -useful if you’re going to run on a cluster). The command to run inference on +useful if you’re going to run on a cluster. The command to run inference on an entire video is: :: @@ -276,7 +276,7 @@ an entire video is: -m path/to/models/your_paf_model.json \ -m path/to/models/your_centroid_model.json -The predictions will be saved in path/to/video.mp4.predictions.json.zip, +The predictions will be saved in path/to/video.mp4.predictions.h5, which you can open from the GUI app. You can also import these predictions into your project by opening your project and then using the "Import Predictions..." command in the "Predict" menu. diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 494fe8e96..a0c5f3a62 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -57,11 +57,11 @@ class MainWindow(QMainWindow): """The SLEAP GUI application. Each project (`Labels` dataset) that you have loaded in the GUI will - have it's own `MainWindow` object. + have its own `MainWindow` object. Attributes: labels: The :class:`Labels` dataset. If None, a new, empty project - (i.e., :class:`Labels' object) will be created. + (i.e., :class:`Labels` object) will be created. skeleton: The active :class:`Skeleton` for the project in the gui video: The active :class:`Video` in view in the gui """ diff --git a/sleap/gui/formbuilder.py b/sleap/gui/formbuilder.py index 83591421d..1029a2e33 100644 --- a/sleap/gui/formbuilder.py +++ b/sleap/gui/formbuilder.py @@ -2,6 +2,7 @@ Module for creating a form from a yaml file. Example: + >>> widget = YamlFormWidget(yaml_file="example.yaml") >>> widget.mainAction.connect(my_function) @@ -81,7 +82,7 @@ class FormBuilderLayout(QtWidgets.QFormLayout): Custom QFormLayout which populates itself from list of form fields. Args: - items_to_create: list which gets passed to :method:`get_form_data` + items_to_create: list which gets passed to :meth:`get_form_data` (see there for details about format) """ @@ -183,7 +184,8 @@ def build_form(self, items_to_create: List[Dict[str, Any]]): """Adds widgets to form layout for each item in items_to_create. Args: - items_to_create: list of dictionaries with fields: + items_to_create: list of dictionaries with keys + * name: used as key when we return form data as dict * label: string to show in form * type: supports double, int, bool, list, button, stack @@ -311,7 +313,7 @@ class StackBuilderWidget(QtWidgets.QWidget): The "options" key will give the list of options to show in menu. Each of the "options" will also be the key of a dictionary within stack_data that has the same structure as the dictionary - passed to :method:`FormBuilderLayout.build_form()`. + passed to :meth:`FormBuilderLayout.build_form()`. """ def __init__(self, stack_data, *args, **kwargs): diff --git a/sleap/gui/importvideos.py b/sleap/gui/importvideos.py index 7c049ef49..84d478e83 100644 --- a/sleap/gui/importvideos.py +++ b/sleap/gui/importvideos.py @@ -17,6 +17,7 @@ method while passing the user-selected params as the named parameters: >>> vid = item["video_class"](**item["params"]) + """ from PySide2.QtCore import Qt, QRectF, Signal diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py index ed3a8c578..51ba7a82a 100644 --- a/sleap/gui/merge.py +++ b/sleap/gui/merge.py @@ -217,7 +217,7 @@ class MergeTable(QtWidgets.QTableView): Args: merged: The frames that were cleanly merged. - See :method:`Labels.complex_merge_between` for details. + See :meth:`Labels.complex_merge_between` for details. """ def __init__(self, merged, *args, **kwargs): diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index af4dd8c6a..bb4daaa32 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -167,7 +167,7 @@ class TrackTrailOverlay: trail_length: The maximum number of frames to include in trail. Usage: - After class is instantiated, call :method:`add_to_scene(frame_idx)` + After class is instantiated, call :meth:`add_to_scene(frame_idx)` to plot the trails in scene. """ diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 851616ea6..6189fae6a 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -28,7 +28,7 @@ class VideoFrameSuggestions: * proofreading Each of algorithm method should accept `video`; other parameters will be - passed from the `params` dict given to :method:`suggest()`. + passed from the `params` dict given to :meth:`suggest`. """ @@ -343,9 +343,10 @@ def clusters_to_list( @classmethod def get_scale_factor(cls, video: "Video") -> int: - """Determines how much we need to scale to get video within size. + """ + Determines how much we need to scale to get video within size. - Size is specified by :attribute:`rescale_below`. + Size is specified by :attr:`rescale_below`. """ factor = 1 if cls.rescale: diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 399b8abd9..4665737e7 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -9,6 +9,7 @@ >>> vp = QtVideoPlayer(video=my_video) >>> vp.addInstance(instance=my_instance, color=(r, g, b)) + """ from PySide2 import QtWidgets @@ -52,14 +53,15 @@ class QtVideoPlayer(QWidget): """ Main QWidget for displaying video with skeleton instances. + Signals: + * changedPlot: Emitted whenever the plot is redrawn + * changedData: Emitted whenever data is changed by user + Attributes: video: The :class:`Video` to display color_manager: A :class:`TrackColorManager` object which determines which color to show the instances. - Signals: - changedPlot: Emitted whenever the plot is redrawn - changedData: Emitted whenever data is changed by user """ changedPlot = Signal(QWidget, int, Instance) @@ -268,6 +270,10 @@ def onSequenceSelect( (if given). If the user cancels (by unselecting without new selection), the `on_failure` callback is called (if given). + Note: + If successful, we call + >>> on_success(sequence_of_selected_instance_indexes) + Args: seq_len: Number of instances we want to collect in sequence. on_success: Callback for when user has selected desired number of @@ -276,9 +282,6 @@ def onSequenceSelect( on_failure: Callback if user cancels process before selecting enough instances. - Note: - If successful, we call - >>> on_success(sequence_of_selected_instance_indexes) """ indexes = [] @@ -430,22 +433,22 @@ class GraphicsView(QGraphicsView): and selection of instances in view. Signals: - updatedViewer: Emitted after update to view (e.g., zoom). + * updatedViewer: Emitted after update to view (e.g., zoom). Used internally so we know when to update points for each instance. - updatedSelection: Emitted after the user has (un)selected an instance. - instanceDoubleClicked: Emitted after an instance is double-clicked. + * updatedSelection: Emitted after the user has (un)selected an instance. + * instanceDoubleClicked: Emitted after an instance is double-clicked. Passes the :class:`Instance` that was double-clicked. - areaSelected: Emitted after user selects an area when in "area" + * areaSelected: Emitted after user selects an area when in "area" click mode. Passes x0, y0, x1, y1 for selected box coordinates. - pointSelected: Emitted after user clicks a point (in "point" click + * pointSelected: Emitted after user clicks a point (in "point" click mode.) Passes x, y coordinates of point. + * leftMouseButtonPressed: Emitted by event handler. + * rightMouseButtonPressed: Emitted by event handler. + * leftMouseButtonReleased: Emitted by event handler. + * rightMouseButtonReleased: Emitted by event handler. + * leftMouseButtonDoubleClicked: Emitted by event handler. + * rightMouseButtonDoubleClicked: Emitted by event handler. - leftMouseButtonPressed - rightMouseButtonPressed - leftMouseButtonReleased - rightMouseButtonReleased - leftMouseButtonDoubleClicked - rightMouseButtonDoubleClicked """ updatedViewer = Signal() diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 896def438..a91a815fa 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -7,14 +7,15 @@ of video. Call from command line as: -> python -m sleap.io.write_tracking_h5 + +>>> python -m sleap.io.write_tracking_h5 Will write file to `.tracking.h5`. The HDF5 file has these datasets: - "track_occupancy" shape: tracks * frames - "tracks" shape: frames * nodes * 2 * tracks - "track_names" shape: tracks +* "track_occupancy" shape: tracks * frames +* "tracks" shape: frames * nodes * 2 * tracks +* "track_names" shape: tracks Note: the datasets are stored column-major as expected by MATLAB. """ @@ -30,7 +31,7 @@ def get_tracks_as_np_strings(labels: Labels) -> List[np.string_]: - """Get list of track names as `np.string_`s.""" + """Get list of track names as `np.string_`.""" return [np.string_(track.name) for track in labels.tracks] @@ -49,6 +50,7 @@ def get_occupancy_and_points_matrices( Returns: tuple of two matrices: + * occupancy matrix with shape (tracks, frames) * point location matrix with shape (frames, nodes, 2, tracks) """ diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 072983c00..f8d7a4fb1 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -67,7 +67,7 @@ class Labels(MutableSequence): labeled_frames: A list of :class:`LabeledFrame` objects videos: A list of :class:`Video` objects that these labels may or may not reference. The video for every `LabeledFrame` will be - stored in :attribute:`videos`, but some videos in + stored in `videos` attribute, but some videos in this list may not have any associated labeled frames. skeletons: A list of :class:`Skeleton` objects (again, that may or may not be referenced by an :class:`Instance` in labeled frame). @@ -568,7 +568,7 @@ def track_swap( Swaps track assignment for instances in two tracks. If you need to change the track to or from None, you'll need - to use :method:`Labels.track_set_instance()` for each specific + to use :meth:`track_set_instance` for each specific instance you want to modify. Args: @@ -865,12 +865,13 @@ def extend_from( self, new_frames: Union["Labels", List[LabeledFrame]], unify: bool = False ): """ - Merge in data from another Labels object or list of LabeledFrames. + Merge data from another `Labels` object or `LabeledFrame` list. Arg: new_frames: the object from which to copy data unify: whether to replace objects in new frames with corresponding objects from current `Labels` data + Returns: bool, True if we added frames, False otherwise """ @@ -909,19 +910,19 @@ def complex_merge_between( cls, base_labels: "Labels", new_labels: "Labels", unify: bool = True ) -> tuple: """ - Merge frames and other data that can be merged cleanly, - and return frames that conflict. + Merge frames and other data from one dataset into another. Anything that can be merged cleanly is merged into base_labels. Frames conflict just in case each labels object has a matching - frame (same video and frame idx) which instances not in the other. + frame (same video and frame idx) with instances not in other. + + Frames can be merged cleanly if: - Frames can be merged cleanly if - - the frame is in only one of the labels, or - - the frame is in both labels, but all instances perfectly match + * the frame is in only one of the labels, or + * the frame is in both labels, but all instances perfectly match (which means they are redundant), or - - the frame is in both labels, maybe there are some redundant + * the frame is in both labels, maybe there are some redundant instances, but only one version of the frame has additional instances not in the other. @@ -933,11 +934,13 @@ def complex_merge_between( Returns: tuple of three items: + * Dictionary, keys are :class:`Video`, values are dictionary in which keys are frame index (int) and value is list of :class:`Instance` objects * list of conflicting :class:`Instance` objects from base - * list of conflicting :class:`Instance` objects from new frames + * list of conflicting :class:`Instance` objects from new + """ # If unify, we want to replace objects in the frames with # corresponding objects from the current labels. @@ -1342,8 +1345,8 @@ def load_json( video_callback: A callback function that which can modify video paths before we try to create the corresponding :class:`Video` objects. Usually you'll want to pass - a callback created by :method:`Labels.make_video_callback` - or :method:`Labels.make_gui_video_callback`. + a callback created by :meth:`make_video_callback` + or :meth:`make_gui_video_callback`. match_to: If given, we'll replace particular objects in the data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly @@ -1721,8 +1724,8 @@ def load_hdf5( video_callback: A callback function that which can modify video paths before we try to create the corresponding :class:`Video` objects. Usually you'll want to pass - a callback created by :method:`Labels.make_video_callback` - or :method:`Labels.make_gui_video_callback`. + a callback created by :meth:`make_video_callback` + or :meth:`make_gui_video_callback`. match_to: If given, we'll replace particular objects in the data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly diff --git a/sleap/io/legacy.py b/sleap/io/legacy.py index 4340d460c..ebcb38256 100644 --- a/sleap/io/legacy.py +++ b/sleap/io/legacy.py @@ -39,7 +39,7 @@ def load_predicted_labels_json_old( fix_rel_paths: Whether to fix paths to videos to absolute paths. Returns: - List of :class:`LabeledFrame`s. + List of :class:`LabeledFrame` objects. """ if parsed_json is None: data = json.loads(open(data_path).read()) From d9acab5537978fc40d492bc4999b651b437b8190 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 30 Sep 2019 19:18:37 -0400 Subject: [PATCH 151/176] --nonnative param to use Qt file dialogs. --- sleap/gui/app.py | 81 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 16 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index a0c5f3a62..b652b54eb 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -70,16 +70,25 @@ class MainWindow(QMainWindow): skeleton: Skeleton video: Video - def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): + def __init__( + self, + labels_path: Optional[str] = None, + nonnative: bool = False, + *args, + **kwargs, + ): """Initialize the app. Args: labels_path: Path to saved :class:`Labels` dataset. + nonnative: Whether to use native or Qt file dialog. Returns: None. """ super(MainWindow, self).__init__(*args, **kwargs) + self.nonnative = nonnative + self.labels = Labels() self.skeleton = Skeleton() self.labeled_frame = None @@ -102,7 +111,11 @@ def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): self.changestack_clear() self._initialize_gui() - if labels_path is not None: + self._file_dialog_options = 0 + if self.nonnative: + self._file_dialog_options = QFileDialog.DontUseNativeDialog + + if labels_path: self.loadProject(labels_path) def event(self, e: QEvent) -> bool: @@ -1009,7 +1022,11 @@ def openSkeleton(self): """Shows gui for loading saved skeleton into project.""" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] filename, selected_filter = QFileDialog.getOpenFileName( - self, dir=None, caption="Open skeleton...", filter=";;".join(filters) + self, + dir=None, + caption="Open skeleton...", + filter=";;".join(filters), + options=self._file_dialog_options, ) if len(filename) == 0: @@ -1034,7 +1051,11 @@ def saveSkeleton(self): default_name = "skeleton.json" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] filename, selected_filter = QFileDialog.getSaveFileName( - self, caption="Save As...", dir=default_name, filter=";;".join(filters) + self, + caption="Save As...", + dir=default_name, + filter=";;".join(filters), + options=self._file_dialog_options, ) if len(filename) == 0: @@ -1264,6 +1285,7 @@ def visualizeOutputs(self): dir=models_dir, caption="Import model outputs...", filter=";;".join(filters), + options=self._file_dialog_options, ) if len(filename) == 0: @@ -1474,7 +1496,11 @@ def importPredictions(self): """Starts gui for importing another dataset into currently one.""" filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"] filenames, selected_filter = QFileDialog.getOpenFileNames( - self, dir=None, caption="Import labeled data...", filter=";;".join(filters) + self, + dir=None, + caption="Import labeled data...", + filter=";;".join(filters), + options=self._file_dialog_options, ) if len(filenames) == 0: @@ -1834,7 +1860,11 @@ def openProject(self, first_open: bool = False): ] filename, selected_filter = QFileDialog.getOpenFileName( - self, dir=None, caption="Import labeled data...", filter=";;".join(filters) + self, + dir=None, + caption="Import labeled data...", + filter=";;".join(filters), + options=self._file_dialog_options, ) if len(filename) == 0: @@ -1857,7 +1887,7 @@ def saveProject(self): def saveProjectAs(self): """Show gui to save project as a new file.""" - default_name = self.filename if self.filename is not None else "untitled.json" + default_name = self.filename if self.filename is not None else "untitled" p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem} copy{p.suffix}")) @@ -1867,7 +1897,11 @@ def saveProjectAs(self): "Compressed JSON (*.zip)", ] filename, selected_filter = QFileDialog.getSaveFileName( - self, caption="Save As...", dir=default_name, filter=";;".join(filters) + self, + caption="Save As...", + dir=default_name, + filter=";;".join(filters), + options=self._file_dialog_options, ) if len(filename) == 0: @@ -1971,6 +2005,7 @@ def exportLabeledClip(self): caption="Save Video As...", dir=self.filename + ".avi", filter="AVI Video (*.avi)", + options=self._file_dialog_options, ) if len(filename) == 0: @@ -1988,7 +2023,10 @@ def exportLabeledClip(self): def exportLabeledFrames(self): """Gui for exporting the training dataset of labels/frame images.""" filename, _ = QFileDialog.getSaveFileName( - self, caption="Save Labeled Frames As...", dir=self.filename + self, + caption="Save Labeled Frames As...", + dir=self.filename, + options=self._file_dialog_options, ) if len(filename) == 0: return @@ -2138,16 +2176,27 @@ def main(*args, **kwargs): window = MainWindow(*args, **kwargs) window.showMaximized() - if "labels_path" not in kwargs: + if not kwargs.get("labels_path", None): window.openProject(first_open=True) app.exec_() if __name__ == "__main__": - - kwargs = dict() - if len(sys.argv) > 1: - kwargs["labels_path"] = sys.argv[1] - - main(**kwargs) + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "labels_path", help="Path to labels file", type=str, default=None, nargs="?" + ) + parser.add_argument( + "--nonnative", + help="Don't use native file dialogs", + action="store_const", + const=True, + default=False, + ) + + args = parser.parse_args() + + main(**vars(args)) From 758b3668c880901720601e5edee4b9919d2e8cef Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 09:34:06 -0400 Subject: [PATCH 152/176] Add support for saving hdf5 w/ labels and imgs. --- sleap/io/dataset.py | 43 ++++++++++++++----- sleap/io/video.py | 90 +++++++++++++++++++++++++++++++++++++--- tests/io/test_dataset.py | 13 +++++- tests/io/test_video.py | 59 ++++++++++++++++++++++++++ 4 files changed, 189 insertions(+), 16 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index f8d7a4fb1..18da90a59 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1479,21 +1479,12 @@ def save_hdf5( any labeled frame as well. This is useful for uploading the HDF5 for model training when video files are to large to move. This will only save video frames that - have some labeled instances. NOT YET IMPLENTED. - - Raises: - NotImplementedError: If save_frame_data is True. + have some labeled instances. Returns: None """ - # FIXME: Need to implement this. - if save_frame_data: - raise NotImplementedError( - "Saving frame data is not implemented yet with HDF5 Labels datasets." - ) - # Delete the file if it exists, we want to start from scratch since # h5py truncates the file which seems to not actually delete data # from the file. Don't if we are appending of course. @@ -1503,6 +1494,10 @@ def save_hdf5( # Serialize all the meta-data to JSON. d = labels.to_dict(skip_labels=True) + if save_frame_data: + new_videos = labels.save_frame_data_hdf5(filename) + d["videos"] = Video.cattr().unstructure(new_videos) + with h5.File(filename, "a") as f: # Add all the JSON metadata @@ -1878,6 +1873,34 @@ def save_frame_data_imgstore( return imgstore_vids + def save_frame_data_hdf5(self, output_path: str, all_labels: bool = False): + """ + Write labeled frames from all videos to hdf5 file. + + Args: + output_path: Path to HDF5 file. + all_labels: Include any labeled frames, not just the frames + we'll use for training (i.e., those with Instances). + + Returns: + A list of :class:`HDF5Video` objects with the stored frames. + """ + new_vids = [] + for v_idx, v in enumerate(self.videos): + frame_nums = [ + lf.frame_idx + for lf in self.labeled_frames + if v == lf.video and (all_labels or lf.has_user_instances) + ] + + vid = v.to_hdf5( + path=output_path, dataset=f"video{v_idx}", frame_numbers=frame_nums + ) + vid.close() + new_vids.append(vid) + + return new_vids + @staticmethod def _unwrap_mat_scalar(a): """Extract single value from nested MATLAB file data.""" diff --git a/sleap/io/video.py b/sleap/io/video.py index aac91b171..21374e286 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -44,6 +44,8 @@ class HDF5Video: def __attrs_post_init__(self): """Called by attrs after __init__().""" + self.__original_to_current_frame_idx = dict() + # Handle cases where the user feeds in h5.File objects instead of filename if isinstance(self.filename, h5.File): self.__file_h5 = self.filename @@ -63,12 +65,21 @@ def __attrs_post_init__(self): self.__dataset_h5 = self.dataset self.__file_h5 = self.__dataset_h5.file self.dataset = self.__dataset_h5.name - elif ( - (self.dataset is not None) - and isinstance(self.dataset, str) - and (self.__file_h5 is not None) - ): + + # File loaded and dataset name given, so load dataset + elif isinstance(self.dataset, str) and (self.__file_h5 is not None): self.__dataset_h5 = self.__file_h5[self.dataset] + + # Check for frame_numbers dataset corresponding to video + base_dataset_path = "/".join(self.dataset.split("/")[:-1]) + framenum_dataset = f"{base_dataset_path}/frame_numbers" + if framenum_dataset in self.__file_h5: + original_idx_lists = self.__file_h5[framenum_dataset] + # Create map from idx in original video to idx in current + for current_idx in range(len(original_idx_lists)): + original_idx = original_idx_lists[current_idx] + self.__original_to_current_frame_idx[original_idx] = current_idx + else: self.__dataset_h5 = None @@ -142,6 +153,13 @@ def get_frame(self, idx) -> np.ndarray: Returns: The numpy.ndarray representing the video frame data. """ + # If we only saved some frames from a video, map to idx in dataset. + if self.__original_to_current_frame_idx: + if idx in self.__original_to_current_frame_idx: + idx = self.__original_to_current_frame_idx[idx] + else: + raise ValueError(f"Frame index {idx} not in original index.") + frame = self.__dataset_h5[idx] if self.input_format == "channels_first": @@ -152,6 +170,12 @@ def get_frame(self, idx) -> np.ndarray: return frame + def close(self): + """Closes the HDF5 file object (if it's open).""" + if self.__file_h5: + self.__file_h5.close() + self.__file_h5 = None + @attr.s(auto_attribs=True, cmp=False) class MediaVideo: @@ -865,6 +889,62 @@ def to_imgstore( backend=ImgStoreVideo(filename=path, index_by_original=index_by_original) ) + def to_hdf5( + self, + path: str, + dataset: str, + frame_numbers: List[int] = None, + index_by_original: bool = True, + ): + """ + Converts frames from arbitrary video backend to HDF5Video. + + Used for building an HDF5 that holds all data needed for training. + + Args: + path: Filename to HDF5 (which could already exist). + dataset: The HDF5 dataset in which to store video frames. + frame_numbers: A list of frame numbers from the video to save. + If None save the entire video. + index_by_original: If the index_by_original is set to True then + the get_frame function will accept the original frame + numbers of from original video. + If False, then it will accept the frame index directly. + Default to True so that we can use resulting video in a + dataset to replace another video without having to update + all the frame indices in the dataset. + + Returns: + A new Video object that references the HDF5 dataset. + """ + + # If the user has not provided a list of frames to store, store them all. + if frame_numbers is None: + frame_numbers = range(self.num_frames) + + frame_data = self.get_frames(frame_numbers) + frame_numbers_data = np.array(list(frame_numbers), dtype=int) + + with h5.File(path, "a") as f: + f.create_dataset( + dataset + "/video", + data=frame_data, + compression="gzip", + compression_opts=9, + ) + + if index_by_original: + f.create_dataset(dataset + "/frame_numbers", data=frame_numbers_data) + + return self.__class__( + backend=HDF5Video( + filename=path, + dataset=dataset + "/video", + input_format="channels_last", + convert_range=False, + ) + ) + @staticmethod def cattr(): """ diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index b0799e402..4f9fadc54 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -14,7 +14,7 @@ def _check_labels_match(expected_labels, other_labels, format="png"): """ - A utitlity function to check whether to sets of labels match. + A utility function to check whether to sets of labels match. This doesn't directly compares some things (like video objects). Args: @@ -605,6 +605,17 @@ def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): loaded_labels = Labels.load_json(f"{filename}.zip") +def test_save_labels_and_frames_hdf5(multi_skel_vid_labels, tmpdir): + labels = multi_skel_vid_labels + filename = os.path.join(tmpdir, "test.h5") + + Labels.save_hdf5(filename=filename, labels=labels, save_frame_data=True) + + loaded_labels = Labels.load_hdf5(filename=filename) + + _check_labels_match(labels, loaded_labels) + + def test_labels_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels filename = os.path.join(tmpdir, "test.h5") diff --git a/tests/io/test_video.py b/tests/io/test_video.py index 3e17991fb..ea4d33631 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -166,3 +166,62 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): with pytest.raises(ValueError): imgstore_vid.get_frames([0, 1, 2]) + + +def test_hdf5_inline_video(small_robot_mp4_vid, tmpdir): + + path = os.path.join(tmpdir, "test_to_hdf5") + frame_indices = [0, 1, 5] + + # Save hdf5 version of the first few frames of this video. + hdf5_vid = small_robot_mp4_vid.to_hdf5(path, "testvid", frame_numbers=frame_indices) + + assert hdf5_vid.num_frames == len(frame_indices) + + # Make sure we can read arbitrary frames by imgstore frame number + for i in frame_indices: + assert type(hdf5_vid.get_frame(i)) == np.ndarray + + assert hdf5_vid.channels == 3 + assert hdf5_vid.height == 320 + assert hdf5_vid.width == 560 + + # Check the image data is exactly the same when lossless is used. + assert np.allclose( + hdf5_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91 + ) + + +def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): + """ + Test different types of indexing (by frame number or index). + """ + path = os.path.join(tmpdir, "test_to_hdf5") + + frame_indices = [20, 40, 15] + + hdf5_vid = small_robot_mp4_vid.to_hdf5( + path, dataset="testvid2", frame_numbers=frame_indices, index_by_original=False + ) + + # Index by frame index in imgstore + frames = hdf5_vid.get_frames([0, 1, 2]) + assert frames.shape == (3, 320, 560, 3) + + with pytest.raises(ValueError): + hdf5_vid.get_frames(frame_indices) + + # We have to close file before we can add another video dataset. + hdf5_vid.close() + + # Now re-create the imgstore with frame number indexing, (the default) + hdf5_vid2 = small_robot_mp4_vid.to_hdf5( + path, dataset="testvid3", frame_numbers=frame_indices + ) + + # Index by frame index in imgstore + frames = hdf5_vid2.get_frames(frame_indices) + assert frames.shape == (3, 320, 560, 3) + + with pytest.raises(ValueError): + hdf5_vid2.get_frames([0, 1, 2]) From 4b5a55321ad6d47717516f364c34ac5565a2adc2 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 10:00:49 -0400 Subject: [PATCH 153/176] Enable hdf5 package export in gui. --- sleap/gui/app.py | 10 ++++++++-- sleap/io/dataset.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index b652b54eb..48e8275f6 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -2022,15 +2022,21 @@ def exportLabeledClip(self): def exportLabeledFrames(self): """Gui for exporting the training dataset of labels/frame images.""" + filters = [ + "HDF5 dataset (*.h5 *.hdf5)", + "Compressed JSON dataset (*.json *.json.zip)", + ] filename, _ = QFileDialog.getSaveFileName( self, caption="Save Labeled Frames As...", - dir=self.filename, + dir=self.filename + ".h5", + filters=";;".join(filters), options=self._file_dialog_options, ) if len(filename) == 0: return - Labels.save_json(self.labels, filename, save_frame_data=True) + + Labels.save_file(self.labels, filename, save_frame_data=True) def _plot_if_next(self, frame_iterator: Iterator) -> bool: """Plots next frame (if there is one) from iterator. diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 18da90a59..1e71f3e2c 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1822,9 +1822,9 @@ def save_file(cls, labels: "Labels", filename: str, *args, **kwargs): """Save file, detecting format from filename.""" if filename.endswith((".json", ".zip")): compress = filename.endswith(".zip") - cls.save_json(labels=labels, filename=filename, compress=compress) + cls.save_json(labels=labels, filename=filename, compress=compress, **kwargs) elif filename.endswith(".h5"): - cls.save_hdf5(labels=labels, filename=filename) + cls.save_hdf5(labels=labels, filename=filename, **kwargs) else: raise ValueError(f"Cannot detect filetype for {filename}") From fcad548ea0abaf858297bab8b76e488da1cdfa8f Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 10:44:57 -0400 Subject: [PATCH 154/176] Use "." for paths of hdf5 video in labels file. --- sleap/io/dataset.py | 14 ++++++++++++++ sleap/io/video.py | 16 ++++++++++------ tests/io/test_dataset.py | 5 +++++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 1e71f3e2c..10fa20313 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1496,6 +1496,14 @@ def save_hdf5( if save_frame_data: new_videos = labels.save_frame_data_hdf5(filename) + + # Replace path to video file with "." (which indicates that the + # video is in the same file as the HDF5 labels dataset). + # Otherwise, the video paths will break if the HDF5 labels + # dataset file is moved. + for vid in new_videos: + vid.backend.filename = "." + d["videos"] = Video.cattr().unstructure(new_videos) with h5.File(filename, "a") as f: @@ -1738,6 +1746,12 @@ def load_hdf5( f.require_group("metadata").attrs["json"].tostring().decode() ) + # Video path "." means the video is saved in same file as labels, + # so replace these paths. + for video_item in dicts["videos"]: + if video_item["backend"]["filename"] == ".": + video_item["backend"]["filename"] = filename + # Use the callback if given to handle missing videos if callable(video_callback): video_callback(dicts["videos"]) diff --git a/sleap/io/video.py b/sleap/io/video.py index 21374e286..8e943a13b 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -115,6 +115,16 @@ def matches(self, other: "HDF5Video") -> bool: and self.input_format == other.input_format ) + def close(self): + """Closes the HDF5 file object (if it's open).""" + if self.__file_h5: + self.__file_h5.close() + self.__file_h5 = None + + def __del__(self): + """Releases file object.""" + self.close() + # The properties and methods below complete our contract with the # higher level Video interface. @@ -170,12 +180,6 @@ def get_frame(self, idx) -> np.ndarray: return frame - def close(self): - """Closes the HDF5 file object (if it's open).""" - if self.__file_h5: - self.__file_h5.close() - self.__file_h5 = None - @attr.s(auto_attribs=True, cmp=False) class MediaVideo: diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 4f9fadc54..6578b3c82 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -615,6 +615,11 @@ def test_save_labels_and_frames_hdf5(multi_skel_vid_labels, tmpdir): _check_labels_match(labels, loaded_labels) + # Make sure we can after rename + filerename = os.path.join(tmpdir, "test_rename.h5") + os.rename(filename, filerename) + loaded_labels = Labels.load_hdf5(filename=filerename) + def test_labels_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels From 1afe6f092a0e7a8f6e9e656e0ca7b1c20521ed32 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 11:22:46 -0400 Subject: [PATCH 155/176] Release hdf5 labels object in test. --- tests/io/test_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 6578b3c82..de1679011 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -616,6 +616,7 @@ def test_save_labels_and_frames_hdf5(multi_skel_vid_labels, tmpdir): _check_labels_match(labels, loaded_labels) # Make sure we can after rename + loaded_labels = None filerename = os.path.join(tmpdir, "test_rename.h5") os.rename(filename, filerename) loaded_labels = Labels.load_hdf5(filename=filerename) From 5a4a586b11423ceaf993d104fea435338446d68b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 11:37:04 -0400 Subject: [PATCH 156/176] Close hdf5 videos in test. --- tests/io/test_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index de1679011..e23628ea5 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -615,10 +615,13 @@ def test_save_labels_and_frames_hdf5(multi_skel_vid_labels, tmpdir): _check_labels_match(labels, loaded_labels) - # Make sure we can after rename - loaded_labels = None + # Rename file (after closing videos) + for vid in loaded_labels.videos: + vid.close() filerename = os.path.join(tmpdir, "test_rename.h5") os.rename(filename, filerename) + + # Make sure we open can after rename loaded_labels = Labels.load_hdf5(filename=filerename) From 20871d9b0d87317ab20b135997d475dfbbdeaaf9 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 11:48:07 -0400 Subject: [PATCH 157/176] Use subset of frames for hdf5 test. --- tests/io/test_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index e23628ea5..0695030a3 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -606,7 +606,10 @@ def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): def test_save_labels_and_frames_hdf5(multi_skel_vid_labels, tmpdir): + # Lets take a subset of the labels so this doesn't take too long labels = multi_skel_vid_labels + labels.labeled_frames = labels.labeled_frames[5:30] + filename = os.path.join(tmpdir, "test.h5") Labels.save_hdf5(filename=filename, labels=labels, save_frame_data=True) From 3462df52228fcf31b1db4e602cfa7e0bd9371317 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 12:25:09 -0400 Subject: [PATCH 158/176] Improve rangelist test coverage. --- tests/test_rangelist.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_rangelist.py b/tests/test_rangelist.py index 42da9e8f9..579b0d9a7 100644 --- a/tests/test_rangelist.py +++ b/tests/test_rangelist.py @@ -11,8 +11,12 @@ def test_rangelist(): [(60, 70)], [(70, 100)], ) + + # Test inserting range as tuple assert a.insert((10, 20)) == [(1, 2), (3, 5), (7, 20), (50, 100)] - assert a.insert((5, 8)) == [(1, 2), (3, 20), (50, 100)] + + # Test insert range as range + assert a.insert(range(5, 8)) == [(1, 2), (3, 20), (50, 100)] a.remove((5, 8)) assert a.list == [(1, 2), (3, 5), (8, 20), (50, 100)] @@ -31,3 +35,19 @@ def test_rangelist(): b.add(10) assert b.list == [(1, 3), (4, 7), (9, 11)] + + empty = RangeList() + assert empty.start is None + assert empty.cut_range((3, 4)) == ([], [], []) + + empty.insert((1, 2)) + assert str(empty) == "RangeList([(1, 2)])" + + empty.insert_list([(1, 2), (3, 5), (7, 13), (50, 100)]) + assert empty.list == [(1, 2), (3, 5), (7, 13), (50, 100)] + + # Test special cases for helper functions + assert RangeList.join_([(1, 2)]) == (1, 2) + assert RangeList.join_pair_(list_a=[(1, 2)], list_b=[]) == [(1, 2)] + assert RangeList.join_pair_(list_a=[], list_b=[(1, 2)]) == [(1, 2)] + assert RangeList.join_pair_(list_a=[], list_b=[]) == [] From af599691febf1c568842f1b217dd3ad7b0437c73 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 12:41:48 -0400 Subject: [PATCH 159/176] Use player.view in track overlay. --- sleap/gui/overlays/tracks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index bb4daaa32..d9f9b00a5 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -323,7 +323,7 @@ def visible(self, val): if self.text_box is None: return if val: - pos = self.view.mapToScene(10, 10) + pos = self.player.view.mapToScene(10, 10) if pos.x() > 0: self.text_box.setPos(pos) else: From 7e20727c8272829db0e41e0ee77080032007d0d2 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 13:26:36 -0400 Subject: [PATCH 160/176] Make empty frame for empty hdf5 video. --- sleap/io/video.py | 6 +++++- tests/io/test_video.py | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index 8e943a13b..dd627c834 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -926,7 +926,11 @@ def to_hdf5( if frame_numbers is None: frame_numbers = range(self.num_frames) - frame_data = self.get_frames(frame_numbers) + if frame_numbers: + frame_data = self.get_frames(frame_numbers) + else: + frame_data = np.zeros((1, 1, 1, 1)) + frame_numbers_data = np.array(list(frame_numbers), dtype=int) with h5.File(path, "a") as f: diff --git a/tests/io/test_video.py b/tests/io/test_video.py index ea4d33631..4189152c0 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -168,6 +168,11 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): imgstore_vid.get_frames([0, 1, 2]) +def test_empty_hdf5_video(small_robot_mp4_vid, tmpdir): + path = os.path.join(tmpdir, "test_to_hdf5") + hdf5_vid = small_robot_mp4_vid.to_hdf5(path, "testvid", frame_numbers=[]) + + def test_hdf5_inline_video(small_robot_mp4_vid, tmpdir): path = os.path.join(tmpdir, "test_to_hdf5") From 688e851dd6e3a9a1674d905732a8528c627ad2e2 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 13:27:58 -0400 Subject: [PATCH 161/176] Only include negative samples from labeled frames. --- sleap/nn/datagen.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/sleap/nn/datagen.py b/sleap/nn/datagen.py index 165280949..3507e30d1 100644 --- a/sleap/nn/datagen.py +++ b/sleap/nn/datagen.py @@ -768,6 +768,7 @@ def negative_anchor_crops( Args: labels: the `Labels` object + negative_anchors: The anchors for negative training samples. scale: scale, should match scale given to generate_images() crop_size: the size of the crops returned by instance_crops() Returns: @@ -778,10 +779,14 @@ def negative_anchor_crops( # negative_anchors[video]: (frame_idx, x, y) for center of crop + # Filter negative anchors so we only include frames with labeled data + training_frames = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames] + neg_anchor_tuples = [ (video, frame_idx, x, y) for video in negative_anchors for (frame_idx, x, y) in negative_anchors[video] + if (video, frame_idx) in training_frames ] if len(neg_anchor_tuples) == 0: @@ -945,19 +950,11 @@ def demo_datagen_time(): def demo_datagen(): - import os - data_path = ( - "C:/Users/tdp/OneDrive/code/sandbox/leap_wt_gold_pilot/centered_pair.json" - ) - if not os.path.exists(data_path): - data_path = "tests/data/json_format_v1/centered_pair.json" - # data_path = "tests/data/json_format_v2/minimal_instance.json" + data_path = "tests/data/json_format_v1/centered_pair.json" + data_path = "/Users/tabris/Desktop/macpaths.json.h5" - labels = Labels.load_json(data_path) - # testing - labels.negative_anchors = {labels.videos[0]: [(0, 125, 125), (0, 150, 150)]} - # labels.labeled_frames = labels.labeled_frames[123:423:10] + labels = Labels.load_file(data_path) scale = 1 From cfbee84e7ff29aea0ec1c377095a88fd7f938d66 Mon Sep 17 00:00:00 2001 From: Talmo Date: Tue, 1 Oct 2019 13:37:30 -0400 Subject: [PATCH 162/176] Update inference overlay to use InferenceModel (Merged by ntabris with some additional fixes.) --- sleap/gui/overlays/base.py | 67 ++++++++++-------------- sleap/gui/overlays/pafs.py | 103 +++++++++++++++++++++++++++---------- 2 files changed, 102 insertions(+), 68 deletions(-) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index b7d03077f..5d4d92549 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -1,4 +1,4 @@ -"""Base class for overlays that use datasource (hdf5, model).""" +"""Base class for overlays.""" from PySide2 import QtWidgets @@ -6,14 +6,13 @@ import numpy as np from typing import Sequence +import sleap from sleap.io.video import Video, HDF5Video from sleap.gui.video import QtVideoPlayer from sleap.nn.transform import DataTransform class HDF5Data(HDF5Video): - """Class to wrap HDF5Video so we can use it as overlay datasource.""" - def __getitem__(self, i): """Get data for frame i from `HDF5Video` object.""" x = self.get_frame(i) @@ -22,10 +21,7 @@ def __getitem__(self, i): @attr.s(auto_attribs=True) class ModelData: - """Class to wrap model so we can use it as overlay datasource.""" - - # TODO: Unify this class with inference.Predictor or InferenceModel - model: "keras.Model" + inference_model: "sleap.nn.inference.InferenceModel" video: Video do_rescale: bool = False output_scale: float = 1.0 @@ -36,19 +32,23 @@ def __getitem__(self, i): frame_img = self.video[i] # Trim to size that works for model - frame_img = frame_img[ - :, : self.video.height // 8 * 8, : self.video.width // 8 * 8, : - ] + size_reduction = 2 ** (self.inference_model.down_blocks) + input_size = ( + (self.video.height // size_reduction) * size_reduction, + (self.video.width // size_reduction) * size_reduction, + self.video.channels, + ) + frame_img = frame_img[:, : input_size[0], : input_size[1], :] inference_transform = DataTransform() if self.do_rescale: # Scale input image if model trained on scaled images frame_img = inference_transform.scale_to( - imgs=frame_img, target_size=self.model.input_shape[1:3] + imgs=frame_img, target_size=self.inference_model.input_shape[1:3] ) # Get predictions - frame_result = self.model.predict(frame_img.astype("float32") / 255) + frame_result = self.inference_model.predict(frame_img) if self.do_rescale or self.output_scale != 1.0: inference_transform.scale *= self.output_scale frame_result = inference_transform.invert_scale(frame_result) @@ -73,7 +73,6 @@ def __getitem__(self, i): @attr.s(auto_attribs=True) class DataOverlay: - """Base class for overlays which use datasources.""" data: Sequence = None player: QtVideoPlayer = None @@ -81,7 +80,6 @@ class DataOverlay: transform: DataTransform = None def add_to_scene(self, video, frame_idx): - """Add overlay to scene.""" if self.data is None: return @@ -137,7 +135,6 @@ def _add( @classmethod def from_h5(cls, filename, dataset, input_format="channels_last", **kwargs): - """Creates instance of class with HDF5 datasource.""" import h5py as h5 with h5.File(filename, "r") as f: @@ -154,51 +151,42 @@ def from_h5(cls, filename, dataset, input_format="channels_last", **kwargs): @classmethod def from_model(cls, filename, video, **kwargs): - """Creates instance of class with model datasource.""" from sleap.nn.model import ModelOutputType - from sleap.nn.loadmodel import load_model, get_model_data + from sleap.nn.inference import InferenceModel from sleap.nn.training import TrainingJob # Load the trained model - - trainingjob = TrainingJob.load_json(filename) - - input_size = (video.height // 8 * 8, video.width // 8 * 8, video.channels) - model_output_type = trainingjob.model.output_type - - model = load_model( - sleap_models={model_output_type: trainingjob}, - input_size=input_size, - output_types=[model_output_type], - ) - - model_data = get_model_data( - sleap_models={model_output_type: trainingjob}, - output_types=[model_output_type], + training_job = TrainingJob.load_json(filename) + inference_model = InferenceModel(training_job) + + size_reduction = 2 ** (inference_model.down_blocks) + input_size = ( + (video.height // size_reduction) * size_reduction, + (video.width // size_reduction) * size_reduction, + video.channels, ) + model_output_type = training_job.model.output_type # Here we determine if the input should be scaled. If so, then # the output of the model will also be rescaled accordingly. - - do_rescale = model_data["scale"] < 1 + do_rescale = inference_model.input_scale != 1.0 # Determine how the output from the model should be scaled img_output_scale = 1.0 # image rescaling obj_output_scale = 1.0 # scale to pass to overlay object if model_output_type == ModelOutputType.PART_AFFINITY_FIELD: - obj_output_scale = model_data["multiscale"] + obj_output_scale = inference_model.output_relative_scale + else: - img_output_scale = model_data["multiscale"] + img_output_scale = inference_model.output_relative_scale # Construct the ModelData object that runs inference - data_object = ModelData( - model, video, do_rescale=do_rescale, output_scale=img_output_scale + inference_model, video, do_rescale=do_rescale, output_scale=img_output_scale ) # Determine whether to use confmap or paf overlay - from sleap.gui.overlays.confmaps import ConfMapsPlot from sleap.gui.overlays.pafs import MultiQuiverPlot @@ -212,7 +200,6 @@ def from_model(cls, filename, video, **kwargs): # This doesn't require rescaling the input, and the "scale" # will be passed to the overlay object to do its own upscaling # (at least for pafs). - transform = DataTransform(scale=obj_output_scale) return cls( diff --git a/sleap/gui/overlays/pafs.py b/sleap/gui/overlays/pafs.py index 19caf39bf..c42234410 100644 --- a/sleap/gui/overlays/pafs.py +++ b/sleap/gui/overlays/pafs.py @@ -1,51 +1,46 @@ -""" -Module for showing part affinity fields as an overlay within a QtVideoPlayer. -""" from PySide2 import QtWidgets, QtGui, QtCore import numpy as np import itertools import math -from typing import Optional +from sleap.io.video import Video, HDF5Video +from sleap.gui.multicheck import MultiCheckWidget from sleap.gui.overlays.base import DataOverlay, h5_colors class PafOverlay(DataOverlay): - """Overlay to show part affinity fields.""" - @classmethod def from_h5(cls, filename, input_format="channels_last", **kwargs): - """Creates object with hdf5 as datasource.""" return DataOverlay.from_h5( filename, "/pafs", input_format, overlay_class=MultiQuiverPlot, **kwargs ) class MultiQuiverPlot(QtWidgets.QGraphicsObject): - """ - QGraphicsObject to display multiple quiver plots in a QGraphicsView. - - When initialized, creates on child QuiverPlot item for each channel. - Each channel in data corresponds to two (h, w) arrays: - x and y for the arrow vectors. + """QtWidgets.QGraphicsObject to display multiple quiver plots in a QtWidgets.QGraphicsView. Args: - frame: Data for one frame of quiver plot data. + frame (numpy.array): Data for one frame of quiver plot data. Shape of array should be (channels, height, width). - show: List of channels to show. If None, show all channels. - decimation: Decimation factor. If 1, show every arrow. + show (list, optional): List of channels to show. If None, show all channels. + decimation (int, optional): Decimation factor. If 1, show every arrow. Returns: None. + + Note: + Each channel corresponds to two (h, w) arrays: x and y for the vector. + + When initialized, creates one child QuiverPlot item for each channel. """ def __init__( self, frame: np.array = None, show: list = None, - decimation: int = 5, + decimation: int = 2, scale: float = 1.0, *args, **kwargs, @@ -92,10 +87,10 @@ class QuiverPlot(QtWidgets.QGraphicsObject): """QtWidgets.QGraphicsObject for drawing single quiver plot. Args: - field_x: (h, w) array of x component of vectors. - field_y: (h, w) array of y component of vectors. - color: Arrow color. Format as (r, g, b) array. - decimation: Decimation factor. If 1, show every arrow. + field_x (numpy.array): (h, w) array of x component of vectors. + field_y (numpy.array): (h, w) array of y component of vectors. + color (list, optional): Arrow color. Format as (r, g, b) array. + decimation (int, optional): Decimation factor. If 1, show every arrow. Returns: None. @@ -103,8 +98,8 @@ class QuiverPlot(QtWidgets.QGraphicsObject): def __init__( self, - field_x: Optional[np.ndarray] = None, - field_y: Optional[np.ndarray] = None, + field_x: np.array = None, + field_y: np.array = None, color=[255, 255, 255], decimation=1, scale=1, @@ -196,7 +191,6 @@ def _add_arrows(self, min_length=0.01): self.points = list(itertools.starmap(QtCore.QPointF, points)) def _decimate(self, image: np.array, box: int): - """Decimates quiverplot.""" height = width = box # Source: https://stackoverflow.com/questions/48482317/slice-an-image-into-tiles-using-numpy _nrows, _ncols, depth = image.shape @@ -236,9 +230,6 @@ def paint(self, painter, option, widget=None): def show_pafs_from_h5(filename, input_format="channels_last", standalone=False): - """Demo function.""" - from sleap.io.video import HDF5Video - video = HDF5Video(filename, "/box", input_format=input_format) paf_data = HDF5Video( filename, "/pafs", input_format=input_format, convert_range=False @@ -251,7 +242,6 @@ def show_pafs_from_h5(filename, input_format="channels_last", standalone=False): def demo_pafs(pafs, video, decimation=4, standalone=False): - """Demo function.""" from sleap.gui.video import QtVideoPlayer if standalone: @@ -290,9 +280,66 @@ def plot_fields(parent, i): if __name__ == "__main__": + from video import * + # data_path = "training.scale=1.00,sigma=5.h5" data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" input_format = "channels_first" + data_path = "/Volumes/fileset-mmurthy/nat/nyu-mouse/predict.h5" + input_format = "channels_last" + show_pafs_from_h5(data_path, input_format=input_format, standalone=True) + + +def foo(): + + vid = HDF5Video(data_path, "/box", input_format=input_format) + overlay_data = HDF5Video( + data_path, "/pafs", input_format=input_format, convert_range=False + ) + print( + f"{overlay_data.frames}, {overlay_data.height}, {overlay_data.width}, {overlay_data.channels}" + ) + app = QtWidgets.QApplication([]) + window = QtVideoPlayer(video=vid) + + field_count = overlay_data.get_frame(1).shape[-1] // 2 - 1 + # show the first, middle, and last fields + show_fields = [0, field_count // 2, field_count] + + field_check_groupbox = MultiCheckWidget( + count=field_count, selected=show_fields, title="Affinity Field Channel" + ) + field_check_groupbox.selectionChanged.connect(window.plot) + window.layout.addWidget(field_check_groupbox) + + # show one arrow for each decimation*decimation box + default_decimation = 9 + + decimation_size_bar = QSlider(QtCore.Qt.Horizontal) + decimation_size_bar.valueChanged.connect(lambda evt: window.plot()) + decimation_size_bar.setValue(default_decimation) + decimation_size_bar.setMinimum(1) + decimation_size_bar.setMaximum(21) + decimation_size_bar.setEnabled(True) + window.layout.addWidget(decimation_size_bar) + + def plot_fields(parent, i): + # build list of checked boxes to determine which affinity fields to show + selected = field_check_groupbox.getSelected() + # get decimation size from slider + decimation = decimation_size_bar.value() + # show affinity fields + frame_data = overlay_data.get_frame(parent.frame_idx) + aff_fields_item = MultiQuiverPlot(frame_data, selected, decimation) + + window.view.scene.addItem(aff_fields_item) + + window.changedPlot.connect(plot_fields) + + window.show() + window.plot() + + app.exec_() From 1ffbbd6ec715bcd1265b7963adcadcdcbc531983 Mon Sep 17 00:00:00 2001 From: Nat Tabris <46289310+ntabris@users.noreply.github.com> Date: Tue, 1 Oct 2019 14:06:43 -0400 Subject: [PATCH 163/176] sLEAP -> SLEAP in readme. --- README.rst | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/README.rst b/README.rst index 38dbdd609..d928ca360 100644 --- a/README.rst +++ b/README.rst @@ -10,7 +10,7 @@ .. |GitHub release| image:: https://img.shields.io/github/release/murthylab/sleap.js.svg :target: https://GitHub.com/murthylab/sleap/releases/ -Social LEAP Estimates Animal Pose (sLEAP) +Social LEAP Estimates Animal Pose (SLEAP) ========================================= .. image:: docs/_static/supp_mov1-long_clip.gif @@ -18,8 +18,8 @@ Social LEAP Estimates Animal Pose (sLEAP) | -**S**\ ocial **L**\ EAP **E**\ stimates **A**\ nimal **P**\ ose (**sLEAP**) is a framework for multi-animal -body part position estimation via deep learning. It is the successor to LEAP_. **sLEAP** is written entirely in +**S**\ ocial **L**\ EAP **E**\ stimates **A**\ nimal **P**\ ose (**SLEAP**) is a framework for multi-animal +body part position estimation via deep learning. It is the successor to LEAP_. **SLEAP** is written entirely in Python, supports multi-animal pose estimation, animal instance tracking, and a labeling/training GUI that supports active learning. @@ -30,19 +30,19 @@ supports active learning. Installation ------------ -**sLEAP** is compatible with Python versions 3.6 and above, with support for Windows and Linux. Mac OS X works but without GPU support. +**SLEAP** is compatible with Python versions 3.6 and above, with support for Windows and Linux. Mac OS X works but without GPU support. Windows ------- -Since **sLEAP** has a number of complex binary dependencies (TensorFlow, Keras, OpenCV), it is recommended to use the +Since **SLEAP** has a number of complex binary dependencies (TensorFlow, Keras, OpenCV), it is recommended to use the Anaconda_ Python distribution to simplify installation. Once Anaconda_ has been installed, go to start menu and type in *Anaconda*, which should bring up a menu entry **Anaconda Prompt** which opens a command line with the base anaconda environment activated. One of the key advantages to using `Anaconda Environments`_ is the ability to create separate Python installations (environments) for different projects, mitigating issues of managing complex dependencies. To create a new conda environment for -**sLEAP** related development and use: +**SLEAP** related development and use: :: @@ -59,7 +59,7 @@ Any Python installation commands (:code:`conda install` or :code:`pip install`) environment will only effect the environment. Thus it is important to make sure the environment is active when issuing any commands that deal with Python on the command line. -**sLEAP** is now installed in the :code:`sleap_env` conda environment. With the environment active, +**SLEAP** is now installed in the :code:`sleap_env` conda environment. With the environment active, you can run the labeling GUI by entering the following command: :: @@ -72,10 +72,10 @@ you can run the labeling GUI by entering the following command: Linux ----- -No Linux conda packages are currently provided by the **sLEAP** channel. However, installing via :code:`pip` should not +No Linux conda packages are currently provided by the **SLEAP** channel. However, installing via :code:`pip` should not be difficult on most Linux systems. The first step is to get a working version of TensorFlow installed in your Python environment. Follow official directions for installing TensorFlow_ with GPU support. Once TensorFlow is installed, simple -issue the following command to install **sLEAP** +issue the following command to install **SLEAP** .. _TensorFlow: https://www.tensorflow.org/install/gpu @@ -83,7 +83,7 @@ issue the following command to install **sLEAP** pip install git+https://github.com/murthylab/sleap.git -**sLEAP** is now installed you can run the labeling GUI by entering the following command: +**SLEAP** is now installed you can run the labeling GUI by entering the following command: :: @@ -93,7 +93,7 @@ Mac OS ------ The installation for Mac OS X is the same as for Linux, although there's no TensorFlow GPU support for Mac OS. -You can install TensorFlow and **sLEAP** together by running +You can install TensorFlow and **SLEAP** together by running :: @@ -102,6 +102,6 @@ You can install TensorFlow and **sLEAP** together by running Research -------- -If you use **sLEAP** in your research please acknowledge ... +If you use **SLEAP** in your research please acknowledge ... From ddf1d2e8f3ced9bb6c850dd7a2fceaf5ae7bc26b Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 14:34:13 -0400 Subject: [PATCH 164/176] LossViewer releases zmq socket when closed. --- sleap/nn/monitor.py | 17 +++++++++++++---- tests/nn/test_monitor.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 tests/nn/test_monitor.py diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index 3ffd8f49a..862acc0b5 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -26,18 +26,27 @@ def __init__(self, zmq_context=None, show_controller=True, parent=None): self.setup_zmq(zmq_context) def __del__(self): + self.unbind() + + def close(self): + self.unbind() + super(LossViewer, self).close() + + def unbind(self): # close the zmq socket - self.sub.unbind(self.sub.LAST_ENDPOINT) - self.sub.close() - self.sub = None + if self.sub is not None: + self.sub.unbind(self.sub.LAST_ENDPOINT) + self.sub.close() + self.sub = None if self.zmq_ctrl is not None: url = self.zmq_ctrl.LAST_ENDPOINT self.zmq_ctrl.unbind(url) self.zmq_ctrl.close() self.zmq_ctrl = None # if we started out own zmq context, terminate it - if not self.ctx_given: + if not self.ctx_given and self.ctx is not None: self.ctx.term() + self.ctx = None def reset(self, what=""): self.chart = QtCharts.QtCharts.QChart() diff --git a/tests/nn/test_monitor.py b/tests/nn/test_monitor.py new file mode 100644 index 000000000..d678897e8 --- /dev/null +++ b/tests/nn/test_monitor.py @@ -0,0 +1,12 @@ +from sleap.nn.monitor import LossViewer + + +def test_monitor_release(qtbot): + win = LossViewer() + win.show() + win.close() + + # Make sure the first monitor released its zmq socket + win2 = LossViewer() + win2.show() + win2.close() From cc404402c69977c6bbc151ddda58f21aaa1a04a4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 14:41:38 -0400 Subject: [PATCH 165/176] Move LossMonitor test into gui tests dir so it won't run with gpu tests --- tests/{nn => gui}/test_monitor.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{nn => gui}/test_monitor.py (100%) diff --git a/tests/nn/test_monitor.py b/tests/gui/test_monitor.py similarity index 100% rename from tests/nn/test_monitor.py rename to tests/gui/test_monitor.py From bcf0470fe3fee8e9a59997934f3d40a7fb1d610e Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 1 Oct 2019 15:16:59 -0400 Subject: [PATCH 166/176] Add param to add default suffix when saving. When saving labels dataset and we cannot detect valid extension, the default suffix (if given) will be added. Training package defaults to saving as h5. --- sleap/gui/app.py | 9 ++++----- sleap/io/dataset.py | 23 +++++++++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 48e8275f6..543204c47 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -2022,10 +2022,7 @@ def exportLabeledClip(self): def exportLabeledFrames(self): """Gui for exporting the training dataset of labels/frame images.""" - filters = [ - "HDF5 dataset (*.h5 *.hdf5)", - "Compressed JSON dataset (*.json *.json.zip)", - ] + filters = ["HDF5 dataset (*.h5)", "Compressed JSON dataset (*.json *.json.zip)"] filename, _ = QFileDialog.getSaveFileName( self, caption="Save Labeled Frames As...", @@ -2036,7 +2033,9 @@ def exportLabeledFrames(self): if len(filename) == 0: return - Labels.save_file(self.labels, filename, save_frame_data=True) + Labels.save_file( + self.labels, filename, default_suffix="h5", save_frame_data=True + ) def _plot_if_next(self, frame_iterator: Iterator) -> bool: """Plots next frame (if there is one) from iterator. diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 10fa20313..82b9e1ce6 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1832,8 +1832,27 @@ def load_file(cls, filename: str, *args, **kwargs): raise ValueError(f"Cannot detect filetype for {filename}") @classmethod - def save_file(cls, labels: "Labels", filename: str, *args, **kwargs): - """Save file, detecting format from filename.""" + def save_file( + cls, labels: "Labels", filename: str, default_suffix: str = "", *args, **kwargs + ): + """Save file, detecting format from filename. + + Args: + labels: The dataset to save. + filename: Path where we'll save it. We attempt to detect format + from the suffix (e.g., ".json"). + default_suffix: If we can't detect valid suffix on filename, + we can add default suffix to filename (and use corresponding + format). Doesn't need to have "." before file extension. + + Raises: + ValueError: If cannot detect valid filetype. + + Returns: + None. + """ + if not filename.endswith((".json", ".zip", ".h5")) and default_suffix: + filename += f".{default_suffix}" if filename.endswith((".json", ".zip")): compress = filename.endswith(".zip") cls.save_json(labels=labels, filename=filename, compress=compress, **kwargs) From ac2acba082012e33c19a0d001b8bb61c6fefad30 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 08:27:07 -0400 Subject: [PATCH 167/176] Add support for en/decoding frames stored in hdf5. Saving labels dataset w/ frames in h5 defaults to png encoding. --- sleap/io/dataset.py | 48 ++++++++++++++++++++++++-------------- sleap/io/video.py | 52 ++++++++++++++++++++++++++++++++++++------ tests/io/test_video.py | 18 +++++++++------ 3 files changed, 87 insertions(+), 31 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 82b9e1ce6..fd409a778 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1146,20 +1146,20 @@ def save_json( is used to set the data format to use when writing frame data to ImgStore objects. Supported formats should be: - * 'pgm', - * 'bmp', - * 'ppm', - * 'tif', - * 'png', - * 'jpg', - * 'npy', - * 'mjpeg/avi', - * 'h264/mkv', - * 'avc1/mp4' - - Note: 'h264/mkv' and 'avc1/mp4' require separate installation of - these codecs on your system. They are excluded from SLEAP - because of their GPL license. + * 'pgm', + * 'bmp', + * 'ppm', + * 'tif', + * 'png', + * 'jpg', + * 'npy', + * 'mjpeg/avi', + * 'h264/mkv', + * 'avc1/mp4' + + Note: 'h264/mkv' and 'avc1/mp4' require separate installation + of these codecs on your system. They are excluded from SLEAP + because of their GPL license. Returns: None @@ -1466,6 +1466,7 @@ def save_hdf5( filename: str, append: bool = False, save_frame_data: bool = False, + frame_data_format: str = "png", ): """ Serialize the labels dataset to an HDF5 file. @@ -1480,6 +1481,14 @@ def save_hdf5( the HDF5 for model training when video files are to large to move. This will only save video frames that have some labeled instances. + frame_data_format: If save_frame_data is True, then this argument + is used to set the data format to use when encoding images + saved in HDF5. Supported formats include: + + * "" for no encoding (ndarray) + * "png" + * "jpg" + * anything else supported by `cv2.imencode` Returns: None @@ -1495,7 +1504,7 @@ def save_hdf5( d = labels.to_dict(skip_labels=True) if save_frame_data: - new_videos = labels.save_frame_data_hdf5(filename) + new_videos = labels.save_frame_data_hdf5(filename, frame_data_format) # Replace path to video file with "." (which indicates that the # video is in the same file as the HDF5 labels dataset). @@ -1906,7 +1915,9 @@ def save_frame_data_imgstore( return imgstore_vids - def save_frame_data_hdf5(self, output_path: str, all_labels: bool = False): + def save_frame_data_hdf5( + self, output_path: str, format: str = "png", all_labels: bool = False + ): """ Write labeled frames from all videos to hdf5 file. @@ -1927,7 +1938,10 @@ def save_frame_data_hdf5(self, output_path: str, all_labels: bool = False): ] vid = v.to_hdf5( - path=output_path, dataset=f"video{v_idx}", frame_numbers=frame_nums + path=output_path, + dataset=f"video{v_idx}", + format=format, + frame_numbers=frame_nums, ) vid.close() new_vids.append(vid) diff --git a/sleap/io/video.py b/sleap/io/video.py index dd627c834..23479cb87 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -118,7 +118,10 @@ def matches(self, other: "HDF5Video") -> bool: def close(self): """Closes the HDF5 file object (if it's open).""" if self.__file_h5: - self.__file_h5.close() + try: + self.__file_h5.close() + except: + pass self.__file_h5 = None def __del__(self): @@ -136,16 +139,22 @@ def frames(self): @property def channels(self): """See :class:`Video`.""" + if "channels" in self.__dataset_h5.attrs: + return self.__dataset_h5.attrs["channels"] return self.__dataset_h5.shape[self.__channel_idx] @property def width(self): """See :class:`Video`.""" + if "width" in self.__dataset_h5.attrs: + return self.__dataset_h5.attrs["width"] return self.__dataset_h5.shape[self.__width_idx] @property def height(self): """See :class:`Video`.""" + if "height" in self.__dataset_h5.attrs: + return self.__dataset_h5.attrs["height"] return self.__dataset_h5.shape[self.__height_idx] @property @@ -172,6 +181,13 @@ def get_frame(self, idx) -> np.ndarray: frame = self.__dataset_h5[idx] + if self.__dataset_h5.attrs.get("format", ""): + frame = cv2.imdecode(frame, cv2.IMREAD_UNCHANGED) + + # Add dimension for single channel (dropped by opencv). + if frame.ndim == 2: + frame = frame[..., np.newaxis] + if self.input_format == "channels_first": frame = np.transpose(frame, (2, 1, 0)) @@ -898,6 +914,7 @@ def to_hdf5( path: str, dataset: str, frame_numbers: List[int] = None, + format: str = "", index_by_original: bool = True, ): """ @@ -910,6 +927,8 @@ def to_hdf5( dataset: The HDF5 dataset in which to store video frames. frame_numbers: A list of frame numbers from the video to save. If None save the entire video. + format: If non-empty, then encode images in format before saving. + Otherwise, save numpy matrix of frames. index_by_original: If the index_by_original is set to True then the get_frame function will accept the original frame numbers of from original video. @@ -934,12 +953,31 @@ def to_hdf5( frame_numbers_data = np.array(list(frame_numbers), dtype=int) with h5.File(path, "a") as f: - f.create_dataset( - dataset + "/video", - data=frame_data, - compression="gzip", - compression_opts=9, - ) + + if format: + + def encode(img): + _, encoded = cv2.imencode("." + format, img) + return np.squeeze(encoded) + + dtype = h5.special_dtype(vlen=np.dtype("int8")) + dset = f.create_dataset( + dataset + "/video", (len(frame_numbers),), dtype=dtype + ) + dset.attrs["format"] = format + dset.attrs["channels"] = self.channels + dset.attrs["height"] = self.height + dset.attrs["width"] = self.width + + for i in range(len(frame_numbers)): + dset[i] = encode(frame_data[i]) + else: + f.create_dataset( + dataset + "/video", + data=frame_data, + compression="gzip", + compression_opts=9, + ) if index_by_original: f.create_dataset(dataset + "/frame_numbers", data=frame_numbers_data) diff --git a/tests/io/test_video.py b/tests/io/test_video.py index 4189152c0..2f940ecdc 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -173,13 +173,16 @@ def test_empty_hdf5_video(small_robot_mp4_vid, tmpdir): hdf5_vid = small_robot_mp4_vid.to_hdf5(path, "testvid", frame_numbers=[]) -def test_hdf5_inline_video(small_robot_mp4_vid, tmpdir): +@pytest.mark.parametrize("format", ["", "png", "jpg"]) +def test_hdf5_inline_video(small_robot_mp4_vid, tmpdir, format): - path = os.path.join(tmpdir, "test_to_hdf5") + path = os.path.join(tmpdir, f"test_to_hdf5_{format}") frame_indices = [0, 1, 5] # Save hdf5 version of the first few frames of this video. - hdf5_vid = small_robot_mp4_vid.to_hdf5(path, "testvid", frame_numbers=frame_indices) + hdf5_vid = small_robot_mp4_vid.to_hdf5( + path, "testvid", format=format, frame_numbers=frame_indices + ) assert hdf5_vid.num_frames == len(frame_indices) @@ -192,12 +195,13 @@ def test_hdf5_inline_video(small_robot_mp4_vid, tmpdir): assert hdf5_vid.width == 560 # Check the image data is exactly the same when lossless is used. - assert np.allclose( - hdf5_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91 - ) + if format in ("", "png"): + assert np.allclose( + hdf5_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91 + ) -def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): +def test_hdf5_indexing(small_robot_mp4_vid, tmpdir): """ Test different types of indexing (by frame number or index). """ From d5e4f6e92eaff2b3cc89745a93574453042ada72 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 08:28:59 -0400 Subject: [PATCH 168/176] Ignore errors when trying to plot initial frame. --- sleap/gui/video.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 4665737e7..1e7d10925 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -118,7 +118,10 @@ def load_video(self, video: Video, initial_frame=0, plot=True): self.seekbar.setEnabled(True) if plot: - self.plot(initial_frame) + try: + self.plot(initial_frame) + except: + pass def reset(self): """ Reset viewer by removing all video data. From 172fb2b4703a9e13e0a2179e1c09ad5316391fec Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 09:12:32 -0400 Subject: [PATCH 169/176] Add last_frame_idx property on videos. For imgstore or hdf5 videos with select frames indexed by the original video, this lets us show the entire range of frames in the seekbar for the video. --- sleap/gui/video.py | 2 +- sleap/io/video.py | 36 ++++++++++++++++++++++++++++++++++++ tests/io/test_video.py | 8 ++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 1e7d10925..b1e79deb7 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -114,7 +114,7 @@ def load_video(self, video: Video, initial_frame=0, plot=True): # self.seekbar.setTickInterval(1) self.seekbar.setValue(self.frame_idx) self.seekbar.setMinimum(0) - self.seekbar.setMaximum(self.video.frames - 1) + self.seekbar.setMaximum(self.video.last_frame_idx) self.seekbar.setEnabled(True) if plot: diff --git a/sleap/io/video.py b/sleap/io/video.py index 23479cb87..b950e5895 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -162,6 +162,20 @@ def dtype(self): """See :class:`Video`.""" return self.__dataset_h5.dtype + @property + def last_frame_idx(self) -> int: + """ + The idx number of the last frame. + + Overrides method of base :class:`Video` class for videos with + select frames indexed by number from original video, since the last + frame index here will not match the number of frames in video. + """ + if self.__original_to_current_frame_idx: + last_key = sorted(self.__original_to_current_frame_idx.keys())[-1] + return last_key + return self.frames - 1 + def get_frame(self, idx) -> np.ndarray: """ Get a frame from the underlying HDF5 video data. @@ -514,6 +528,19 @@ def dtype(self): """See :class:`Video`.""" return self.__img.dtype + @property + def last_frame_idx(self) -> int: + """ + The idx number of the last frame. + + Overrides method of base :class:`Video` class for videos with + select frames indexed by number from original video, since the last + frame index here will not match the number of frames in video. + """ + if self.index_by_original: + return self.__store.frame_max + return self.frames - 1 + def get_frame(self, frame_number: int) -> np.ndarray: """ Get a frame from the underlying ImgStore video data. @@ -637,6 +664,15 @@ def num_frames(self) -> int: """ return self.frames + @property + def last_frame_idx(self) -> int: + """ + The idx number of the last frame. Usually `numframes - 1`. + """ + if hasattr(self.backend, "last_frame_idx"): + return self.backend.last_frame_idx + return self.frames - 1 + @property def shape(self) -> Tuple[int, int, int, int]: """ Returns (frame count, height, width, channels).""" diff --git a/tests/io/test_video.py b/tests/io/test_video.py index 2f940ecdc..15bbe663e 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -154,6 +154,8 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): frames = imgstore_vid.get_frames([0, 1, 2]) assert frames.shape == (3, 320, 560, 3) + assert imgstore_vid.last_frame_idx == len(frame_indices) - 1 + with pytest.raises(ValueError): imgstore_vid.get_frames(frame_indices) @@ -164,6 +166,8 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): frames = imgstore_vid.get_frames(frame_indices) assert frames.shape == (3, 320, 560, 3) + assert imgstore_vid.last_frame_idx == max(frame_indices) + with pytest.raises(ValueError): imgstore_vid.get_frames([0, 1, 2]) @@ -217,6 +221,8 @@ def test_hdf5_indexing(small_robot_mp4_vid, tmpdir): frames = hdf5_vid.get_frames([0, 1, 2]) assert frames.shape == (3, 320, 560, 3) + assert hdf5_vid.last_frame_idx == len(frame_indices) - 1 + with pytest.raises(ValueError): hdf5_vid.get_frames(frame_indices) @@ -232,5 +238,7 @@ def test_hdf5_indexing(small_robot_mp4_vid, tmpdir): frames = hdf5_vid2.get_frames(frame_indices) assert frames.shape == (3, 320, 560, 3) + assert hdf5_vid2.last_frame_idx == max(frame_indices) + with pytest.raises(ValueError): hdf5_vid2.get_frames([0, 1, 2]) From 20fed2054806c4e1a3711d6eb9648585d61959e8 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 09:48:49 -0400 Subject: [PATCH 170/176] Fixes so HDF5 selected images -> imgstore works. to_imgstore now uses (height, width, channels) properties instead of shape, and HDF5 video casts those properties as ints. --- sleap/io/video.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index b950e5895..556e921de 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -140,21 +140,21 @@ def frames(self): def channels(self): """See :class:`Video`.""" if "channels" in self.__dataset_h5.attrs: - return self.__dataset_h5.attrs["channels"] + return int(self.__dataset_h5.attrs["channels"]) return self.__dataset_h5.shape[self.__channel_idx] @property def width(self): """See :class:`Video`.""" if "width" in self.__dataset_h5.attrs: - return self.__dataset_h5.attrs["width"] + return int(self.__dataset_h5.attrs["width"]) return self.__dataset_h5.shape[self.__width_idx] @property def height(self): """See :class:`Video`.""" if "height" in self.__dataset_h5.attrs: - return self.__dataset_h5.attrs["height"] + return int(self.__dataset_h5.attrs["height"]) return self.__dataset_h5.shape[self.__height_idx] @property @@ -918,7 +918,7 @@ def to_imgstore( format, mode="w", basedir=path, - imgshape=(self.shape[1], self.shape[2], self.shape[3]), + imgshape=(self.height, self.width, self.channels), chunksize=1000, ) @@ -935,7 +935,7 @@ def to_imgstore( # since we can't save an empty imgstore. if len(frame_numbers) == 0: store.add_image( - np.zeros((self.shape[1], self.shape[2], self.shape[3])), 0, time.time() + np.zeros((self.height, self.width, self.channels)), 0, time.time() ) store.close() From bc899c53e23d8d4bbc6c19e43aaa2e32becff664 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 09:54:02 -0400 Subject: [PATCH 171/176] Add missing arg to docstring. --- sleap/io/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index fd409a778..00da4ac23 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1923,6 +1923,7 @@ def save_frame_data_hdf5( Args: output_path: Path to HDF5 file. + format: The image format to use for the data. Defaults to png. all_labels: Include any labeled frames, not just the frames we'll use for training (i.e., those with Instances). From 9e06282ff2d641bf515428808c1d1a583a15aaff Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 10:14:08 -0400 Subject: [PATCH 172/176] More test coverage for io.video --- tests/io/test_video.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/io/test_video.py b/tests/io/test_video.py index 15bbe663e..f6030ff94 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -52,6 +52,10 @@ def test_mp4_get_shape(small_robot_mp4_vid): assert small_robot_mp4_vid.shape == (166, 320, 560, 3) +def test_mp4_fps(small_robot_mp4_vid): + assert small_robot_mp4_vid.fps == 30.0 + + def test_mp4_len(small_robot_mp4_vid): assert len(small_robot_mp4_vid) == 166 @@ -172,6 +176,24 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): imgstore_vid.get_frames([0, 1, 2]) +def test_imgstore_deferred_loading(small_robot_mp4_vid, tmpdir): + path = os.path.join(tmpdir, "test_imgstore") + frame_indices = [20, 40, 15] + vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices) + + # This is actually testing that the __img will be loaded when needed, + # since we use __img to get dtype. + assert vid.dtype == np.dtype("uint8") + + +def test_imgstore_single_channel(centered_pair_vid, tmpdir): + path = os.path.join(tmpdir, "test_imgstore") + frame_indices = [20, 40, 15] + vid = centered_pair_vid.to_imgstore(path, frame_numbers=frame_indices) + + assert vid.channels == 1 + + def test_empty_hdf5_video(small_robot_mp4_vid, tmpdir): path = os.path.join(tmpdir, "test_to_hdf5") hdf5_vid = small_robot_mp4_vid.to_hdf5(path, "testvid", frame_numbers=[]) From 645c17de615c6b322b7be754a0a58d88a75c24c8 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 12:01:32 -0400 Subject: [PATCH 173/176] Add __init__.py for sleap.info module. --- sleap/info/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 sleap/info/__init__.py diff --git a/sleap/info/__init__.py b/sleap/info/__init__.py new file mode 100644 index 000000000..e69de29bb From 8f021db7cdc321502730bf1a53b31e64b4593133 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 12:06:21 -0400 Subject: [PATCH 174/176] Add __init__.py for sleap.gui.overlays module. --- sleap/gui/overlays/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 sleap/gui/overlays/__init__.py diff --git a/sleap/gui/overlays/__init__.py b/sleap/gui/overlays/__init__.py new file mode 100644 index 000000000..e69de29bb From bb66140247ed78db88b869b179000a0b0526d1e4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 2 Oct 2019 12:09:15 -0400 Subject: [PATCH 175/176] Wrappers for QFileDialog functions w/o options arg --- sleap/gui/app.py | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 543204c47..97e8911fa 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1021,7 +1021,7 @@ def loadVideo(self, video: Video, video_idx: int = None): def openSkeleton(self): """Shows gui for loading saved skeleton into project.""" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] - filename, selected_filter = QFileDialog.getOpenFileName( + filename, selected_filter = openFileDialog( self, dir=None, caption="Open skeleton...", @@ -1050,7 +1050,7 @@ def saveSkeleton(self): """Shows gui for saving skeleton from project.""" default_name = "skeleton.json" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] - filename, selected_filter = QFileDialog.getSaveFileName( + filename, selected_filter = saveFileDialog( self, caption="Save As...", dir=default_name, @@ -1280,7 +1280,7 @@ def visualizeOutputs(self): models_dir = os.path.join(os.path.dirname(self.filename), "models/") # Show dialog - filename, selected_filter = QFileDialog.getOpenFileName( + filename, selected_filter = openFileDialog( self, dir=models_dir, caption="Import model outputs...", @@ -1495,7 +1495,7 @@ def clearFrameNegativeAnchors(self): def importPredictions(self): """Starts gui for importing another dataset into currently one.""" filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"] - filenames, selected_filter = QFileDialog.getOpenFileNames( + filenames, selected_filter = openFileDialogs( self, dir=None, caption="Import labeled data...", @@ -1859,12 +1859,12 @@ def openProject(self, first_open: bool = False): "DeepLabCut csv (*.csv)", ] - filename, selected_filter = QFileDialog.getOpenFileName( + filename, selected_filter = openFileDialog( self, dir=None, caption="Import labeled data...", filter=";;".join(filters), - options=self._file_dialog_options, + # options=self._file_dialog_options, ) if len(filename) == 0: @@ -1896,7 +1896,7 @@ def saveProjectAs(self): "JSON labels (*.json)", "Compressed JSON (*.zip)", ] - filename, selected_filter = QFileDialog.getSaveFileName( + filename, selected_filter = saveFileDialog( self, caption="Save As...", dir=default_name, @@ -2000,7 +2000,7 @@ def exportLabeledClip(self): if not okay: return - filename, _ = QFileDialog.getSaveFileName( + filename, _ = saveFileDialog( self, caption="Save Video As...", dir=self.filename + ".avi", @@ -2023,7 +2023,7 @@ def exportLabeledClip(self): def exportLabeledFrames(self): """Gui for exporting the training dataset of labels/frame images.""" filters = ["HDF5 dataset (*.h5)", "Compressed JSON dataset (*.json *.json.zip)"] - filename, _ = QFileDialog.getSaveFileName( + filename, _ = saveFileDialog( self, caption="Save Labeled Frames As...", dir=self.filename + ".h5", @@ -2173,6 +2173,26 @@ def openKeyRef(self): ShortcutDialog().exec_() +def openFileDialog(*args, **kwargs): + """Wrapper for openFileDialog. + + Passes along everything except empty "options" arg. + """ + if "options" in kwargs and not kwargs["options"]: + del kwargs["options"] + return QFileDialog.getOpenFileName(*args, **kwargs) + + +def saveFileDialog(*args, **kwargs): + """Wrapper for saveFileDialog. + + Passes along everything except empty "options" arg. + """ + if "options" in kwargs and not kwargs["options"]: + del kwargs["options"] + return QFileDialog.getSaveFileName(*args, **kwargs) + + def main(*args, **kwargs): """Starts new instance of app.""" app = QApplication([]) From 7aded367a2f9d6cd7cc5de7946930c44abbe4cb9 Mon Sep 17 00:00:00 2001 From: davidt0x Date: Thu, 3 Oct 2019 13:14:00 -0400 Subject: [PATCH 176/176] Added support for from_predicted in saved HDF5. --- sleap/io/dataset.py | 45 +++++++++++++++++++++++++++++++++++----- tests/io/test_dataset.py | 21 +++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 10fa20313..5f5145999 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1608,6 +1608,12 @@ def append_unique(old, new): } instance_type_to_idx = {Instance: 0, PredictedInstance: 1} + # Each instance we create will have and index in the dataset, keep track of + # these so we can quickly add from_predicted links on a second pass. + instance_to_idx = {} + instances_with_from_predicted = [] + instances_from_predicted = [] + # If we are appending, we need look inside to see what frame, instance, and point # ids we need to start from. This gives us offsets to use. if append and "points" in f: @@ -1624,9 +1630,7 @@ def append_unique(old, new): point_id = 0 pred_point_id = 0 instance_id = 0 - frame_id = 0 - all_from_predicted = [] - from_predicted_id = 0 + for frame_id, label in enumerate(labels): frames[frame_id] = ( frame_id + frame_id_offset, @@ -1636,6 +1640,11 @@ def append_unique(old, new): instance_id + instance_id_offset + len(label.instances), ) for instance in label.instances: + + # Add this instance to our lookup structure we will need for from_predicted + # links + instance_to_idx[instance] = instance_id + parray = instance.get_points_array(copy=False, full=True) instance_type = type(instance) @@ -1650,8 +1659,8 @@ def append_unique(old, new): # Keep track of any from_predicted instance links, we will insert the # correct instance_id in the dataset after we are done. if instance.from_predicted: - all_from_predicted.append(instance.from_predicted) - from_predicted_id = from_predicted_id + 1 + instances_with_from_predicted.append(instance_id) + instances_from_predicted.append(instance.from_predicted) # Copy all the data instances[instance_id] = ( @@ -1679,6 +1688,21 @@ def append_unique(old, new): instance_id = instance_id + 1 + # Add from_predicted links + for instance_id, from_predicted in zip( + instances_with_from_predicted, instances_from_predicted + ): + try: + instances[instance_id]["from_predicted"] = instance_to_idx[ + from_predicted + ] + except KeyError: + # If we haven't encountered the from_predicted instance yet then don't save the link. + # It’s possible for a user to create a regular instance from a predicted instance and then + # delete all predicted instances from the file, but in this case I don’t think there’s any reason + # to remember which predicted instance the regular instance came from. + pass + # We pre-allocated our points array with max possible size considering the max # skeleton size, drop any unused points. points = points[0:point_id] @@ -1776,6 +1800,10 @@ def load_hdf5( tracks = labels.tracks.copy() tracks.extend([None]) + # A dict to keep track of instances that have a from_predicted link. The key is the + # instance and the value is the index of the instance. + from_predicted_lookup = {} + # Create the instances instances = [] for i in instances_dset: @@ -1797,6 +1825,13 @@ def load_hdf5( ) instances.append(instance) + if i["from_predicted"] != -1: + from_predicted_lookup[instance] = i["from_predicted"] + + # Make a second pass to add any from_predicted links + for instance, from_predicted_idx in from_predicted_lookup.items(): + instance.from_predicted = instances[from_predicted_idx] + # Create the labeled frames frames = [ LabeledFrame( diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 0695030a3..f141ae6fc 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -692,3 +692,24 @@ def test_labels_append_hdf5(multi_skel_vid_labels, tmpdir): loaded_labels = Labels.load_hdf5(filename=filename) _check_labels_match(labels, loaded_labels) + + +def test_hdf5_from_predicted(multi_skel_vid_labels, tmpdir): + labels = multi_skel_vid_labels + filename = os.path.join(tmpdir, "test.h5") + + # Add some predicted instances to create from_predicted links + for frame_num, frame in enumerate(labels): + if frame_num % 20 == 0: + frame.instances[0].from_predicted = PredictedInstance.from_instance( + frame.instances[0], float(frame_num) + ) + frame.instances.append(frame.instances[0].from_predicted) + + # Save and load, compare the results + Labels.save_hdf5(filename=filename, labels=labels) + loaded_labels = Labels.load_hdf5(filename=filename) + + for frame_num, frame in enumerate(loaded_labels): + if frame_num % 20 == 0: + assert frame.instances[0].from_predicted.score == float(frame_num)